package database import ( "context" "database/sql" "errors" "fmt" "net/http" "time" "ibd-trader/internal/keys" ) type CookieStore interface { CookieSource AddCookie(ctx context.Context, subject string, cookie *http.Cookie) error RepairCookie(ctx context.Context, id uint) error } type CookieSource interface { GetAnyCookie(ctx context.Context) (*IBDCookie, error) GetCookies(ctx context.Context, subject string, degraded bool) ([]IBDCookie, error) ReportCookieFailure(ctx context.Context, id uint) error } func (d *database) GetAnyCookie(ctx context.Context) (*IBDCookie, error) { row, err := d.queryRow(ctx, d.db, "cookies/get_any_cookie") if err != nil { return nil, fmt.Errorf("unable to get any ibd cookie: %w", err) } var id uint var encryptedToken, encryptedKey []byte var keyName string var expiry time.Time err = row.Scan(&id, &encryptedToken, &encryptedKey, &keyName, &expiry) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, nil } return nil, fmt.Errorf("unable to scan sql row into ibd cookie: %w", err) } token, err := keys.Decrypt(ctx, d.kms, keyName, encryptedToken, encryptedKey) if err != nil { return nil, fmt.Errorf("unable to decrypt token: %w", err) } return &IBDCookie{ Token: string(token), Expiry: expiry, }, nil } func (d *database) GetCookies(ctx context.Context, subject string, degraded bool) ([]IBDCookie, error) { row, err := d.query(ctx, d.db, "cookies/get_cookies", subject, degraded) if err != nil { return nil, fmt.Errorf("unable to get ibd cookies: %w", err) } cookies := make([]IBDCookie, 0) for row.Next() { var id uint var encryptedToken, encryptedKey []byte var keyName string var expiry time.Time err = row.Scan(&id, &encryptedToken, &encryptedKey, &keyName, &expiry) if err != nil { return nil, fmt.Errorf("unable to scan sql row into ibd cookie: %w", err) } token, err := keys.Decrypt(ctx, d.kms, keyName, encryptedToken, encryptedKey) if err != nil { return nil, fmt.Errorf("unable to decrypt token: %w", err) } cookie := IBDCookie{ ID: id, Token: string(token), Expiry: expiry, } cookies = append(cookies, cookie) } return cookies, nil } func (d *database) AddCookie(ctx context.Context, subject string, cookie *http.Cookie) error { // Get the key ID for the user user, err := d.GetUser(ctx, subject) if err != nil { return fmt.Errorf("unable to get user: %w", err) } if user.EncryptionKeyID == nil { return errors.New("user does not have an encryption key") } // Get the key key, err := d.GetKey(ctx, *user.EncryptionKeyID) if err != nil { return fmt.Errorf("unable to get key: %w", err) } // Encrypt the token encryptedToken, err := keys.EncryptWithKey(ctx, d.kms, key.Name, key.Key, []byte(cookie.Value)) if err != nil { return fmt.Errorf("unable to encrypt token: %w", err) } // Add the cookie to the database _, err = d.exec(ctx, d.db, "cookies/add_cookie", encryptedToken, cookie.Expires, subject, key.Id) if err != nil { return fmt.Errorf("unable to add cookie: %w", err) } return nil } func (d *database) ReportCookieFailure(ctx context.Context, id uint) error { _, err := d.exec(ctx, d.db, "cookies/set_cookie_degraded", true, id) if err != nil { return fmt.Errorf("unable to report cookie failure: %w", err) } return nil } func (d *database) RepairCookie(ctx context.Context, id uint) error { _, err := d.exec(ctx, d.db, "cookies/set_cookie_degraded", false, id) if err != nil { return fmt.Errorf("unable to report cookie failure: %w", err) } return nil } type IBDCookie struct { ID uint Token string Expiry time.Time } func (c *IBDCookie) ToHTTPCookie() *http.Cookie { return &http.Cookie{ Name: ".ASPXAUTH", Value: c.Token, Path: "/", Domain: "investors.com", Expires: c.Expiry, Secure: true, HttpOnly: false, SameSite: http.SameSiteLaxMode, } }