diff options
author | 2024-08-11 13:15:50 -0700 | |
---|---|---|
committer | 2024-08-11 13:15:50 -0700 | |
commit | 6a3c21fb0b1c126849f2bbff494403bbe901448e (patch) | |
tree | 5d7805524357c2c8a9819c39d2051a4e3633a1d5 /backend/internal/worker | |
parent | 29c6040a51616e9e4cf6c70ee16391b2a3b238c9 (diff) | |
parent | f34b92ded11b07f78575ac62c260a380c468e5ea (diff) | |
download | ibd-trader-6a3c21fb0b1c126849f2bbff494403bbe901448e.tar.gz ibd-trader-6a3c21fb0b1c126849f2bbff494403bbe901448e.tar.zst ibd-trader-6a3c21fb0b1c126849f2bbff494403bbe901448e.zip |
Merge remote-tracking branch 'backend/main'
Diffstat (limited to 'backend/internal/worker')
-rw-r--r-- | backend/internal/worker/analyzer/analyzer.go | 142 | ||||
-rw-r--r-- | backend/internal/worker/auth/auth.go | 239 | ||||
-rw-r--r-- | backend/internal/worker/scraper/scraper.go | 198 | ||||
-rw-r--r-- | backend/internal/worker/worker.go | 151 |
4 files changed, 730 insertions, 0 deletions
diff --git a/backend/internal/worker/analyzer/analyzer.go b/backend/internal/worker/analyzer/analyzer.go new file mode 100644 index 0000000..20621dd --- /dev/null +++ b/backend/internal/worker/analyzer/analyzer.go @@ -0,0 +1,142 @@ +package analyzer + +import ( + "context" + "log/slog" + "time" + + "github.com/ansg191/ibd-trader-backend/internal/analyzer" + "github.com/ansg191/ibd-trader-backend/internal/database" + "github.com/ansg191/ibd-trader-backend/internal/redis/taskqueue" + + "github.com/redis/go-redis/v9" +) + +const ( + Queue = "analyzer" + QueueEncoding = taskqueue.EncodingJSON + + lockTimeout = 1 * time.Minute + dequeueTimeout = 5 * time.Second +) + +func RunAnalyzer( + ctx context.Context, + redis *redis.Client, + analyzer analyzer.Analyzer, + db database.Executor, + name string, +) error { + queue, err := taskqueue.New( + ctx, + redis, + Queue, + name, + taskqueue.WithEncoding[TaskInfo](QueueEncoding), + ) + if err != nil { + return err + } + + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + waitForTask(ctx, queue, analyzer, db) + } + } +} + +func waitForTask( + ctx context.Context, + queue taskqueue.TaskQueue[TaskInfo], + analyzer analyzer.Analyzer, + db database.Executor, +) { + task, err := queue.Dequeue(ctx, lockTimeout, dequeueTimeout) + if err != nil { + slog.ErrorContext(ctx, "Failed to dequeue task", "error", err) + return + } + if task == nil { + // No task available. + return + } + + errCh := make(chan error) + resCh := make(chan string) + defer close(errCh) + defer close(resCh) + go func() { + res, err := analyzeStock(ctx, analyzer, db, task.Data.ID) + if err != nil { + errCh <- err + return + } + resCh <- res + }() + + ticker := time.NewTicker(lockTimeout / 5) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + // Context was canceled. Return early. + return + case <-ticker.C: + // Extend the lock periodically. + func() { + ctx, cancel := context.WithTimeout(ctx, lockTimeout/5) + defer cancel() + + err := queue.Extend(ctx, task.ID) + if err != nil { + slog.ErrorContext(ctx, "Failed to extend lock", "error", err) + } + }() + case err = <-errCh: + // analyzeStock has errored. + slog.ErrorContext(ctx, "Failed to analyze", "error", err) + _, err = queue.Return(ctx, task.ID, err) + if err != nil { + slog.ErrorContext(ctx, "Failed to return task", "error", err) + return + } + return + case res := <-resCh: + // analyzeStock has completed successfully. + slog.DebugContext(ctx, "Analyzed ID", "id", task.Data.ID, "result", res) + err = queue.Complete(ctx, task.ID, res) + if err != nil { + slog.ErrorContext(ctx, "Failed to complete task", "error", err) + return + } + return + } + } +} + +func analyzeStock(ctx context.Context, a analyzer.Analyzer, db database.Executor, id string) (string, error) { + info, err := database.GetStockInfo(ctx, db, id) + if err != nil { + return "", err + } + + analysis, err := a.Analyze( + ctx, + info.Symbol, + info.Price, + info.ChartAnalysis, + ) + if err != nil { + return "", err + } + + return database.AddAnalysis(ctx, db, id, analysis) +} + +type TaskInfo struct { + ID string `json:"id"` +} diff --git a/backend/internal/worker/auth/auth.go b/backend/internal/worker/auth/auth.go new file mode 100644 index 0000000..0daa112 --- /dev/null +++ b/backend/internal/worker/auth/auth.go @@ -0,0 +1,239 @@ +package auth + +import ( + "context" + "database/sql" + "errors" + "fmt" + "log/slog" + "time" + + "github.com/ansg191/ibd-trader-backend/internal/database" + "github.com/ansg191/ibd-trader-backend/internal/ibd" + "github.com/ansg191/ibd-trader-backend/internal/keys" + "github.com/ansg191/ibd-trader-backend/internal/leader/manager/ibd/auth" + "github.com/ansg191/ibd-trader-backend/internal/redis/taskqueue" + + "github.com/redis/go-redis/v9" +) + +const ( + lockTimeout = 1 * time.Minute + dequeueTimeout = 5 * time.Second +) + +func RunAuthScraper( + ctx context.Context, + client *ibd.Client, + redis *redis.Client, + db database.TransactionExecutor, + kms keys.KeyManagementService, + name string, +) error { + queue, err := taskqueue.New( + ctx, + redis, + auth.Queue, + name, + taskqueue.WithEncoding[auth.TaskInfo](auth.QueueEncoding), + ) + if err != nil { + return err + } + + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + waitForTask(ctx, queue, client, db, kms) + } + } +} + +func waitForTask( + ctx context.Context, + queue taskqueue.TaskQueue[auth.TaskInfo], + client *ibd.Client, + db database.TransactionExecutor, + kms keys.KeyManagementService, +) { + task, err := queue.Dequeue(ctx, lockTimeout, dequeueTimeout) + if err != nil { + slog.ErrorContext(ctx, "Failed to dequeue task", "error", err) + return + } + if task == nil { + // No task available. + return + } + slog.DebugContext(ctx, "Picked up auth task", "task", task.ID, "user", task.Data.UserSubject) + + ch := make(chan error) + defer close(ch) + go func() { + ch <- scrapeCookies(ctx, client, db, kms, task.Data.UserSubject) + }() + + ticker := time.NewTicker(lockTimeout / 5) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + // The context was canceled. Return early. + return + case <-ticker.C: + // Extend the lock periodically. + func() { + ctx, cancel := context.WithTimeout(ctx, lockTimeout/5) + defer cancel() + + err := queue.Extend(ctx, task.ID) + if err != nil { + slog.ErrorContext(ctx, "Failed to extend lock", "error", err) + } + }() + case err = <-ch: + // scrapeCookies has completed. + if err != nil { + slog.ErrorContext(ctx, "Failed to scrape cookies", "error", err) + _, err = queue.Return(ctx, task.ID, err) + if err != nil { + slog.ErrorContext(ctx, "Failed to return task", "error", err) + return + } + } else { + err = queue.Complete(ctx, task.ID, "") + if err != nil { + slog.ErrorContext(ctx, "Failed to complete task", "error", err) + return + } + slog.DebugContext(ctx, "Authenticated user", "user", task.Data.UserSubject) + } + return + } + } +} + +func scrapeCookies( + ctx context.Context, + client *ibd.Client, + db database.TransactionExecutor, + kms keys.KeyManagementService, + user string, +) error { + ctx, cancel := context.WithTimeout(ctx, lockTimeout) + defer cancel() + + // Check if the user has valid cookies + done, err := hasValidCookies(ctx, db, user) + if err != nil { + return fmt.Errorf("failed to check cookies: %w", err) + } + if done { + return nil + } + + // Health check degraded cookies + done, err = healthCheckDegradedCookies(ctx, client, db, kms, user) + if err != nil { + return fmt.Errorf("failed to health check cookies: %w", err) + } + if done { + return nil + } + + // No cookies are valid, so scrape new cookies + return scrapeNewCookies(ctx, client, db, kms, user) +} + +func hasValidCookies(ctx context.Context, db database.Executor, user string) (bool, error) { + // Check if the user has non-degraded cookies + row := db.QueryRowContext(ctx, ` +SELECT 1 +FROM ibd_tokens +WHERE user_subject = $1 + AND expires_at > NOW() + AND degraded = FALSE;`, user) + + var exists bool + err := row.Scan(&exists) + if errors.Is(err, sql.ErrNoRows) { + return false, nil + } + if err != nil { + return false, fmt.Errorf("failed to get non-degraded cookies: %w", err) + } + + return true, nil +} + +func healthCheckDegradedCookies( + ctx context.Context, + client *ibd.Client, + db database.Executor, + kms keys.KeyManagementService, + user string, +) (bool, error) { + // Check if the user has degraded cookies + cookies, err := database.GetCookies(ctx, db, kms, user, true) + if err != nil { + return false, fmt.Errorf("failed to get degraded cookies: %w", err) + } + + valid := false + for _, cookie := range cookies { + slog.DebugContext(ctx, "Health checking cookie", "cookie", cookie.ID) + + // Health check the cookie + up, err := client.UserInfo(ctx, cookie.ToHTTPCookie()) + if err != nil { + slog.ErrorContext(ctx, "Failed to health check cookie", "error", err) + continue + } + + if up.Status != ibd.UserStatusSubscriber { + continue + } + + // Cookie is valid + valid = true + + // Update the cookie + err = database.RepairCookie(ctx, db, cookie.ID) + if err != nil { + slog.ErrorContext(ctx, "Failed to repair cookie", "error", err) + } + } + + return valid, nil +} + +func scrapeNewCookies( + ctx context.Context, + client *ibd.Client, + db database.TransactionExecutor, + kms keys.KeyManagementService, + user string, +) error { + // Get the user's credentials + username, password, err := database.GetIBDCreds(ctx, db, kms, user) + if err != nil { + return fmt.Errorf("failed to get IBD credentials: %w", err) + } + + // Scrape the user's cookies + cookie, err := client.Authenticate(ctx, username, password) + if err != nil { + return fmt.Errorf("failed to authenticate user: %w", err) + } + + // Store the cookie + err = database.AddCookie(ctx, db, kms, user, cookie) + if err != nil { + return fmt.Errorf("failed to store cookie: %w", err) + } + + return nil +} diff --git a/backend/internal/worker/scraper/scraper.go b/backend/internal/worker/scraper/scraper.go new file mode 100644 index 0000000..c5c1b6c --- /dev/null +++ b/backend/internal/worker/scraper/scraper.go @@ -0,0 +1,198 @@ +package scraper + +import ( + "context" + "errors" + "fmt" + "log/slog" + "time" + + "github.com/ansg191/ibd-trader-backend/internal/database" + "github.com/ansg191/ibd-trader-backend/internal/ibd" + "github.com/ansg191/ibd-trader-backend/internal/leader/manager/ibd/scrape" + "github.com/ansg191/ibd-trader-backend/internal/redis/taskqueue" + "github.com/ansg191/ibd-trader-backend/internal/worker/analyzer" + + "github.com/redis/go-redis/v9" +) + +const ( + lockTimeout = 1 * time.Minute + dequeueTimeout = 5 * time.Second +) + +func RunScraper( + ctx context.Context, + redis *redis.Client, + client *ibd.Client, + db database.TransactionExecutor, + name string, +) error { + queue, err := taskqueue.New( + ctx, + redis, + scrape.Queue, + name, + taskqueue.WithEncoding[scrape.TaskInfo](scrape.QueueEncoding), + ) + if err != nil { + return err + } + + aQueue, err := taskqueue.New( + ctx, + redis, + analyzer.Queue, + name, + taskqueue.WithEncoding[analyzer.TaskInfo](analyzer.QueueEncoding), + ) + if err != nil { + return err + } + + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + waitForTask(ctx, queue, aQueue, client, db) + } + } +} + +func waitForTask( + ctx context.Context, + queue taskqueue.TaskQueue[scrape.TaskInfo], + aQueue taskqueue.TaskQueue[analyzer.TaskInfo], + client *ibd.Client, + db database.TransactionExecutor, +) { + task, err := queue.Dequeue(ctx, lockTimeout, dequeueTimeout) + if err != nil { + slog.ErrorContext(ctx, "Failed to dequeue task", "error", err) + return + } + if task == nil { + // No task available. + return + } + + errCh := make(chan error) + resCh := make(chan string) + defer close(errCh) + defer close(resCh) + go func() { + res, err := scrapeUrl(ctx, client, db, aQueue, task.Data.Symbol) + if err != nil { + errCh <- err + return + } + resCh <- res + }() + + ticker := time.NewTicker(lockTimeout / 5) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + // Context was canceled. Return early. + return + case <-ticker.C: + // Extend the lock periodically. + func() { + ctx, cancel := context.WithTimeout(ctx, lockTimeout/5) + defer cancel() + + err := queue.Extend(ctx, task.ID) + if err != nil { + slog.ErrorContext(ctx, "Failed to extend lock", "error", err) + } + }() + case err = <-errCh: + // scrapeUrl has errored. + slog.ErrorContext(ctx, "Failed to scrape URL", "error", err) + _, err = queue.Return(ctx, task.ID, err) + if err != nil { + slog.ErrorContext(ctx, "Failed to return task", "error", err) + return + } + return + case res := <-resCh: + // scrapeUrl has completed successfully. + slog.DebugContext(ctx, "Scraped URL", "symbol", task.Data.Symbol) + err = queue.Complete(ctx, task.ID, res) + if err != nil { + slog.ErrorContext(ctx, "Failed to complete task", "error", err) + return + } + return + } + } +} + +func scrapeUrl( + ctx context.Context, + client *ibd.Client, + db database.TransactionExecutor, + aQueue taskqueue.TaskQueue[analyzer.TaskInfo], + symbol string, +) (string, error) { + ctx, cancel := context.WithTimeout(ctx, lockTimeout) + defer cancel() + + stockUrl, err := getStockUrl(ctx, db, client, symbol) + if err != nil { + return "", fmt.Errorf("failed to get stock url: %w", err) + } + + // Scrape the stock info. + info, err := client.StockInfo(ctx, stockUrl) + if err != nil { + return "", fmt.Errorf("failed to get stock info: %w", err) + } + + // Add stock info to the database. + id, err := database.AddStockInfo(ctx, db, info) + if err != nil { + return "", fmt.Errorf("failed to add stock info: %w", err) + } + + // Add the stock to the analyzer queue. + _, err = aQueue.Enqueue(ctx, analyzer.TaskInfo{ID: id}) + if err != nil { + return "", fmt.Errorf("failed to enqueue analysis task: %w", err) + } + + slog.DebugContext(ctx, "Added stock info", "id", id) + + return id, nil +} + +func getStockUrl(ctx context.Context, db database.TransactionExecutor, client *ibd.Client, symbol string) (string, error) { + // Get the stock from the database. + stock, err := database.GetStock(ctx, db, symbol) + if err == nil { + return stock.IBDUrl, nil + } + if !errors.Is(err, database.ErrStockNotFound) { + return "", fmt.Errorf("failed to get stock: %w", err) + } + + // If stock isn't found in the database, get the stock from IBD. + stock, err = client.Search(ctx, symbol) + if errors.Is(err, ibd.ErrSymbolNotFound) { + return "", fmt.Errorf("symbol not found: %w", err) + } + if err != nil { + return "", fmt.Errorf("failed to search for symbol: %w", err) + } + + // Add the stock to the database. + err = database.AddStock(ctx, db, stock) + if err != nil { + return "", fmt.Errorf("failed to add stock: %w", err) + } + + return stock.IBDUrl, nil +} diff --git a/backend/internal/worker/worker.go b/backend/internal/worker/worker.go new file mode 100644 index 0000000..6017fb7 --- /dev/null +++ b/backend/internal/worker/worker.go @@ -0,0 +1,151 @@ +package worker + +import ( + "context" + "crypto/rand" + "encoding/base64" + "io" + "log/slog" + "os" + "time" + + "github.com/ansg191/ibd-trader-backend/internal/analyzer" + "github.com/ansg191/ibd-trader-backend/internal/database" + "github.com/ansg191/ibd-trader-backend/internal/ibd" + "github.com/ansg191/ibd-trader-backend/internal/keys" + "github.com/ansg191/ibd-trader-backend/internal/leader/manager" + analyzer2 "github.com/ansg191/ibd-trader-backend/internal/worker/analyzer" + "github.com/ansg191/ibd-trader-backend/internal/worker/auth" + "github.com/ansg191/ibd-trader-backend/internal/worker/scraper" + + "github.com/redis/go-redis/v9" + "golang.org/x/sync/errgroup" +) + +const ( + HeartbeatInterval = 5 * time.Second + HeartbeatTTL = 30 * time.Second +) + +func StartWorker( + ctx context.Context, + ibdClient *ibd.Client, + client *redis.Client, + db database.TransactionExecutor, + kms keys.KeyManagementService, + a analyzer.Analyzer, +) error { + // Get the worker name. + name, err := workerName() + if err != nil { + return err + } + slog.InfoContext(ctx, "Starting worker", "worker", name) + + g, ctx := errgroup.WithContext(ctx) + + g.Go(func() error { + return workerRegistrationLoop(ctx, client, name) + }) + g.Go(func() error { + return scraper.RunScraper(ctx, client, ibdClient, db, name) + }) + g.Go(func() error { + return auth.RunAuthScraper(ctx, ibdClient, client, db, kms, name) + }) + g.Go(func() error { + return analyzer2.RunAnalyzer(ctx, client, a, db, name) + }) + + return g.Wait() +} + +func workerRegistrationLoop(ctx context.Context, client *redis.Client, name string) error { + sendHeartbeat(ctx, client, name) + + ticker := time.NewTicker(HeartbeatInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + sendHeartbeat(ctx, client, name) + case <-ctx.Done(): + removeWorker(ctx, client, name) + return ctx.Err() + } + } +} + +// sendHeartbeat sends a heartbeat for the worker. +// It ensures that the worker is in the active workers set and its heartbeat exists. +func sendHeartbeat(ctx context.Context, client *redis.Client, name string) { + ctx, cancel := context.WithTimeout(ctx, HeartbeatInterval) + defer cancel() + + // Add the worker to the active workers set. + if err := client.SAdd(ctx, manager.ActiveWorkersSet, name).Err(); err != nil { + slog.ErrorContext(ctx, + "Unable to add worker to active workers set", + "worker", name, + "error", err, + ) + return + } + + // Set the worker's heartbeat. + heartbeatKey := manager.WorkerHeartbeatKey(name) + if err := client.Set(ctx, heartbeatKey, time.Now().Unix(), HeartbeatTTL).Err(); err != nil { + slog.ErrorContext(ctx, + "Unable to set worker heartbeat", + "worker", name, + "error", err, + ) + return + } +} + +// removeWorker removes the worker from the active workers set. +func removeWorker(ctx context.Context, client *redis.Client, name string) { + if ctx.Err() != nil { + // If the context is canceled, create a new uncanceled context. + ctx = context.WithoutCancel(ctx) + } + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + // Remove the worker from the active workers set. + if err := client.SRem(ctx, manager.ActiveWorkersSet, name).Err(); err != nil { + slog.ErrorContext(ctx, + "Unable to remove worker from active workers set", + "worker", name, + "error", err, + ) + return + } + + // Remove the worker's heartbeat. + heartbeatKey := manager.WorkerHeartbeatKey(name) + if err := client.Del(ctx, heartbeatKey).Err(); err != nil { + slog.ErrorContext(ctx, + "Unable to remove worker heartbeat", + "worker", name, + "error", err, + ) + return + } +} + +func workerName() (string, error) { + hostname, err := os.Hostname() + if err != nil { + return "", err + } + + bytes := make([]byte, 12) + if _, err = io.ReadFull(rand.Reader, bytes); err != nil { + return "", err + } + + return hostname + "-" + base64.URLEncoding.EncodeToString(bytes), nil +} |