diff options
Diffstat (limited to 'backend/internal/database')
-rw-r--r-- | backend/internal/database/cookies.go | 189 | ||||
-rw-r--r-- | backend/internal/database/database.go | 166 | ||||
-rw-r--r-- | backend/internal/database/database_test.go | 79 | ||||
-rw-r--r-- | backend/internal/database/stocks.go | 293 | ||||
-rw-r--r-- | backend/internal/database/users.go | 151 |
5 files changed, 878 insertions, 0 deletions
diff --git a/backend/internal/database/cookies.go b/backend/internal/database/cookies.go new file mode 100644 index 0000000..3ea21d0 --- /dev/null +++ b/backend/internal/database/cookies.go @@ -0,0 +1,189 @@ +package database + +import ( + "context" + "database/sql" + "errors" + "fmt" + "net/http" + "time" + + "github.com/ansg191/ibd-trader-backend/internal/keys" +) + +func GetAnyCookie(ctx context.Context, exec Executor, kms keys.KeyManagementService) (*IBDCookie, error) { + row := exec.QueryRowContext(ctx, ` +SELECT ibd_tokens.id, token, encrypted_key, kms_key_name, expires_at +FROM ibd_tokens + INNER JOIN keys ON encryption_key = keys.id +WHERE expires_at > NOW() + AND degraded = FALSE +ORDER BY random() +LIMIT 1;`) + + var id uint + var encryptedToken, encryptedKey []byte + var keyName string + var expiry time.Time + err := row.Scan(&id, &encryptedToken, &encryptedKey, &keyName, &expiry) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, fmt.Errorf("unable to scan sql row into ibd cookie: %w", err) + } + + // Set the expiry to UTC explicitly. + // For some reason, the expiry time is set to location="". + expiry = expiry.UTC() + + token, err := keys.Decrypt(ctx, kms, keyName, encryptedToken, encryptedKey) + if err != nil { + return nil, fmt.Errorf("unable to decrypt token: %w", err) + } + return &IBDCookie{ + Token: string(token), + Expiry: expiry, + }, nil +} + +func GetCookies( + ctx context.Context, + exec Executor, + kms keys.KeyManagementService, + subject string, + degraded bool, +) ([]IBDCookie, error) { + rows, err := exec.QueryContext(ctx, ` +SELECT ibd_tokens.id, token, encrypted_key, kms_key_name, expires_at +FROM ibd_tokens + INNER JOIN keys ON encryption_key = keys.id +WHERE user_subject = $1 + AND expires_at > NOW() + AND degraded = $2 +ORDER BY expires_at DESC;`, subject, degraded) + if err != nil { + return nil, fmt.Errorf("unable to get ibd cookies: %w", err) + } + + cookies := make([]IBDCookie, 0) + for rows.Next() { + var id uint + var encryptedToken, encryptedKey []byte + var keyName string + var expiry time.Time + err = rows.Scan(&id, &encryptedToken, &encryptedKey, &keyName, &expiry) + if err != nil { + return nil, fmt.Errorf("unable to scan sql row into ibd cookie: %w", err) + } + + // Set the expiry to UTC explicitly. + // For some reason, the expiry time is set to location="". + expiry = expiry.UTC() + + token, err := keys.Decrypt(ctx, kms, keyName, encryptedToken, encryptedKey) + if err != nil { + return nil, fmt.Errorf("unable to decrypt token: %w", err) + } + cookie := IBDCookie{ + ID: id, + Token: string(token), + Expiry: expiry, + } + cookies = append(cookies, cookie) + } + + return cookies, nil +} + +func AddCookie( + ctx context.Context, + exec TransactionExecutor, + kms keys.KeyManagementService, + subject string, + cookie *http.Cookie, +) error { + tx, err := exec.BeginTx(ctx, nil) + if err != nil { + return err + } + + // Get the key ID for the user + user, err := GetUser(ctx, tx, subject) + if err != nil { + return fmt.Errorf("unable to get user: %w", err) + } + if user.EncryptionKeyID == nil { + return errors.New("user does not have an encryption key") + } + + // Get the key + var keyName string + var key []byte + err = tx.QueryRowContext(ctx, ` +SELECT kms_key_name, encrypted_key +FROM keys +WHERE id = $1;`, + *user.EncryptionKeyID, + ).Scan(&keyName, &key) + if err != nil { + return fmt.Errorf("unable to get key: %w", err) + } + + // Encrypt the token + encryptedToken, err := keys.EncryptWithKey(ctx, kms, keyName, key, []byte(cookie.Value)) + if err != nil { + return fmt.Errorf("unable to encrypt token: %w", err) + } + + // Add the cookie to the database + _, err = exec.ExecContext(ctx, ` +INSERT INTO ibd_tokens (token, expires_at, user_subject, encryption_key) +VALUES ($1, $2, $3, $4)`, encryptedToken, cookie.Expires, subject, *user.EncryptionKeyID) + if err != nil { + return fmt.Errorf("unable to add cookie: %w", err) + } + + return nil +} + +func ReportCookieFailure(ctx context.Context, exec Executor, id uint) error { + _, err := exec.ExecContext(ctx, ` +UPDATE ibd_tokens +SET degraded = TRUE +WHERE id = $1;`, id) + if err != nil { + return fmt.Errorf("unable to report cookie failure: %w", err) + } + return nil +} + +func RepairCookie(ctx context.Context, exec Executor, id uint) error { + _, err := exec.ExecContext(ctx, ` +UPDATE ibd_tokens +SET degraded = FALSE +WHERE id = $1;`, id) + if err != nil { + return fmt.Errorf("unable to report cookie failure: %w", err) + } + return nil +} + +type IBDCookie struct { + ID uint + Token string + Expiry time.Time +} + +func (c *IBDCookie) ToHTTPCookie() *http.Cookie { + return &http.Cookie{ + Name: ".ASPXAUTH", + Value: c.Token, + Path: "/", + Domain: "investors.com", + Expires: c.Expiry, + Secure: true, + HttpOnly: false, + SameSite: http.SameSiteLaxMode, + } +} diff --git a/backend/internal/database/database.go b/backend/internal/database/database.go new file mode 100644 index 0000000..409dd3c --- /dev/null +++ b/backend/internal/database/database.go @@ -0,0 +1,166 @@ +package database + +import ( + "context" + "database/sql" + "database/sql/driver" + "errors" + "io" + "log/slog" + "sync" + "time" + + "github.com/ansg191/ibd-trader-backend/db" + "github.com/ansg191/ibd-trader-backend/internal/keys" + + "github.com/golang-migrate/migrate/v4" + _ "github.com/golang-migrate/migrate/v4/database/postgres" + "github.com/golang-migrate/migrate/v4/source/iofs" + _ "github.com/lib/pq" +) + +type Database interface { + io.Closer + TransactionExecutor + driver.Pinger + + Migrate(ctx context.Context) error + Maintenance(ctx context.Context) +} + +type database struct { + logger *slog.Logger + + db *sql.DB + url string + + kms keys.KeyManagementService + keyName string +} + +func New(ctx context.Context, logger *slog.Logger, url string, kms keys.KeyManagementService, keyName string) (Database, error) { + sqlDB, err := sql.Open("postgres", url) + if err != nil { + return nil, err + } + + err = sqlDB.PingContext(ctx) + if err != nil { + // Ping failed. Don't error, but give a warning. + logger.WarnContext(ctx, "Unable to ping database", "error", err) + } + + return &database{ + logger: logger, + db: sqlDB, + url: url, + kms: kms, + keyName: keyName, + }, nil +} + +func (d *database) Close() error { + return d.db.Close() +} + +func (d *database) Migrate(ctx context.Context) error { + return Migrate(ctx, d.url) +} + +func (d *database) Maintenance(ctx context.Context) { + ticker := time.NewTicker(15 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + func() { + var wg sync.WaitGroup + wg.Add(1) + + _, cancel := context.WithTimeout(context.Background(), 10*time.Minute) + defer cancel() + + wg.Wait() + }() + case <-ctx.Done(): + return + } + } +} + +func Migrate(ctx context.Context, url string) error { + fs, err := iofs.New(db.Migrations, "migrations") + if err != nil { + return err + } + + m, err := migrate.NewWithSourceInstance("iofs", fs, url) + if err != nil { + return err + } + + slog.InfoContext(ctx, "Running DB migration") + err = m.Up() + if err != nil && !errors.Is(err, migrate.ErrNoChange) { + slog.ErrorContext(ctx, "DB migration failed", "error", err) + return err + } + + return nil +} + +func (d *database) Ping(ctx context.Context) error { + return d.db.PingContext(ctx) +} + +type Executor interface { + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) + QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row +} + +type TransactionExecutor interface { + Executor + BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) +} + +func (d *database) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { + d.logger.DebugContext(ctx, "Executing query", "query", query) + + now := time.Now() + ret, err := d.db.ExecContext(ctx, query, args...) + if err != nil { + return nil, err + } + + d.logger.DebugContext(ctx, "Query executed successfully", "duration", time.Since(now)) + return ret, nil +} + +func (d *database) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { + d.logger.DebugContext(ctx, "Executing query", "query", query) + + now := time.Now() + ret, err := d.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + + d.logger.DebugContext(ctx, "Query executed successfully", "duration", time.Since(now)) + return ret, nil +} + +func (d *database) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row { + d.logger.DebugContext(ctx, "Executing query", "query", query) + + now := time.Now() + ret := d.db.QueryRowContext(ctx, query, args...) + + d.logger.DebugContext(ctx, "Query executed successfully", "duration", time.Since(now)) + return ret +} + +func (d *database) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { + return d.db.BeginTx(ctx, opts) +} diff --git a/backend/internal/database/database_test.go b/backend/internal/database/database_test.go new file mode 100644 index 0000000..407a09a --- /dev/null +++ b/backend/internal/database/database_test.go @@ -0,0 +1,79 @@ +package database + +import ( + "context" + "database/sql" + "fmt" + "log" + "testing" + "time" + + "github.com/ory/dockertest/v3" + "github.com/ory/dockertest/v3/docker" +) + +var exec *sql.DB + +func TestMain(m *testing.M) { + pool, err := dockertest.NewPool("") + if err != nil { + log.Fatalf("Could not create pool: %s", err) + } + + err = pool.Client.Ping() + if err != nil { + log.Fatalf("Could not connect to Docker: %s", err) + } + + resource, err := pool.RunWithOptions(&dockertest.RunOptions{ + Repository: "postgres", + Tag: "16", + Env: []string{ + "POSTGRES_PASSWORD=secret", + "POSTGRES_USER=ibd-client-test", + "POSTGRES_DB=ibd-client-test", + "listen_addresses='*'", + }, + Cmd: []string{ + "postgres", + "-c", + "log_statement=all", + }, + }, func(config *docker.HostConfig) { + config.AutoRemove = true + config.RestartPolicy = docker.RestartPolicy{Name: "no"} + }) + if err != nil { + log.Fatalf("Could not start resource: %s", err) + } + + hostAndPort := resource.GetHostPort("5432/tcp") + databaseUrl := fmt.Sprintf("postgres://ibd-client-test:secret@%s/ibd-client-test?sslmode=disable", hostAndPort) + + // Kill container after 120 seconds + _ = resource.Expire(120) + + pool.MaxWait = 120 * time.Second + if err = pool.Retry(func() error { + exec, err = sql.Open("postgres", databaseUrl) + if err != nil { + return err + } + return exec.Ping() + }); err != nil { + log.Fatalf("Could not connect to database: %s", err) + } + + err = Migrate(context.Background(), databaseUrl) + if err != nil { + log.Fatalf("Could not migrate database: %s", err) + } + + defer func() { + if err := pool.Purge(resource); err != nil { + log.Fatalf("Could not purge resource: %s", err) + } + }() + + m.Run() +} diff --git a/backend/internal/database/stocks.go b/backend/internal/database/stocks.go new file mode 100644 index 0000000..24f5fe7 --- /dev/null +++ b/backend/internal/database/stocks.go @@ -0,0 +1,293 @@ +package database + +import ( + "context" + "database/sql" + "database/sql/driver" + "errors" + + pb "github.com/ansg191/ibd-trader-backend/api/gen/idb/stock/v1" + "github.com/ansg191/ibd-trader-backend/internal/analyzer" + "github.com/ansg191/ibd-trader-backend/internal/utils" + + "github.com/Rhymond/go-money" +) + +var ErrStockNotFound = errors.New("stock not found") + +func GetStock(ctx context.Context, exec Executor, symbol string) (Stock, error) { + row := exec.QueryRowContext(ctx, ` +SELECT symbol, name, ibd_url +FROM stocks +WHERE symbol = $1; +`, symbol) + + var stock Stock + if err := row.Scan(&stock.Symbol, &stock.Name, &stock.IBDUrl); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return Stock{}, ErrStockNotFound + } + return Stock{}, err + } + + return stock, nil +} + +func AddStock(ctx context.Context, exec Executor, stock Stock) error { + _, err := exec.ExecContext(ctx, ` +INSERT INTO stocks (symbol, name, ibd_url) +VALUES ($1, $2, $3) +ON CONFLICT (symbol) + DO UPDATE SET name = $2, + ibd_url = $3;`, stock.Symbol, stock.Name, stock.IBDUrl) + return err +} + +func AddRanking(ctx context.Context, exec Executor, symbol string, ibd50, cap20 int) error { + if ibd50 > 0 { + _, err := exec.ExecContext(ctx, ` +INSERT INTO stock_rank (symbol, rank_type, rank) +VALUES ($1, $2, $3)`, symbol, "ibd50", ibd50) + if err != nil { + return err + } + } + if cap20 > 0 { + _, err := exec.ExecContext(ctx, ` +INSERT INTO stock_rank (symbol, rank_type, rank) +VALUES ($1, $2, $3)`, symbol, "cap20", cap20) + if err != nil { + return err + } + } + return nil +} + +func AddStockInfo(ctx context.Context, exec TransactionExecutor, info *StockInfo) (string, error) { + tx, err := exec.BeginTx(ctx, nil) + if err != nil { + return "", err + } + defer func(tx *sql.Tx) { + _ = tx.Rollback() + }(tx) + + // Add raw chart analysis + row := tx.QueryRowContext(ctx, ` +INSERT INTO chart_analysis (raw_analysis) +VALUES ($1) +RETURNING id;`, info.ChartAnalysis) + + var chartAnalysisID string + if err = row.Scan(&chartAnalysisID); err != nil { + return "", err + } + + // Add stock info + row = tx.QueryRowContext(ctx, + ` +INSERT INTO ratings (symbol, composite, eps, rel_str, group_rel_str, smr, acc_dis, chart_analysis, price) +VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) +RETURNING id;`, + info.Symbol, + info.Ratings.Composite, + info.Ratings.EPS, + info.Ratings.RelStr, + info.Ratings.GroupRelStr, + info.Ratings.SMR, + info.Ratings.AccDis, + chartAnalysisID, + info.Price.Display(), + ) + + var ratingsID string + if err = row.Scan(&ratingsID); err != nil { + return "", err + } + + return ratingsID, tx.Commit() +} + +func GetStockInfo(ctx context.Context, exec Executor, id string) (*StockInfo, error) { + row := exec.QueryRowContext(ctx, ` +SELECT r.symbol, + s.name, + ca.raw_analysis, + r.composite, + r.eps, + r.rel_str, + r.group_rel_str, + r.smr, + r.acc_dis, + r.price +FROM ratings r + INNER JOIN stocks s on r.symbol = s.symbol + INNER JOIN chart_analysis ca on r.chart_analysis = ca.id +WHERE r.id = $1;`, id) + + var info StockInfo + var priceStr string + err := row.Scan( + &info.Symbol, + &info.Name, + &info.ChartAnalysis, + &info.Ratings.Composite, + &info.Ratings.EPS, + &info.Ratings.RelStr, + &info.Ratings.GroupRelStr, + &info.Ratings.SMR, + &info.Ratings.AccDis, + &priceStr, + ) + if err != nil { + return nil, err + } + + info.Price, err = utils.ParseMoney(priceStr) + if err != nil { + return nil, err + } + + return &info, nil +} + +func AddAnalysis( + ctx context.Context, + exec Executor, + ratingId string, + analysis *analyzer.Analysis, +) (id string, err error) { + err = exec.QueryRowContext(ctx, ` +UPDATE chart_analysis ca +SET processed = true, + action = $2, + price = $3, + reason = $4, + confidence = $5 +FROM ratings r +WHERE r.id = $1 + AND r.chart_analysis = ca.id +RETURNING ca.id;`, + ratingId, + analysis.Action, + analysis.Price.Display(), + analysis.Reason, + analysis.Confidence, + ).Scan(&id) + return id, err +} + +type Stock struct { + Symbol string + Name string + IBDUrl string +} + +type StockInfo struct { + Symbol string + Name string + ChartAnalysis string + Ratings Ratings + Price *money.Money +} + +type Ratings struct { + Composite uint8 + EPS uint8 + RelStr uint8 + GroupRelStr LetterRating + SMR LetterRating + AccDis LetterRating +} + +type LetterRating pb.LetterGrade + +func (r LetterRating) String() string { + switch pb.LetterGrade(r) { + case pb.LetterGrade_LETTER_GRADE_E: + return "E" + case pb.LetterGrade_LETTER_GRADE_E_PLUS: + return "E+" + case pb.LetterGrade_LETTER_GRADE_D_MINUS: + return "D-" + case pb.LetterGrade_LETTER_GRADE_D: + return "D" + case pb.LetterGrade_LETTER_GRADE_D_PLUS: + return "D+" + case pb.LetterGrade_LETTER_GRADE_C_MINUS: + return "C-" + case pb.LetterGrade_LETTER_GRADE_C: + return "C" + case pb.LetterGrade_LETTER_GRADE_C_PLUS: + return "C+" + case pb.LetterGrade_LETTER_GRADE_B_MINUS: + return "B-" + case pb.LetterGrade_LETTER_GRADE_B: + return "B" + case pb.LetterGrade_LETTER_GRADE_B_PLUS: + return "B+" + case pb.LetterGrade_LETTER_GRADE_A_MINUS: + return "A-" + case pb.LetterGrade_LETTER_GRADE_A: + return "A" + case pb.LetterGrade_LETTER_GRADE_A_PLUS: + return "A+" + default: + return "NA" + } +} + +func LetterRatingFromString(str string) LetterRating { + switch str { + case "E": + return LetterRating(pb.LetterGrade_LETTER_GRADE_E) + case "E+": + return LetterRating(pb.LetterGrade_LETTER_GRADE_E_PLUS) + case "D-": + return LetterRating(pb.LetterGrade_LETTER_GRADE_D_MINUS) + case "D": + return LetterRating(pb.LetterGrade_LETTER_GRADE_D) + case "D+": + return LetterRating(pb.LetterGrade_LETTER_GRADE_D_PLUS) + case "C-": + return LetterRating(pb.LetterGrade_LETTER_GRADE_C_MINUS) + case "C": + return LetterRating(pb.LetterGrade_LETTER_GRADE_C) + case "C+": + return LetterRating(pb.LetterGrade_LETTER_GRADE_C_PLUS) + case "B-": + return LetterRating(pb.LetterGrade_LETTER_GRADE_B_MINUS) + case "B": + return LetterRating(pb.LetterGrade_LETTER_GRADE_B) + case "B+": + return LetterRating(pb.LetterGrade_LETTER_GRADE_B_PLUS) + case "A-": + return LetterRating(pb.LetterGrade_LETTER_GRADE_A_MINUS) + case "A": + return LetterRating(pb.LetterGrade_LETTER_GRADE_A) + case "A+": + return LetterRating(pb.LetterGrade_LETTER_GRADE_A_PLUS) + case "NA": + fallthrough + default: + return LetterRating(pb.LetterGrade_LETTER_GRADE_UNSPECIFIED) + } +} + +func (r LetterRating) Value() (driver.Value, error) { + return r.String(), nil +} + +func (r *LetterRating) Scan(src any) error { + var source string + switch v := src.(type) { + case string: + source = v + case []byte: + source = string(v) + default: + return errors.New("incompatible type for LetterRating") + } + *r = LetterRatingFromString(source) + return nil +} diff --git a/backend/internal/database/users.go b/backend/internal/database/users.go new file mode 100644 index 0000000..f7998fb --- /dev/null +++ b/backend/internal/database/users.go @@ -0,0 +1,151 @@ +package database + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/ansg191/ibd-trader-backend/internal/keys" +) + +var ErrUserNotFound = fmt.Errorf("user not found") +var ErrIBDCredsNotFound = fmt.Errorf("ibd creds not found") + +func AddUser(ctx context.Context, exec Executor, subject string) (err error) { + _, err = exec.ExecContext(ctx, ` +INSERT INTO users (subject) +VALUES ($1) +ON CONFLICT DO NOTHING;`, subject) + return +} + +func GetUser(ctx context.Context, exec Executor, subject string) (*User, error) { + row := exec.QueryRowContext(ctx, ` +SELECT subject, ibd_username, ibd_password, encryption_key +FROM users +WHERE subject = $1;`, subject) + + user := &User{} + err := row.Scan(&user.Subject, &user.IBDUsername, &user.EncryptedIBDPassword, &user.EncryptionKeyID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrUserNotFound + } + return nil, fmt.Errorf("unable to scan sql row into user: %w", err) + } + + return user, nil +} + +func ListUsers(ctx context.Context, exec Executor, hasIBDCreds bool) ([]User, error) { + rows, err := exec.QueryContext(ctx, ` +SELECT subject, ibd_username, ibd_password, encryption_key +FROM users; +`) + if err != nil { + return nil, fmt.Errorf("unable to list users: %w", err) + } + + users := make([]User, 0) + for rows.Next() { + user := User{} + err = rows.Scan(&user.Subject, &user.IBDUsername, &user.EncryptedIBDPassword, &user.EncryptionKeyID) + if err != nil { + return nil, fmt.Errorf("unable to scan sql row into user: %w", err) + } + + if hasIBDCreds && user.IBDUsername == nil { + continue + } + users = append(users, user) + } + + return users, nil +} + +func AddIBDCreds( + ctx context.Context, + exec TransactionExecutor, + kms keys.KeyManagementService, + keyName, subject, username, password string, +) error { + encryptedPass, encryptedKey, err := keys.Encrypt(ctx, kms, keyName, []byte(password)) + if err != nil { + return fmt.Errorf("unable to encrypt password: %w", err) + } + + tx, err := exec.BeginTx(ctx, nil) + if err != nil { + return err + } + defer func(tx *sql.Tx) { + _ = tx.Rollback() + }(tx) + + var keyId int + err = tx.QueryRowContext(ctx, ` +INSERT INTO keys (kms_key_name, encrypted_key) +VALUES ($1, $2) +RETURNING id;`, keyName, encryptedKey).Scan(&keyId) + if err != nil { + return fmt.Errorf("unable to add ibd creds key: %w", err) + } + + _, err = exec.ExecContext(ctx, ` +UPDATE users +SET ibd_username = $2, + ibd_password = $3, + encryption_key = $4 +WHERE subject = $1;`, subject, username, encryptedPass, keyId) + if err != nil { + return fmt.Errorf("unable to add ibd creds to user: %w", err) + } + + if err = tx.Commit(); err != nil { + return fmt.Errorf("unable to commit transaction: %w", err) + } + + return nil +} + +func GetIBDCreds( + ctx context.Context, + exec Executor, + kms keys.KeyManagementService, + subject string, +) ( + username string, + password string, + err error, +) { + row := exec.QueryRowContext(ctx, ` +SELECT ibd_username, ibd_password, encrypted_key, kms_key_name +FROM users +INNER JOIN public.keys k on k.id = users.encryption_key +WHERE subject = $1;`, subject) + + var encryptedPass, encryptedKey []byte + var keyName string + err = row.Scan(&username, &encryptedPass, &encryptedKey, &keyName) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return "", "", ErrIBDCredsNotFound + } + return "", "", fmt.Errorf("unable to scan sql row into ibd creds: %w", err) + } + + passwordBytes, err := keys.Decrypt(ctx, kms, keyName, encryptedPass, encryptedKey) + if err != nil { + return "", "", fmt.Errorf("unable to decrypt password: %w", err) + } + + return username, string(passwordBytes), nil +} + +type User struct { + Subject string + IBDUsername *string + EncryptedIBDPassword *string + EncryptionKeyID *int +} |