aboutsummaryrefslogtreecommitdiff
path: root/backend/internal/worker/worker.go
blob: 6017fb7fa9fd22ac96e725912ebd40277eab4fe4 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
package worker

import (
	"context"
	"crypto/rand"
	"encoding/base64"
	"io"
	"log/slog"
	"os"
	"time"

	"github.com/ansg191/ibd-trader-backend/internal/analyzer"
	"github.com/ansg191/ibd-trader-backend/internal/database"
	"github.com/ansg191/ibd-trader-backend/internal/ibd"
	"github.com/ansg191/ibd-trader-backend/internal/keys"
	"github.com/ansg191/ibd-trader-backend/internal/leader/manager"
	analyzer2 "github.com/ansg191/ibd-trader-backend/internal/worker/analyzer"
	"github.com/ansg191/ibd-trader-backend/internal/worker/auth"
	"github.com/ansg191/ibd-trader-backend/internal/worker/scraper"

	"github.com/redis/go-redis/v9"
	"golang.org/x/sync/errgroup"
)

const (
	HeartbeatInterval = 5 * time.Second
	HeartbeatTTL      = 30 * time.Second
)

func StartWorker(
	ctx context.Context,
	ibdClient *ibd.Client,
	client *redis.Client,
	db database.TransactionExecutor,
	kms keys.KeyManagementService,
	a analyzer.Analyzer,
) error {
	// Get the worker name.
	name, err := workerName()
	if err != nil {
		return err
	}
	slog.InfoContext(ctx, "Starting worker", "worker", name)

	g, ctx := errgroup.WithContext(ctx)

	g.Go(func() error {
		return workerRegistrationLoop(ctx, client, name)
	})
	g.Go(func() error {
		return scraper.RunScraper(ctx, client, ibdClient, db, name)
	})
	g.Go(func() error {
		return auth.RunAuthScraper(ctx, ibdClient, client, db, kms, name)
	})
	g.Go(func() error {
		return analyzer2.RunAnalyzer(ctx, client, a, db, name)
	})

	return g.Wait()
}

func workerRegistrationLoop(ctx context.Context, client *redis.Client, name string) error {
	sendHeartbeat(ctx, client, name)

	ticker := time.NewTicker(HeartbeatInterval)
	defer ticker.Stop()

	for {
		select {
		case <-ticker.C:
			sendHeartbeat(ctx, client, name)
		case <-ctx.Done():
			removeWorker(ctx, client, name)
			return ctx.Err()
		}
	}
}

// sendHeartbeat sends a heartbeat for the worker.
// It ensures that the worker is in the active workers set and its heartbeat exists.
func sendHeartbeat(ctx context.Context, client *redis.Client, name string) {
	ctx, cancel := context.WithTimeout(ctx, HeartbeatInterval)
	defer cancel()

	// Add the worker to the active workers set.
	if err := client.SAdd(ctx, manager.ActiveWorkersSet, name).Err(); err != nil {
		slog.ErrorContext(ctx,
			"Unable to add worker to active workers set",
			"worker", name,
			"error", err,
		)
		return
	}

	// Set the worker's heartbeat.
	heartbeatKey := manager.WorkerHeartbeatKey(name)
	if err := client.Set(ctx, heartbeatKey, time.Now().Unix(), HeartbeatTTL).Err(); err != nil {
		slog.ErrorContext(ctx,
			"Unable to set worker heartbeat",
			"worker", name,
			"error", err,
		)
		return
	}
}

// removeWorker removes the worker from the active workers set.
func removeWorker(ctx context.Context, client *redis.Client, name string) {
	if ctx.Err() != nil {
		// If the context is canceled, create a new uncanceled context.
		ctx = context.WithoutCancel(ctx)
	}
	ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
	defer cancel()

	// Remove the worker from the active workers set.
	if err := client.SRem(ctx, manager.ActiveWorkersSet, name).Err(); err != nil {
		slog.ErrorContext(ctx,
			"Unable to remove worker from active workers set",
			"worker", name,
			"error", err,
		)
		return
	}

	// Remove the worker's heartbeat.
	heartbeatKey := manager.WorkerHeartbeatKey(name)
	if err := client.Del(ctx, heartbeatKey).Err(); err != nil {
		slog.ErrorContext(ctx,
			"Unable to remove worker heartbeat",
			"worker", name,
			"error", err,
		)
		return
	}
}

func workerName() (string, error) {
	hostname, err := os.Hostname()
	if err != nil {
		return "", err
	}

	bytes := make([]byte, 12)
	if _, err = io.ReadFull(rand.Reader, bytes); err != nil {
		return "", err
	}

	return hostname + "-" + base64.URLEncoding.EncodeToString(bytes), nil
}