aboutsummaryrefslogtreecommitdiff
path: root/backend/internal/worker/auth
diff options
context:
space:
mode:
authorGravatar Anshul Gupta <ansg191@anshulg.com> 2024-08-07 18:56:01 -0700
committerGravatar GitHub <noreply@github.com> 2024-08-07 18:56:01 -0700
commit08993e2f8497341079010d3d06361c99492c4c07 (patch)
treec65d6d571c928410faace1fa51c2ea3f49fce003 /backend/internal/worker/auth
parent3de4ebb7560851ccbefe296c197456fe80c22901 (diff)
parentb8aef1a7fb24815c7d93bc30c7b289b4f5896779 (diff)
downloadibd-trader-08993e2f8497341079010d3d06361c99492c4c07.tar.gz
ibd-trader-08993e2f8497341079010d3d06361c99492c4c07.tar.zst
ibd-trader-08993e2f8497341079010d3d06361c99492c4c07.zip
Merge pull request #1 from ansg191/refactor-database
Diffstat (limited to 'backend/internal/worker/auth')
-rw-r--r--backend/internal/worker/auth/auth.go61
1 files changed, 36 insertions, 25 deletions
diff --git a/backend/internal/worker/auth/auth.go b/backend/internal/worker/auth/auth.go
index 1f591fe..579a180 100644
--- a/backend/internal/worker/auth/auth.go
+++ b/backend/internal/worker/auth/auth.go
@@ -2,12 +2,15 @@ package auth
import (
"context"
+ "database/sql"
+ "errors"
"fmt"
"log/slog"
"time"
"github.com/ansg191/ibd-trader-backend/internal/database"
"github.com/ansg191/ibd-trader-backend/internal/ibd"
+ "github.com/ansg191/ibd-trader-backend/internal/keys"
"github.com/ansg191/ibd-trader-backend/internal/leader/manager/ibd/auth"
"github.com/ansg191/ibd-trader-backend/internal/redis/taskqueue"
@@ -23,8 +26,8 @@ func RunAuthScraper(
ctx context.Context,
client *ibd.Client,
redis *redis.Client,
- users database.UserStore,
- cookies database.CookieStore,
+ db database.Executor,
+ kms keys.KeyManagementService,
name string,
) error {
queue, err := taskqueue.New(
@@ -43,7 +46,7 @@ func RunAuthScraper(
case <-ctx.Done():
return ctx.Err()
default:
- waitForTask(ctx, queue, client, users, cookies)
+ waitForTask(ctx, queue, client, db, kms)
}
}
}
@@ -52,8 +55,8 @@ func waitForTask(
ctx context.Context,
queue taskqueue.TaskQueue[auth.TaskInfo],
client *ibd.Client,
- users database.UserStore,
- cookies database.CookieStore,
+ db database.Executor,
+ kms keys.KeyManagementService,
) {
task, err := queue.Dequeue(ctx, lockTimeout, dequeueTimeout)
if err != nil {
@@ -69,7 +72,7 @@ func waitForTask(
ch := make(chan error)
defer close(ch)
go func() {
- ch <- scrapeCookies(ctx, client, users, cookies, task.Data.UserSubject)
+ ch <- scrapeCookies(ctx, client, db, kms, task.Data.UserSubject)
}()
ticker := time.NewTicker(lockTimeout / 5)
@@ -116,15 +119,15 @@ func waitForTask(
func scrapeCookies(
ctx context.Context,
client *ibd.Client,
- users database.UserStore,
- store database.CookieStore,
+ db database.Executor,
+ kms keys.KeyManagementService,
user string,
) error {
ctx, cancel := context.WithTimeout(ctx, lockTimeout)
defer cancel()
// Check if the user has valid cookies
- done, err := hasValidCookies(ctx, store, user)
+ done, err := hasValidCookies(ctx, db, user)
if err != nil {
return fmt.Errorf("failed to check cookies: %w", err)
}
@@ -133,7 +136,7 @@ func scrapeCookies(
}
// Health check degraded cookies
- done, err = healthCheckDegradedCookies(ctx, client, store, user)
+ done, err = healthCheckDegradedCookies(ctx, client, db, kms, user)
if err != nil {
return fmt.Errorf("failed to health check cookies: %w", err)
}
@@ -142,31 +145,39 @@ func scrapeCookies(
}
// No cookies are valid, so scrape new cookies
- return scrapeNewCookies(ctx, client, users, store, user)
+ return scrapeNewCookies(ctx, client, db, kms, user)
}
-func hasValidCookies(ctx context.Context, store database.CookieStore, user string) (bool, error) {
+func hasValidCookies(ctx context.Context, db database.Executor, user string) (bool, error) {
// Check if the user has non-degraded cookies
- cookies, err := store.GetCookies(ctx, user, false)
+ row := db.QueryRowContext(ctx, `
+SELECT 1
+FROM ibd_tokens
+WHERE user_subject = $1
+ AND expires_at > NOW()
+ AND degraded = FALSE;`, user)
+
+ var exists bool
+ err := row.Scan(&exists)
+ if errors.Is(err, sql.ErrNoRows) {
+ return false, nil
+ }
if err != nil {
return false, fmt.Errorf("failed to get non-degraded cookies: %w", err)
}
- // If the user has non-degraded cookies, return true
- if len(cookies) > 0 {
- return true, nil
- }
- return false, nil
+ return true, nil
}
func healthCheckDegradedCookies(
ctx context.Context,
client *ibd.Client,
- store database.CookieStore,
+ db database.Executor,
+ kms keys.KeyManagementService,
user string,
) (bool, error) {
// Check if the user has degraded cookies
- cookies, err := store.GetCookies(ctx, user, true)
+ cookies, err := database.GetCookies(ctx, db, kms, user, true)
if err != nil {
return false, fmt.Errorf("failed to get degraded cookies: %w", err)
}
@@ -190,7 +201,7 @@ func healthCheckDegradedCookies(
valid = true
// Update the cookie
- err = store.RepairCookie(ctx, cookie.ID)
+ err = database.RepairCookie(ctx, db, cookie.ID)
if err != nil {
slog.ErrorContext(ctx, "Failed to repair cookie", "error", err)
}
@@ -202,12 +213,12 @@ func healthCheckDegradedCookies(
func scrapeNewCookies(
ctx context.Context,
client *ibd.Client,
- users database.UserStore,
- store database.CookieStore,
+ db database.Executor,
+ kms keys.KeyManagementService,
user string,
) error {
// Get the user's credentials
- username, password, err := users.GetIBDCreds(ctx, user)
+ username, password, err := database.GetIBDCreds(ctx, db, kms, user)
if err != nil {
return fmt.Errorf("failed to get IBD credentials: %w", err)
}
@@ -219,7 +230,7 @@ func scrapeNewCookies(
}
// Store the cookie
- err = store.AddCookie(ctx, user, cookie)
+ err = database.AddCookie(ctx, db, kms, user, cookie)
if err != nil {
return fmt.Errorf("failed to store cookie: %w", err)
}