diff options
author | 2024-08-05 18:55:10 -0700 | |
---|---|---|
committer | 2024-08-05 18:55:19 -0700 | |
commit | b96fcd1a54a46a95f98467b49a051564bc21c23c (patch) | |
tree | 93caeeb05f8d6310e241095608ea2428c749b18c /backend/internal/database/users.go | |
download | ibd-trader-b96fcd1a54a46a95f98467b49a051564bc21c23c.tar.gz ibd-trader-b96fcd1a54a46a95f98467b49a051564bc21c23c.tar.zst ibd-trader-b96fcd1a54a46a95f98467b49a051564bc21c23c.zip |
Initial Commit
Diffstat (limited to 'backend/internal/database/users.go')
-rw-r--r-- | backend/internal/database/users.go | 140 |
1 files changed, 140 insertions, 0 deletions
diff --git a/backend/internal/database/users.go b/backend/internal/database/users.go new file mode 100644 index 0000000..1950fcb --- /dev/null +++ b/backend/internal/database/users.go @@ -0,0 +1,140 @@ +package database + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "ibd-trader/internal/keys" +) + +type UserStore interface { + AddUser(ctx context.Context, subject string) error + GetUser(ctx context.Context, subject string) (*User, error) + ListUsers(ctx context.Context, hasIBDCreds bool) ([]User, error) + AddIBDCreds(ctx context.Context, subject string, username string, password string) error + GetIBDCreds(ctx context.Context, subject string) (username string, password string, err error) +} + +var ErrUserNotFound = fmt.Errorf("user not found") +var ErrIBDCredsNotFound = fmt.Errorf("ibd creds not found") + +func (d *database) AddUser(ctx context.Context, subject string) (err error) { + _, err = d.exec( + ctx, + d.db, + "users/add_user", + subject, + ) + return +} + +func (d *database) GetUser(ctx context.Context, subject string) (*User, error) { + row, err := d.queryRow(ctx, d.db, "users/get_user", subject) + if err != nil { + return nil, fmt.Errorf("unable to get user: %w", err) + } + + user := &User{} + err = row.Scan(&user.Subject, &user.IBDUsername, &user.EncryptedIBDPassword, &user.EncryptionKeyID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrUserNotFound + } + return nil, fmt.Errorf("unable to scan sql row into user: %w", err) + } + + return user, nil +} + +func (d *database) ListUsers(ctx context.Context, hasIBDCreds bool) ([]User, error) { + rows, err := d.query(ctx, d.db, "users/list_users") + if err != nil { + return nil, fmt.Errorf("unable to list users: %w", err) + } + + users := make([]User, 0) + for rows.Next() { + user := User{} + err = rows.Scan(&user.Subject, &user.IBDUsername, &user.EncryptedIBDPassword, &user.EncryptionKeyID) + if err != nil { + return nil, fmt.Errorf("unable to scan sql row into user: %w", err) + } + + if hasIBDCreds && user.IBDUsername == nil { + continue + } + users = append(users, user) + } + + return users, nil +} + +func (d *database) AddIBDCreds(ctx context.Context, subject string, username string, password string) error { + encryptedPass, encryptedKey, err := keys.Encrypt(ctx, d.kms, d.keyName, []byte(password)) + if err != nil { + return fmt.Errorf("unable to encrypt password: %w", err) + } + + tx, err := d.db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer func(tx *sql.Tx) { + _ = tx.Rollback() + }(tx) + + row, err := d.queryRow(ctx, tx, "keys/add_key", d.keyName, encryptedKey) + if err != nil { + return fmt.Errorf("unable to add ibd creds key: %w", err) + } + + var keyId int + err = row.Scan(&keyId) + if err != nil { + return fmt.Errorf("unable to scan key id: %w", err) + } + + _, err = d.exec(ctx, tx, "users/add_ibd_creds", subject, username, encryptedPass, keyId) + if err != nil { + return fmt.Errorf("unable to add ibd creds to user: %w", err) + } + + if err = tx.Commit(); err != nil { + return fmt.Errorf("unable to commit transaction: %w", err) + } + + return nil +} + +func (d *database) GetIBDCreds(ctx context.Context, subject string) (username string, password string, err error) { + row, err := d.queryRow(ctx, d.db, "users/get_ibd_creds", subject) + if err != nil { + return "", "", fmt.Errorf("unable to get ibd creds: %w", err) + } + + var encryptedPass, encryptedKey []byte + var keyName string + err = row.Scan(&username, &encryptedPass, &encryptedKey, &keyName) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return "", "", ErrIBDCredsNotFound + } + return "", "", fmt.Errorf("unable to scan sql row into ibd creds: %w", err) + } + + passwordBytes, err := keys.Decrypt(ctx, d.kms, keyName, encryptedPass, encryptedKey) + if err != nil { + return "", "", fmt.Errorf("unable to decrypt password: %w", err) + } + + return username, string(passwordBytes), nil +} + +type User struct { + Subject string + IBDUsername *string + EncryptedIBDPassword *string + EncryptionKeyID *int +} |