diff options
author | 2024-08-05 18:55:10 -0700 | |
---|---|---|
committer | 2024-08-05 18:55:19 -0700 | |
commit | b96fcd1a54a46a95f98467b49a051564bc21c23c (patch) | |
tree | 93caeeb05f8d6310e241095608ea2428c749b18c /backend/internal/database/cookies.go | |
download | ibd-trader-b96fcd1a54a46a95f98467b49a051564bc21c23c.tar.gz ibd-trader-b96fcd1a54a46a95f98467b49a051564bc21c23c.tar.zst ibd-trader-b96fcd1a54a46a95f98467b49a051564bc21c23c.zip |
Initial Commit
Diffstat (limited to 'backend/internal/database/cookies.go')
-rw-r--r-- | backend/internal/database/cookies.go | 150 |
1 files changed, 150 insertions, 0 deletions
diff --git a/backend/internal/database/cookies.go b/backend/internal/database/cookies.go new file mode 100644 index 0000000..cb38272 --- /dev/null +++ b/backend/internal/database/cookies.go @@ -0,0 +1,150 @@ +package database + +import ( + "context" + "database/sql" + "errors" + "fmt" + "net/http" + "time" + + "ibd-trader/internal/keys" +) + +type CookieStore interface { + CookieSource + AddCookie(ctx context.Context, subject string, cookie *http.Cookie) error + RepairCookie(ctx context.Context, id uint) error +} + +type CookieSource interface { + GetAnyCookie(ctx context.Context) (*IBDCookie, error) + GetCookies(ctx context.Context, subject string, degraded bool) ([]IBDCookie, error) + ReportCookieFailure(ctx context.Context, id uint) error +} + +func (d *database) GetAnyCookie(ctx context.Context) (*IBDCookie, error) { + row, err := d.queryRow(ctx, d.db, "cookies/get_any_cookie") + if err != nil { + return nil, fmt.Errorf("unable to get any ibd cookie: %w", err) + } + + var id uint + var encryptedToken, encryptedKey []byte + var keyName string + var expiry time.Time + err = row.Scan(&id, &encryptedToken, &encryptedKey, &keyName, &expiry) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, fmt.Errorf("unable to scan sql row into ibd cookie: %w", err) + } + + token, err := keys.Decrypt(ctx, d.kms, keyName, encryptedToken, encryptedKey) + if err != nil { + return nil, fmt.Errorf("unable to decrypt token: %w", err) + } + return &IBDCookie{ + Token: string(token), + Expiry: expiry, + }, nil +} + +func (d *database) GetCookies(ctx context.Context, subject string, degraded bool) ([]IBDCookie, error) { + row, err := d.query(ctx, d.db, "cookies/get_cookies", subject, degraded) + if err != nil { + return nil, fmt.Errorf("unable to get ibd cookies: %w", err) + } + + cookies := make([]IBDCookie, 0) + for row.Next() { + var id uint + var encryptedToken, encryptedKey []byte + var keyName string + var expiry time.Time + err = row.Scan(&id, &encryptedToken, &encryptedKey, &keyName, &expiry) + if err != nil { + return nil, fmt.Errorf("unable to scan sql row into ibd cookie: %w", err) + } + + token, err := keys.Decrypt(ctx, d.kms, keyName, encryptedToken, encryptedKey) + if err != nil { + return nil, fmt.Errorf("unable to decrypt token: %w", err) + } + cookie := IBDCookie{ + ID: id, + Token: string(token), + Expiry: expiry, + } + cookies = append(cookies, cookie) + } + + return cookies, nil +} + +func (d *database) AddCookie(ctx context.Context, subject string, cookie *http.Cookie) error { + // Get the key ID for the user + user, err := d.GetUser(ctx, subject) + if err != nil { + return fmt.Errorf("unable to get user: %w", err) + } + if user.EncryptionKeyID == nil { + return errors.New("user does not have an encryption key") + } + + // Get the key + key, err := d.GetKey(ctx, *user.EncryptionKeyID) + if err != nil { + return fmt.Errorf("unable to get key: %w", err) + } + + // Encrypt the token + encryptedToken, err := keys.EncryptWithKey(ctx, d.kms, key.Name, key.Key, []byte(cookie.Value)) + if err != nil { + return fmt.Errorf("unable to encrypt token: %w", err) + } + + // Add the cookie to the database + _, err = d.exec(ctx, d.db, "cookies/add_cookie", encryptedToken, cookie.Expires, subject, key.Id) + if err != nil { + return fmt.Errorf("unable to add cookie: %w", err) + } + + return nil +} + +func (d *database) ReportCookieFailure(ctx context.Context, id uint) error { + _, err := d.exec(ctx, d.db, "cookies/set_cookie_degraded", true, id) + if err != nil { + return fmt.Errorf("unable to report cookie failure: %w", err) + } + return nil +} + +func (d *database) RepairCookie(ctx context.Context, id uint) error { + _, err := d.exec(ctx, d.db, "cookies/set_cookie_degraded", false, id) + if err != nil { + return fmt.Errorf("unable to report cookie failure: %w", err) + } + return nil +} + +type IBDCookie struct { + ID uint + Token string + Expiry time.Time +} + +func (c *IBDCookie) ToHTTPCookie() *http.Cookie { + return &http.Cookie{ + Name: ".ASPXAUTH", + Value: c.Token, + Path: "/", + Domain: "investors.com", + Expires: c.Expiry, + Secure: true, + HttpOnly: false, + SameSite: http.SameSiteLaxMode, + } +} |