aboutsummaryrefslogtreecommitdiff
path: root/backend/internal/database/users.go
diff options
context:
space:
mode:
authorGravatar Anshul Gupta <ansg191@anshulg.com> 2024-08-05 18:55:10 -0700
committerGravatar Anshul Gupta <ansg191@anshulg.com> 2024-08-05 18:55:19 -0700
commitb96fcd1a54a46a95f98467b49a051564bc21c23c (patch)
tree93caeeb05f8d6310e241095608ea2428c749b18c /backend/internal/database/users.go
downloadibd-trader-b96fcd1a54a46a95f98467b49a051564bc21c23c.tar.gz
ibd-trader-b96fcd1a54a46a95f98467b49a051564bc21c23c.tar.zst
ibd-trader-b96fcd1a54a46a95f98467b49a051564bc21c23c.zip
Initial Commit
Diffstat (limited to 'backend/internal/database/users.go')
-rw-r--r--backend/internal/database/users.go140
1 files changed, 140 insertions, 0 deletions
diff --git a/backend/internal/database/users.go b/backend/internal/database/users.go
new file mode 100644
index 0000000..1950fcb
--- /dev/null
+++ b/backend/internal/database/users.go
@@ -0,0 +1,140 @@
+package database
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "fmt"
+
+ "ibd-trader/internal/keys"
+)
+
+type UserStore interface {
+ AddUser(ctx context.Context, subject string) error
+ GetUser(ctx context.Context, subject string) (*User, error)
+ ListUsers(ctx context.Context, hasIBDCreds bool) ([]User, error)
+ AddIBDCreds(ctx context.Context, subject string, username string, password string) error
+ GetIBDCreds(ctx context.Context, subject string) (username string, password string, err error)
+}
+
+var ErrUserNotFound = fmt.Errorf("user not found")
+var ErrIBDCredsNotFound = fmt.Errorf("ibd creds not found")
+
+func (d *database) AddUser(ctx context.Context, subject string) (err error) {
+ _, err = d.exec(
+ ctx,
+ d.db,
+ "users/add_user",
+ subject,
+ )
+ return
+}
+
+func (d *database) GetUser(ctx context.Context, subject string) (*User, error) {
+ row, err := d.queryRow(ctx, d.db, "users/get_user", subject)
+ if err != nil {
+ return nil, fmt.Errorf("unable to get user: %w", err)
+ }
+
+ user := &User{}
+ err = row.Scan(&user.Subject, &user.IBDUsername, &user.EncryptedIBDPassword, &user.EncryptionKeyID)
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, ErrUserNotFound
+ }
+ return nil, fmt.Errorf("unable to scan sql row into user: %w", err)
+ }
+
+ return user, nil
+}
+
+func (d *database) ListUsers(ctx context.Context, hasIBDCreds bool) ([]User, error) {
+ rows, err := d.query(ctx, d.db, "users/list_users")
+ if err != nil {
+ return nil, fmt.Errorf("unable to list users: %w", err)
+ }
+
+ users := make([]User, 0)
+ for rows.Next() {
+ user := User{}
+ err = rows.Scan(&user.Subject, &user.IBDUsername, &user.EncryptedIBDPassword, &user.EncryptionKeyID)
+ if err != nil {
+ return nil, fmt.Errorf("unable to scan sql row into user: %w", err)
+ }
+
+ if hasIBDCreds && user.IBDUsername == nil {
+ continue
+ }
+ users = append(users, user)
+ }
+
+ return users, nil
+}
+
+func (d *database) AddIBDCreds(ctx context.Context, subject string, username string, password string) error {
+ encryptedPass, encryptedKey, err := keys.Encrypt(ctx, d.kms, d.keyName, []byte(password))
+ if err != nil {
+ return fmt.Errorf("unable to encrypt password: %w", err)
+ }
+
+ tx, err := d.db.BeginTx(ctx, nil)
+ if err != nil {
+ return err
+ }
+ defer func(tx *sql.Tx) {
+ _ = tx.Rollback()
+ }(tx)
+
+ row, err := d.queryRow(ctx, tx, "keys/add_key", d.keyName, encryptedKey)
+ if err != nil {
+ return fmt.Errorf("unable to add ibd creds key: %w", err)
+ }
+
+ var keyId int
+ err = row.Scan(&keyId)
+ if err != nil {
+ return fmt.Errorf("unable to scan key id: %w", err)
+ }
+
+ _, err = d.exec(ctx, tx, "users/add_ibd_creds", subject, username, encryptedPass, keyId)
+ if err != nil {
+ return fmt.Errorf("unable to add ibd creds to user: %w", err)
+ }
+
+ if err = tx.Commit(); err != nil {
+ return fmt.Errorf("unable to commit transaction: %w", err)
+ }
+
+ return nil
+}
+
+func (d *database) GetIBDCreds(ctx context.Context, subject string) (username string, password string, err error) {
+ row, err := d.queryRow(ctx, d.db, "users/get_ibd_creds", subject)
+ if err != nil {
+ return "", "", fmt.Errorf("unable to get ibd creds: %w", err)
+ }
+
+ var encryptedPass, encryptedKey []byte
+ var keyName string
+ err = row.Scan(&username, &encryptedPass, &encryptedKey, &keyName)
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return "", "", ErrIBDCredsNotFound
+ }
+ return "", "", fmt.Errorf("unable to scan sql row into ibd creds: %w", err)
+ }
+
+ passwordBytes, err := keys.Decrypt(ctx, d.kms, keyName, encryptedPass, encryptedKey)
+ if err != nil {
+ return "", "", fmt.Errorf("unable to decrypt password: %w", err)
+ }
+
+ return username, string(passwordBytes), nil
+}
+
+type User struct {
+ Subject string
+ IBDUsername *string
+ EncryptedIBDPassword *string
+ EncryptionKeyID *int
+}