aboutsummaryrefslogtreecommitdiff
path: root/backend/internal/database/users.go
blob: f7998fb6afbbf8e0d90367e2f0d02f41d08735b6 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
package database

import (
	"context"
	"database/sql"
	"errors"
	"fmt"

	"github.com/ansg191/ibd-trader-backend/internal/keys"
)

var ErrUserNotFound = fmt.Errorf("user not found")
var ErrIBDCredsNotFound = fmt.Errorf("ibd creds not found")

func AddUser(ctx context.Context, exec Executor, subject string) (err error) {
	_, err = exec.ExecContext(ctx, `
INSERT INTO users (subject)
VALUES ($1)
ON CONFLICT DO NOTHING;`, subject)
	return
}

func GetUser(ctx context.Context, exec Executor, subject string) (*User, error) {
	row := exec.QueryRowContext(ctx, `
SELECT subject, ibd_username, ibd_password, encryption_key
FROM users
WHERE subject = $1;`, subject)

	user := &User{}
	err := row.Scan(&user.Subject, &user.IBDUsername, &user.EncryptedIBDPassword, &user.EncryptionKeyID)
	if err != nil {
		if errors.Is(err, sql.ErrNoRows) {
			return nil, ErrUserNotFound
		}
		return nil, fmt.Errorf("unable to scan sql row into user: %w", err)
	}

	return user, nil
}

func ListUsers(ctx context.Context, exec Executor, hasIBDCreds bool) ([]User, error) {
	rows, err := exec.QueryContext(ctx, `
SELECT subject, ibd_username, ibd_password, encryption_key
FROM users;
`)
	if err != nil {
		return nil, fmt.Errorf("unable to list users: %w", err)
	}

	users := make([]User, 0)
	for rows.Next() {
		user := User{}
		err = rows.Scan(&user.Subject, &user.IBDUsername, &user.EncryptedIBDPassword, &user.EncryptionKeyID)
		if err != nil {
			return nil, fmt.Errorf("unable to scan sql row into user: %w", err)
		}

		if hasIBDCreds && user.IBDUsername == nil {
			continue
		}
		users = append(users, user)
	}

	return users, nil
}

func AddIBDCreds(
	ctx context.Context,
	exec TransactionExecutor,
	kms keys.KeyManagementService,
	keyName, subject, username, password string,
) error {
	encryptedPass, encryptedKey, err := keys.Encrypt(ctx, kms, keyName, []byte(password))
	if err != nil {
		return fmt.Errorf("unable to encrypt password: %w", err)
	}

	tx, err := exec.BeginTx(ctx, nil)
	if err != nil {
		return err
	}
	defer func(tx *sql.Tx) {
		_ = tx.Rollback()
	}(tx)

	var keyId int
	err = tx.QueryRowContext(ctx, `
INSERT INTO keys (kms_key_name, encrypted_key)
VALUES ($1, $2)
RETURNING id;`, keyName, encryptedKey).Scan(&keyId)
	if err != nil {
		return fmt.Errorf("unable to add ibd creds key: %w", err)
	}

	_, err = exec.ExecContext(ctx, `
UPDATE users
SET ibd_username   = $2,
    ibd_password   = $3,
    encryption_key = $4
WHERE subject = $1;`, subject, username, encryptedPass, keyId)
	if err != nil {
		return fmt.Errorf("unable to add ibd creds to user: %w", err)
	}

	if err = tx.Commit(); err != nil {
		return fmt.Errorf("unable to commit transaction: %w", err)
	}

	return nil
}

func GetIBDCreds(
	ctx context.Context,
	exec Executor,
	kms keys.KeyManagementService,
	subject string,
) (
	username string,
	password string,
	err error,
) {
	row := exec.QueryRowContext(ctx, `
SELECT ibd_username, ibd_password, encrypted_key, kms_key_name
FROM users
INNER JOIN public.keys k on k.id = users.encryption_key
WHERE subject = $1;`, subject)

	var encryptedPass, encryptedKey []byte
	var keyName string
	err = row.Scan(&username, &encryptedPass, &encryptedKey, &keyName)
	if err != nil {
		if errors.Is(err, sql.ErrNoRows) {
			return "", "", ErrIBDCredsNotFound
		}
		return "", "", fmt.Errorf("unable to scan sql row into ibd creds: %w", err)
	}

	passwordBytes, err := keys.Decrypt(ctx, kms, keyName, encryptedPass, encryptedKey)
	if err != nil {
		return "", "", fmt.Errorf("unable to decrypt password: %w", err)
	}

	return username, string(passwordBytes), nil
}

type User struct {
	Subject              string
	IBDUsername          *string
	EncryptedIBDPassword *string
	EncryptionKeyID      *int
}