aboutsummaryrefslogtreecommitdiff
path: root/backend/internal/database/users.go
diff options
context:
space:
mode:
authorGravatar Anshul Gupta <ansg191@anshulg.com> 2024-08-07 17:48:57 -0700
committerGravatar Anshul Gupta <ansg191@anshulg.com> 2024-08-07 18:48:10 -0700
commite9ee45b9d2bd494332dcf8b2073714f92fd0738d (patch)
treed34af1af84984409d27003981538f13cde4ba218 /backend/internal/database/users.go
parent3de4ebb7560851ccbefe296c197456fe80c22901 (diff)
downloadibd-trader-e9ee45b9d2bd494332dcf8b2073714f92fd0738d.tar.gz
ibd-trader-e9ee45b9d2bd494332dcf8b2073714f92fd0738d.tar.zst
ibd-trader-e9ee45b9d2bd494332dcf8b2073714f92fd0738d.zip
Refactor DB to remove restrictive query system
Diffstat (limited to 'backend/internal/database/users.go')
-rw-r--r--backend/internal/database/users.go87
1 files changed, 47 insertions, 40 deletions
diff --git a/backend/internal/database/users.go b/backend/internal/database/users.go
index ff6f674..d023598 100644
--- a/backend/internal/database/users.go
+++ b/backend/internal/database/users.go
@@ -9,35 +9,25 @@ import (
"github.com/ansg191/ibd-trader-backend/internal/keys"
)
-type UserStore interface {
- AddUser(ctx context.Context, subject string) error
- GetUser(ctx context.Context, subject string) (*User, error)
- ListUsers(ctx context.Context, hasIBDCreds bool) ([]User, error)
- AddIBDCreds(ctx context.Context, subject string, username string, password string) error
- GetIBDCreds(ctx context.Context, subject string) (username string, password string, err error)
-}
-
var ErrUserNotFound = fmt.Errorf("user not found")
var ErrIBDCredsNotFound = fmt.Errorf("ibd creds not found")
-func (d *database) AddUser(ctx context.Context, subject string) (err error) {
- _, err = d.exec(
- ctx,
- d.db,
- "users/add_user",
- subject,
- )
+func AddUser(ctx context.Context, exec Executor, subject string) (err error) {
+ _, err = exec.ExecContext(ctx, `
+INSERT INTO users (subject)
+VALUES ($1)
+ON CONFLICT DO NOTHING;`, subject)
return
}
-func (d *database) GetUser(ctx context.Context, subject string) (*User, error) {
- row, err := d.queryRow(ctx, d.db, "users/get_user", subject)
- if err != nil {
- return nil, fmt.Errorf("unable to get user: %w", err)
- }
+func GetUser(ctx context.Context, exec Executor, subject string) (*User, error) {
+ row := exec.QueryRowContext(ctx, `
+SELECT subject, ibd_username, ibd_password, encryption_key
+FROM users
+WHERE subject = $1;`, subject)
user := &User{}
- err = row.Scan(&user.Subject, &user.IBDUsername, &user.EncryptedIBDPassword, &user.EncryptionKeyID)
+ err := row.Scan(&user.Subject, &user.IBDUsername, &user.EncryptedIBDPassword, &user.EncryptionKeyID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrUserNotFound
@@ -48,8 +38,11 @@ func (d *database) GetUser(ctx context.Context, subject string) (*User, error) {
return user, nil
}
-func (d *database) ListUsers(ctx context.Context, hasIBDCreds bool) ([]User, error) {
- rows, err := d.query(ctx, d.db, "users/list_users")
+func ListUsers(ctx context.Context, exec Executor, hasIBDCreds bool) ([]User, error) {
+ rows, err := exec.QueryContext(ctx, `
+SELECT subject, ibd_username, ibd_password, encryption_key
+FROM users;
+`)
if err != nil {
return nil, fmt.Errorf("unable to list users: %w", err)
}
@@ -71,13 +64,18 @@ func (d *database) ListUsers(ctx context.Context, hasIBDCreds bool) ([]User, err
return users, nil
}
-func (d *database) AddIBDCreds(ctx context.Context, subject string, username string, password string) error {
- encryptedPass, encryptedKey, err := keys.Encrypt(ctx, d.kms, d.keyName, []byte(password))
+func AddIBDCreds(
+ ctx context.Context,
+ exec TransactionExecutor,
+ kms keys.KeyManagementService,
+ keyName, subject, username, password string,
+) error {
+ encryptedPass, encryptedKey, err := keys.Encrypt(ctx, kms, keyName, []byte(password))
if err != nil {
return fmt.Errorf("unable to encrypt password: %w", err)
}
- tx, err := d.db.BeginTx(ctx, nil)
+ tx, err := exec.BeginTx(ctx, nil)
if err != nil {
return err
}
@@ -85,18 +83,17 @@ func (d *database) AddIBDCreds(ctx context.Context, subject string, username str
_ = tx.Rollback()
}(tx)
- row, err := d.queryRow(ctx, tx, "keys/add_key", d.keyName, encryptedKey)
+ keyId, err := AddKey(ctx, tx, keyName, encryptedKey)
if err != nil {
return fmt.Errorf("unable to add ibd creds key: %w", err)
}
- var keyId int
- err = row.Scan(&keyId)
- if err != nil {
- return fmt.Errorf("unable to scan key id: %w", err)
- }
-
- _, err = d.exec(ctx, tx, "users/add_ibd_creds", subject, username, encryptedPass, keyId)
+ _, err = exec.ExecContext(ctx, `
+UPDATE users
+SET ibd_username = $2,
+ ibd_password = $3,
+ encryption_key = $4
+WHERE subject = $1;`, subject, username, encryptedPass, keyId)
if err != nil {
return fmt.Errorf("unable to add ibd creds to user: %w", err)
}
@@ -108,11 +105,21 @@ func (d *database) AddIBDCreds(ctx context.Context, subject string, username str
return nil
}
-func (d *database) GetIBDCreds(ctx context.Context, subject string) (username string, password string, err error) {
- row, err := d.queryRow(ctx, d.db, "users/get_ibd_creds", subject)
- if err != nil {
- return "", "", fmt.Errorf("unable to get ibd creds: %w", err)
- }
+func GetIBDCreds(
+ ctx context.Context,
+ exec Executor,
+ kms keys.KeyManagementService,
+ subject string,
+) (
+ username string,
+ password string,
+ err error,
+) {
+ row := exec.QueryRowContext(ctx, `
+SELECT ibd_username, ibd_password, encrypted_key, kms_key_name
+FROM users
+INNER JOIN public.keys k on k.id = users.encryption_key
+WHERE subject = $1;`, subject)
var encryptedPass, encryptedKey []byte
var keyName string
@@ -124,7 +131,7 @@ func (d *database) GetIBDCreds(ctx context.Context, subject string) (username st
return "", "", fmt.Errorf("unable to scan sql row into ibd creds: %w", err)
}
- passwordBytes, err := keys.Decrypt(ctx, d.kms, keyName, encryptedPass, encryptedKey)
+ passwordBytes, err := keys.Decrypt(ctx, kms, keyName, encryptedPass, encryptedKey)
if err != nil {
return "", "", fmt.Errorf("unable to decrypt password: %w", err)
}