package database import ( "context" "database/sql" "errors" "fmt" "github.com/ansg191/ibd-trader-backend/internal/keys" ) type UserStore interface { AddUser(ctx context.Context, subject string) error GetUser(ctx context.Context, subject string) (*User, error) ListUsers(ctx context.Context, hasIBDCreds bool) ([]User, error) AddIBDCreds(ctx context.Context, subject string, username string, password string) error GetIBDCreds(ctx context.Context, subject string) (username string, password string, err error) } var ErrUserNotFound = fmt.Errorf("user not found") var ErrIBDCredsNotFound = fmt.Errorf("ibd creds not found") func (d *database) AddUser(ctx context.Context, subject string) (err error) { _, err = d.exec( ctx, d.db, "users/add_user", subject, ) return } func (d *database) GetUser(ctx context.Context, subject string) (*User, error) { row, err := d.queryRow(ctx, d.db, "users/get_user", subject) if err != nil { return nil, fmt.Errorf("unable to get user: %w", err) } user := &User{} err = row.Scan(&user.Subject, &user.IBDUsername, &user.EncryptedIBDPassword, &user.EncryptionKeyID) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, ErrUserNotFound } return nil, fmt.Errorf("unable to scan sql row into user: %w", err) } return user, nil } func (d *database) ListUsers(ctx context.Context, hasIBDCreds bool) ([]User, error) { rows, err := d.query(ctx, d.db, "users/list_users") if err != nil { return nil, fmt.Errorf("unable to list users: %w", err) } users := make([]User, 0) for rows.Next() { user := User{} err = rows.Scan(&user.Subject, &user.IBDUsername, &user.EncryptedIBDPassword, &user.EncryptionKeyID) if err != nil { return nil, fmt.Errorf("unable to scan sql row into user: %w", err) } if hasIBDCreds && user.IBDUsername == nil { continue } users = append(users, user) } return users, nil } func (d *database) AddIBDCreds(ctx context.Context, subject string, username string, password string) error { encryptedPass, encryptedKey, err := keys.Encrypt(ctx, d.kms, d.keyName, []byte(password)) if err != nil { return fmt.Errorf("unable to encrypt password: %w", err) } tx, err := d.db.BeginTx(ctx, nil) if err != nil { return err } defer func(tx *sql.Tx) { _ = tx.Rollback() }(tx) row, err := d.queryRow(ctx, tx, "keys/add_key", d.keyName, encryptedKey) if err != nil { return fmt.Errorf("unable to add ibd creds key: %w", err) } var keyId int err = row.Scan(&keyId) if err != nil { return fmt.Errorf("unable to scan key id: %w", err) } _, err = d.exec(ctx, tx, "users/add_ibd_creds", subject, username, encryptedPass, keyId) if err != nil { return fmt.Errorf("unable to add ibd creds to user: %w", err) } if err = tx.Commit(); err != nil { return fmt.Errorf("unable to commit transaction: %w", err) } return nil } func (d *database) GetIBDCreds(ctx context.Context, subject string) (username string, password string, err error) { row, err := d.queryRow(ctx, d.db, "users/get_ibd_creds", subject) if err != nil { return "", "", fmt.Errorf("unable to get ibd creds: %w", err) } var encryptedPass, encryptedKey []byte var keyName string err = row.Scan(&username, &encryptedPass, &encryptedKey, &keyName) if err != nil { if errors.Is(err, sql.ErrNoRows) { return "", "", ErrIBDCredsNotFound } return "", "", fmt.Errorf("unable to scan sql row into ibd creds: %w", err) } passwordBytes, err := keys.Decrypt(ctx, d.kms, keyName, encryptedPass, encryptedKey) if err != nil { return "", "", fmt.Errorf("unable to decrypt password: %w", err) } return username, string(passwordBytes), nil } type User struct { Subject string IBDUsername *string EncryptedIBDPassword *string EncryptionKeyID *int }