package database import ( "context" "database/sql" "errors" "fmt" "net/http" "time" "github.com/ansg191/ibd-trader-backend/internal/keys" ) func GetAnyCookie(ctx context.Context, exec Executor, kms keys.KeyManagementService) (*IBDCookie, error) { row := exec.QueryRowContext(ctx, ` SELECT ibd_tokens.id, token, encrypted_key, kms_key_name, expires_at FROM ibd_tokens INNER JOIN keys ON encryption_key = keys.id WHERE expires_at > NOW() AND degraded = FALSE ORDER BY random() LIMIT 1;`) var id uint var encryptedToken, encryptedKey []byte var keyName string var expiry time.Time err := row.Scan(&id, &encryptedToken, &encryptedKey, &keyName, &expiry) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, nil } return nil, fmt.Errorf("unable to scan sql row into ibd cookie: %w", err) } // Set the expiry to UTC explicitly. // For some reason, the expiry time is set to location="". expiry = expiry.UTC() token, err := keys.Decrypt(ctx, kms, keyName, encryptedToken, encryptedKey) if err != nil { return nil, fmt.Errorf("unable to decrypt token: %w", err) } return &IBDCookie{ Token: string(token), Expiry: expiry, }, nil } func GetCookies( ctx context.Context, exec Executor, kms keys.KeyManagementService, subject string, degraded bool, ) ([]IBDCookie, error) { rows, err := exec.QueryContext(ctx, ` SELECT ibd_tokens.id, token, encrypted_key, kms_key_name, expires_at FROM ibd_tokens INNER JOIN keys ON encryption_key = keys.id WHERE user_subject = $1 AND expires_at > NOW() AND degraded = $2 ORDER BY expires_at DESC;`, subject, degraded) if err != nil { return nil, fmt.Errorf("unable to get ibd cookies: %w", err) } cookies := make([]IBDCookie, 0) for rows.Next() { var id uint var encryptedToken, encryptedKey []byte var keyName string var expiry time.Time err = rows.Scan(&id, &encryptedToken, &encryptedKey, &keyName, &expiry) if err != nil { return nil, fmt.Errorf("unable to scan sql row into ibd cookie: %w", err) } // Set the expiry to UTC explicitly. // For some reason, the expiry time is set to location="". expiry = expiry.UTC() token, err := keys.Decrypt(ctx, kms, keyName, encryptedToken, encryptedKey) if err != nil { return nil, fmt.Errorf("unable to decrypt token: %w", err) } cookie := IBDCookie{ ID: id, Token: string(token), Expiry: expiry, } cookies = append(cookies, cookie) } return cookies, nil } func AddCookie( ctx context.Context, exec TransactionExecutor, kms keys.KeyManagementService, subject string, cookie *http.Cookie, ) error { tx, err := exec.BeginTx(ctx, nil) if err != nil { return err } // Get the key ID for the user user, err := GetUser(ctx, tx, subject) if err != nil { return fmt.Errorf("unable to get user: %w", err) } if user.EncryptionKeyID == nil { return errors.New("user does not have an encryption key") } // Get the key var keyName string var key []byte err = tx.QueryRowContext(ctx, ` SELECT kms_key_name, encrypted_key FROM keys WHERE id = $1;`, *user.EncryptionKeyID, ).Scan(&keyName, &key) if err != nil { return fmt.Errorf("unable to get key: %w", err) } // Encrypt the token encryptedToken, err := keys.EncryptWithKey(ctx, kms, keyName, key, []byte(cookie.Value)) if err != nil { return fmt.Errorf("unable to encrypt token: %w", err) } // Add the cookie to the database _, err = exec.ExecContext(ctx, ` INSERT INTO ibd_tokens (token, expires_at, user_subject, encryption_key) VALUES ($1, $2, $3, $4)`, encryptedToken, cookie.Expires, subject, *user.EncryptionKeyID) if err != nil { return fmt.Errorf("unable to add cookie: %w", err) } return nil } func ReportCookieFailure(ctx context.Context, exec Executor, id uint) error { _, err := exec.ExecContext(ctx, ` UPDATE ibd_tokens SET degraded = TRUE WHERE id = $1;`, id) if err != nil { return fmt.Errorf("unable to report cookie failure: %w", err) } return nil } func RepairCookie(ctx context.Context, exec Executor, id uint) error { _, err := exec.ExecContext(ctx, ` UPDATE ibd_tokens SET degraded = FALSE WHERE id = $1;`, id) if err != nil { return fmt.Errorf("unable to report cookie failure: %w", err) } return nil } type IBDCookie struct { ID uint Token string Expiry time.Time } func (c *IBDCookie) ToHTTPCookie() *http.Cookie { return &http.Cookie{ Name: ".ASPXAUTH", Value: c.Token, Path: "/", Domain: "investors.com", Expires: c.Expiry, Secure: true, HttpOnly: false, SameSite: http.SameSiteLaxMode, } }