diff options
author | 2024-08-07 18:56:01 -0700 | |
---|---|---|
committer | 2024-08-07 18:56:01 -0700 | |
commit | 08993e2f8497341079010d3d06361c99492c4c07 (patch) | |
tree | c65d6d571c928410faace1fa51c2ea3f49fce003 /backend/internal/worker | |
parent | 3de4ebb7560851ccbefe296c197456fe80c22901 (diff) | |
parent | b8aef1a7fb24815c7d93bc30c7b289b4f5896779 (diff) | |
download | ibd-trader-08993e2f8497341079010d3d06361c99492c4c07.tar.gz ibd-trader-08993e2f8497341079010d3d06361c99492c4c07.tar.zst ibd-trader-08993e2f8497341079010d3d06361c99492c4c07.zip |
Merge pull request #1 from ansg191/refactor-database
Diffstat (limited to 'backend/internal/worker')
-rw-r--r-- | backend/internal/worker/analyzer/analyzer.go | 10 | ||||
-rw-r--r-- | backend/internal/worker/auth/auth.go | 61 | ||||
-rw-r--r-- | backend/internal/worker/scraper/scraper.go | 20 | ||||
-rw-r--r-- | backend/internal/worker/worker.go | 6 |
4 files changed, 55 insertions, 42 deletions
diff --git a/backend/internal/worker/analyzer/analyzer.go b/backend/internal/worker/analyzer/analyzer.go index 79a35ee..ea8069e 100644 --- a/backend/internal/worker/analyzer/analyzer.go +++ b/backend/internal/worker/analyzer/analyzer.go @@ -24,7 +24,7 @@ func RunAnalyzer( ctx context.Context, redis *redis.Client, analyzer analyzer.Analyzer, - db database.StockStore, + db database.Executor, name string, ) error { queue, err := taskqueue.New( @@ -52,7 +52,7 @@ func waitForTask( ctx context.Context, queue taskqueue.TaskQueue[TaskInfo], analyzer analyzer.Analyzer, - db database.StockStore, + db database.Executor, ) { task, err := queue.Dequeue(ctx, lockTimeout, dequeueTimeout) if err != nil { @@ -111,8 +111,8 @@ func waitForTask( } } -func analyzeStock(ctx context.Context, a analyzer.Analyzer, db database.StockStore, id string) error { - info, err := db.GetStockInfo(ctx, id) +func analyzeStock(ctx context.Context, a analyzer.Analyzer, db database.Executor, id string) error { + info, err := database.GetStockInfo(ctx, db, id) if err != nil { return err } @@ -127,7 +127,7 @@ func analyzeStock(ctx context.Context, a analyzer.Analyzer, db database.StockSto return err } - return db.AddAnalysis(ctx, id, analysis) + return database.AddAnalysis(ctx, db, id, analysis) } type TaskInfo struct { diff --git a/backend/internal/worker/auth/auth.go b/backend/internal/worker/auth/auth.go index 1f591fe..579a180 100644 --- a/backend/internal/worker/auth/auth.go +++ b/backend/internal/worker/auth/auth.go @@ -2,12 +2,15 @@ package auth import ( "context" + "database/sql" + "errors" "fmt" "log/slog" "time" "github.com/ansg191/ibd-trader-backend/internal/database" "github.com/ansg191/ibd-trader-backend/internal/ibd" + "github.com/ansg191/ibd-trader-backend/internal/keys" "github.com/ansg191/ibd-trader-backend/internal/leader/manager/ibd/auth" "github.com/ansg191/ibd-trader-backend/internal/redis/taskqueue" @@ -23,8 +26,8 @@ func RunAuthScraper( ctx context.Context, client *ibd.Client, redis *redis.Client, - users database.UserStore, - cookies database.CookieStore, + db database.Executor, + kms keys.KeyManagementService, name string, ) error { queue, err := taskqueue.New( @@ -43,7 +46,7 @@ func RunAuthScraper( case <-ctx.Done(): return ctx.Err() default: - waitForTask(ctx, queue, client, users, cookies) + waitForTask(ctx, queue, client, db, kms) } } } @@ -52,8 +55,8 @@ func waitForTask( ctx context.Context, queue taskqueue.TaskQueue[auth.TaskInfo], client *ibd.Client, - users database.UserStore, - cookies database.CookieStore, + db database.Executor, + kms keys.KeyManagementService, ) { task, err := queue.Dequeue(ctx, lockTimeout, dequeueTimeout) if err != nil { @@ -69,7 +72,7 @@ func waitForTask( ch := make(chan error) defer close(ch) go func() { - ch <- scrapeCookies(ctx, client, users, cookies, task.Data.UserSubject) + ch <- scrapeCookies(ctx, client, db, kms, task.Data.UserSubject) }() ticker := time.NewTicker(lockTimeout / 5) @@ -116,15 +119,15 @@ func waitForTask( func scrapeCookies( ctx context.Context, client *ibd.Client, - users database.UserStore, - store database.CookieStore, + db database.Executor, + kms keys.KeyManagementService, user string, ) error { ctx, cancel := context.WithTimeout(ctx, lockTimeout) defer cancel() // Check if the user has valid cookies - done, err := hasValidCookies(ctx, store, user) + done, err := hasValidCookies(ctx, db, user) if err != nil { return fmt.Errorf("failed to check cookies: %w", err) } @@ -133,7 +136,7 @@ func scrapeCookies( } // Health check degraded cookies - done, err = healthCheckDegradedCookies(ctx, client, store, user) + done, err = healthCheckDegradedCookies(ctx, client, db, kms, user) if err != nil { return fmt.Errorf("failed to health check cookies: %w", err) } @@ -142,31 +145,39 @@ func scrapeCookies( } // No cookies are valid, so scrape new cookies - return scrapeNewCookies(ctx, client, users, store, user) + return scrapeNewCookies(ctx, client, db, kms, user) } -func hasValidCookies(ctx context.Context, store database.CookieStore, user string) (bool, error) { +func hasValidCookies(ctx context.Context, db database.Executor, user string) (bool, error) { // Check if the user has non-degraded cookies - cookies, err := store.GetCookies(ctx, user, false) + row := db.QueryRowContext(ctx, ` +SELECT 1 +FROM ibd_tokens +WHERE user_subject = $1 + AND expires_at > NOW() + AND degraded = FALSE;`, user) + + var exists bool + err := row.Scan(&exists) + if errors.Is(err, sql.ErrNoRows) { + return false, nil + } if err != nil { return false, fmt.Errorf("failed to get non-degraded cookies: %w", err) } - // If the user has non-degraded cookies, return true - if len(cookies) > 0 { - return true, nil - } - return false, nil + return true, nil } func healthCheckDegradedCookies( ctx context.Context, client *ibd.Client, - store database.CookieStore, + db database.Executor, + kms keys.KeyManagementService, user string, ) (bool, error) { // Check if the user has degraded cookies - cookies, err := store.GetCookies(ctx, user, true) + cookies, err := database.GetCookies(ctx, db, kms, user, true) if err != nil { return false, fmt.Errorf("failed to get degraded cookies: %w", err) } @@ -190,7 +201,7 @@ func healthCheckDegradedCookies( valid = true // Update the cookie - err = store.RepairCookie(ctx, cookie.ID) + err = database.RepairCookie(ctx, db, cookie.ID) if err != nil { slog.ErrorContext(ctx, "Failed to repair cookie", "error", err) } @@ -202,12 +213,12 @@ func healthCheckDegradedCookies( func scrapeNewCookies( ctx context.Context, client *ibd.Client, - users database.UserStore, - store database.CookieStore, + db database.Executor, + kms keys.KeyManagementService, user string, ) error { // Get the user's credentials - username, password, err := users.GetIBDCreds(ctx, user) + username, password, err := database.GetIBDCreds(ctx, db, kms, user) if err != nil { return fmt.Errorf("failed to get IBD credentials: %w", err) } @@ -219,7 +230,7 @@ func scrapeNewCookies( } // Store the cookie - err = store.AddCookie(ctx, user, cookie) + err = database.AddCookie(ctx, db, kms, user, cookie) if err != nil { return fmt.Errorf("failed to store cookie: %w", err) } diff --git a/backend/internal/worker/scraper/scraper.go b/backend/internal/worker/scraper/scraper.go index ec71d62..4788834 100644 --- a/backend/internal/worker/scraper/scraper.go +++ b/backend/internal/worker/scraper/scraper.go @@ -25,7 +25,7 @@ func RunScraper( ctx context.Context, redis *redis.Client, client *ibd.Client, - store database.StockStore, + db database.TransactionExecutor, name string, ) error { queue, err := taskqueue.New( @@ -55,7 +55,7 @@ func RunScraper( case <-ctx.Done(): return ctx.Err() default: - waitForTask(ctx, queue, aQueue, client, store) + waitForTask(ctx, queue, aQueue, client, db) } } } @@ -65,7 +65,7 @@ func waitForTask( queue taskqueue.TaskQueue[scrape.TaskInfo], aQueue taskqueue.TaskQueue[analyzer.TaskInfo], client *ibd.Client, - store database.StockStore, + db database.TransactionExecutor, ) { task, err := queue.Dequeue(ctx, lockTimeout, dequeueTimeout) if err != nil { @@ -80,7 +80,7 @@ func waitForTask( ch := make(chan error) go func() { defer close(ch) - ch <- scrapeUrl(ctx, client, store, aQueue, task.Data.Symbol) + ch <- scrapeUrl(ctx, client, db, aQueue, task.Data.Symbol) }() ticker := time.NewTicker(lockTimeout / 5) @@ -127,14 +127,14 @@ func waitForTask( func scrapeUrl( ctx context.Context, client *ibd.Client, - store database.StockStore, + db database.TransactionExecutor, aQueue taskqueue.TaskQueue[analyzer.TaskInfo], symbol string, ) error { ctx, cancel := context.WithTimeout(ctx, lockTimeout) defer cancel() - stockUrl, err := getStockUrl(ctx, store, client, symbol) + stockUrl, err := getStockUrl(ctx, db, client, symbol) if err != nil { return fmt.Errorf("failed to get stock url: %w", err) } @@ -146,7 +146,7 @@ func scrapeUrl( } // Add stock info to the database. - id, err := store.AddStockInfo(ctx, info) + id, err := database.AddStockInfo(ctx, db, info) if err != nil { return fmt.Errorf("failed to add stock info: %w", err) } @@ -162,9 +162,9 @@ func scrapeUrl( return nil } -func getStockUrl(ctx context.Context, store database.StockStore, client *ibd.Client, symbol string) (string, error) { +func getStockUrl(ctx context.Context, db database.TransactionExecutor, client *ibd.Client, symbol string) (string, error) { // Get the stock from the database. - stock, err := store.GetStock(ctx, symbol) + stock, err := database.GetStock(ctx, db, symbol) if err == nil { return stock.IBDUrl, nil } @@ -182,7 +182,7 @@ func getStockUrl(ctx context.Context, store database.StockStore, client *ibd.Cli } // Add the stock to the database. - err = store.AddStock(ctx, stock) + err = database.AddStock(ctx, db, stock) if err != nil { return "", fmt.Errorf("failed to add stock: %w", err) } diff --git a/backend/internal/worker/worker.go b/backend/internal/worker/worker.go index 3d7e9c8..6017fb7 100644 --- a/backend/internal/worker/worker.go +++ b/backend/internal/worker/worker.go @@ -12,6 +12,7 @@ import ( "github.com/ansg191/ibd-trader-backend/internal/analyzer" "github.com/ansg191/ibd-trader-backend/internal/database" "github.com/ansg191/ibd-trader-backend/internal/ibd" + "github.com/ansg191/ibd-trader-backend/internal/keys" "github.com/ansg191/ibd-trader-backend/internal/leader/manager" analyzer2 "github.com/ansg191/ibd-trader-backend/internal/worker/analyzer" "github.com/ansg191/ibd-trader-backend/internal/worker/auth" @@ -30,7 +31,8 @@ func StartWorker( ctx context.Context, ibdClient *ibd.Client, client *redis.Client, - db database.Database, + db database.TransactionExecutor, + kms keys.KeyManagementService, a analyzer.Analyzer, ) error { // Get the worker name. @@ -49,7 +51,7 @@ func StartWorker( return scraper.RunScraper(ctx, client, ibdClient, db, name) }) g.Go(func() error { - return auth.RunAuthScraper(ctx, ibdClient, client, db, db, name) + return auth.RunAuthScraper(ctx, ibdClient, client, db, kms, name) }) g.Go(func() error { return analyzer2.RunAnalyzer(ctx, client, a, db, name) |