diff options
author | 2024-08-07 17:48:57 -0700 | |
---|---|---|
committer | 2024-08-07 18:48:10 -0700 | |
commit | e9ee45b9d2bd494332dcf8b2073714f92fd0738d (patch) | |
tree | d34af1af84984409d27003981538f13cde4ba218 /backend/internal/database/users.go | |
parent | 3de4ebb7560851ccbefe296c197456fe80c22901 (diff) | |
download | ibd-trader-e9ee45b9d2bd494332dcf8b2073714f92fd0738d.tar.gz ibd-trader-e9ee45b9d2bd494332dcf8b2073714f92fd0738d.tar.zst ibd-trader-e9ee45b9d2bd494332dcf8b2073714f92fd0738d.zip |
Refactor DB to remove restrictive query system
Diffstat (limited to 'backend/internal/database/users.go')
-rw-r--r-- | backend/internal/database/users.go | 87 |
1 files changed, 47 insertions, 40 deletions
diff --git a/backend/internal/database/users.go b/backend/internal/database/users.go index ff6f674..d023598 100644 --- a/backend/internal/database/users.go +++ b/backend/internal/database/users.go @@ -9,35 +9,25 @@ import ( "github.com/ansg191/ibd-trader-backend/internal/keys" ) -type UserStore interface { - AddUser(ctx context.Context, subject string) error - GetUser(ctx context.Context, subject string) (*User, error) - ListUsers(ctx context.Context, hasIBDCreds bool) ([]User, error) - AddIBDCreds(ctx context.Context, subject string, username string, password string) error - GetIBDCreds(ctx context.Context, subject string) (username string, password string, err error) -} - var ErrUserNotFound = fmt.Errorf("user not found") var ErrIBDCredsNotFound = fmt.Errorf("ibd creds not found") -func (d *database) AddUser(ctx context.Context, subject string) (err error) { - _, err = d.exec( - ctx, - d.db, - "users/add_user", - subject, - ) +func AddUser(ctx context.Context, exec Executor, subject string) (err error) { + _, err = exec.ExecContext(ctx, ` +INSERT INTO users (subject) +VALUES ($1) +ON CONFLICT DO NOTHING;`, subject) return } -func (d *database) GetUser(ctx context.Context, subject string) (*User, error) { - row, err := d.queryRow(ctx, d.db, "users/get_user", subject) - if err != nil { - return nil, fmt.Errorf("unable to get user: %w", err) - } +func GetUser(ctx context.Context, exec Executor, subject string) (*User, error) { + row := exec.QueryRowContext(ctx, ` +SELECT subject, ibd_username, ibd_password, encryption_key +FROM users +WHERE subject = $1;`, subject) user := &User{} - err = row.Scan(&user.Subject, &user.IBDUsername, &user.EncryptedIBDPassword, &user.EncryptionKeyID) + err := row.Scan(&user.Subject, &user.IBDUsername, &user.EncryptedIBDPassword, &user.EncryptionKeyID) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, ErrUserNotFound @@ -48,8 +38,11 @@ func (d *database) GetUser(ctx context.Context, subject string) (*User, error) { return user, nil } -func (d *database) ListUsers(ctx context.Context, hasIBDCreds bool) ([]User, error) { - rows, err := d.query(ctx, d.db, "users/list_users") +func ListUsers(ctx context.Context, exec Executor, hasIBDCreds bool) ([]User, error) { + rows, err := exec.QueryContext(ctx, ` +SELECT subject, ibd_username, ibd_password, encryption_key +FROM users; +`) if err != nil { return nil, fmt.Errorf("unable to list users: %w", err) } @@ -71,13 +64,18 @@ func (d *database) ListUsers(ctx context.Context, hasIBDCreds bool) ([]User, err return users, nil } -func (d *database) AddIBDCreds(ctx context.Context, subject string, username string, password string) error { - encryptedPass, encryptedKey, err := keys.Encrypt(ctx, d.kms, d.keyName, []byte(password)) +func AddIBDCreds( + ctx context.Context, + exec TransactionExecutor, + kms keys.KeyManagementService, + keyName, subject, username, password string, +) error { + encryptedPass, encryptedKey, err := keys.Encrypt(ctx, kms, keyName, []byte(password)) if err != nil { return fmt.Errorf("unable to encrypt password: %w", err) } - tx, err := d.db.BeginTx(ctx, nil) + tx, err := exec.BeginTx(ctx, nil) if err != nil { return err } @@ -85,18 +83,17 @@ func (d *database) AddIBDCreds(ctx context.Context, subject string, username str _ = tx.Rollback() }(tx) - row, err := d.queryRow(ctx, tx, "keys/add_key", d.keyName, encryptedKey) + keyId, err := AddKey(ctx, tx, keyName, encryptedKey) if err != nil { return fmt.Errorf("unable to add ibd creds key: %w", err) } - var keyId int - err = row.Scan(&keyId) - if err != nil { - return fmt.Errorf("unable to scan key id: %w", err) - } - - _, err = d.exec(ctx, tx, "users/add_ibd_creds", subject, username, encryptedPass, keyId) + _, err = exec.ExecContext(ctx, ` +UPDATE users +SET ibd_username = $2, + ibd_password = $3, + encryption_key = $4 +WHERE subject = $1;`, subject, username, encryptedPass, keyId) if err != nil { return fmt.Errorf("unable to add ibd creds to user: %w", err) } @@ -108,11 +105,21 @@ func (d *database) AddIBDCreds(ctx context.Context, subject string, username str return nil } -func (d *database) GetIBDCreds(ctx context.Context, subject string) (username string, password string, err error) { - row, err := d.queryRow(ctx, d.db, "users/get_ibd_creds", subject) - if err != nil { - return "", "", fmt.Errorf("unable to get ibd creds: %w", err) - } +func GetIBDCreds( + ctx context.Context, + exec Executor, + kms keys.KeyManagementService, + subject string, +) ( + username string, + password string, + err error, +) { + row := exec.QueryRowContext(ctx, ` +SELECT ibd_username, ibd_password, encrypted_key, kms_key_name +FROM users +INNER JOIN public.keys k on k.id = users.encryption_key +WHERE subject = $1;`, subject) var encryptedPass, encryptedKey []byte var keyName string @@ -124,7 +131,7 @@ func (d *database) GetIBDCreds(ctx context.Context, subject string) (username st return "", "", fmt.Errorf("unable to scan sql row into ibd creds: %w", err) } - passwordBytes, err := keys.Decrypt(ctx, d.kms, keyName, encryptedPass, encryptedKey) + passwordBytes, err := keys.Decrypt(ctx, kms, keyName, encryptedPass, encryptedKey) if err != nil { return "", "", fmt.Errorf("unable to decrypt password: %w", err) } |