aboutsummaryrefslogtreecommitdiff
path: root/backend/internal/worker
diff options
context:
space:
mode:
authorGravatar Anshul Gupta <ansg191@anshulg.com> 2024-08-07 18:56:01 -0700
committerGravatar GitHub <noreply@github.com> 2024-08-07 18:56:01 -0700
commit08993e2f8497341079010d3d06361c99492c4c07 (patch)
treec65d6d571c928410faace1fa51c2ea3f49fce003 /backend/internal/worker
parent3de4ebb7560851ccbefe296c197456fe80c22901 (diff)
parentb8aef1a7fb24815c7d93bc30c7b289b4f5896779 (diff)
downloadibd-trader-08993e2f8497341079010d3d06361c99492c4c07.tar.gz
ibd-trader-08993e2f8497341079010d3d06361c99492c4c07.tar.zst
ibd-trader-08993e2f8497341079010d3d06361c99492c4c07.zip
Merge pull request #1 from ansg191/refactor-database
Diffstat (limited to 'backend/internal/worker')
-rw-r--r--backend/internal/worker/analyzer/analyzer.go10
-rw-r--r--backend/internal/worker/auth/auth.go61
-rw-r--r--backend/internal/worker/scraper/scraper.go20
-rw-r--r--backend/internal/worker/worker.go6
4 files changed, 55 insertions, 42 deletions
diff --git a/backend/internal/worker/analyzer/analyzer.go b/backend/internal/worker/analyzer/analyzer.go
index 79a35ee..ea8069e 100644
--- a/backend/internal/worker/analyzer/analyzer.go
+++ b/backend/internal/worker/analyzer/analyzer.go
@@ -24,7 +24,7 @@ func RunAnalyzer(
ctx context.Context,
redis *redis.Client,
analyzer analyzer.Analyzer,
- db database.StockStore,
+ db database.Executor,
name string,
) error {
queue, err := taskqueue.New(
@@ -52,7 +52,7 @@ func waitForTask(
ctx context.Context,
queue taskqueue.TaskQueue[TaskInfo],
analyzer analyzer.Analyzer,
- db database.StockStore,
+ db database.Executor,
) {
task, err := queue.Dequeue(ctx, lockTimeout, dequeueTimeout)
if err != nil {
@@ -111,8 +111,8 @@ func waitForTask(
}
}
-func analyzeStock(ctx context.Context, a analyzer.Analyzer, db database.StockStore, id string) error {
- info, err := db.GetStockInfo(ctx, id)
+func analyzeStock(ctx context.Context, a analyzer.Analyzer, db database.Executor, id string) error {
+ info, err := database.GetStockInfo(ctx, db, id)
if err != nil {
return err
}
@@ -127,7 +127,7 @@ func analyzeStock(ctx context.Context, a analyzer.Analyzer, db database.StockSto
return err
}
- return db.AddAnalysis(ctx, id, analysis)
+ return database.AddAnalysis(ctx, db, id, analysis)
}
type TaskInfo struct {
diff --git a/backend/internal/worker/auth/auth.go b/backend/internal/worker/auth/auth.go
index 1f591fe..579a180 100644
--- a/backend/internal/worker/auth/auth.go
+++ b/backend/internal/worker/auth/auth.go
@@ -2,12 +2,15 @@ package auth
import (
"context"
+ "database/sql"
+ "errors"
"fmt"
"log/slog"
"time"
"github.com/ansg191/ibd-trader-backend/internal/database"
"github.com/ansg191/ibd-trader-backend/internal/ibd"
+ "github.com/ansg191/ibd-trader-backend/internal/keys"
"github.com/ansg191/ibd-trader-backend/internal/leader/manager/ibd/auth"
"github.com/ansg191/ibd-trader-backend/internal/redis/taskqueue"
@@ -23,8 +26,8 @@ func RunAuthScraper(
ctx context.Context,
client *ibd.Client,
redis *redis.Client,
- users database.UserStore,
- cookies database.CookieStore,
+ db database.Executor,
+ kms keys.KeyManagementService,
name string,
) error {
queue, err := taskqueue.New(
@@ -43,7 +46,7 @@ func RunAuthScraper(
case <-ctx.Done():
return ctx.Err()
default:
- waitForTask(ctx, queue, client, users, cookies)
+ waitForTask(ctx, queue, client, db, kms)
}
}
}
@@ -52,8 +55,8 @@ func waitForTask(
ctx context.Context,
queue taskqueue.TaskQueue[auth.TaskInfo],
client *ibd.Client,
- users database.UserStore,
- cookies database.CookieStore,
+ db database.Executor,
+ kms keys.KeyManagementService,
) {
task, err := queue.Dequeue(ctx, lockTimeout, dequeueTimeout)
if err != nil {
@@ -69,7 +72,7 @@ func waitForTask(
ch := make(chan error)
defer close(ch)
go func() {
- ch <- scrapeCookies(ctx, client, users, cookies, task.Data.UserSubject)
+ ch <- scrapeCookies(ctx, client, db, kms, task.Data.UserSubject)
}()
ticker := time.NewTicker(lockTimeout / 5)
@@ -116,15 +119,15 @@ func waitForTask(
func scrapeCookies(
ctx context.Context,
client *ibd.Client,
- users database.UserStore,
- store database.CookieStore,
+ db database.Executor,
+ kms keys.KeyManagementService,
user string,
) error {
ctx, cancel := context.WithTimeout(ctx, lockTimeout)
defer cancel()
// Check if the user has valid cookies
- done, err := hasValidCookies(ctx, store, user)
+ done, err := hasValidCookies(ctx, db, user)
if err != nil {
return fmt.Errorf("failed to check cookies: %w", err)
}
@@ -133,7 +136,7 @@ func scrapeCookies(
}
// Health check degraded cookies
- done, err = healthCheckDegradedCookies(ctx, client, store, user)
+ done, err = healthCheckDegradedCookies(ctx, client, db, kms, user)
if err != nil {
return fmt.Errorf("failed to health check cookies: %w", err)
}
@@ -142,31 +145,39 @@ func scrapeCookies(
}
// No cookies are valid, so scrape new cookies
- return scrapeNewCookies(ctx, client, users, store, user)
+ return scrapeNewCookies(ctx, client, db, kms, user)
}
-func hasValidCookies(ctx context.Context, store database.CookieStore, user string) (bool, error) {
+func hasValidCookies(ctx context.Context, db database.Executor, user string) (bool, error) {
// Check if the user has non-degraded cookies
- cookies, err := store.GetCookies(ctx, user, false)
+ row := db.QueryRowContext(ctx, `
+SELECT 1
+FROM ibd_tokens
+WHERE user_subject = $1
+ AND expires_at > NOW()
+ AND degraded = FALSE;`, user)
+
+ var exists bool
+ err := row.Scan(&exists)
+ if errors.Is(err, sql.ErrNoRows) {
+ return false, nil
+ }
if err != nil {
return false, fmt.Errorf("failed to get non-degraded cookies: %w", err)
}
- // If the user has non-degraded cookies, return true
- if len(cookies) > 0 {
- return true, nil
- }
- return false, nil
+ return true, nil
}
func healthCheckDegradedCookies(
ctx context.Context,
client *ibd.Client,
- store database.CookieStore,
+ db database.Executor,
+ kms keys.KeyManagementService,
user string,
) (bool, error) {
// Check if the user has degraded cookies
- cookies, err := store.GetCookies(ctx, user, true)
+ cookies, err := database.GetCookies(ctx, db, kms, user, true)
if err != nil {
return false, fmt.Errorf("failed to get degraded cookies: %w", err)
}
@@ -190,7 +201,7 @@ func healthCheckDegradedCookies(
valid = true
// Update the cookie
- err = store.RepairCookie(ctx, cookie.ID)
+ err = database.RepairCookie(ctx, db, cookie.ID)
if err != nil {
slog.ErrorContext(ctx, "Failed to repair cookie", "error", err)
}
@@ -202,12 +213,12 @@ func healthCheckDegradedCookies(
func scrapeNewCookies(
ctx context.Context,
client *ibd.Client,
- users database.UserStore,
- store database.CookieStore,
+ db database.Executor,
+ kms keys.KeyManagementService,
user string,
) error {
// Get the user's credentials
- username, password, err := users.GetIBDCreds(ctx, user)
+ username, password, err := database.GetIBDCreds(ctx, db, kms, user)
if err != nil {
return fmt.Errorf("failed to get IBD credentials: %w", err)
}
@@ -219,7 +230,7 @@ func scrapeNewCookies(
}
// Store the cookie
- err = store.AddCookie(ctx, user, cookie)
+ err = database.AddCookie(ctx, db, kms, user, cookie)
if err != nil {
return fmt.Errorf("failed to store cookie: %w", err)
}
diff --git a/backend/internal/worker/scraper/scraper.go b/backend/internal/worker/scraper/scraper.go
index ec71d62..4788834 100644
--- a/backend/internal/worker/scraper/scraper.go
+++ b/backend/internal/worker/scraper/scraper.go
@@ -25,7 +25,7 @@ func RunScraper(
ctx context.Context,
redis *redis.Client,
client *ibd.Client,
- store database.StockStore,
+ db database.TransactionExecutor,
name string,
) error {
queue, err := taskqueue.New(
@@ -55,7 +55,7 @@ func RunScraper(
case <-ctx.Done():
return ctx.Err()
default:
- waitForTask(ctx, queue, aQueue, client, store)
+ waitForTask(ctx, queue, aQueue, client, db)
}
}
}
@@ -65,7 +65,7 @@ func waitForTask(
queue taskqueue.TaskQueue[scrape.TaskInfo],
aQueue taskqueue.TaskQueue[analyzer.TaskInfo],
client *ibd.Client,
- store database.StockStore,
+ db database.TransactionExecutor,
) {
task, err := queue.Dequeue(ctx, lockTimeout, dequeueTimeout)
if err != nil {
@@ -80,7 +80,7 @@ func waitForTask(
ch := make(chan error)
go func() {
defer close(ch)
- ch <- scrapeUrl(ctx, client, store, aQueue, task.Data.Symbol)
+ ch <- scrapeUrl(ctx, client, db, aQueue, task.Data.Symbol)
}()
ticker := time.NewTicker(lockTimeout / 5)
@@ -127,14 +127,14 @@ func waitForTask(
func scrapeUrl(
ctx context.Context,
client *ibd.Client,
- store database.StockStore,
+ db database.TransactionExecutor,
aQueue taskqueue.TaskQueue[analyzer.TaskInfo],
symbol string,
) error {
ctx, cancel := context.WithTimeout(ctx, lockTimeout)
defer cancel()
- stockUrl, err := getStockUrl(ctx, store, client, symbol)
+ stockUrl, err := getStockUrl(ctx, db, client, symbol)
if err != nil {
return fmt.Errorf("failed to get stock url: %w", err)
}
@@ -146,7 +146,7 @@ func scrapeUrl(
}
// Add stock info to the database.
- id, err := store.AddStockInfo(ctx, info)
+ id, err := database.AddStockInfo(ctx, db, info)
if err != nil {
return fmt.Errorf("failed to add stock info: %w", err)
}
@@ -162,9 +162,9 @@ func scrapeUrl(
return nil
}
-func getStockUrl(ctx context.Context, store database.StockStore, client *ibd.Client, symbol string) (string, error) {
+func getStockUrl(ctx context.Context, db database.TransactionExecutor, client *ibd.Client, symbol string) (string, error) {
// Get the stock from the database.
- stock, err := store.GetStock(ctx, symbol)
+ stock, err := database.GetStock(ctx, db, symbol)
if err == nil {
return stock.IBDUrl, nil
}
@@ -182,7 +182,7 @@ func getStockUrl(ctx context.Context, store database.StockStore, client *ibd.Cli
}
// Add the stock to the database.
- err = store.AddStock(ctx, stock)
+ err = database.AddStock(ctx, db, stock)
if err != nil {
return "", fmt.Errorf("failed to add stock: %w", err)
}
diff --git a/backend/internal/worker/worker.go b/backend/internal/worker/worker.go
index 3d7e9c8..6017fb7 100644
--- a/backend/internal/worker/worker.go
+++ b/backend/internal/worker/worker.go
@@ -12,6 +12,7 @@ import (
"github.com/ansg191/ibd-trader-backend/internal/analyzer"
"github.com/ansg191/ibd-trader-backend/internal/database"
"github.com/ansg191/ibd-trader-backend/internal/ibd"
+ "github.com/ansg191/ibd-trader-backend/internal/keys"
"github.com/ansg191/ibd-trader-backend/internal/leader/manager"
analyzer2 "github.com/ansg191/ibd-trader-backend/internal/worker/analyzer"
"github.com/ansg191/ibd-trader-backend/internal/worker/auth"
@@ -30,7 +31,8 @@ func StartWorker(
ctx context.Context,
ibdClient *ibd.Client,
client *redis.Client,
- db database.Database,
+ db database.TransactionExecutor,
+ kms keys.KeyManagementService,
a analyzer.Analyzer,
) error {
// Get the worker name.
@@ -49,7 +51,7 @@ func StartWorker(
return scraper.RunScraper(ctx, client, ibdClient, db, name)
})
g.Go(func() error {
- return auth.RunAuthScraper(ctx, ibdClient, client, db, db, name)
+ return auth.RunAuthScraper(ctx, ibdClient, client, db, kms, name)
})
g.Go(func() error {
return analyzer2.RunAnalyzer(ctx, client, a, db, name)