aboutsummaryrefslogtreecommitdiff
path: root/backend/internal/database/cookies.go
diff options
context:
space:
mode:
Diffstat (limited to 'backend/internal/database/cookies.go')
-rw-r--r--backend/internal/database/cookies.go93
1 files changed, 60 insertions, 33 deletions
diff --git a/backend/internal/database/cookies.go b/backend/internal/database/cookies.go
index 8bed854..d652b65 100644
--- a/backend/internal/database/cookies.go
+++ b/backend/internal/database/cookies.go
@@ -11,29 +11,21 @@ import (
"github.com/ansg191/ibd-trader-backend/internal/keys"
)
-type CookieStore interface {
- CookieSource
- AddCookie(ctx context.Context, subject string, cookie *http.Cookie) error
- RepairCookie(ctx context.Context, id uint) error
-}
-
-type CookieSource interface {
- GetAnyCookie(ctx context.Context) (*IBDCookie, error)
- GetCookies(ctx context.Context, subject string, degraded bool) ([]IBDCookie, error)
- ReportCookieFailure(ctx context.Context, id uint) error
-}
-
-func (d *database) GetAnyCookie(ctx context.Context) (*IBDCookie, error) {
- row, err := d.queryRow(ctx, d.db, "cookies/get_any_cookie")
- if err != nil {
- return nil, fmt.Errorf("unable to get any ibd cookie: %w", err)
- }
+func GetAnyCookie(ctx context.Context, exec Executor, kms keys.KeyManagementService) (*IBDCookie, error) {
+ row := exec.QueryRowContext(ctx, `
+SELECT ibd_tokens.id, token, encrypted_key, kms_key_name, expires_at
+FROM ibd_tokens
+ INNER JOIN keys ON encryption_key = keys.id
+WHERE expires_at > NOW()
+ AND degraded = FALSE
+ORDER BY random()
+LIMIT 1;`)
var id uint
var encryptedToken, encryptedKey []byte
var keyName string
var expiry time.Time
- err = row.Scan(&id, &encryptedToken, &encryptedKey, &keyName, &expiry)
+ err := row.Scan(&id, &encryptedToken, &encryptedKey, &keyName, &expiry)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
@@ -41,7 +33,11 @@ func (d *database) GetAnyCookie(ctx context.Context) (*IBDCookie, error) {
return nil, fmt.Errorf("unable to scan sql row into ibd cookie: %w", err)
}
- token, err := keys.Decrypt(ctx, d.kms, keyName, encryptedToken, encryptedKey)
+ // Set the expiry to UTC explicitly.
+ // For some reason, the expiry time is set to location="".
+ expiry = expiry.UTC()
+
+ token, err := keys.Decrypt(ctx, kms, keyName, encryptedToken, encryptedKey)
if err != nil {
return nil, fmt.Errorf("unable to decrypt token: %w", err)
}
@@ -51,24 +47,41 @@ func (d *database) GetAnyCookie(ctx context.Context) (*IBDCookie, error) {
}, nil
}
-func (d *database) GetCookies(ctx context.Context, subject string, degraded bool) ([]IBDCookie, error) {
- row, err := d.query(ctx, d.db, "cookies/get_cookies", subject, degraded)
+func GetCookies(
+ ctx context.Context,
+ exec Executor,
+ kms keys.KeyManagementService,
+ subject string,
+ degraded bool,
+) ([]IBDCookie, error) {
+ rows, err := exec.QueryContext(ctx, `
+SELECT ibd_tokens.id, token, encrypted_key, kms_key_name, expires_at
+FROM ibd_tokens
+ INNER JOIN keys ON encryption_key = keys.id
+WHERE user_subject = $1
+ AND expires_at > NOW()
+ AND degraded = $2
+ORDER BY expires_at DESC;`, subject, degraded)
if err != nil {
return nil, fmt.Errorf("unable to get ibd cookies: %w", err)
}
cookies := make([]IBDCookie, 0)
- for row.Next() {
+ for rows.Next() {
var id uint
var encryptedToken, encryptedKey []byte
var keyName string
var expiry time.Time
- err = row.Scan(&id, &encryptedToken, &encryptedKey, &keyName, &expiry)
+ err = rows.Scan(&id, &encryptedToken, &encryptedKey, &keyName, &expiry)
if err != nil {
return nil, fmt.Errorf("unable to scan sql row into ibd cookie: %w", err)
}
- token, err := keys.Decrypt(ctx, d.kms, keyName, encryptedToken, encryptedKey)
+ // Set the expiry to UTC explicitly.
+ // For some reason, the expiry time is set to location="".
+ expiry = expiry.UTC()
+
+ token, err := keys.Decrypt(ctx, kms, keyName, encryptedToken, encryptedKey)
if err != nil {
return nil, fmt.Errorf("unable to decrypt token: %w", err)
}
@@ -83,9 +96,15 @@ func (d *database) GetCookies(ctx context.Context, subject string, degraded bool
return cookies, nil
}
-func (d *database) AddCookie(ctx context.Context, subject string, cookie *http.Cookie) error {
+func AddCookie(
+ ctx context.Context,
+ exec Executor,
+ kms keys.KeyManagementService,
+ subject string,
+ cookie *http.Cookie,
+) error {
// Get the key ID for the user
- user, err := d.GetUser(ctx, subject)
+ user, err := GetUser(ctx, exec, subject)
if err != nil {
return fmt.Errorf("unable to get user: %w", err)
}
@@ -94,19 +113,21 @@ func (d *database) AddCookie(ctx context.Context, subject string, cookie *http.C
}
// Get the key
- key, err := d.GetKey(ctx, *user.EncryptionKeyID)
+ key, err := GetKey(ctx, exec, *user.EncryptionKeyID)
if err != nil {
return fmt.Errorf("unable to get key: %w", err)
}
// Encrypt the token
- encryptedToken, err := keys.EncryptWithKey(ctx, d.kms, key.Name, key.Key, []byte(cookie.Value))
+ encryptedToken, err := keys.EncryptWithKey(ctx, kms, key.Name, key.Key, []byte(cookie.Value))
if err != nil {
return fmt.Errorf("unable to encrypt token: %w", err)
}
// Add the cookie to the database
- _, err = d.exec(ctx, d.db, "cookies/add_cookie", encryptedToken, cookie.Expires, subject, key.Id)
+ _, err = exec.ExecContext(ctx, `
+INSERT INTO ibd_tokens (token, expires_at, user_subject, encryption_key)
+VALUES ($1, $2, $3, $4)`, encryptedToken, cookie.Expires, subject, key.Id)
if err != nil {
return fmt.Errorf("unable to add cookie: %w", err)
}
@@ -114,16 +135,22 @@ func (d *database) AddCookie(ctx context.Context, subject string, cookie *http.C
return nil
}
-func (d *database) ReportCookieFailure(ctx context.Context, id uint) error {
- _, err := d.exec(ctx, d.db, "cookies/set_cookie_degraded", true, id)
+func ReportCookieFailure(ctx context.Context, exec Executor, id uint) error {
+ _, err := exec.ExecContext(ctx, `
+UPDATE ibd_tokens
+SET degraded = TRUE
+WHERE id = $1;`, id)
if err != nil {
return fmt.Errorf("unable to report cookie failure: %w", err)
}
return nil
}
-func (d *database) RepairCookie(ctx context.Context, id uint) error {
- _, err := d.exec(ctx, d.db, "cookies/set_cookie_degraded", false, id)
+func RepairCookie(ctx context.Context, exec Executor, id uint) error {
+ _, err := exec.ExecContext(ctx, `
+UPDATE ibd_tokens
+SET degraded = FALSE
+WHERE id = $1;`, id)
if err != nil {
return fmt.Errorf("unable to report cookie failure: %w", err)
}