aboutsummaryrefslogtreecommitdiff
path: root/backend/internal/database
diff options
context:
space:
mode:
authorGravatar Anshul Gupta <ansg191@anshulg.com> 2024-08-05 18:55:10 -0700
committerGravatar Anshul Gupta <ansg191@anshulg.com> 2024-08-05 18:55:19 -0700
commitb96fcd1a54a46a95f98467b49a051564bc21c23c (patch)
tree93caeeb05f8d6310e241095608ea2428c749b18c /backend/internal/database
downloadibd-trader-b96fcd1a54a46a95f98467b49a051564bc21c23c.tar.gz
ibd-trader-b96fcd1a54a46a95f98467b49a051564bc21c23c.tar.zst
ibd-trader-b96fcd1a54a46a95f98467b49a051564bc21c23c.zip
Initial Commit
Diffstat (limited to 'backend/internal/database')
-rw-r--r--backend/internal/database/cookies.go150
-rw-r--r--backend/internal/database/database.go178
-rw-r--r--backend/internal/database/keys.go49
-rw-r--r--backend/internal/database/session.go122
-rw-r--r--backend/internal/database/stocks.go264
-rw-r--r--backend/internal/database/users.go140
6 files changed, 903 insertions, 0 deletions
diff --git a/backend/internal/database/cookies.go b/backend/internal/database/cookies.go
new file mode 100644
index 0000000..cb38272
--- /dev/null
+++ b/backend/internal/database/cookies.go
@@ -0,0 +1,150 @@
+package database
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "fmt"
+ "net/http"
+ "time"
+
+ "ibd-trader/internal/keys"
+)
+
+type CookieStore interface {
+ CookieSource
+ AddCookie(ctx context.Context, subject string, cookie *http.Cookie) error
+ RepairCookie(ctx context.Context, id uint) error
+}
+
+type CookieSource interface {
+ GetAnyCookie(ctx context.Context) (*IBDCookie, error)
+ GetCookies(ctx context.Context, subject string, degraded bool) ([]IBDCookie, error)
+ ReportCookieFailure(ctx context.Context, id uint) error
+}
+
+func (d *database) GetAnyCookie(ctx context.Context) (*IBDCookie, error) {
+ row, err := d.queryRow(ctx, d.db, "cookies/get_any_cookie")
+ if err != nil {
+ return nil, fmt.Errorf("unable to get any ibd cookie: %w", err)
+ }
+
+ var id uint
+ var encryptedToken, encryptedKey []byte
+ var keyName string
+ var expiry time.Time
+ err = row.Scan(&id, &encryptedToken, &encryptedKey, &keyName, &expiry)
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, nil
+ }
+ return nil, fmt.Errorf("unable to scan sql row into ibd cookie: %w", err)
+ }
+
+ token, err := keys.Decrypt(ctx, d.kms, keyName, encryptedToken, encryptedKey)
+ if err != nil {
+ return nil, fmt.Errorf("unable to decrypt token: %w", err)
+ }
+ return &IBDCookie{
+ Token: string(token),
+ Expiry: expiry,
+ }, nil
+}
+
+func (d *database) GetCookies(ctx context.Context, subject string, degraded bool) ([]IBDCookie, error) {
+ row, err := d.query(ctx, d.db, "cookies/get_cookies", subject, degraded)
+ if err != nil {
+ return nil, fmt.Errorf("unable to get ibd cookies: %w", err)
+ }
+
+ cookies := make([]IBDCookie, 0)
+ for row.Next() {
+ var id uint
+ var encryptedToken, encryptedKey []byte
+ var keyName string
+ var expiry time.Time
+ err = row.Scan(&id, &encryptedToken, &encryptedKey, &keyName, &expiry)
+ if err != nil {
+ return nil, fmt.Errorf("unable to scan sql row into ibd cookie: %w", err)
+ }
+
+ token, err := keys.Decrypt(ctx, d.kms, keyName, encryptedToken, encryptedKey)
+ if err != nil {
+ return nil, fmt.Errorf("unable to decrypt token: %w", err)
+ }
+ cookie := IBDCookie{
+ ID: id,
+ Token: string(token),
+ Expiry: expiry,
+ }
+ cookies = append(cookies, cookie)
+ }
+
+ return cookies, nil
+}
+
+func (d *database) AddCookie(ctx context.Context, subject string, cookie *http.Cookie) error {
+ // Get the key ID for the user
+ user, err := d.GetUser(ctx, subject)
+ if err != nil {
+ return fmt.Errorf("unable to get user: %w", err)
+ }
+ if user.EncryptionKeyID == nil {
+ return errors.New("user does not have an encryption key")
+ }
+
+ // Get the key
+ key, err := d.GetKey(ctx, *user.EncryptionKeyID)
+ if err != nil {
+ return fmt.Errorf("unable to get key: %w", err)
+ }
+
+ // Encrypt the token
+ encryptedToken, err := keys.EncryptWithKey(ctx, d.kms, key.Name, key.Key, []byte(cookie.Value))
+ if err != nil {
+ return fmt.Errorf("unable to encrypt token: %w", err)
+ }
+
+ // Add the cookie to the database
+ _, err = d.exec(ctx, d.db, "cookies/add_cookie", encryptedToken, cookie.Expires, subject, key.Id)
+ if err != nil {
+ return fmt.Errorf("unable to add cookie: %w", err)
+ }
+
+ return nil
+}
+
+func (d *database) ReportCookieFailure(ctx context.Context, id uint) error {
+ _, err := d.exec(ctx, d.db, "cookies/set_cookie_degraded", true, id)
+ if err != nil {
+ return fmt.Errorf("unable to report cookie failure: %w", err)
+ }
+ return nil
+}
+
+func (d *database) RepairCookie(ctx context.Context, id uint) error {
+ _, err := d.exec(ctx, d.db, "cookies/set_cookie_degraded", false, id)
+ if err != nil {
+ return fmt.Errorf("unable to report cookie failure: %w", err)
+ }
+ return nil
+}
+
+type IBDCookie struct {
+ ID uint
+ Token string
+ Expiry time.Time
+}
+
+func (c *IBDCookie) ToHTTPCookie() *http.Cookie {
+ return &http.Cookie{
+ Name: ".ASPXAUTH",
+ Value: c.Token,
+ Path: "/",
+ Domain: "investors.com",
+ Expires: c.Expiry,
+ Secure: true,
+ HttpOnly: false,
+ SameSite: http.SameSiteLaxMode,
+ }
+}
diff --git a/backend/internal/database/database.go b/backend/internal/database/database.go
new file mode 100644
index 0000000..4022dde
--- /dev/null
+++ b/backend/internal/database/database.go
@@ -0,0 +1,178 @@
+package database
+
+import (
+ "context"
+ "database/sql"
+ "database/sql/driver"
+ "errors"
+ "fmt"
+ "io"
+ "log/slog"
+ "sync"
+ "time"
+
+ "ibd-trader/db"
+ "ibd-trader/internal/keys"
+
+ "github.com/golang-migrate/migrate/v4"
+ _ "github.com/golang-migrate/migrate/v4/database/postgres"
+ "github.com/golang-migrate/migrate/v4/source/iofs"
+ _ "github.com/lib/pq"
+)
+
+type Database interface {
+ io.Closer
+ UserStore
+ CookieStore
+ KeyStore
+ SessionStore
+ StockStore
+
+ driver.Pinger
+
+ Migrate(ctx context.Context) error
+ Maintenance(ctx context.Context)
+}
+
+type database struct {
+ logger *slog.Logger
+
+ db *sql.DB
+ url string
+
+ kms keys.KeyManagementService
+ keyName string
+}
+
+func New(ctx context.Context, logger *slog.Logger, url string, kms keys.KeyManagementService, keyName string) (Database, error) {
+ sqlDB, err := sql.Open("postgres", url)
+ if err != nil {
+ return nil, err
+ }
+
+ err = sqlDB.PingContext(ctx)
+ if err != nil {
+ // Ping failed. Don't error, but give a warning.
+ logger.WarnContext(ctx, "Unable to ping database", "error", err)
+ }
+
+ return &database{
+ logger: logger,
+ db: sqlDB,
+ url: url,
+ kms: kms,
+ keyName: keyName,
+ }, nil
+}
+
+func (d *database) Close() error {
+ return d.db.Close()
+}
+
+func (d *database) Migrate(ctx context.Context) error {
+ fs, err := iofs.New(db.Migrations, "migrations")
+ if err != nil {
+ return err
+ }
+
+ m, err := migrate.NewWithSourceInstance("iofs", fs, d.url)
+ if err != nil {
+ return err
+ }
+
+ d.logger.InfoContext(ctx, "Running DB migration")
+ err = m.Up()
+ if err != nil && !errors.Is(err, migrate.ErrNoChange) {
+ d.logger.ErrorContext(ctx, "DB migration failed", "error", err)
+ return err
+ }
+
+ return nil
+}
+
+func (d *database) Maintenance(ctx context.Context) {
+ ticker := time.NewTicker(15 * time.Minute)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ticker.C:
+ func() {
+ var wg sync.WaitGroup
+ wg.Add(1)
+
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
+ defer cancel()
+
+ go d.cleanupSessions(ctx, &wg)
+
+ wg.Wait()
+ }()
+ case <-ctx.Done():
+ return
+ }
+ }
+}
+
+func (d *database) Ping(ctx context.Context) error {
+ return d.db.PingContext(ctx)
+}
+
+func (d *database) execInternal(ctx context.Context, queryName string, fn func(string) (any, error)) (any, error) {
+ query, err := db.GetQuery(queryName)
+ if err != nil {
+ return nil, fmt.Errorf("unable to get query: %w", err)
+ }
+ d.logger.DebugContext(ctx, "Executing query", "name", queryName, "query", query)
+
+ now := time.Now()
+
+ // Execute the query
+ result, err := fn(query)
+ if err != nil {
+ return nil, fmt.Errorf("unable to execute query: %w", err)
+ }
+
+ d.logger.DebugContext(ctx, "Query executed successfully", "name", queryName, "duration", time.Since(now))
+
+ return result, nil
+}
+
+func (d *database) exec(ctx context.Context, exec executor, queryName string, args ...any) (sql.Result, error) {
+ ret, err := d.execInternal(ctx, queryName, func(query string) (any, error) {
+ return exec.ExecContext(ctx, query, args...)
+ })
+ if err != nil {
+ return nil, err
+ } else {
+ return ret.(sql.Result), nil
+ }
+}
+
+func (d *database) query(ctx context.Context, exec executor, queryName string, args ...any) (*sql.Rows, error) {
+ ret, err := d.execInternal(ctx, queryName, func(query string) (any, error) {
+ return exec.QueryContext(ctx, query, args...)
+ })
+ if err != nil {
+ return nil, err
+ } else {
+ return ret.(*sql.Rows), nil
+ }
+}
+
+func (d *database) queryRow(ctx context.Context, exec executor, queryName string, args ...any) (*sql.Row, error) {
+ ret, err := d.execInternal(ctx, queryName, func(query string) (any, error) {
+ return exec.QueryRowContext(ctx, query, args...), nil
+ })
+ if err != nil {
+ return nil, err
+ } else {
+ return ret.(*sql.Row), nil
+ }
+}
+
+type executor interface {
+ ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
+ QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
+ QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
+}
diff --git a/backend/internal/database/keys.go b/backend/internal/database/keys.go
new file mode 100644
index 0000000..0ec4b67
--- /dev/null
+++ b/backend/internal/database/keys.go
@@ -0,0 +1,49 @@
+package database
+
+import (
+ "context"
+ "fmt"
+ "time"
+)
+
+type KeyStore interface {
+ AddKey(ctx context.Context, keyName string, key []byte) (int, error)
+ GetKey(ctx context.Context, keyId int) (*Key, error)
+}
+
+func (d *database) AddKey(ctx context.Context, keyName string, key []byte) (int, error) {
+ row, err := d.queryRow(ctx, d.db, "keys/add_key", keyName, key)
+ if err != nil {
+ return 0, fmt.Errorf("unable to add key: %w", err)
+ }
+
+ var keyId int
+ err = row.Scan(&keyId)
+ if err != nil {
+ return 0, fmt.Errorf("unable to scan key id: %w", err)
+ }
+
+ return keyId, nil
+}
+
+func (d *database) GetKey(ctx context.Context, keyId int) (*Key, error) {
+ row, err := d.queryRow(ctx, d.db, "keys/get_key", keyId)
+ if err != nil {
+ return nil, fmt.Errorf("unable to get key: %w", err)
+ }
+
+ key := &Key{}
+ err = row.Scan(&key.Id, &key.Name, &key.Key, &key.Created)
+ if err != nil {
+ return nil, fmt.Errorf("unable to scan key: %w", err)
+ }
+
+ return key, nil
+}
+
+type Key struct {
+ Id int
+ Name string
+ Key []byte
+ Created time.Time
+}
diff --git a/backend/internal/database/session.go b/backend/internal/database/session.go
new file mode 100644
index 0000000..36867b3
--- /dev/null
+++ b/backend/internal/database/session.go
@@ -0,0 +1,122 @@
+package database
+
+import (
+ "context"
+ "crypto/rand"
+ "database/sql"
+ "encoding/base64"
+ "errors"
+ "io"
+ "sync"
+
+ "github.com/coreos/go-oidc/v3/oidc"
+ "golang.org/x/oauth2"
+)
+
+type SessionStore interface {
+ CreateState(ctx context.Context) (string, error)
+ CheckState(ctx context.Context, state string) (bool, error)
+ CreateSession(ctx context.Context, token *oauth2.Token, idToken *oidc.IDToken) (string, error)
+ GetSession(ctx context.Context, sessionToken string) (*Session, error)
+}
+
+func (d *database) CreateState(ctx context.Context) (string, error) {
+ // Generate a random CSRF state token
+ tokenBytes := make([]byte, 32)
+ if _, err := io.ReadFull(rand.Reader, tokenBytes); err != nil {
+ return "", err
+ }
+ token := base64.URLEncoding.EncodeToString(tokenBytes)
+
+ // Insert the state into the database
+ _, err := d.exec(ctx, d.db, "sessions/create_state", token)
+ if err != nil {
+ return "", err
+ }
+
+ return token, nil
+}
+
+func (d *database) CheckState(ctx context.Context, state string) (bool, error) {
+ var exists bool
+ row, err := d.queryRow(ctx, d.db, "sessions/check_state", state)
+ if err != nil {
+ return false, err
+ }
+ err = row.Scan(&exists)
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return false, nil
+ }
+ return false, err
+ }
+ return exists, nil
+}
+
+func (d *database) CreateSession(
+ ctx context.Context,
+ token *oauth2.Token,
+ idToken *oidc.IDToken,
+) (sessionToken string, err error) {
+ // Generate a random session token
+ tokenBytes := make([]byte, 32)
+ if _, err = io.ReadFull(rand.Reader, tokenBytes); err != nil {
+ return
+ }
+ sessionToken = base64.URLEncoding.EncodeToString(tokenBytes)
+
+ // Insert the session into the database
+ _, err = d.exec(
+ ctx,
+ d.db,
+ "sessions/create_session",
+ sessionToken,
+ idToken.Subject,
+ token.AccessToken,
+ token.Expiry,
+ )
+ return
+}
+
+func (d *database) GetSession(ctx context.Context, sessionToken string) (*Session, error) {
+ row, err := d.queryRow(ctx, d.db, "sessions/get_session", sessionToken)
+ if err != nil {
+ d.logger.ErrorContext(ctx, "Failed to get session", "error", err)
+ return nil, err
+ }
+
+ var session Session
+ err = row.Scan(&session.Token, &session.Subject, &session.OAuthToken.AccessToken, &session.OAuthToken.Expiry)
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, nil
+ }
+ d.logger.ErrorContext(ctx, "Failed to scan session", "error", err)
+ return nil, err
+ }
+
+ return &session, nil
+}
+
+func (d *database) cleanupSessions(ctx context.Context, wg *sync.WaitGroup) {
+ defer wg.Done()
+
+ result, err := d.exec(ctx, d.db, "sessions/cleanup_sessions")
+ if err != nil {
+ d.logger.Error("Failed to clean up sessions", "error", err)
+ return
+ }
+
+ rows, err := result.RowsAffected()
+ if err != nil {
+ d.logger.ErrorContext(ctx, "Failed to get rows affected", "error", err)
+ return
+ }
+ d.logger.DebugContext(ctx, "Cleaned up sessions", "rows", rows)
+}
+
+type Session struct {
+ Token string
+ Subject string
+ OAuthToken oauth2.Token
+}
diff --git a/backend/internal/database/stocks.go b/backend/internal/database/stocks.go
new file mode 100644
index 0000000..701201e
--- /dev/null
+++ b/backend/internal/database/stocks.go
@@ -0,0 +1,264 @@
+package database
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "fmt"
+
+ "ibd-trader/internal/analyzer"
+ "ibd-trader/internal/utils"
+
+ "github.com/Rhymond/go-money"
+)
+
+var ErrStockNotFound = errors.New("stock not found")
+
+type StockStore interface {
+ GetStock(ctx context.Context, symbol string) (Stock, error)
+ AddStock(ctx context.Context, stock Stock) error
+ AddRanking(ctx context.Context, symbol string, ibd50, cap20 int) error
+ AddStockInfo(ctx context.Context, info *StockInfo) (string, error)
+ GetStockInfo(ctx context.Context, id string) (*StockInfo, error)
+ AddAnalysis(ctx context.Context, ratingId string, analysis *analyzer.Analysis) error
+}
+
+func (d *database) GetStock(ctx context.Context, symbol string) (Stock, error) {
+ row, err := d.queryRow(ctx, d.db, "stocks/get_stock", symbol)
+ if err != nil {
+ return Stock{}, err
+ }
+
+ var stock Stock
+ if err = row.Scan(&stock.Symbol, &stock.Name, &stock.IBDUrl); err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return Stock{}, ErrStockNotFound
+ }
+ return Stock{}, err
+ }
+
+ return stock, nil
+}
+
+func (d *database) AddStock(ctx context.Context, stock Stock) error {
+ _, err := d.exec(ctx, d.db, "stocks/add_stock", stock.Symbol, stock.Name, stock.IBDUrl)
+ return err
+}
+
+func (d *database) AddRanking(ctx context.Context, symbol string, ibd50, cap20 int) error {
+ if ibd50 > 0 {
+ _, err := d.exec(ctx, d.db, "stocks/add_rank", symbol, "ibd50", ibd50)
+ if err != nil {
+ return err
+ }
+ }
+ if cap20 > 0 {
+ _, err := d.exec(ctx, d.db, "stocks/add_rank", symbol, "cap20", cap20)
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (d *database) AddStockInfo(ctx context.Context, info *StockInfo) (string, error) {
+ tx, err := d.db.BeginTx(ctx, nil)
+ if err != nil {
+ return "", err
+ }
+ defer func(tx *sql.Tx) {
+ _ = tx.Rollback()
+ }(tx)
+
+ // Add raw chart analysis
+ row, err := d.queryRow(ctx, tx, "stocks/add_raw_chart_analysis", info.ChartAnalysis)
+ if err != nil {
+ return "", err
+ }
+
+ var chartAnalysisID string
+ if err = row.Scan(&chartAnalysisID); err != nil {
+ return "", err
+ }
+
+ // Add stock info
+ row, err = d.queryRow(ctx, tx,
+ "stocks/add_rating",
+ info.Symbol,
+ info.Ratings.Composite,
+ info.Ratings.EPS,
+ info.Ratings.RelStr,
+ info.Ratings.GroupRelStr,
+ info.Ratings.SMR,
+ info.Ratings.AccDis,
+ chartAnalysisID,
+ info.Price.Display(),
+ )
+ if err != nil {
+ return "", err
+ }
+
+ var ratingsID string
+ if err = row.Scan(&ratingsID); err != nil {
+ return "", err
+ }
+
+ return ratingsID, tx.Commit()
+}
+
+func (d *database) GetStockInfo(ctx context.Context, id string) (*StockInfo, error) {
+ row, err := d.queryRow(ctx, d.db, "stocks/get_stock_info", id)
+ if err != nil {
+ return nil, err
+ }
+
+ var info StockInfo
+ var priceStr string
+ err = row.Scan(
+ &info.Symbol,
+ &info.Name,
+ &info.ChartAnalysis,
+ &info.Ratings.Composite,
+ &info.Ratings.EPS,
+ &info.Ratings.RelStr,
+ &info.Ratings.GroupRelStr,
+ &info.Ratings.SMR,
+ &info.Ratings.AccDis,
+ &priceStr,
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ info.Price, err = utils.ParseMoney(priceStr)
+ if err != nil {
+ return nil, err
+ }
+
+ return &info, nil
+}
+
+func (d *database) AddAnalysis(ctx context.Context, ratingId string, analysis *analyzer.Analysis) error {
+ _, err := d.exec(ctx, d.db, "stocks/add_analysis",
+ ratingId,
+ analysis.Action,
+ analysis.Price.Display(),
+ analysis.Reason,
+ analysis.Confidence,
+ )
+ return err
+}
+
+type Stock struct {
+ Symbol string
+ Name string
+ IBDUrl string
+}
+
+type StockInfo struct {
+ Symbol string
+ Name string
+ ChartAnalysis string
+ Ratings Ratings
+ Price *money.Money
+}
+
+type Ratings struct {
+ Composite uint8
+ EPS uint8
+ RelStr uint8
+ GroupRelStr LetterRating
+ SMR LetterRating
+ AccDis LetterRating
+}
+
+type LetterRating uint8
+
+const (
+ LetterRatingE LetterRating = iota
+ LetterRatingEPlus
+ LetterRatingDMinus
+ LetterRatingD
+ LetterRatingDPlus
+ LetterRatingCMinus
+ LetterRatingC
+ LetterRatingCPlus
+ LetterRatingBMinus
+ LetterRatingB
+ LetterRatingBPlus
+ LetterRatingAMinus
+ LetterRatingA
+ LetterRatingAPlus
+)
+
+func (r LetterRating) String() string {
+ switch r {
+ case LetterRatingE:
+ return "E"
+ case LetterRatingEPlus:
+ return "E+"
+ case LetterRatingDMinus:
+ return "D-"
+ case LetterRatingD:
+ return "D"
+ case LetterRatingDPlus:
+ return "D+"
+ case LetterRatingCMinus:
+ return "C-"
+ case LetterRatingC:
+ return "C"
+ case LetterRatingCPlus:
+ return "C+"
+ case LetterRatingBMinus:
+ return "B-"
+ case LetterRatingB:
+ return "B"
+ case LetterRatingBPlus:
+ return "B+"
+ case LetterRatingAMinus:
+ return "A-"
+ case LetterRatingA:
+ return "A"
+ case LetterRatingAPlus:
+ return "A+"
+ default:
+ return "Unknown"
+ }
+}
+
+func LetterRatingFromString(str string) (LetterRating, error) {
+ switch str {
+ case "N/A":
+ fallthrough
+ case "E":
+ return LetterRatingE, nil
+ case "E+":
+ return LetterRatingEPlus, nil
+ case "D-":
+ return LetterRatingDMinus, nil
+ case "D":
+ return LetterRatingD, nil
+ case "D+":
+ return LetterRatingDPlus, nil
+ case "C-":
+ return LetterRatingCMinus, nil
+ case "C":
+ return LetterRatingC, nil
+ case "C+":
+ return LetterRatingCPlus, nil
+ case "B-":
+ return LetterRatingBMinus, nil
+ case "B":
+ return LetterRatingB, nil
+ case "B+":
+ return LetterRatingBPlus, nil
+ case "A-":
+ return LetterRatingAMinus, nil
+ case "A":
+ return LetterRatingA, nil
+ case "A+":
+ return LetterRatingAPlus, nil
+ default:
+ return 0, fmt.Errorf("unknown rating: %s", str)
+ }
+}
diff --git a/backend/internal/database/users.go b/backend/internal/database/users.go
new file mode 100644
index 0000000..1950fcb
--- /dev/null
+++ b/backend/internal/database/users.go
@@ -0,0 +1,140 @@
+package database
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "fmt"
+
+ "ibd-trader/internal/keys"
+)
+
+type UserStore interface {
+ AddUser(ctx context.Context, subject string) error
+ GetUser(ctx context.Context, subject string) (*User, error)
+ ListUsers(ctx context.Context, hasIBDCreds bool) ([]User, error)
+ AddIBDCreds(ctx context.Context, subject string, username string, password string) error
+ GetIBDCreds(ctx context.Context, subject string) (username string, password string, err error)
+}
+
+var ErrUserNotFound = fmt.Errorf("user not found")
+var ErrIBDCredsNotFound = fmt.Errorf("ibd creds not found")
+
+func (d *database) AddUser(ctx context.Context, subject string) (err error) {
+ _, err = d.exec(
+ ctx,
+ d.db,
+ "users/add_user",
+ subject,
+ )
+ return
+}
+
+func (d *database) GetUser(ctx context.Context, subject string) (*User, error) {
+ row, err := d.queryRow(ctx, d.db, "users/get_user", subject)
+ if err != nil {
+ return nil, fmt.Errorf("unable to get user: %w", err)
+ }
+
+ user := &User{}
+ err = row.Scan(&user.Subject, &user.IBDUsername, &user.EncryptedIBDPassword, &user.EncryptionKeyID)
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, ErrUserNotFound
+ }
+ return nil, fmt.Errorf("unable to scan sql row into user: %w", err)
+ }
+
+ return user, nil
+}
+
+func (d *database) ListUsers(ctx context.Context, hasIBDCreds bool) ([]User, error) {
+ rows, err := d.query(ctx, d.db, "users/list_users")
+ if err != nil {
+ return nil, fmt.Errorf("unable to list users: %w", err)
+ }
+
+ users := make([]User, 0)
+ for rows.Next() {
+ user := User{}
+ err = rows.Scan(&user.Subject, &user.IBDUsername, &user.EncryptedIBDPassword, &user.EncryptionKeyID)
+ if err != nil {
+ return nil, fmt.Errorf("unable to scan sql row into user: %w", err)
+ }
+
+ if hasIBDCreds && user.IBDUsername == nil {
+ continue
+ }
+ users = append(users, user)
+ }
+
+ return users, nil
+}
+
+func (d *database) AddIBDCreds(ctx context.Context, subject string, username string, password string) error {
+ encryptedPass, encryptedKey, err := keys.Encrypt(ctx, d.kms, d.keyName, []byte(password))
+ if err != nil {
+ return fmt.Errorf("unable to encrypt password: %w", err)
+ }
+
+ tx, err := d.db.BeginTx(ctx, nil)
+ if err != nil {
+ return err
+ }
+ defer func(tx *sql.Tx) {
+ _ = tx.Rollback()
+ }(tx)
+
+ row, err := d.queryRow(ctx, tx, "keys/add_key", d.keyName, encryptedKey)
+ if err != nil {
+ return fmt.Errorf("unable to add ibd creds key: %w", err)
+ }
+
+ var keyId int
+ err = row.Scan(&keyId)
+ if err != nil {
+ return fmt.Errorf("unable to scan key id: %w", err)
+ }
+
+ _, err = d.exec(ctx, tx, "users/add_ibd_creds", subject, username, encryptedPass, keyId)
+ if err != nil {
+ return fmt.Errorf("unable to add ibd creds to user: %w", err)
+ }
+
+ if err = tx.Commit(); err != nil {
+ return fmt.Errorf("unable to commit transaction: %w", err)
+ }
+
+ return nil
+}
+
+func (d *database) GetIBDCreds(ctx context.Context, subject string) (username string, password string, err error) {
+ row, err := d.queryRow(ctx, d.db, "users/get_ibd_creds", subject)
+ if err != nil {
+ return "", "", fmt.Errorf("unable to get ibd creds: %w", err)
+ }
+
+ var encryptedPass, encryptedKey []byte
+ var keyName string
+ err = row.Scan(&username, &encryptedPass, &encryptedKey, &keyName)
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return "", "", ErrIBDCredsNotFound
+ }
+ return "", "", fmt.Errorf("unable to scan sql row into ibd creds: %w", err)
+ }
+
+ passwordBytes, err := keys.Decrypt(ctx, d.kms, keyName, encryptedPass, encryptedKey)
+ if err != nil {
+ return "", "", fmt.Errorf("unable to decrypt password: %w", err)
+ }
+
+ return username, string(passwordBytes), nil
+}
+
+type User struct {
+ Subject string
+ IBDUsername *string
+ EncryptedIBDPassword *string
+ EncryptionKeyID *int
+}