aboutsummaryrefslogtreecommitdiff
path: root/backend/internal/database/users.go
diff options
context:
space:
mode:
Diffstat (limited to 'backend/internal/database/users.go')
-rw-r--r--backend/internal/database/users.go151
1 files changed, 151 insertions, 0 deletions
diff --git a/backend/internal/database/users.go b/backend/internal/database/users.go
new file mode 100644
index 0000000..f7998fb
--- /dev/null
+++ b/backend/internal/database/users.go
@@ -0,0 +1,151 @@
+package database
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "fmt"
+
+ "github.com/ansg191/ibd-trader-backend/internal/keys"
+)
+
+var ErrUserNotFound = fmt.Errorf("user not found")
+var ErrIBDCredsNotFound = fmt.Errorf("ibd creds not found")
+
+func AddUser(ctx context.Context, exec Executor, subject string) (err error) {
+ _, err = exec.ExecContext(ctx, `
+INSERT INTO users (subject)
+VALUES ($1)
+ON CONFLICT DO NOTHING;`, subject)
+ return
+}
+
+func GetUser(ctx context.Context, exec Executor, subject string) (*User, error) {
+ row := exec.QueryRowContext(ctx, `
+SELECT subject, ibd_username, ibd_password, encryption_key
+FROM users
+WHERE subject = $1;`, subject)
+
+ user := &User{}
+ err := row.Scan(&user.Subject, &user.IBDUsername, &user.EncryptedIBDPassword, &user.EncryptionKeyID)
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, ErrUserNotFound
+ }
+ return nil, fmt.Errorf("unable to scan sql row into user: %w", err)
+ }
+
+ return user, nil
+}
+
+func ListUsers(ctx context.Context, exec Executor, hasIBDCreds bool) ([]User, error) {
+ rows, err := exec.QueryContext(ctx, `
+SELECT subject, ibd_username, ibd_password, encryption_key
+FROM users;
+`)
+ if err != nil {
+ return nil, fmt.Errorf("unable to list users: %w", err)
+ }
+
+ users := make([]User, 0)
+ for rows.Next() {
+ user := User{}
+ err = rows.Scan(&user.Subject, &user.IBDUsername, &user.EncryptedIBDPassword, &user.EncryptionKeyID)
+ if err != nil {
+ return nil, fmt.Errorf("unable to scan sql row into user: %w", err)
+ }
+
+ if hasIBDCreds && user.IBDUsername == nil {
+ continue
+ }
+ users = append(users, user)
+ }
+
+ return users, nil
+}
+
+func AddIBDCreds(
+ ctx context.Context,
+ exec TransactionExecutor,
+ kms keys.KeyManagementService,
+ keyName, subject, username, password string,
+) error {
+ encryptedPass, encryptedKey, err := keys.Encrypt(ctx, kms, keyName, []byte(password))
+ if err != nil {
+ return fmt.Errorf("unable to encrypt password: %w", err)
+ }
+
+ tx, err := exec.BeginTx(ctx, nil)
+ if err != nil {
+ return err
+ }
+ defer func(tx *sql.Tx) {
+ _ = tx.Rollback()
+ }(tx)
+
+ var keyId int
+ err = tx.QueryRowContext(ctx, `
+INSERT INTO keys (kms_key_name, encrypted_key)
+VALUES ($1, $2)
+RETURNING id;`, keyName, encryptedKey).Scan(&keyId)
+ if err != nil {
+ return fmt.Errorf("unable to add ibd creds key: %w", err)
+ }
+
+ _, err = exec.ExecContext(ctx, `
+UPDATE users
+SET ibd_username = $2,
+ ibd_password = $3,
+ encryption_key = $4
+WHERE subject = $1;`, subject, username, encryptedPass, keyId)
+ if err != nil {
+ return fmt.Errorf("unable to add ibd creds to user: %w", err)
+ }
+
+ if err = tx.Commit(); err != nil {
+ return fmt.Errorf("unable to commit transaction: %w", err)
+ }
+
+ return nil
+}
+
+func GetIBDCreds(
+ ctx context.Context,
+ exec Executor,
+ kms keys.KeyManagementService,
+ subject string,
+) (
+ username string,
+ password string,
+ err error,
+) {
+ row := exec.QueryRowContext(ctx, `
+SELECT ibd_username, ibd_password, encrypted_key, kms_key_name
+FROM users
+INNER JOIN public.keys k on k.id = users.encryption_key
+WHERE subject = $1;`, subject)
+
+ var encryptedPass, encryptedKey []byte
+ var keyName string
+ err = row.Scan(&username, &encryptedPass, &encryptedKey, &keyName)
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return "", "", ErrIBDCredsNotFound
+ }
+ return "", "", fmt.Errorf("unable to scan sql row into ibd creds: %w", err)
+ }
+
+ passwordBytes, err := keys.Decrypt(ctx, kms, keyName, encryptedPass, encryptedKey)
+ if err != nil {
+ return "", "", fmt.Errorf("unable to decrypt password: %w", err)
+ }
+
+ return username, string(passwordBytes), nil
+}
+
+type User struct {
+ Subject string
+ IBDUsername *string
+ EncryptedIBDPassword *string
+ EncryptionKeyID *int
+}