aboutsummaryrefslogtreecommitdiff
path: root/backend/internal/database/cookies.go
diff options
context:
space:
mode:
Diffstat (limited to 'backend/internal/database/cookies.go')
-rw-r--r--backend/internal/database/cookies.go189
1 files changed, 189 insertions, 0 deletions
diff --git a/backend/internal/database/cookies.go b/backend/internal/database/cookies.go
new file mode 100644
index 0000000..3ea21d0
--- /dev/null
+++ b/backend/internal/database/cookies.go
@@ -0,0 +1,189 @@
+package database
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "fmt"
+ "net/http"
+ "time"
+
+ "github.com/ansg191/ibd-trader-backend/internal/keys"
+)
+
+func GetAnyCookie(ctx context.Context, exec Executor, kms keys.KeyManagementService) (*IBDCookie, error) {
+ row := exec.QueryRowContext(ctx, `
+SELECT ibd_tokens.id, token, encrypted_key, kms_key_name, expires_at
+FROM ibd_tokens
+ INNER JOIN keys ON encryption_key = keys.id
+WHERE expires_at > NOW()
+ AND degraded = FALSE
+ORDER BY random()
+LIMIT 1;`)
+
+ var id uint
+ var encryptedToken, encryptedKey []byte
+ var keyName string
+ var expiry time.Time
+ err := row.Scan(&id, &encryptedToken, &encryptedKey, &keyName, &expiry)
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, nil
+ }
+ return nil, fmt.Errorf("unable to scan sql row into ibd cookie: %w", err)
+ }
+
+ // Set the expiry to UTC explicitly.
+ // For some reason, the expiry time is set to location="".
+ expiry = expiry.UTC()
+
+ token, err := keys.Decrypt(ctx, kms, keyName, encryptedToken, encryptedKey)
+ if err != nil {
+ return nil, fmt.Errorf("unable to decrypt token: %w", err)
+ }
+ return &IBDCookie{
+ Token: string(token),
+ Expiry: expiry,
+ }, nil
+}
+
+func GetCookies(
+ ctx context.Context,
+ exec Executor,
+ kms keys.KeyManagementService,
+ subject string,
+ degraded bool,
+) ([]IBDCookie, error) {
+ rows, err := exec.QueryContext(ctx, `
+SELECT ibd_tokens.id, token, encrypted_key, kms_key_name, expires_at
+FROM ibd_tokens
+ INNER JOIN keys ON encryption_key = keys.id
+WHERE user_subject = $1
+ AND expires_at > NOW()
+ AND degraded = $2
+ORDER BY expires_at DESC;`, subject, degraded)
+ if err != nil {
+ return nil, fmt.Errorf("unable to get ibd cookies: %w", err)
+ }
+
+ cookies := make([]IBDCookie, 0)
+ for rows.Next() {
+ var id uint
+ var encryptedToken, encryptedKey []byte
+ var keyName string
+ var expiry time.Time
+ err = rows.Scan(&id, &encryptedToken, &encryptedKey, &keyName, &expiry)
+ if err != nil {
+ return nil, fmt.Errorf("unable to scan sql row into ibd cookie: %w", err)
+ }
+
+ // Set the expiry to UTC explicitly.
+ // For some reason, the expiry time is set to location="".
+ expiry = expiry.UTC()
+
+ token, err := keys.Decrypt(ctx, kms, keyName, encryptedToken, encryptedKey)
+ if err != nil {
+ return nil, fmt.Errorf("unable to decrypt token: %w", err)
+ }
+ cookie := IBDCookie{
+ ID: id,
+ Token: string(token),
+ Expiry: expiry,
+ }
+ cookies = append(cookies, cookie)
+ }
+
+ return cookies, nil
+}
+
+func AddCookie(
+ ctx context.Context,
+ exec TransactionExecutor,
+ kms keys.KeyManagementService,
+ subject string,
+ cookie *http.Cookie,
+) error {
+ tx, err := exec.BeginTx(ctx, nil)
+ if err != nil {
+ return err
+ }
+
+ // Get the key ID for the user
+ user, err := GetUser(ctx, tx, subject)
+ if err != nil {
+ return fmt.Errorf("unable to get user: %w", err)
+ }
+ if user.EncryptionKeyID == nil {
+ return errors.New("user does not have an encryption key")
+ }
+
+ // Get the key
+ var keyName string
+ var key []byte
+ err = tx.QueryRowContext(ctx, `
+SELECT kms_key_name, encrypted_key
+FROM keys
+WHERE id = $1;`,
+ *user.EncryptionKeyID,
+ ).Scan(&keyName, &key)
+ if err != nil {
+ return fmt.Errorf("unable to get key: %w", err)
+ }
+
+ // Encrypt the token
+ encryptedToken, err := keys.EncryptWithKey(ctx, kms, keyName, key, []byte(cookie.Value))
+ if err != nil {
+ return fmt.Errorf("unable to encrypt token: %w", err)
+ }
+
+ // Add the cookie to the database
+ _, err = exec.ExecContext(ctx, `
+INSERT INTO ibd_tokens (token, expires_at, user_subject, encryption_key)
+VALUES ($1, $2, $3, $4)`, encryptedToken, cookie.Expires, subject, *user.EncryptionKeyID)
+ if err != nil {
+ return fmt.Errorf("unable to add cookie: %w", err)
+ }
+
+ return nil
+}
+
+func ReportCookieFailure(ctx context.Context, exec Executor, id uint) error {
+ _, err := exec.ExecContext(ctx, `
+UPDATE ibd_tokens
+SET degraded = TRUE
+WHERE id = $1;`, id)
+ if err != nil {
+ return fmt.Errorf("unable to report cookie failure: %w", err)
+ }
+ return nil
+}
+
+func RepairCookie(ctx context.Context, exec Executor, id uint) error {
+ _, err := exec.ExecContext(ctx, `
+UPDATE ibd_tokens
+SET degraded = FALSE
+WHERE id = $1;`, id)
+ if err != nil {
+ return fmt.Errorf("unable to report cookie failure: %w", err)
+ }
+ return nil
+}
+
+type IBDCookie struct {
+ ID uint
+ Token string
+ Expiry time.Time
+}
+
+func (c *IBDCookie) ToHTTPCookie() *http.Cookie {
+ return &http.Cookie{
+ Name: ".ASPXAUTH",
+ Value: c.Token,
+ Path: "/",
+ Domain: "investors.com",
+ Expires: c.Expiry,
+ Secure: true,
+ HttpOnly: false,
+ SameSite: http.SameSiteLaxMode,
+ }
+}