aboutsummaryrefslogtreecommitdiff
path: root/backend/internal/database
diff options
context:
space:
mode:
authorGravatar Anshul Gupta <ansg191@anshulg.com> 2024-08-11 13:15:50 -0700
committerGravatar Anshul Gupta <ansg191@anshulg.com> 2024-08-11 13:15:50 -0700
commit6a3c21fb0b1c126849f2bbff494403bbe901448e (patch)
tree5d7805524357c2c8a9819c39d2051a4e3633a1d5 /backend/internal/database
parent29c6040a51616e9e4cf6c70ee16391b2a3b238c9 (diff)
parentf34b92ded11b07f78575ac62c260a380c468e5ea (diff)
downloadibd-trader-6a3c21fb0b1c126849f2bbff494403bbe901448e.tar.gz
ibd-trader-6a3c21fb0b1c126849f2bbff494403bbe901448e.tar.zst
ibd-trader-6a3c21fb0b1c126849f2bbff494403bbe901448e.zip
Merge remote-tracking branch 'backend/main'
Diffstat (limited to 'backend/internal/database')
-rw-r--r--backend/internal/database/cookies.go189
-rw-r--r--backend/internal/database/database.go166
-rw-r--r--backend/internal/database/database_test.go79
-rw-r--r--backend/internal/database/stocks.go293
-rw-r--r--backend/internal/database/users.go151
5 files changed, 878 insertions, 0 deletions
diff --git a/backend/internal/database/cookies.go b/backend/internal/database/cookies.go
new file mode 100644
index 0000000..3ea21d0
--- /dev/null
+++ b/backend/internal/database/cookies.go
@@ -0,0 +1,189 @@
+package database
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "fmt"
+ "net/http"
+ "time"
+
+ "github.com/ansg191/ibd-trader-backend/internal/keys"
+)
+
+func GetAnyCookie(ctx context.Context, exec Executor, kms keys.KeyManagementService) (*IBDCookie, error) {
+ row := exec.QueryRowContext(ctx, `
+SELECT ibd_tokens.id, token, encrypted_key, kms_key_name, expires_at
+FROM ibd_tokens
+ INNER JOIN keys ON encryption_key = keys.id
+WHERE expires_at > NOW()
+ AND degraded = FALSE
+ORDER BY random()
+LIMIT 1;`)
+
+ var id uint
+ var encryptedToken, encryptedKey []byte
+ var keyName string
+ var expiry time.Time
+ err := row.Scan(&id, &encryptedToken, &encryptedKey, &keyName, &expiry)
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, nil
+ }
+ return nil, fmt.Errorf("unable to scan sql row into ibd cookie: %w", err)
+ }
+
+ // Set the expiry to UTC explicitly.
+ // For some reason, the expiry time is set to location="".
+ expiry = expiry.UTC()
+
+ token, err := keys.Decrypt(ctx, kms, keyName, encryptedToken, encryptedKey)
+ if err != nil {
+ return nil, fmt.Errorf("unable to decrypt token: %w", err)
+ }
+ return &IBDCookie{
+ Token: string(token),
+ Expiry: expiry,
+ }, nil
+}
+
+func GetCookies(
+ ctx context.Context,
+ exec Executor,
+ kms keys.KeyManagementService,
+ subject string,
+ degraded bool,
+) ([]IBDCookie, error) {
+ rows, err := exec.QueryContext(ctx, `
+SELECT ibd_tokens.id, token, encrypted_key, kms_key_name, expires_at
+FROM ibd_tokens
+ INNER JOIN keys ON encryption_key = keys.id
+WHERE user_subject = $1
+ AND expires_at > NOW()
+ AND degraded = $2
+ORDER BY expires_at DESC;`, subject, degraded)
+ if err != nil {
+ return nil, fmt.Errorf("unable to get ibd cookies: %w", err)
+ }
+
+ cookies := make([]IBDCookie, 0)
+ for rows.Next() {
+ var id uint
+ var encryptedToken, encryptedKey []byte
+ var keyName string
+ var expiry time.Time
+ err = rows.Scan(&id, &encryptedToken, &encryptedKey, &keyName, &expiry)
+ if err != nil {
+ return nil, fmt.Errorf("unable to scan sql row into ibd cookie: %w", err)
+ }
+
+ // Set the expiry to UTC explicitly.
+ // For some reason, the expiry time is set to location="".
+ expiry = expiry.UTC()
+
+ token, err := keys.Decrypt(ctx, kms, keyName, encryptedToken, encryptedKey)
+ if err != nil {
+ return nil, fmt.Errorf("unable to decrypt token: %w", err)
+ }
+ cookie := IBDCookie{
+ ID: id,
+ Token: string(token),
+ Expiry: expiry,
+ }
+ cookies = append(cookies, cookie)
+ }
+
+ return cookies, nil
+}
+
+func AddCookie(
+ ctx context.Context,
+ exec TransactionExecutor,
+ kms keys.KeyManagementService,
+ subject string,
+ cookie *http.Cookie,
+) error {
+ tx, err := exec.BeginTx(ctx, nil)
+ if err != nil {
+ return err
+ }
+
+ // Get the key ID for the user
+ user, err := GetUser(ctx, tx, subject)
+ if err != nil {
+ return fmt.Errorf("unable to get user: %w", err)
+ }
+ if user.EncryptionKeyID == nil {
+ return errors.New("user does not have an encryption key")
+ }
+
+ // Get the key
+ var keyName string
+ var key []byte
+ err = tx.QueryRowContext(ctx, `
+SELECT kms_key_name, encrypted_key
+FROM keys
+WHERE id = $1;`,
+ *user.EncryptionKeyID,
+ ).Scan(&keyName, &key)
+ if err != nil {
+ return fmt.Errorf("unable to get key: %w", err)
+ }
+
+ // Encrypt the token
+ encryptedToken, err := keys.EncryptWithKey(ctx, kms, keyName, key, []byte(cookie.Value))
+ if err != nil {
+ return fmt.Errorf("unable to encrypt token: %w", err)
+ }
+
+ // Add the cookie to the database
+ _, err = exec.ExecContext(ctx, `
+INSERT INTO ibd_tokens (token, expires_at, user_subject, encryption_key)
+VALUES ($1, $2, $3, $4)`, encryptedToken, cookie.Expires, subject, *user.EncryptionKeyID)
+ if err != nil {
+ return fmt.Errorf("unable to add cookie: %w", err)
+ }
+
+ return nil
+}
+
+func ReportCookieFailure(ctx context.Context, exec Executor, id uint) error {
+ _, err := exec.ExecContext(ctx, `
+UPDATE ibd_tokens
+SET degraded = TRUE
+WHERE id = $1;`, id)
+ if err != nil {
+ return fmt.Errorf("unable to report cookie failure: %w", err)
+ }
+ return nil
+}
+
+func RepairCookie(ctx context.Context, exec Executor, id uint) error {
+ _, err := exec.ExecContext(ctx, `
+UPDATE ibd_tokens
+SET degraded = FALSE
+WHERE id = $1;`, id)
+ if err != nil {
+ return fmt.Errorf("unable to report cookie failure: %w", err)
+ }
+ return nil
+}
+
+type IBDCookie struct {
+ ID uint
+ Token string
+ Expiry time.Time
+}
+
+func (c *IBDCookie) ToHTTPCookie() *http.Cookie {
+ return &http.Cookie{
+ Name: ".ASPXAUTH",
+ Value: c.Token,
+ Path: "/",
+ Domain: "investors.com",
+ Expires: c.Expiry,
+ Secure: true,
+ HttpOnly: false,
+ SameSite: http.SameSiteLaxMode,
+ }
+}
diff --git a/backend/internal/database/database.go b/backend/internal/database/database.go
new file mode 100644
index 0000000..409dd3c
--- /dev/null
+++ b/backend/internal/database/database.go
@@ -0,0 +1,166 @@
+package database
+
+import (
+ "context"
+ "database/sql"
+ "database/sql/driver"
+ "errors"
+ "io"
+ "log/slog"
+ "sync"
+ "time"
+
+ "github.com/ansg191/ibd-trader-backend/db"
+ "github.com/ansg191/ibd-trader-backend/internal/keys"
+
+ "github.com/golang-migrate/migrate/v4"
+ _ "github.com/golang-migrate/migrate/v4/database/postgres"
+ "github.com/golang-migrate/migrate/v4/source/iofs"
+ _ "github.com/lib/pq"
+)
+
+type Database interface {
+ io.Closer
+ TransactionExecutor
+ driver.Pinger
+
+ Migrate(ctx context.Context) error
+ Maintenance(ctx context.Context)
+}
+
+type database struct {
+ logger *slog.Logger
+
+ db *sql.DB
+ url string
+
+ kms keys.KeyManagementService
+ keyName string
+}
+
+func New(ctx context.Context, logger *slog.Logger, url string, kms keys.KeyManagementService, keyName string) (Database, error) {
+ sqlDB, err := sql.Open("postgres", url)
+ if err != nil {
+ return nil, err
+ }
+
+ err = sqlDB.PingContext(ctx)
+ if err != nil {
+ // Ping failed. Don't error, but give a warning.
+ logger.WarnContext(ctx, "Unable to ping database", "error", err)
+ }
+
+ return &database{
+ logger: logger,
+ db: sqlDB,
+ url: url,
+ kms: kms,
+ keyName: keyName,
+ }, nil
+}
+
+func (d *database) Close() error {
+ return d.db.Close()
+}
+
+func (d *database) Migrate(ctx context.Context) error {
+ return Migrate(ctx, d.url)
+}
+
+func (d *database) Maintenance(ctx context.Context) {
+ ticker := time.NewTicker(15 * time.Minute)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ticker.C:
+ func() {
+ var wg sync.WaitGroup
+ wg.Add(1)
+
+ _, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
+ defer cancel()
+
+ wg.Wait()
+ }()
+ case <-ctx.Done():
+ return
+ }
+ }
+}
+
+func Migrate(ctx context.Context, url string) error {
+ fs, err := iofs.New(db.Migrations, "migrations")
+ if err != nil {
+ return err
+ }
+
+ m, err := migrate.NewWithSourceInstance("iofs", fs, url)
+ if err != nil {
+ return err
+ }
+
+ slog.InfoContext(ctx, "Running DB migration")
+ err = m.Up()
+ if err != nil && !errors.Is(err, migrate.ErrNoChange) {
+ slog.ErrorContext(ctx, "DB migration failed", "error", err)
+ return err
+ }
+
+ return nil
+}
+
+func (d *database) Ping(ctx context.Context) error {
+ return d.db.PingContext(ctx)
+}
+
+type Executor interface {
+ ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
+ QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
+ QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
+}
+
+type TransactionExecutor interface {
+ Executor
+ BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
+}
+
+func (d *database) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
+ d.logger.DebugContext(ctx, "Executing query", "query", query)
+
+ now := time.Now()
+ ret, err := d.db.ExecContext(ctx, query, args...)
+ if err != nil {
+ return nil, err
+ }
+
+ d.logger.DebugContext(ctx, "Query executed successfully", "duration", time.Since(now))
+ return ret, nil
+}
+
+func (d *database) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
+ d.logger.DebugContext(ctx, "Executing query", "query", query)
+
+ now := time.Now()
+ ret, err := d.db.QueryContext(ctx, query, args...)
+ if err != nil {
+ return nil, err
+ }
+
+ d.logger.DebugContext(ctx, "Query executed successfully", "duration", time.Since(now))
+ return ret, nil
+}
+
+func (d *database) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
+ d.logger.DebugContext(ctx, "Executing query", "query", query)
+
+ now := time.Now()
+ ret := d.db.QueryRowContext(ctx, query, args...)
+
+ d.logger.DebugContext(ctx, "Query executed successfully", "duration", time.Since(now))
+ return ret
+}
+
+func (d *database) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
+ return d.db.BeginTx(ctx, opts)
+}
diff --git a/backend/internal/database/database_test.go b/backend/internal/database/database_test.go
new file mode 100644
index 0000000..407a09a
--- /dev/null
+++ b/backend/internal/database/database_test.go
@@ -0,0 +1,79 @@
+package database
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "log"
+ "testing"
+ "time"
+
+ "github.com/ory/dockertest/v3"
+ "github.com/ory/dockertest/v3/docker"
+)
+
+var exec *sql.DB
+
+func TestMain(m *testing.M) {
+ pool, err := dockertest.NewPool("")
+ if err != nil {
+ log.Fatalf("Could not create pool: %s", err)
+ }
+
+ err = pool.Client.Ping()
+ if err != nil {
+ log.Fatalf("Could not connect to Docker: %s", err)
+ }
+
+ resource, err := pool.RunWithOptions(&dockertest.RunOptions{
+ Repository: "postgres",
+ Tag: "16",
+ Env: []string{
+ "POSTGRES_PASSWORD=secret",
+ "POSTGRES_USER=ibd-client-test",
+ "POSTGRES_DB=ibd-client-test",
+ "listen_addresses='*'",
+ },
+ Cmd: []string{
+ "postgres",
+ "-c",
+ "log_statement=all",
+ },
+ }, func(config *docker.HostConfig) {
+ config.AutoRemove = true
+ config.RestartPolicy = docker.RestartPolicy{Name: "no"}
+ })
+ if err != nil {
+ log.Fatalf("Could not start resource: %s", err)
+ }
+
+ hostAndPort := resource.GetHostPort("5432/tcp")
+ databaseUrl := fmt.Sprintf("postgres://ibd-client-test:secret@%s/ibd-client-test?sslmode=disable", hostAndPort)
+
+ // Kill container after 120 seconds
+ _ = resource.Expire(120)
+
+ pool.MaxWait = 120 * time.Second
+ if err = pool.Retry(func() error {
+ exec, err = sql.Open("postgres", databaseUrl)
+ if err != nil {
+ return err
+ }
+ return exec.Ping()
+ }); err != nil {
+ log.Fatalf("Could not connect to database: %s", err)
+ }
+
+ err = Migrate(context.Background(), databaseUrl)
+ if err != nil {
+ log.Fatalf("Could not migrate database: %s", err)
+ }
+
+ defer func() {
+ if err := pool.Purge(resource); err != nil {
+ log.Fatalf("Could not purge resource: %s", err)
+ }
+ }()
+
+ m.Run()
+}
diff --git a/backend/internal/database/stocks.go b/backend/internal/database/stocks.go
new file mode 100644
index 0000000..24f5fe7
--- /dev/null
+++ b/backend/internal/database/stocks.go
@@ -0,0 +1,293 @@
+package database
+
+import (
+ "context"
+ "database/sql"
+ "database/sql/driver"
+ "errors"
+
+ pb "github.com/ansg191/ibd-trader-backend/api/gen/idb/stock/v1"
+ "github.com/ansg191/ibd-trader-backend/internal/analyzer"
+ "github.com/ansg191/ibd-trader-backend/internal/utils"
+
+ "github.com/Rhymond/go-money"
+)
+
+var ErrStockNotFound = errors.New("stock not found")
+
+func GetStock(ctx context.Context, exec Executor, symbol string) (Stock, error) {
+ row := exec.QueryRowContext(ctx, `
+SELECT symbol, name, ibd_url
+FROM stocks
+WHERE symbol = $1;
+`, symbol)
+
+ var stock Stock
+ if err := row.Scan(&stock.Symbol, &stock.Name, &stock.IBDUrl); err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return Stock{}, ErrStockNotFound
+ }
+ return Stock{}, err
+ }
+
+ return stock, nil
+}
+
+func AddStock(ctx context.Context, exec Executor, stock Stock) error {
+ _, err := exec.ExecContext(ctx, `
+INSERT INTO stocks (symbol, name, ibd_url)
+VALUES ($1, $2, $3)
+ON CONFLICT (symbol)
+ DO UPDATE SET name = $2,
+ ibd_url = $3;`, stock.Symbol, stock.Name, stock.IBDUrl)
+ return err
+}
+
+func AddRanking(ctx context.Context, exec Executor, symbol string, ibd50, cap20 int) error {
+ if ibd50 > 0 {
+ _, err := exec.ExecContext(ctx, `
+INSERT INTO stock_rank (symbol, rank_type, rank)
+VALUES ($1, $2, $3)`, symbol, "ibd50", ibd50)
+ if err != nil {
+ return err
+ }
+ }
+ if cap20 > 0 {
+ _, err := exec.ExecContext(ctx, `
+INSERT INTO stock_rank (symbol, rank_type, rank)
+VALUES ($1, $2, $3)`, symbol, "cap20", cap20)
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func AddStockInfo(ctx context.Context, exec TransactionExecutor, info *StockInfo) (string, error) {
+ tx, err := exec.BeginTx(ctx, nil)
+ if err != nil {
+ return "", err
+ }
+ defer func(tx *sql.Tx) {
+ _ = tx.Rollback()
+ }(tx)
+
+ // Add raw chart analysis
+ row := tx.QueryRowContext(ctx, `
+INSERT INTO chart_analysis (raw_analysis)
+VALUES ($1)
+RETURNING id;`, info.ChartAnalysis)
+
+ var chartAnalysisID string
+ if err = row.Scan(&chartAnalysisID); err != nil {
+ return "", err
+ }
+
+ // Add stock info
+ row = tx.QueryRowContext(ctx,
+ `
+INSERT INTO ratings (symbol, composite, eps, rel_str, group_rel_str, smr, acc_dis, chart_analysis, price)
+VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
+RETURNING id;`,
+ info.Symbol,
+ info.Ratings.Composite,
+ info.Ratings.EPS,
+ info.Ratings.RelStr,
+ info.Ratings.GroupRelStr,
+ info.Ratings.SMR,
+ info.Ratings.AccDis,
+ chartAnalysisID,
+ info.Price.Display(),
+ )
+
+ var ratingsID string
+ if err = row.Scan(&ratingsID); err != nil {
+ return "", err
+ }
+
+ return ratingsID, tx.Commit()
+}
+
+func GetStockInfo(ctx context.Context, exec Executor, id string) (*StockInfo, error) {
+ row := exec.QueryRowContext(ctx, `
+SELECT r.symbol,
+ s.name,
+ ca.raw_analysis,
+ r.composite,
+ r.eps,
+ r.rel_str,
+ r.group_rel_str,
+ r.smr,
+ r.acc_dis,
+ r.price
+FROM ratings r
+ INNER JOIN stocks s on r.symbol = s.symbol
+ INNER JOIN chart_analysis ca on r.chart_analysis = ca.id
+WHERE r.id = $1;`, id)
+
+ var info StockInfo
+ var priceStr string
+ err := row.Scan(
+ &info.Symbol,
+ &info.Name,
+ &info.ChartAnalysis,
+ &info.Ratings.Composite,
+ &info.Ratings.EPS,
+ &info.Ratings.RelStr,
+ &info.Ratings.GroupRelStr,
+ &info.Ratings.SMR,
+ &info.Ratings.AccDis,
+ &priceStr,
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ info.Price, err = utils.ParseMoney(priceStr)
+ if err != nil {
+ return nil, err
+ }
+
+ return &info, nil
+}
+
+func AddAnalysis(
+ ctx context.Context,
+ exec Executor,
+ ratingId string,
+ analysis *analyzer.Analysis,
+) (id string, err error) {
+ err = exec.QueryRowContext(ctx, `
+UPDATE chart_analysis ca
+SET processed = true,
+ action = $2,
+ price = $3,
+ reason = $4,
+ confidence = $5
+FROM ratings r
+WHERE r.id = $1
+ AND r.chart_analysis = ca.id
+RETURNING ca.id;`,
+ ratingId,
+ analysis.Action,
+ analysis.Price.Display(),
+ analysis.Reason,
+ analysis.Confidence,
+ ).Scan(&id)
+ return id, err
+}
+
+type Stock struct {
+ Symbol string
+ Name string
+ IBDUrl string
+}
+
+type StockInfo struct {
+ Symbol string
+ Name string
+ ChartAnalysis string
+ Ratings Ratings
+ Price *money.Money
+}
+
+type Ratings struct {
+ Composite uint8
+ EPS uint8
+ RelStr uint8
+ GroupRelStr LetterRating
+ SMR LetterRating
+ AccDis LetterRating
+}
+
+type LetterRating pb.LetterGrade
+
+func (r LetterRating) String() string {
+ switch pb.LetterGrade(r) {
+ case pb.LetterGrade_LETTER_GRADE_E:
+ return "E"
+ case pb.LetterGrade_LETTER_GRADE_E_PLUS:
+ return "E+"
+ case pb.LetterGrade_LETTER_GRADE_D_MINUS:
+ return "D-"
+ case pb.LetterGrade_LETTER_GRADE_D:
+ return "D"
+ case pb.LetterGrade_LETTER_GRADE_D_PLUS:
+ return "D+"
+ case pb.LetterGrade_LETTER_GRADE_C_MINUS:
+ return "C-"
+ case pb.LetterGrade_LETTER_GRADE_C:
+ return "C"
+ case pb.LetterGrade_LETTER_GRADE_C_PLUS:
+ return "C+"
+ case pb.LetterGrade_LETTER_GRADE_B_MINUS:
+ return "B-"
+ case pb.LetterGrade_LETTER_GRADE_B:
+ return "B"
+ case pb.LetterGrade_LETTER_GRADE_B_PLUS:
+ return "B+"
+ case pb.LetterGrade_LETTER_GRADE_A_MINUS:
+ return "A-"
+ case pb.LetterGrade_LETTER_GRADE_A:
+ return "A"
+ case pb.LetterGrade_LETTER_GRADE_A_PLUS:
+ return "A+"
+ default:
+ return "NA"
+ }
+}
+
+func LetterRatingFromString(str string) LetterRating {
+ switch str {
+ case "E":
+ return LetterRating(pb.LetterGrade_LETTER_GRADE_E)
+ case "E+":
+ return LetterRating(pb.LetterGrade_LETTER_GRADE_E_PLUS)
+ case "D-":
+ return LetterRating(pb.LetterGrade_LETTER_GRADE_D_MINUS)
+ case "D":
+ return LetterRating(pb.LetterGrade_LETTER_GRADE_D)
+ case "D+":
+ return LetterRating(pb.LetterGrade_LETTER_GRADE_D_PLUS)
+ case "C-":
+ return LetterRating(pb.LetterGrade_LETTER_GRADE_C_MINUS)
+ case "C":
+ return LetterRating(pb.LetterGrade_LETTER_GRADE_C)
+ case "C+":
+ return LetterRating(pb.LetterGrade_LETTER_GRADE_C_PLUS)
+ case "B-":
+ return LetterRating(pb.LetterGrade_LETTER_GRADE_B_MINUS)
+ case "B":
+ return LetterRating(pb.LetterGrade_LETTER_GRADE_B)
+ case "B+":
+ return LetterRating(pb.LetterGrade_LETTER_GRADE_B_PLUS)
+ case "A-":
+ return LetterRating(pb.LetterGrade_LETTER_GRADE_A_MINUS)
+ case "A":
+ return LetterRating(pb.LetterGrade_LETTER_GRADE_A)
+ case "A+":
+ return LetterRating(pb.LetterGrade_LETTER_GRADE_A_PLUS)
+ case "NA":
+ fallthrough
+ default:
+ return LetterRating(pb.LetterGrade_LETTER_GRADE_UNSPECIFIED)
+ }
+}
+
+func (r LetterRating) Value() (driver.Value, error) {
+ return r.String(), nil
+}
+
+func (r *LetterRating) Scan(src any) error {
+ var source string
+ switch v := src.(type) {
+ case string:
+ source = v
+ case []byte:
+ source = string(v)
+ default:
+ return errors.New("incompatible type for LetterRating")
+ }
+ *r = LetterRatingFromString(source)
+ return nil
+}
diff --git a/backend/internal/database/users.go b/backend/internal/database/users.go
new file mode 100644
index 0000000..f7998fb
--- /dev/null
+++ b/backend/internal/database/users.go
@@ -0,0 +1,151 @@
+package database
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "fmt"
+
+ "github.com/ansg191/ibd-trader-backend/internal/keys"
+)
+
+var ErrUserNotFound = fmt.Errorf("user not found")
+var ErrIBDCredsNotFound = fmt.Errorf("ibd creds not found")
+
+func AddUser(ctx context.Context, exec Executor, subject string) (err error) {
+ _, err = exec.ExecContext(ctx, `
+INSERT INTO users (subject)
+VALUES ($1)
+ON CONFLICT DO NOTHING;`, subject)
+ return
+}
+
+func GetUser(ctx context.Context, exec Executor, subject string) (*User, error) {
+ row := exec.QueryRowContext(ctx, `
+SELECT subject, ibd_username, ibd_password, encryption_key
+FROM users
+WHERE subject = $1;`, subject)
+
+ user := &User{}
+ err := row.Scan(&user.Subject, &user.IBDUsername, &user.EncryptedIBDPassword, &user.EncryptionKeyID)
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, ErrUserNotFound
+ }
+ return nil, fmt.Errorf("unable to scan sql row into user: %w", err)
+ }
+
+ return user, nil
+}
+
+func ListUsers(ctx context.Context, exec Executor, hasIBDCreds bool) ([]User, error) {
+ rows, err := exec.QueryContext(ctx, `
+SELECT subject, ibd_username, ibd_password, encryption_key
+FROM users;
+`)
+ if err != nil {
+ return nil, fmt.Errorf("unable to list users: %w", err)
+ }
+
+ users := make([]User, 0)
+ for rows.Next() {
+ user := User{}
+ err = rows.Scan(&user.Subject, &user.IBDUsername, &user.EncryptedIBDPassword, &user.EncryptionKeyID)
+ if err != nil {
+ return nil, fmt.Errorf("unable to scan sql row into user: %w", err)
+ }
+
+ if hasIBDCreds && user.IBDUsername == nil {
+ continue
+ }
+ users = append(users, user)
+ }
+
+ return users, nil
+}
+
+func AddIBDCreds(
+ ctx context.Context,
+ exec TransactionExecutor,
+ kms keys.KeyManagementService,
+ keyName, subject, username, password string,
+) error {
+ encryptedPass, encryptedKey, err := keys.Encrypt(ctx, kms, keyName, []byte(password))
+ if err != nil {
+ return fmt.Errorf("unable to encrypt password: %w", err)
+ }
+
+ tx, err := exec.BeginTx(ctx, nil)
+ if err != nil {
+ return err
+ }
+ defer func(tx *sql.Tx) {
+ _ = tx.Rollback()
+ }(tx)
+
+ var keyId int
+ err = tx.QueryRowContext(ctx, `
+INSERT INTO keys (kms_key_name, encrypted_key)
+VALUES ($1, $2)
+RETURNING id;`, keyName, encryptedKey).Scan(&keyId)
+ if err != nil {
+ return fmt.Errorf("unable to add ibd creds key: %w", err)
+ }
+
+ _, err = exec.ExecContext(ctx, `
+UPDATE users
+SET ibd_username = $2,
+ ibd_password = $3,
+ encryption_key = $4
+WHERE subject = $1;`, subject, username, encryptedPass, keyId)
+ if err != nil {
+ return fmt.Errorf("unable to add ibd creds to user: %w", err)
+ }
+
+ if err = tx.Commit(); err != nil {
+ return fmt.Errorf("unable to commit transaction: %w", err)
+ }
+
+ return nil
+}
+
+func GetIBDCreds(
+ ctx context.Context,
+ exec Executor,
+ kms keys.KeyManagementService,
+ subject string,
+) (
+ username string,
+ password string,
+ err error,
+) {
+ row := exec.QueryRowContext(ctx, `
+SELECT ibd_username, ibd_password, encrypted_key, kms_key_name
+FROM users
+INNER JOIN public.keys k on k.id = users.encryption_key
+WHERE subject = $1;`, subject)
+
+ var encryptedPass, encryptedKey []byte
+ var keyName string
+ err = row.Scan(&username, &encryptedPass, &encryptedKey, &keyName)
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return "", "", ErrIBDCredsNotFound
+ }
+ return "", "", fmt.Errorf("unable to scan sql row into ibd creds: %w", err)
+ }
+
+ passwordBytes, err := keys.Decrypt(ctx, kms, keyName, encryptedPass, encryptedKey)
+ if err != nil {
+ return "", "", fmt.Errorf("unable to decrypt password: %w", err)
+ }
+
+ return username, string(passwordBytes), nil
+}
+
+type User struct {
+ Subject string
+ IBDUsername *string
+ EncryptedIBDPassword *string
+ EncryptionKeyID *int
+}