diff options
author | 2024-08-07 18:56:01 -0700 | |
---|---|---|
committer | 2024-08-07 18:56:01 -0700 | |
commit | 08993e2f8497341079010d3d06361c99492c4c07 (patch) | |
tree | c65d6d571c928410faace1fa51c2ea3f49fce003 /backend/internal/worker/auth | |
parent | 3de4ebb7560851ccbefe296c197456fe80c22901 (diff) | |
parent | b8aef1a7fb24815c7d93bc30c7b289b4f5896779 (diff) | |
download | ibd-trader-08993e2f8497341079010d3d06361c99492c4c07.tar.gz ibd-trader-08993e2f8497341079010d3d06361c99492c4c07.tar.zst ibd-trader-08993e2f8497341079010d3d06361c99492c4c07.zip |
Merge pull request #1 from ansg191/refactor-database
Diffstat (limited to 'backend/internal/worker/auth')
-rw-r--r-- | backend/internal/worker/auth/auth.go | 61 |
1 files changed, 36 insertions, 25 deletions
diff --git a/backend/internal/worker/auth/auth.go b/backend/internal/worker/auth/auth.go index 1f591fe..579a180 100644 --- a/backend/internal/worker/auth/auth.go +++ b/backend/internal/worker/auth/auth.go @@ -2,12 +2,15 @@ package auth import ( "context" + "database/sql" + "errors" "fmt" "log/slog" "time" "github.com/ansg191/ibd-trader-backend/internal/database" "github.com/ansg191/ibd-trader-backend/internal/ibd" + "github.com/ansg191/ibd-trader-backend/internal/keys" "github.com/ansg191/ibd-trader-backend/internal/leader/manager/ibd/auth" "github.com/ansg191/ibd-trader-backend/internal/redis/taskqueue" @@ -23,8 +26,8 @@ func RunAuthScraper( ctx context.Context, client *ibd.Client, redis *redis.Client, - users database.UserStore, - cookies database.CookieStore, + db database.Executor, + kms keys.KeyManagementService, name string, ) error { queue, err := taskqueue.New( @@ -43,7 +46,7 @@ func RunAuthScraper( case <-ctx.Done(): return ctx.Err() default: - waitForTask(ctx, queue, client, users, cookies) + waitForTask(ctx, queue, client, db, kms) } } } @@ -52,8 +55,8 @@ func waitForTask( ctx context.Context, queue taskqueue.TaskQueue[auth.TaskInfo], client *ibd.Client, - users database.UserStore, - cookies database.CookieStore, + db database.Executor, + kms keys.KeyManagementService, ) { task, err := queue.Dequeue(ctx, lockTimeout, dequeueTimeout) if err != nil { @@ -69,7 +72,7 @@ func waitForTask( ch := make(chan error) defer close(ch) go func() { - ch <- scrapeCookies(ctx, client, users, cookies, task.Data.UserSubject) + ch <- scrapeCookies(ctx, client, db, kms, task.Data.UserSubject) }() ticker := time.NewTicker(lockTimeout / 5) @@ -116,15 +119,15 @@ func waitForTask( func scrapeCookies( ctx context.Context, client *ibd.Client, - users database.UserStore, - store database.CookieStore, + db database.Executor, + kms keys.KeyManagementService, user string, ) error { ctx, cancel := context.WithTimeout(ctx, lockTimeout) defer cancel() // Check if the user has valid cookies - done, err := hasValidCookies(ctx, store, user) + done, err := hasValidCookies(ctx, db, user) if err != nil { return fmt.Errorf("failed to check cookies: %w", err) } @@ -133,7 +136,7 @@ func scrapeCookies( } // Health check degraded cookies - done, err = healthCheckDegradedCookies(ctx, client, store, user) + done, err = healthCheckDegradedCookies(ctx, client, db, kms, user) if err != nil { return fmt.Errorf("failed to health check cookies: %w", err) } @@ -142,31 +145,39 @@ func scrapeCookies( } // No cookies are valid, so scrape new cookies - return scrapeNewCookies(ctx, client, users, store, user) + return scrapeNewCookies(ctx, client, db, kms, user) } -func hasValidCookies(ctx context.Context, store database.CookieStore, user string) (bool, error) { +func hasValidCookies(ctx context.Context, db database.Executor, user string) (bool, error) { // Check if the user has non-degraded cookies - cookies, err := store.GetCookies(ctx, user, false) + row := db.QueryRowContext(ctx, ` +SELECT 1 +FROM ibd_tokens +WHERE user_subject = $1 + AND expires_at > NOW() + AND degraded = FALSE;`, user) + + var exists bool + err := row.Scan(&exists) + if errors.Is(err, sql.ErrNoRows) { + return false, nil + } if err != nil { return false, fmt.Errorf("failed to get non-degraded cookies: %w", err) } - // If the user has non-degraded cookies, return true - if len(cookies) > 0 { - return true, nil - } - return false, nil + return true, nil } func healthCheckDegradedCookies( ctx context.Context, client *ibd.Client, - store database.CookieStore, + db database.Executor, + kms keys.KeyManagementService, user string, ) (bool, error) { // Check if the user has degraded cookies - cookies, err := store.GetCookies(ctx, user, true) + cookies, err := database.GetCookies(ctx, db, kms, user, true) if err != nil { return false, fmt.Errorf("failed to get degraded cookies: %w", err) } @@ -190,7 +201,7 @@ func healthCheckDegradedCookies( valid = true // Update the cookie - err = store.RepairCookie(ctx, cookie.ID) + err = database.RepairCookie(ctx, db, cookie.ID) if err != nil { slog.ErrorContext(ctx, "Failed to repair cookie", "error", err) } @@ -202,12 +213,12 @@ func healthCheckDegradedCookies( func scrapeNewCookies( ctx context.Context, client *ibd.Client, - users database.UserStore, - store database.CookieStore, + db database.Executor, + kms keys.KeyManagementService, user string, ) error { // Get the user's credentials - username, password, err := users.GetIBDCreds(ctx, user) + username, password, err := database.GetIBDCreds(ctx, db, kms, user) if err != nil { return fmt.Errorf("failed to get IBD credentials: %w", err) } @@ -219,7 +230,7 @@ func scrapeNewCookies( } // Store the cookie - err = store.AddCookie(ctx, user, cookie) + err = database.AddCookie(ctx, db, kms, user, cookie) if err != nil { return fmt.Errorf("failed to store cookie: %w", err) } |