diff options
Diffstat (limited to 'backend/internal/database/cookies.go')
-rw-r--r-- | backend/internal/database/cookies.go | 189 |
1 files changed, 189 insertions, 0 deletions
diff --git a/backend/internal/database/cookies.go b/backend/internal/database/cookies.go new file mode 100644 index 0000000..3ea21d0 --- /dev/null +++ b/backend/internal/database/cookies.go @@ -0,0 +1,189 @@ +package database + +import ( + "context" + "database/sql" + "errors" + "fmt" + "net/http" + "time" + + "github.com/ansg191/ibd-trader-backend/internal/keys" +) + +func GetAnyCookie(ctx context.Context, exec Executor, kms keys.KeyManagementService) (*IBDCookie, error) { + row := exec.QueryRowContext(ctx, ` +SELECT ibd_tokens.id, token, encrypted_key, kms_key_name, expires_at +FROM ibd_tokens + INNER JOIN keys ON encryption_key = keys.id +WHERE expires_at > NOW() + AND degraded = FALSE +ORDER BY random() +LIMIT 1;`) + + var id uint + var encryptedToken, encryptedKey []byte + var keyName string + var expiry time.Time + err := row.Scan(&id, &encryptedToken, &encryptedKey, &keyName, &expiry) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, fmt.Errorf("unable to scan sql row into ibd cookie: %w", err) + } + + // Set the expiry to UTC explicitly. + // For some reason, the expiry time is set to location="". + expiry = expiry.UTC() + + token, err := keys.Decrypt(ctx, kms, keyName, encryptedToken, encryptedKey) + if err != nil { + return nil, fmt.Errorf("unable to decrypt token: %w", err) + } + return &IBDCookie{ + Token: string(token), + Expiry: expiry, + }, nil +} + +func GetCookies( + ctx context.Context, + exec Executor, + kms keys.KeyManagementService, + subject string, + degraded bool, +) ([]IBDCookie, error) { + rows, err := exec.QueryContext(ctx, ` +SELECT ibd_tokens.id, token, encrypted_key, kms_key_name, expires_at +FROM ibd_tokens + INNER JOIN keys ON encryption_key = keys.id +WHERE user_subject = $1 + AND expires_at > NOW() + AND degraded = $2 +ORDER BY expires_at DESC;`, subject, degraded) + if err != nil { + return nil, fmt.Errorf("unable to get ibd cookies: %w", err) + } + + cookies := make([]IBDCookie, 0) + for rows.Next() { + var id uint + var encryptedToken, encryptedKey []byte + var keyName string + var expiry time.Time + err = rows.Scan(&id, &encryptedToken, &encryptedKey, &keyName, &expiry) + if err != nil { + return nil, fmt.Errorf("unable to scan sql row into ibd cookie: %w", err) + } + + // Set the expiry to UTC explicitly. + // For some reason, the expiry time is set to location="". + expiry = expiry.UTC() + + token, err := keys.Decrypt(ctx, kms, keyName, encryptedToken, encryptedKey) + if err != nil { + return nil, fmt.Errorf("unable to decrypt token: %w", err) + } + cookie := IBDCookie{ + ID: id, + Token: string(token), + Expiry: expiry, + } + cookies = append(cookies, cookie) + } + + return cookies, nil +} + +func AddCookie( + ctx context.Context, + exec TransactionExecutor, + kms keys.KeyManagementService, + subject string, + cookie *http.Cookie, +) error { + tx, err := exec.BeginTx(ctx, nil) + if err != nil { + return err + } + + // Get the key ID for the user + user, err := GetUser(ctx, tx, subject) + if err != nil { + return fmt.Errorf("unable to get user: %w", err) + } + if user.EncryptionKeyID == nil { + return errors.New("user does not have an encryption key") + } + + // Get the key + var keyName string + var key []byte + err = tx.QueryRowContext(ctx, ` +SELECT kms_key_name, encrypted_key +FROM keys +WHERE id = $1;`, + *user.EncryptionKeyID, + ).Scan(&keyName, &key) + if err != nil { + return fmt.Errorf("unable to get key: %w", err) + } + + // Encrypt the token + encryptedToken, err := keys.EncryptWithKey(ctx, kms, keyName, key, []byte(cookie.Value)) + if err != nil { + return fmt.Errorf("unable to encrypt token: %w", err) + } + + // Add the cookie to the database + _, err = exec.ExecContext(ctx, ` +INSERT INTO ibd_tokens (token, expires_at, user_subject, encryption_key) +VALUES ($1, $2, $3, $4)`, encryptedToken, cookie.Expires, subject, *user.EncryptionKeyID) + if err != nil { + return fmt.Errorf("unable to add cookie: %w", err) + } + + return nil +} + +func ReportCookieFailure(ctx context.Context, exec Executor, id uint) error { + _, err := exec.ExecContext(ctx, ` +UPDATE ibd_tokens +SET degraded = TRUE +WHERE id = $1;`, id) + if err != nil { + return fmt.Errorf("unable to report cookie failure: %w", err) + } + return nil +} + +func RepairCookie(ctx context.Context, exec Executor, id uint) error { + _, err := exec.ExecContext(ctx, ` +UPDATE ibd_tokens +SET degraded = FALSE +WHERE id = $1;`, id) + if err != nil { + return fmt.Errorf("unable to report cookie failure: %w", err) + } + return nil +} + +type IBDCookie struct { + ID uint + Token string + Expiry time.Time +} + +func (c *IBDCookie) ToHTTPCookie() *http.Cookie { + return &http.Cookie{ + Name: ".ASPXAUTH", + Value: c.Token, + Path: "/", + Domain: "investors.com", + Expires: c.Expiry, + Secure: true, + HttpOnly: false, + SameSite: http.SameSiteLaxMode, + } +} |