aboutsummaryrefslogtreecommitdiff
path: root/backend/internal/database/cookies.go
diff options
context:
space:
mode:
authorGravatar Anshul Gupta <ansg191@anshulg.com> 2024-08-05 18:55:10 -0700
committerGravatar Anshul Gupta <ansg191@anshulg.com> 2024-08-05 18:55:19 -0700
commitb96fcd1a54a46a95f98467b49a051564bc21c23c (patch)
tree93caeeb05f8d6310e241095608ea2428c749b18c /backend/internal/database/cookies.go
downloadibd-trader-b96fcd1a54a46a95f98467b49a051564bc21c23c.tar.gz
ibd-trader-b96fcd1a54a46a95f98467b49a051564bc21c23c.tar.zst
ibd-trader-b96fcd1a54a46a95f98467b49a051564bc21c23c.zip
Initial Commit
Diffstat (limited to 'backend/internal/database/cookies.go')
-rw-r--r--backend/internal/database/cookies.go150
1 files changed, 150 insertions, 0 deletions
diff --git a/backend/internal/database/cookies.go b/backend/internal/database/cookies.go
new file mode 100644
index 0000000..cb38272
--- /dev/null
+++ b/backend/internal/database/cookies.go
@@ -0,0 +1,150 @@
+package database
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "fmt"
+ "net/http"
+ "time"
+
+ "ibd-trader/internal/keys"
+)
+
+type CookieStore interface {
+ CookieSource
+ AddCookie(ctx context.Context, subject string, cookie *http.Cookie) error
+ RepairCookie(ctx context.Context, id uint) error
+}
+
+type CookieSource interface {
+ GetAnyCookie(ctx context.Context) (*IBDCookie, error)
+ GetCookies(ctx context.Context, subject string, degraded bool) ([]IBDCookie, error)
+ ReportCookieFailure(ctx context.Context, id uint) error
+}
+
+func (d *database) GetAnyCookie(ctx context.Context) (*IBDCookie, error) {
+ row, err := d.queryRow(ctx, d.db, "cookies/get_any_cookie")
+ if err != nil {
+ return nil, fmt.Errorf("unable to get any ibd cookie: %w", err)
+ }
+
+ var id uint
+ var encryptedToken, encryptedKey []byte
+ var keyName string
+ var expiry time.Time
+ err = row.Scan(&id, &encryptedToken, &encryptedKey, &keyName, &expiry)
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, nil
+ }
+ return nil, fmt.Errorf("unable to scan sql row into ibd cookie: %w", err)
+ }
+
+ token, err := keys.Decrypt(ctx, d.kms, keyName, encryptedToken, encryptedKey)
+ if err != nil {
+ return nil, fmt.Errorf("unable to decrypt token: %w", err)
+ }
+ return &IBDCookie{
+ Token: string(token),
+ Expiry: expiry,
+ }, nil
+}
+
+func (d *database) GetCookies(ctx context.Context, subject string, degraded bool) ([]IBDCookie, error) {
+ row, err := d.query(ctx, d.db, "cookies/get_cookies", subject, degraded)
+ if err != nil {
+ return nil, fmt.Errorf("unable to get ibd cookies: %w", err)
+ }
+
+ cookies := make([]IBDCookie, 0)
+ for row.Next() {
+ var id uint
+ var encryptedToken, encryptedKey []byte
+ var keyName string
+ var expiry time.Time
+ err = row.Scan(&id, &encryptedToken, &encryptedKey, &keyName, &expiry)
+ if err != nil {
+ return nil, fmt.Errorf("unable to scan sql row into ibd cookie: %w", err)
+ }
+
+ token, err := keys.Decrypt(ctx, d.kms, keyName, encryptedToken, encryptedKey)
+ if err != nil {
+ return nil, fmt.Errorf("unable to decrypt token: %w", err)
+ }
+ cookie := IBDCookie{
+ ID: id,
+ Token: string(token),
+ Expiry: expiry,
+ }
+ cookies = append(cookies, cookie)
+ }
+
+ return cookies, nil
+}
+
+func (d *database) AddCookie(ctx context.Context, subject string, cookie *http.Cookie) error {
+ // Get the key ID for the user
+ user, err := d.GetUser(ctx, subject)
+ if err != nil {
+ return fmt.Errorf("unable to get user: %w", err)
+ }
+ if user.EncryptionKeyID == nil {
+ return errors.New("user does not have an encryption key")
+ }
+
+ // Get the key
+ key, err := d.GetKey(ctx, *user.EncryptionKeyID)
+ if err != nil {
+ return fmt.Errorf("unable to get key: %w", err)
+ }
+
+ // Encrypt the token
+ encryptedToken, err := keys.EncryptWithKey(ctx, d.kms, key.Name, key.Key, []byte(cookie.Value))
+ if err != nil {
+ return fmt.Errorf("unable to encrypt token: %w", err)
+ }
+
+ // Add the cookie to the database
+ _, err = d.exec(ctx, d.db, "cookies/add_cookie", encryptedToken, cookie.Expires, subject, key.Id)
+ if err != nil {
+ return fmt.Errorf("unable to add cookie: %w", err)
+ }
+
+ return nil
+}
+
+func (d *database) ReportCookieFailure(ctx context.Context, id uint) error {
+ _, err := d.exec(ctx, d.db, "cookies/set_cookie_degraded", true, id)
+ if err != nil {
+ return fmt.Errorf("unable to report cookie failure: %w", err)
+ }
+ return nil
+}
+
+func (d *database) RepairCookie(ctx context.Context, id uint) error {
+ _, err := d.exec(ctx, d.db, "cookies/set_cookie_degraded", false, id)
+ if err != nil {
+ return fmt.Errorf("unable to report cookie failure: %w", err)
+ }
+ return nil
+}
+
+type IBDCookie struct {
+ ID uint
+ Token string
+ Expiry time.Time
+}
+
+func (c *IBDCookie) ToHTTPCookie() *http.Cookie {
+ return &http.Cookie{
+ Name: ".ASPXAUTH",
+ Value: c.Token,
+ Path: "/",
+ Domain: "investors.com",
+ Expires: c.Expiry,
+ Secure: true,
+ HttpOnly: false,
+ SameSite: http.SameSiteLaxMode,
+ }
+}