package database import ( "context" "database/sql" "errors" "fmt" "github.com/ansg191/ibd-trader/backend/internal/keys" ) var ErrUserNotFound = fmt.Errorf("user not found") var ErrIBDCredsNotFound = fmt.Errorf("ibd creds not found") func AddUser(ctx context.Context, exec Executor, subject string) (err error) { _, err = exec.ExecContext(ctx, ` INSERT INTO users (subject) VALUES ($1) ON CONFLICT DO NOTHING;`, subject) return } func GetUser(ctx context.Context, exec Executor, subject string) (*User, error) { row := exec.QueryRowContext(ctx, ` SELECT subject, ibd_username, ibd_password, encryption_key FROM users WHERE subject = $1;`, subject) user := &User{} err := row.Scan(&user.Subject, &user.IBDUsername, &user.EncryptedIBDPassword, &user.EncryptionKeyID) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, ErrUserNotFound } return nil, fmt.Errorf("unable to scan sql row into user: %w", err) } return user, nil } func ListUsers(ctx context.Context, exec Executor, hasIBDCreds bool) ([]User, error) { rows, err := exec.QueryContext(ctx, ` SELECT subject, ibd_username, ibd_password, encryption_key FROM users; `) if err != nil { return nil, fmt.Errorf("unable to list users: %w", err) } users := make([]User, 0) for rows.Next() { user := User{} err = rows.Scan(&user.Subject, &user.IBDUsername, &user.EncryptedIBDPassword, &user.EncryptionKeyID) if err != nil { return nil, fmt.Errorf("unable to scan sql row into user: %w", err) } if hasIBDCreds && user.IBDUsername == nil { continue } users = append(users, user) } return users, nil } func AddIBDCreds( ctx context.Context, exec TransactionExecutor, kms keys.KeyManagementService, keyName, subject, username, password string, ) error { encryptedPass, encryptedKey, err := keys.Encrypt(ctx, kms, keyName, []byte(password)) if err != nil { return fmt.Errorf("unable to encrypt password: %w", err) } tx, err := exec.BeginTx(ctx, nil) if err != nil { return err } defer func(tx *sql.Tx) { _ = tx.Rollback() }(tx) var keyId int err = tx.QueryRowContext(ctx, ` INSERT INTO keys (kms_key_name, encrypted_key) VALUES ($1, $2) RETURNING id;`, keyName, encryptedKey).Scan(&keyId) if err != nil { return fmt.Errorf("unable to add ibd creds key: %w", err) } _, err = exec.ExecContext(ctx, ` UPDATE users SET ibd_username = $2, ibd_password = $3, encryption_key = $4 WHERE subject = $1;`, subject, username, encryptedPass, keyId) if err != nil { return fmt.Errorf("unable to add ibd creds to user: %w", err) } if err = tx.Commit(); err != nil { return fmt.Errorf("unable to commit transaction: %w", err) } return nil } func GetIBDCreds( ctx context.Context, exec Executor, kms keys.KeyManagementService, subject string, ) ( username string, password string, err error, ) { row := exec.QueryRowContext(ctx, ` SELECT ibd_username, ibd_password, encrypted_key, kms_key_name FROM users INNER JOIN public.keys k on k.id = users.encryption_key WHERE subject = $1;`, subject) var encryptedPass, encryptedKey []byte var keyName string err = row.Scan(&username, &encryptedPass, &encryptedKey, &keyName) if err != nil { if errors.Is(err, sql.ErrNoRows) { return "", "", ErrIBDCredsNotFound } return "", "", fmt.Errorf("unable to scan sql row into ibd creds: %w", err) } passwordBytes, err := keys.Decrypt(ctx, kms, keyName, encryptedPass, encryptedKey) if err != nil { return "", "", fmt.Errorf("unable to decrypt password: %w", err) } return username, string(passwordBytes), nil } type User struct { Subject string IBDUsername *string EncryptedIBDPassword *string EncryptionKeyID *int }