aboutsummaryrefslogtreecommitdiff
path: root/backend/internal/worker
diff options
context:
space:
mode:
authorGravatar Anshul Gupta <ansg191@anshulg.com> 2024-08-05 18:55:10 -0700
committerGravatar Anshul Gupta <ansg191@anshulg.com> 2024-08-05 18:55:19 -0700
commitb96fcd1a54a46a95f98467b49a051564bc21c23c (patch)
tree93caeeb05f8d6310e241095608ea2428c749b18c /backend/internal/worker
downloadibd-trader-b96fcd1a54a46a95f98467b49a051564bc21c23c.tar.gz
ibd-trader-b96fcd1a54a46a95f98467b49a051564bc21c23c.tar.zst
ibd-trader-b96fcd1a54a46a95f98467b49a051564bc21c23c.zip
Initial Commit
Diffstat (limited to 'backend/internal/worker')
-rw-r--r--backend/internal/worker/analyzer/analyzer.go135
-rw-r--r--backend/internal/worker/auth/auth.go228
-rw-r--r--backend/internal/worker/scraper/scraper.go191
-rw-r--r--backend/internal/worker/worker.go149
4 files changed, 703 insertions, 0 deletions
diff --git a/backend/internal/worker/analyzer/analyzer.go b/backend/internal/worker/analyzer/analyzer.go
new file mode 100644
index 0000000..924e571
--- /dev/null
+++ b/backend/internal/worker/analyzer/analyzer.go
@@ -0,0 +1,135 @@
+package analyzer
+
+import (
+ "context"
+ "log/slog"
+ "time"
+
+ "ibd-trader/internal/analyzer"
+ "ibd-trader/internal/database"
+ "ibd-trader/internal/redis/taskqueue"
+
+ "github.com/redis/go-redis/v9"
+)
+
+const (
+ Queue = "analyzer"
+ QueueEncoding = taskqueue.EncodingJSON
+
+ lockTimeout = 1 * time.Minute
+ dequeueTimeout = 5 * time.Second
+)
+
+func RunAnalyzer(
+ ctx context.Context,
+ redis *redis.Client,
+ analyzer analyzer.Analyzer,
+ db database.StockStore,
+ name string,
+) error {
+ queue, err := taskqueue.New(
+ ctx,
+ redis,
+ Queue,
+ name,
+ taskqueue.WithEncoding[TaskInfo](QueueEncoding),
+ )
+ if err != nil {
+ return err
+ }
+
+ for {
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ default:
+ waitForTask(ctx, queue, analyzer, db)
+ }
+ }
+}
+
+func waitForTask(
+ ctx context.Context,
+ queue taskqueue.TaskQueue[TaskInfo],
+ analyzer analyzer.Analyzer,
+ db database.StockStore,
+) {
+ task, err := queue.Dequeue(ctx, lockTimeout, dequeueTimeout)
+ if err != nil {
+ slog.ErrorContext(ctx, "Failed to dequeue task", "error", err)
+ return
+ }
+ if task == nil {
+ // No task available.
+ return
+ }
+
+ ch := make(chan error)
+ defer close(ch)
+ go func() {
+ ch <- analyzeStock(ctx, analyzer, db, task.Data.ID)
+ }()
+
+ ticker := time.NewTicker(lockTimeout / 5)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ctx.Done():
+ // Context was canceled. Return early.
+ return
+ case <-ticker.C:
+ // Extend the lock periodically.
+ func() {
+ ctx, cancel := context.WithTimeout(ctx, lockTimeout/5)
+ defer cancel()
+
+ err := queue.Extend(ctx, task.ID)
+ if err != nil {
+ slog.ErrorContext(ctx, "Failed to extend lock", "error", err)
+ }
+ }()
+ case err = <-ch:
+ // scrapeUrl has completed.
+ if err != nil {
+ slog.ErrorContext(ctx, "Failed to analyze", "error", err)
+ _, err = queue.Return(ctx, task.ID, err)
+ if err != nil {
+ slog.ErrorContext(ctx, "Failed to return task", "error", err)
+ return
+ }
+ } else {
+ slog.DebugContext(ctx, "Analyzed ID", "id", task.Data.ID)
+ err = queue.Complete(ctx, task.ID, nil)
+ if err != nil {
+ slog.ErrorContext(ctx, "Failed to complete task", "error", err)
+ return
+ }
+ }
+ return
+ }
+ }
+}
+
+func analyzeStock(ctx context.Context, a analyzer.Analyzer, db database.StockStore, id string) error {
+ info, err := db.GetStockInfo(ctx, id)
+ if err != nil {
+ return err
+ }
+
+ analysis, err := a.Analyze(
+ ctx,
+ info.Symbol,
+ info.Price,
+ info.ChartAnalysis,
+ )
+ if err != nil {
+ return err
+ }
+
+ return db.AddAnalysis(ctx, id, analysis)
+}
+
+type TaskInfo struct {
+ ID string `json:"id"`
+}
diff --git a/backend/internal/worker/auth/auth.go b/backend/internal/worker/auth/auth.go
new file mode 100644
index 0000000..e1c6661
--- /dev/null
+++ b/backend/internal/worker/auth/auth.go
@@ -0,0 +1,228 @@
+package auth
+
+import (
+ "context"
+ "fmt"
+ "log/slog"
+ "time"
+
+ "ibd-trader/internal/database"
+ "ibd-trader/internal/ibd"
+ "ibd-trader/internal/leader/manager/ibd/auth"
+ "ibd-trader/internal/redis/taskqueue"
+
+ "github.com/redis/go-redis/v9"
+)
+
+const (
+ lockTimeout = 1 * time.Minute
+ dequeueTimeout = 5 * time.Second
+)
+
+func RunAuthScraper(
+ ctx context.Context,
+ client *ibd.Client,
+ redis *redis.Client,
+ users database.UserStore,
+ cookies database.CookieStore,
+ name string,
+) error {
+ queue, err := taskqueue.New(
+ ctx,
+ redis,
+ auth.Queue,
+ name,
+ taskqueue.WithEncoding[auth.TaskInfo](auth.QueueEncoding),
+ )
+ if err != nil {
+ return err
+ }
+
+ for {
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ default:
+ waitForTask(ctx, queue, client, users, cookies)
+ }
+ }
+}
+
+func waitForTask(
+ ctx context.Context,
+ queue taskqueue.TaskQueue[auth.TaskInfo],
+ client *ibd.Client,
+ users database.UserStore,
+ cookies database.CookieStore,
+) {
+ task, err := queue.Dequeue(ctx, lockTimeout, dequeueTimeout)
+ if err != nil {
+ slog.ErrorContext(ctx, "Failed to dequeue task", "error", err)
+ return
+ }
+ if task == nil {
+ // No task available.
+ return
+ }
+ slog.DebugContext(ctx, "Picked up auth task", "task", task.ID, "user", task.Data.UserSubject)
+
+ ch := make(chan error)
+ defer close(ch)
+ go func() {
+ ch <- scrapeCookies(ctx, client, users, cookies, task.Data.UserSubject)
+ }()
+
+ ticker := time.NewTicker(lockTimeout / 5)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ctx.Done():
+ // The context was canceled. Return early.
+ return
+ case <-ticker.C:
+ // Extend the lock periodically.
+ func() {
+ ctx, cancel := context.WithTimeout(ctx, lockTimeout/5)
+ defer cancel()
+
+ err := queue.Extend(ctx, task.ID)
+ if err != nil {
+ slog.ErrorContext(ctx, "Failed to extend lock", "error", err)
+ }
+ }()
+ case err = <-ch:
+ // scrapeCookies has completed.
+ if err != nil {
+ slog.ErrorContext(ctx, "Failed to scrape cookies", "error", err)
+ _, err = queue.Return(ctx, task.ID, err)
+ if err != nil {
+ slog.ErrorContext(ctx, "Failed to return task", "error", err)
+ return
+ }
+ } else {
+ err = queue.Complete(ctx, task.ID, nil)
+ if err != nil {
+ slog.ErrorContext(ctx, "Failed to complete task", "error", err)
+ return
+ }
+ slog.DebugContext(ctx, "Authenticated user", "user", task.Data.UserSubject)
+ }
+ return
+ }
+ }
+}
+
+func scrapeCookies(
+ ctx context.Context,
+ client *ibd.Client,
+ users database.UserStore,
+ store database.CookieStore,
+ user string,
+) error {
+ ctx, cancel := context.WithTimeout(ctx, lockTimeout)
+ defer cancel()
+
+ // Check if the user has valid cookies
+ done, err := hasValidCookies(ctx, store, user)
+ if err != nil {
+ return fmt.Errorf("failed to check cookies: %w", err)
+ }
+ if done {
+ return nil
+ }
+
+ // Health check degraded cookies
+ done, err = healthCheckDegradedCookies(ctx, client, store, user)
+ if err != nil {
+ return fmt.Errorf("failed to health check cookies: %w", err)
+ }
+ if done {
+ return nil
+ }
+
+ // No cookies are valid, so scrape new cookies
+ return scrapeNewCookies(ctx, client, users, store, user)
+}
+
+func hasValidCookies(ctx context.Context, store database.CookieStore, user string) (bool, error) {
+ // Check if the user has non-degraded cookies
+ cookies, err := store.GetCookies(ctx, user, false)
+ if err != nil {
+ return false, fmt.Errorf("failed to get non-degraded cookies: %w", err)
+ }
+
+ // If the user has non-degraded cookies, return true
+ if cookies != nil && len(cookies) > 0 {
+ return true, nil
+ }
+ return false, nil
+}
+
+func healthCheckDegradedCookies(
+ ctx context.Context,
+ client *ibd.Client,
+ store database.CookieStore,
+ user string,
+) (bool, error) {
+ // Check if the user has degraded cookies
+ cookies, err := store.GetCookies(ctx, user, true)
+ if err != nil {
+ return false, fmt.Errorf("failed to get degraded cookies: %w", err)
+ }
+
+ valid := false
+ for _, cookie := range cookies {
+ slog.DebugContext(ctx, "Health checking cookie", "cookie", cookie.ID)
+
+ // Health check the cookie
+ up, err := client.UserInfo(ctx, cookie.ToHTTPCookie())
+ if err != nil {
+ slog.ErrorContext(ctx, "Failed to health check cookie", "error", err)
+ continue
+ }
+
+ if up.Status != ibd.UserStatusSubscriber {
+ continue
+ }
+
+ // Cookie is valid
+ valid = true
+
+ // Update the cookie
+ err = store.RepairCookie(ctx, cookie.ID)
+ if err != nil {
+ slog.ErrorContext(ctx, "Failed to repair cookie", "error", err)
+ }
+ }
+
+ return valid, nil
+}
+
+func scrapeNewCookies(
+ ctx context.Context,
+ client *ibd.Client,
+ users database.UserStore,
+ store database.CookieStore,
+ user string,
+) error {
+ // Get the user's credentials
+ username, password, err := users.GetIBDCreds(ctx, user)
+ if err != nil {
+ return fmt.Errorf("failed to get IBD credentials: %w", err)
+ }
+
+ // Scrape the user's cookies
+ cookie, err := client.Authenticate(ctx, username, password)
+ if err != nil {
+ return fmt.Errorf("failed to authenticate user: %w", err)
+ }
+
+ // Store the cookie
+ err = store.AddCookie(ctx, user, cookie)
+ if err != nil {
+ return fmt.Errorf("failed to store cookie: %w", err)
+ }
+
+ return nil
+}
diff --git a/backend/internal/worker/scraper/scraper.go b/backend/internal/worker/scraper/scraper.go
new file mode 100644
index 0000000..a83d9ae
--- /dev/null
+++ b/backend/internal/worker/scraper/scraper.go
@@ -0,0 +1,191 @@
+package scraper
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "log/slog"
+ "time"
+
+ "ibd-trader/internal/database"
+ "ibd-trader/internal/ibd"
+ "ibd-trader/internal/leader/manager/ibd/scrape"
+ "ibd-trader/internal/redis/taskqueue"
+ "ibd-trader/internal/worker/analyzer"
+
+ "github.com/redis/go-redis/v9"
+)
+
+const (
+ lockTimeout = 1 * time.Minute
+ dequeueTimeout = 5 * time.Second
+)
+
+func RunScraper(
+ ctx context.Context,
+ redis *redis.Client,
+ client *ibd.Client,
+ store database.StockStore,
+ name string,
+) error {
+ queue, err := taskqueue.New(
+ ctx,
+ redis,
+ scrape.Queue,
+ name,
+ taskqueue.WithEncoding[scrape.TaskInfo](scrape.QueueEncoding),
+ )
+ if err != nil {
+ return err
+ }
+
+ aQueue, err := taskqueue.New(
+ ctx,
+ redis,
+ analyzer.Queue,
+ name,
+ taskqueue.WithEncoding[analyzer.TaskInfo](analyzer.QueueEncoding),
+ )
+ if err != nil {
+ return err
+ }
+
+ for {
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ default:
+ waitForTask(ctx, queue, aQueue, client, store)
+ }
+ }
+}
+
+func waitForTask(
+ ctx context.Context,
+ queue taskqueue.TaskQueue[scrape.TaskInfo],
+ aQueue taskqueue.TaskQueue[analyzer.TaskInfo],
+ client *ibd.Client,
+ store database.StockStore,
+) {
+ task, err := queue.Dequeue(ctx, lockTimeout, dequeueTimeout)
+ if err != nil {
+ slog.ErrorContext(ctx, "Failed to dequeue task", "error", err)
+ return
+ }
+ if task == nil {
+ // No task available.
+ return
+ }
+
+ ch := make(chan error)
+ go func() {
+ defer close(ch)
+ ch <- scrapeUrl(ctx, client, store, aQueue, task.Data.Symbol)
+ }()
+
+ ticker := time.NewTicker(lockTimeout / 5)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ctx.Done():
+ // Context was canceled. Return early.
+ return
+ case <-ticker.C:
+ // Extend the lock periodically.
+ func() {
+ ctx, cancel := context.WithTimeout(ctx, lockTimeout/5)
+ defer cancel()
+
+ err := queue.Extend(ctx, task.ID)
+ if err != nil {
+ slog.ErrorContext(ctx, "Failed to extend lock", "error", err)
+ }
+ }()
+ case err = <-ch:
+ // scrapeUrl has completed.
+ if err != nil {
+ slog.ErrorContext(ctx, "Failed to scrape URL", "error", err)
+ _, err = queue.Return(ctx, task.ID, err)
+ if err != nil {
+ slog.ErrorContext(ctx, "Failed to return task", "error", err)
+ return
+ }
+ } else {
+ slog.DebugContext(ctx, "Scraped URL", "symbol", task.Data.Symbol)
+ err = queue.Complete(ctx, task.ID, nil)
+ if err != nil {
+ slog.ErrorContext(ctx, "Failed to complete task", "error", err)
+ return
+ }
+ }
+ return
+ }
+ }
+}
+
+func scrapeUrl(
+ ctx context.Context,
+ client *ibd.Client,
+ store database.StockStore,
+ aQueue taskqueue.TaskQueue[analyzer.TaskInfo],
+ symbol string,
+) error {
+ ctx, cancel := context.WithTimeout(ctx, lockTimeout)
+ defer cancel()
+
+ stockUrl, err := getStockUrl(ctx, store, client, symbol)
+ if err != nil {
+ return fmt.Errorf("failed to get stock url: %w", err)
+ }
+
+ // Scrape the stock info.
+ info, err := client.StockInfo(ctx, stockUrl)
+ if err != nil {
+ return fmt.Errorf("failed to get stock info: %w", err)
+ }
+
+ // Add stock info to the database.
+ id, err := store.AddStockInfo(ctx, info)
+ if err != nil {
+ return fmt.Errorf("failed to add stock info: %w", err)
+ }
+
+ // Add the stock to the analyzer queue.
+ _, err = aQueue.Enqueue(ctx, analyzer.TaskInfo{ID: id})
+ if err != nil {
+ return fmt.Errorf("failed to enqueue analysis task: %w", err)
+ }
+
+ slog.DebugContext(ctx, "Added stock info", "id", id)
+
+ return nil
+}
+
+func getStockUrl(ctx context.Context, store database.StockStore, client *ibd.Client, symbol string) (string, error) {
+ // Get the stock from the database.
+ stock, err := store.GetStock(ctx, symbol)
+ if err == nil {
+ return stock.IBDUrl, nil
+ }
+ if !errors.Is(err, database.ErrStockNotFound) {
+ return "", fmt.Errorf("failed to get stock: %w", err)
+ }
+
+ // If stock isn't found in the database, get the stock from IBD.
+ stock, err = client.Search(ctx, symbol)
+ if errors.Is(err, ibd.ErrSymbolNotFound) {
+ return "", fmt.Errorf("symbol not found: %w", err)
+ }
+ if err != nil {
+ return "", fmt.Errorf("failed to search for symbol: %w", err)
+ }
+
+ // Add the stock to the database.
+ err = store.AddStock(ctx, stock)
+ if err != nil {
+ return "", fmt.Errorf("failed to add stock: %w", err)
+ }
+
+ return stock.IBDUrl, nil
+}
diff --git a/backend/internal/worker/worker.go b/backend/internal/worker/worker.go
new file mode 100644
index 0000000..5858115
--- /dev/null
+++ b/backend/internal/worker/worker.go
@@ -0,0 +1,149 @@
+package worker
+
+import (
+ "context"
+ "crypto/rand"
+ "encoding/base64"
+ "io"
+ "log/slog"
+ "os"
+ "time"
+
+ "ibd-trader/internal/analyzer"
+ "ibd-trader/internal/database"
+ "ibd-trader/internal/ibd"
+ "ibd-trader/internal/leader/manager"
+ analyzer2 "ibd-trader/internal/worker/analyzer"
+ "ibd-trader/internal/worker/auth"
+ "ibd-trader/internal/worker/scraper"
+
+ "github.com/redis/go-redis/v9"
+ "golang.org/x/sync/errgroup"
+)
+
+const (
+ HeartbeatInterval = 5 * time.Second
+ HeartbeatTTL = 30 * time.Second
+)
+
+func StartWorker(
+ ctx context.Context,
+ ibdClient *ibd.Client,
+ client *redis.Client,
+ db database.Database,
+ a analyzer.Analyzer,
+) error {
+ // Get the worker name.
+ name, err := workerName()
+ if err != nil {
+ return err
+ }
+ slog.InfoContext(ctx, "Starting worker", "worker", name)
+
+ g, ctx := errgroup.WithContext(ctx)
+
+ g.Go(func() error {
+ return workerRegistrationLoop(ctx, client, name)
+ })
+ g.Go(func() error {
+ return scraper.RunScraper(ctx, client, ibdClient, db, name)
+ })
+ g.Go(func() error {
+ return auth.RunAuthScraper(ctx, ibdClient, client, db, db, name)
+ })
+ g.Go(func() error {
+ return analyzer2.RunAnalyzer(ctx, client, a, db, name)
+ })
+
+ return g.Wait()
+}
+
+func workerRegistrationLoop(ctx context.Context, client *redis.Client, name string) error {
+ sendHeartbeat(ctx, client, name)
+
+ ticker := time.NewTicker(HeartbeatInterval)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ticker.C:
+ sendHeartbeat(ctx, client, name)
+ case <-ctx.Done():
+ removeWorker(ctx, client, name)
+ return ctx.Err()
+ }
+ }
+}
+
+// sendHeartbeat sends a heartbeat for the worker.
+// It ensures that the worker is in the active workers set and its heartbeat exists.
+func sendHeartbeat(ctx context.Context, client *redis.Client, name string) {
+ ctx, cancel := context.WithTimeout(ctx, HeartbeatInterval)
+ defer cancel()
+
+ // Add the worker to the active workers set.
+ if err := client.SAdd(ctx, manager.ActiveWorkersSet, name).Err(); err != nil {
+ slog.ErrorContext(ctx,
+ "Unable to add worker to active workers set",
+ "worker", name,
+ "error", err,
+ )
+ return
+ }
+
+ // Set the worker's heartbeat.
+ heartbeatKey := manager.WorkerHeartbeatKey(name)
+ if err := client.Set(ctx, heartbeatKey, time.Now().Unix(), HeartbeatTTL).Err(); err != nil {
+ slog.ErrorContext(ctx,
+ "Unable to set worker heartbeat",
+ "worker", name,
+ "error", err,
+ )
+ return
+ }
+}
+
+// removeWorker removes the worker from the active workers set.
+func removeWorker(ctx context.Context, client *redis.Client, name string) {
+ if ctx.Err() != nil {
+ // If the context is canceled, create a new uncanceled context.
+ ctx = context.WithoutCancel(ctx)
+ }
+ ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
+ defer cancel()
+
+ // Remove the worker from the active workers set.
+ if err := client.SRem(ctx, manager.ActiveWorkersSet, name).Err(); err != nil {
+ slog.ErrorContext(ctx,
+ "Unable to remove worker from active workers set",
+ "worker", name,
+ "error", err,
+ )
+ return
+ }
+
+ // Remove the worker's heartbeat.
+ heartbeatKey := manager.WorkerHeartbeatKey(name)
+ if err := client.Del(ctx, heartbeatKey).Err(); err != nil {
+ slog.ErrorContext(ctx,
+ "Unable to remove worker heartbeat",
+ "worker", name,
+ "error", err,
+ )
+ return
+ }
+}
+
+func workerName() (string, error) {
+ hostname, err := os.Hostname()
+ if err != nil {
+ return "", err
+ }
+
+ bytes := make([]byte, 12)
+ if _, err = io.ReadFull(rand.Reader, bytes); err != nil {
+ return "", err
+ }
+
+ return hostname + "-" + base64.URLEncoding.EncodeToString(bytes), nil
+}