diff options
author | 2024-08-07 17:48:57 -0700 | |
---|---|---|
committer | 2024-08-07 18:48:10 -0700 | |
commit | e9ee45b9d2bd494332dcf8b2073714f92fd0738d (patch) | |
tree | d34af1af84984409d27003981538f13cde4ba218 /backend/internal/database/cookies.go | |
parent | 3de4ebb7560851ccbefe296c197456fe80c22901 (diff) | |
download | ibd-trader-e9ee45b9d2bd494332dcf8b2073714f92fd0738d.tar.gz ibd-trader-e9ee45b9d2bd494332dcf8b2073714f92fd0738d.tar.zst ibd-trader-e9ee45b9d2bd494332dcf8b2073714f92fd0738d.zip |
Refactor DB to remove restrictive query system
Diffstat (limited to 'backend/internal/database/cookies.go')
-rw-r--r-- | backend/internal/database/cookies.go | 93 |
1 files changed, 60 insertions, 33 deletions
diff --git a/backend/internal/database/cookies.go b/backend/internal/database/cookies.go index 8bed854..d652b65 100644 --- a/backend/internal/database/cookies.go +++ b/backend/internal/database/cookies.go @@ -11,29 +11,21 @@ import ( "github.com/ansg191/ibd-trader-backend/internal/keys" ) -type CookieStore interface { - CookieSource - AddCookie(ctx context.Context, subject string, cookie *http.Cookie) error - RepairCookie(ctx context.Context, id uint) error -} - -type CookieSource interface { - GetAnyCookie(ctx context.Context) (*IBDCookie, error) - GetCookies(ctx context.Context, subject string, degraded bool) ([]IBDCookie, error) - ReportCookieFailure(ctx context.Context, id uint) error -} - -func (d *database) GetAnyCookie(ctx context.Context) (*IBDCookie, error) { - row, err := d.queryRow(ctx, d.db, "cookies/get_any_cookie") - if err != nil { - return nil, fmt.Errorf("unable to get any ibd cookie: %w", err) - } +func GetAnyCookie(ctx context.Context, exec Executor, kms keys.KeyManagementService) (*IBDCookie, error) { + row := exec.QueryRowContext(ctx, ` +SELECT ibd_tokens.id, token, encrypted_key, kms_key_name, expires_at +FROM ibd_tokens + INNER JOIN keys ON encryption_key = keys.id +WHERE expires_at > NOW() + AND degraded = FALSE +ORDER BY random() +LIMIT 1;`) var id uint var encryptedToken, encryptedKey []byte var keyName string var expiry time.Time - err = row.Scan(&id, &encryptedToken, &encryptedKey, &keyName, &expiry) + err := row.Scan(&id, &encryptedToken, &encryptedKey, &keyName, &expiry) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, nil @@ -41,7 +33,11 @@ func (d *database) GetAnyCookie(ctx context.Context) (*IBDCookie, error) { return nil, fmt.Errorf("unable to scan sql row into ibd cookie: %w", err) } - token, err := keys.Decrypt(ctx, d.kms, keyName, encryptedToken, encryptedKey) + // Set the expiry to UTC explicitly. + // For some reason, the expiry time is set to location="". + expiry = expiry.UTC() + + token, err := keys.Decrypt(ctx, kms, keyName, encryptedToken, encryptedKey) if err != nil { return nil, fmt.Errorf("unable to decrypt token: %w", err) } @@ -51,24 +47,41 @@ func (d *database) GetAnyCookie(ctx context.Context) (*IBDCookie, error) { }, nil } -func (d *database) GetCookies(ctx context.Context, subject string, degraded bool) ([]IBDCookie, error) { - row, err := d.query(ctx, d.db, "cookies/get_cookies", subject, degraded) +func GetCookies( + ctx context.Context, + exec Executor, + kms keys.KeyManagementService, + subject string, + degraded bool, +) ([]IBDCookie, error) { + rows, err := exec.QueryContext(ctx, ` +SELECT ibd_tokens.id, token, encrypted_key, kms_key_name, expires_at +FROM ibd_tokens + INNER JOIN keys ON encryption_key = keys.id +WHERE user_subject = $1 + AND expires_at > NOW() + AND degraded = $2 +ORDER BY expires_at DESC;`, subject, degraded) if err != nil { return nil, fmt.Errorf("unable to get ibd cookies: %w", err) } cookies := make([]IBDCookie, 0) - for row.Next() { + for rows.Next() { var id uint var encryptedToken, encryptedKey []byte var keyName string var expiry time.Time - err = row.Scan(&id, &encryptedToken, &encryptedKey, &keyName, &expiry) + err = rows.Scan(&id, &encryptedToken, &encryptedKey, &keyName, &expiry) if err != nil { return nil, fmt.Errorf("unable to scan sql row into ibd cookie: %w", err) } - token, err := keys.Decrypt(ctx, d.kms, keyName, encryptedToken, encryptedKey) + // Set the expiry to UTC explicitly. + // For some reason, the expiry time is set to location="". + expiry = expiry.UTC() + + token, err := keys.Decrypt(ctx, kms, keyName, encryptedToken, encryptedKey) if err != nil { return nil, fmt.Errorf("unable to decrypt token: %w", err) } @@ -83,9 +96,15 @@ func (d *database) GetCookies(ctx context.Context, subject string, degraded bool return cookies, nil } -func (d *database) AddCookie(ctx context.Context, subject string, cookie *http.Cookie) error { +func AddCookie( + ctx context.Context, + exec Executor, + kms keys.KeyManagementService, + subject string, + cookie *http.Cookie, +) error { // Get the key ID for the user - user, err := d.GetUser(ctx, subject) + user, err := GetUser(ctx, exec, subject) if err != nil { return fmt.Errorf("unable to get user: %w", err) } @@ -94,19 +113,21 @@ func (d *database) AddCookie(ctx context.Context, subject string, cookie *http.C } // Get the key - key, err := d.GetKey(ctx, *user.EncryptionKeyID) + key, err := GetKey(ctx, exec, *user.EncryptionKeyID) if err != nil { return fmt.Errorf("unable to get key: %w", err) } // Encrypt the token - encryptedToken, err := keys.EncryptWithKey(ctx, d.kms, key.Name, key.Key, []byte(cookie.Value)) + encryptedToken, err := keys.EncryptWithKey(ctx, kms, key.Name, key.Key, []byte(cookie.Value)) if err != nil { return fmt.Errorf("unable to encrypt token: %w", err) } // Add the cookie to the database - _, err = d.exec(ctx, d.db, "cookies/add_cookie", encryptedToken, cookie.Expires, subject, key.Id) + _, err = exec.ExecContext(ctx, ` +INSERT INTO ibd_tokens (token, expires_at, user_subject, encryption_key) +VALUES ($1, $2, $3, $4)`, encryptedToken, cookie.Expires, subject, key.Id) if err != nil { return fmt.Errorf("unable to add cookie: %w", err) } @@ -114,16 +135,22 @@ func (d *database) AddCookie(ctx context.Context, subject string, cookie *http.C return nil } -func (d *database) ReportCookieFailure(ctx context.Context, id uint) error { - _, err := d.exec(ctx, d.db, "cookies/set_cookie_degraded", true, id) +func ReportCookieFailure(ctx context.Context, exec Executor, id uint) error { + _, err := exec.ExecContext(ctx, ` +UPDATE ibd_tokens +SET degraded = TRUE +WHERE id = $1;`, id) if err != nil { return fmt.Errorf("unable to report cookie failure: %w", err) } return nil } -func (d *database) RepairCookie(ctx context.Context, id uint) error { - _, err := d.exec(ctx, d.db, "cookies/set_cookie_degraded", false, id) +func RepairCookie(ctx context.Context, exec Executor, id uint) error { + _, err := exec.ExecContext(ctx, ` +UPDATE ibd_tokens +SET degraded = FALSE +WHERE id = $1;`, id) if err != nil { return fmt.Errorf("unable to report cookie failure: %w", err) } |