1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
|
package database
import (
"context"
"database/sql"
"errors"
"fmt"
"github.com/ansg191/ibd-trader-backend/internal/keys"
)
type UserStore interface {
AddUser(ctx context.Context, subject string) error
GetUser(ctx context.Context, subject string) (*User, error)
ListUsers(ctx context.Context, hasIBDCreds bool) ([]User, error)
AddIBDCreds(ctx context.Context, subject string, username string, password string) error
GetIBDCreds(ctx context.Context, subject string) (username string, password string, err error)
}
var ErrUserNotFound = fmt.Errorf("user not found")
var ErrIBDCredsNotFound = fmt.Errorf("ibd creds not found")
func (d *database) AddUser(ctx context.Context, subject string) (err error) {
_, err = d.exec(
ctx,
d.db,
"users/add_user",
subject,
)
return
}
func (d *database) GetUser(ctx context.Context, subject string) (*User, error) {
row, err := d.queryRow(ctx, d.db, "users/get_user", subject)
if err != nil {
return nil, fmt.Errorf("unable to get user: %w", err)
}
user := &User{}
err = row.Scan(&user.Subject, &user.IBDUsername, &user.EncryptedIBDPassword, &user.EncryptionKeyID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("unable to scan sql row into user: %w", err)
}
return user, nil
}
func (d *database) ListUsers(ctx context.Context, hasIBDCreds bool) ([]User, error) {
rows, err := d.query(ctx, d.db, "users/list_users")
if err != nil {
return nil, fmt.Errorf("unable to list users: %w", err)
}
users := make([]User, 0)
for rows.Next() {
user := User{}
err = rows.Scan(&user.Subject, &user.IBDUsername, &user.EncryptedIBDPassword, &user.EncryptionKeyID)
if err != nil {
return nil, fmt.Errorf("unable to scan sql row into user: %w", err)
}
if hasIBDCreds && user.IBDUsername == nil {
continue
}
users = append(users, user)
}
return users, nil
}
func (d *database) AddIBDCreds(ctx context.Context, subject string, username string, password string) error {
encryptedPass, encryptedKey, err := keys.Encrypt(ctx, d.kms, d.keyName, []byte(password))
if err != nil {
return fmt.Errorf("unable to encrypt password: %w", err)
}
tx, err := d.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer func(tx *sql.Tx) {
_ = tx.Rollback()
}(tx)
row, err := d.queryRow(ctx, tx, "keys/add_key", d.keyName, encryptedKey)
if err != nil {
return fmt.Errorf("unable to add ibd creds key: %w", err)
}
var keyId int
err = row.Scan(&keyId)
if err != nil {
return fmt.Errorf("unable to scan key id: %w", err)
}
_, err = d.exec(ctx, tx, "users/add_ibd_creds", subject, username, encryptedPass, keyId)
if err != nil {
return fmt.Errorf("unable to add ibd creds to user: %w", err)
}
if err = tx.Commit(); err != nil {
return fmt.Errorf("unable to commit transaction: %w", err)
}
return nil
}
func (d *database) GetIBDCreds(ctx context.Context, subject string) (username string, password string, err error) {
row, err := d.queryRow(ctx, d.db, "users/get_ibd_creds", subject)
if err != nil {
return "", "", fmt.Errorf("unable to get ibd creds: %w", err)
}
var encryptedPass, encryptedKey []byte
var keyName string
err = row.Scan(&username, &encryptedPass, &encryptedKey, &keyName)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return "", "", ErrIBDCredsNotFound
}
return "", "", fmt.Errorf("unable to scan sql row into ibd creds: %w", err)
}
passwordBytes, err := keys.Decrypt(ctx, d.kms, keyName, encryptedPass, encryptedKey)
if err != nil {
return "", "", fmt.Errorf("unable to decrypt password: %w", err)
}
return username, string(passwordBytes), nil
}
type User struct {
Subject string
IBDUsername *string
EncryptedIBDPassword *string
EncryptionKeyID *int
}
|