aboutsummaryrefslogtreecommitdiff
path: root/backend/internal/database
diff options
context:
space:
mode:
Diffstat (limited to 'backend/internal/database')
-rw-r--r--backend/internal/database/cookies.go93
-rw-r--r--backend/internal/database/database.go124
-rw-r--r--backend/internal/database/keys.go29
-rw-r--r--backend/internal/database/session.go122
-rw-r--r--backend/internal/database/stocks.go99
-rw-r--r--backend/internal/database/users.go87
6 files changed, 235 insertions, 319 deletions
diff --git a/backend/internal/database/cookies.go b/backend/internal/database/cookies.go
index 8bed854..d652b65 100644
--- a/backend/internal/database/cookies.go
+++ b/backend/internal/database/cookies.go
@@ -11,29 +11,21 @@ import (
"github.com/ansg191/ibd-trader-backend/internal/keys"
)
-type CookieStore interface {
- CookieSource
- AddCookie(ctx context.Context, subject string, cookie *http.Cookie) error
- RepairCookie(ctx context.Context, id uint) error
-}
-
-type CookieSource interface {
- GetAnyCookie(ctx context.Context) (*IBDCookie, error)
- GetCookies(ctx context.Context, subject string, degraded bool) ([]IBDCookie, error)
- ReportCookieFailure(ctx context.Context, id uint) error
-}
-
-func (d *database) GetAnyCookie(ctx context.Context) (*IBDCookie, error) {
- row, err := d.queryRow(ctx, d.db, "cookies/get_any_cookie")
- if err != nil {
- return nil, fmt.Errorf("unable to get any ibd cookie: %w", err)
- }
+func GetAnyCookie(ctx context.Context, exec Executor, kms keys.KeyManagementService) (*IBDCookie, error) {
+ row := exec.QueryRowContext(ctx, `
+SELECT ibd_tokens.id, token, encrypted_key, kms_key_name, expires_at
+FROM ibd_tokens
+ INNER JOIN keys ON encryption_key = keys.id
+WHERE expires_at > NOW()
+ AND degraded = FALSE
+ORDER BY random()
+LIMIT 1;`)
var id uint
var encryptedToken, encryptedKey []byte
var keyName string
var expiry time.Time
- err = row.Scan(&id, &encryptedToken, &encryptedKey, &keyName, &expiry)
+ err := row.Scan(&id, &encryptedToken, &encryptedKey, &keyName, &expiry)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
@@ -41,7 +33,11 @@ func (d *database) GetAnyCookie(ctx context.Context) (*IBDCookie, error) {
return nil, fmt.Errorf("unable to scan sql row into ibd cookie: %w", err)
}
- token, err := keys.Decrypt(ctx, d.kms, keyName, encryptedToken, encryptedKey)
+ // Set the expiry to UTC explicitly.
+ // For some reason, the expiry time is set to location="".
+ expiry = expiry.UTC()
+
+ token, err := keys.Decrypt(ctx, kms, keyName, encryptedToken, encryptedKey)
if err != nil {
return nil, fmt.Errorf("unable to decrypt token: %w", err)
}
@@ -51,24 +47,41 @@ func (d *database) GetAnyCookie(ctx context.Context) (*IBDCookie, error) {
}, nil
}
-func (d *database) GetCookies(ctx context.Context, subject string, degraded bool) ([]IBDCookie, error) {
- row, err := d.query(ctx, d.db, "cookies/get_cookies", subject, degraded)
+func GetCookies(
+ ctx context.Context,
+ exec Executor,
+ kms keys.KeyManagementService,
+ subject string,
+ degraded bool,
+) ([]IBDCookie, error) {
+ rows, err := exec.QueryContext(ctx, `
+SELECT ibd_tokens.id, token, encrypted_key, kms_key_name, expires_at
+FROM ibd_tokens
+ INNER JOIN keys ON encryption_key = keys.id
+WHERE user_subject = $1
+ AND expires_at > NOW()
+ AND degraded = $2
+ORDER BY expires_at DESC;`, subject, degraded)
if err != nil {
return nil, fmt.Errorf("unable to get ibd cookies: %w", err)
}
cookies := make([]IBDCookie, 0)
- for row.Next() {
+ for rows.Next() {
var id uint
var encryptedToken, encryptedKey []byte
var keyName string
var expiry time.Time
- err = row.Scan(&id, &encryptedToken, &encryptedKey, &keyName, &expiry)
+ err = rows.Scan(&id, &encryptedToken, &encryptedKey, &keyName, &expiry)
if err != nil {
return nil, fmt.Errorf("unable to scan sql row into ibd cookie: %w", err)
}
- token, err := keys.Decrypt(ctx, d.kms, keyName, encryptedToken, encryptedKey)
+ // Set the expiry to UTC explicitly.
+ // For some reason, the expiry time is set to location="".
+ expiry = expiry.UTC()
+
+ token, err := keys.Decrypt(ctx, kms, keyName, encryptedToken, encryptedKey)
if err != nil {
return nil, fmt.Errorf("unable to decrypt token: %w", err)
}
@@ -83,9 +96,15 @@ func (d *database) GetCookies(ctx context.Context, subject string, degraded bool
return cookies, nil
}
-func (d *database) AddCookie(ctx context.Context, subject string, cookie *http.Cookie) error {
+func AddCookie(
+ ctx context.Context,
+ exec Executor,
+ kms keys.KeyManagementService,
+ subject string,
+ cookie *http.Cookie,
+) error {
// Get the key ID for the user
- user, err := d.GetUser(ctx, subject)
+ user, err := GetUser(ctx, exec, subject)
if err != nil {
return fmt.Errorf("unable to get user: %w", err)
}
@@ -94,19 +113,21 @@ func (d *database) AddCookie(ctx context.Context, subject string, cookie *http.C
}
// Get the key
- key, err := d.GetKey(ctx, *user.EncryptionKeyID)
+ key, err := GetKey(ctx, exec, *user.EncryptionKeyID)
if err != nil {
return fmt.Errorf("unable to get key: %w", err)
}
// Encrypt the token
- encryptedToken, err := keys.EncryptWithKey(ctx, d.kms, key.Name, key.Key, []byte(cookie.Value))
+ encryptedToken, err := keys.EncryptWithKey(ctx, kms, key.Name, key.Key, []byte(cookie.Value))
if err != nil {
return fmt.Errorf("unable to encrypt token: %w", err)
}
// Add the cookie to the database
- _, err = d.exec(ctx, d.db, "cookies/add_cookie", encryptedToken, cookie.Expires, subject, key.Id)
+ _, err = exec.ExecContext(ctx, `
+INSERT INTO ibd_tokens (token, expires_at, user_subject, encryption_key)
+VALUES ($1, $2, $3, $4)`, encryptedToken, cookie.Expires, subject, key.Id)
if err != nil {
return fmt.Errorf("unable to add cookie: %w", err)
}
@@ -114,16 +135,22 @@ func (d *database) AddCookie(ctx context.Context, subject string, cookie *http.C
return nil
}
-func (d *database) ReportCookieFailure(ctx context.Context, id uint) error {
- _, err := d.exec(ctx, d.db, "cookies/set_cookie_degraded", true, id)
+func ReportCookieFailure(ctx context.Context, exec Executor, id uint) error {
+ _, err := exec.ExecContext(ctx, `
+UPDATE ibd_tokens
+SET degraded = TRUE
+WHERE id = $1;`, id)
if err != nil {
return fmt.Errorf("unable to report cookie failure: %w", err)
}
return nil
}
-func (d *database) RepairCookie(ctx context.Context, id uint) error {
- _, err := d.exec(ctx, d.db, "cookies/set_cookie_degraded", false, id)
+func RepairCookie(ctx context.Context, exec Executor, id uint) error {
+ _, err := exec.ExecContext(ctx, `
+UPDATE ibd_tokens
+SET degraded = FALSE
+WHERE id = $1;`, id)
if err != nil {
return fmt.Errorf("unable to report cookie failure: %w", err)
}
diff --git a/backend/internal/database/database.go b/backend/internal/database/database.go
index 3c822bc..409dd3c 100644
--- a/backend/internal/database/database.go
+++ b/backend/internal/database/database.go
@@ -5,7 +5,6 @@ import (
"database/sql"
"database/sql/driver"
"errors"
- "fmt"
"io"
"log/slog"
"sync"
@@ -22,12 +21,7 @@ import (
type Database interface {
io.Closer
- UserStore
- CookieStore
- KeyStore
- SessionStore
- StockStore
-
+ TransactionExecutor
driver.Pinger
Migrate(ctx context.Context) error
@@ -70,24 +64,7 @@ func (d *database) Close() error {
}
func (d *database) Migrate(ctx context.Context) error {
- fs, err := iofs.New(db.Migrations, "migrations")
- if err != nil {
- return err
- }
-
- m, err := migrate.NewWithSourceInstance("iofs", fs, d.url)
- if err != nil {
- return err
- }
-
- d.logger.InfoContext(ctx, "Running DB migration")
- err = m.Up()
- if err != nil && !errors.Is(err, migrate.ErrNoChange) {
- d.logger.ErrorContext(ctx, "DB migration failed", "error", err)
- return err
- }
-
- return nil
+ return Migrate(ctx, d.url)
}
func (d *database) Maintenance(ctx context.Context) {
@@ -101,11 +78,9 @@ func (d *database) Maintenance(ctx context.Context) {
var wg sync.WaitGroup
wg.Add(1)
- ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
+ _, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
- go d.cleanupSessions(ctx, &wg)
-
wg.Wait()
}()
case <-ctx.Done():
@@ -114,65 +89,78 @@ func (d *database) Maintenance(ctx context.Context) {
}
}
-func (d *database) Ping(ctx context.Context) error {
- return d.db.PingContext(ctx)
-}
-
-func (d *database) execInternal(ctx context.Context, queryName string, fn func(string) (any, error)) (any, error) {
- query, err := db.GetQuery(queryName)
+func Migrate(ctx context.Context, url string) error {
+ fs, err := iofs.New(db.Migrations, "migrations")
if err != nil {
- return nil, fmt.Errorf("unable to get query: %w", err)
+ return err
}
- d.logger.DebugContext(ctx, "Executing query", "name", queryName, "query", query)
-
- now := time.Now()
- // Execute the query
- result, err := fn(query)
+ m, err := migrate.NewWithSourceInstance("iofs", fs, url)
if err != nil {
- return nil, fmt.Errorf("unable to execute query: %w", err)
+ return err
}
- d.logger.DebugContext(ctx, "Query executed successfully", "name", queryName, "duration", time.Since(now))
+ slog.InfoContext(ctx, "Running DB migration")
+ err = m.Up()
+ if err != nil && !errors.Is(err, migrate.ErrNoChange) {
+ slog.ErrorContext(ctx, "DB migration failed", "error", err)
+ return err
+ }
- return result, nil
+ return nil
}
-func (d *database) exec(ctx context.Context, exec executor, queryName string, args ...any) (sql.Result, error) {
- ret, err := d.execInternal(ctx, queryName, func(query string) (any, error) {
- return exec.ExecContext(ctx, query, args...)
- })
- if err != nil {
- return nil, err
- } else {
- return ret.(sql.Result), nil
- }
+func (d *database) Ping(ctx context.Context) error {
+ return d.db.PingContext(ctx)
+}
+
+type Executor interface {
+ ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
+ QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
+ QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
+}
+
+type TransactionExecutor interface {
+ Executor
+ BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
}
-func (d *database) query(ctx context.Context, exec executor, queryName string, args ...any) (*sql.Rows, error) {
- ret, err := d.execInternal(ctx, queryName, func(query string) (any, error) {
- return exec.QueryContext(ctx, query, args...)
- })
+func (d *database) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
+ d.logger.DebugContext(ctx, "Executing query", "query", query)
+
+ now := time.Now()
+ ret, err := d.db.ExecContext(ctx, query, args...)
if err != nil {
return nil, err
- } else {
- return ret.(*sql.Rows), nil
}
+
+ d.logger.DebugContext(ctx, "Query executed successfully", "duration", time.Since(now))
+ return ret, nil
}
-func (d *database) queryRow(ctx context.Context, exec executor, queryName string, args ...any) (*sql.Row, error) {
- ret, err := d.execInternal(ctx, queryName, func(query string) (any, error) {
- return exec.QueryRowContext(ctx, query, args...), nil
- })
+func (d *database) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
+ d.logger.DebugContext(ctx, "Executing query", "query", query)
+
+ now := time.Now()
+ ret, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
- } else {
- return ret.(*sql.Row), nil
}
+
+ d.logger.DebugContext(ctx, "Query executed successfully", "duration", time.Since(now))
+ return ret, nil
}
-type executor interface {
- ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
- QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
- QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
+func (d *database) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
+ d.logger.DebugContext(ctx, "Executing query", "query", query)
+
+ now := time.Now()
+ ret := d.db.QueryRowContext(ctx, query, args...)
+
+ d.logger.DebugContext(ctx, "Query executed successfully", "duration", time.Since(now))
+ return ret
+}
+
+func (d *database) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
+ return d.db.BeginTx(ctx, opts)
}
diff --git a/backend/internal/database/keys.go b/backend/internal/database/keys.go
index 0ec4b67..e2e2770 100644
--- a/backend/internal/database/keys.go
+++ b/backend/internal/database/keys.go
@@ -6,19 +6,14 @@ import (
"time"
)
-type KeyStore interface {
- AddKey(ctx context.Context, keyName string, key []byte) (int, error)
- GetKey(ctx context.Context, keyId int) (*Key, error)
-}
-
-func (d *database) AddKey(ctx context.Context, keyName string, key []byte) (int, error) {
- row, err := d.queryRow(ctx, d.db, "keys/add_key", keyName, key)
- if err != nil {
- return 0, fmt.Errorf("unable to add key: %w", err)
- }
+func AddKey(ctx context.Context, exec Executor, keyName string, key []byte) (int, error) {
+ row := exec.QueryRowContext(ctx, `
+INSERT INTO keys (kms_key_name, encrypted_key)
+VALUES ($1, $2)
+RETURNING id;`, keyName, key)
var keyId int
- err = row.Scan(&keyId)
+ err := row.Scan(&keyId)
if err != nil {
return 0, fmt.Errorf("unable to scan key id: %w", err)
}
@@ -26,14 +21,14 @@ func (d *database) AddKey(ctx context.Context, keyName string, key []byte) (int,
return keyId, nil
}
-func (d *database) GetKey(ctx context.Context, keyId int) (*Key, error) {
- row, err := d.queryRow(ctx, d.db, "keys/get_key", keyId)
- if err != nil {
- return nil, fmt.Errorf("unable to get key: %w", err)
- }
+func GetKey(ctx context.Context, exec Executor, keyId int) (*Key, error) {
+ row := exec.QueryRowContext(ctx, `
+SELECT id, kms_key_name, encrypted_key, created_at
+FROM keys
+WHERE id = $1;`, keyId)
key := &Key{}
- err = row.Scan(&key.Id, &key.Name, &key.Key, &key.Created)
+ err := row.Scan(&key.Id, &key.Name, &key.Key, &key.Created)
if err != nil {
return nil, fmt.Errorf("unable to scan key: %w", err)
}
diff --git a/backend/internal/database/session.go b/backend/internal/database/session.go
deleted file mode 100644
index 36867b3..0000000
--- a/backend/internal/database/session.go
+++ /dev/null
@@ -1,122 +0,0 @@
-package database
-
-import (
- "context"
- "crypto/rand"
- "database/sql"
- "encoding/base64"
- "errors"
- "io"
- "sync"
-
- "github.com/coreos/go-oidc/v3/oidc"
- "golang.org/x/oauth2"
-)
-
-type SessionStore interface {
- CreateState(ctx context.Context) (string, error)
- CheckState(ctx context.Context, state string) (bool, error)
- CreateSession(ctx context.Context, token *oauth2.Token, idToken *oidc.IDToken) (string, error)
- GetSession(ctx context.Context, sessionToken string) (*Session, error)
-}
-
-func (d *database) CreateState(ctx context.Context) (string, error) {
- // Generate a random CSRF state token
- tokenBytes := make([]byte, 32)
- if _, err := io.ReadFull(rand.Reader, tokenBytes); err != nil {
- return "", err
- }
- token := base64.URLEncoding.EncodeToString(tokenBytes)
-
- // Insert the state into the database
- _, err := d.exec(ctx, d.db, "sessions/create_state", token)
- if err != nil {
- return "", err
- }
-
- return token, nil
-}
-
-func (d *database) CheckState(ctx context.Context, state string) (bool, error) {
- var exists bool
- row, err := d.queryRow(ctx, d.db, "sessions/check_state", state)
- if err != nil {
- return false, err
- }
- err = row.Scan(&exists)
- if err != nil {
- if errors.Is(err, sql.ErrNoRows) {
- return false, nil
- }
- return false, err
- }
- return exists, nil
-}
-
-func (d *database) CreateSession(
- ctx context.Context,
- token *oauth2.Token,
- idToken *oidc.IDToken,
-) (sessionToken string, err error) {
- // Generate a random session token
- tokenBytes := make([]byte, 32)
- if _, err = io.ReadFull(rand.Reader, tokenBytes); err != nil {
- return
- }
- sessionToken = base64.URLEncoding.EncodeToString(tokenBytes)
-
- // Insert the session into the database
- _, err = d.exec(
- ctx,
- d.db,
- "sessions/create_session",
- sessionToken,
- idToken.Subject,
- token.AccessToken,
- token.Expiry,
- )
- return
-}
-
-func (d *database) GetSession(ctx context.Context, sessionToken string) (*Session, error) {
- row, err := d.queryRow(ctx, d.db, "sessions/get_session", sessionToken)
- if err != nil {
- d.logger.ErrorContext(ctx, "Failed to get session", "error", err)
- return nil, err
- }
-
- var session Session
- err = row.Scan(&session.Token, &session.Subject, &session.OAuthToken.AccessToken, &session.OAuthToken.Expiry)
- if err != nil {
- if errors.Is(err, sql.ErrNoRows) {
- return nil, nil
- }
- d.logger.ErrorContext(ctx, "Failed to scan session", "error", err)
- return nil, err
- }
-
- return &session, nil
-}
-
-func (d *database) cleanupSessions(ctx context.Context, wg *sync.WaitGroup) {
- defer wg.Done()
-
- result, err := d.exec(ctx, d.db, "sessions/cleanup_sessions")
- if err != nil {
- d.logger.Error("Failed to clean up sessions", "error", err)
- return
- }
-
- rows, err := result.RowsAffected()
- if err != nil {
- d.logger.ErrorContext(ctx, "Failed to get rows affected", "error", err)
- return
- }
- d.logger.DebugContext(ctx, "Cleaned up sessions", "rows", rows)
-}
-
-type Session struct {
- Token string
- Subject string
- OAuthToken oauth2.Token
-}
diff --git a/backend/internal/database/stocks.go b/backend/internal/database/stocks.go
index f74e4e8..865aec4 100644
--- a/backend/internal/database/stocks.go
+++ b/backend/internal/database/stocks.go
@@ -14,23 +14,15 @@ import (
var ErrStockNotFound = errors.New("stock not found")
-type StockStore interface {
- GetStock(ctx context.Context, symbol string) (Stock, error)
- AddStock(ctx context.Context, stock Stock) error
- AddRanking(ctx context.Context, symbol string, ibd50, cap20 int) error
- AddStockInfo(ctx context.Context, info *StockInfo) (string, error)
- GetStockInfo(ctx context.Context, id string) (*StockInfo, error)
- AddAnalysis(ctx context.Context, ratingId string, analysis *analyzer.Analysis) error
-}
-
-func (d *database) GetStock(ctx context.Context, symbol string) (Stock, error) {
- row, err := d.queryRow(ctx, d.db, "stocks/get_stock", symbol)
- if err != nil {
- return Stock{}, err
- }
+func GetStock(ctx context.Context, exec Executor, symbol string) (Stock, error) {
+ row := exec.QueryRowContext(ctx, `
+SELECT symbol, name, ibd_url
+FROM stocks
+WHERE symbol = $1;
+`, symbol)
var stock Stock
- if err = row.Scan(&stock.Symbol, &stock.Name, &stock.IBDUrl); err != nil {
+ if err := row.Scan(&stock.Symbol, &stock.Name, &stock.IBDUrl); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return Stock{}, ErrStockNotFound
}
@@ -40,20 +32,29 @@ func (d *database) GetStock(ctx context.Context, symbol string) (Stock, error) {
return stock, nil
}
-func (d *database) AddStock(ctx context.Context, stock Stock) error {
- _, err := d.exec(ctx, d.db, "stocks/add_stock", stock.Symbol, stock.Name, stock.IBDUrl)
+func AddStock(ctx context.Context, exec Executor, stock Stock) error {
+ _, err := exec.ExecContext(ctx, `
+INSERT INTO stocks (symbol, name, ibd_url)
+VALUES ($1, $2, $3)
+ON CONFLICT (symbol)
+ DO UPDATE SET name = $2,
+ ibd_url = $3;`, stock.Symbol, stock.Name, stock.IBDUrl)
return err
}
-func (d *database) AddRanking(ctx context.Context, symbol string, ibd50, cap20 int) error {
+func AddRanking(ctx context.Context, exec Executor, symbol string, ibd50, cap20 int) error {
if ibd50 > 0 {
- _, err := d.exec(ctx, d.db, "stocks/add_rank", symbol, "ibd50", ibd50)
+ _, err := exec.ExecContext(ctx, `
+INSERT INTO stock_rank (symbol, rank_type, rank)
+VALUES ($1, $2, $3)`, symbol, "ibd50", ibd50)
if err != nil {
return err
}
}
if cap20 > 0 {
- _, err := d.exec(ctx, d.db, "stocks/add_rank", symbol, "cap20", cap20)
+ _, err := exec.ExecContext(ctx, `
+INSERT INTO stock_rank (symbol, rank_type, rank)
+VALUES ($1, $2, $3)`, symbol, "cap20", cap20)
if err != nil {
return err
}
@@ -61,8 +62,8 @@ func (d *database) AddRanking(ctx context.Context, symbol string, ibd50, cap20 i
return nil
}
-func (d *database) AddStockInfo(ctx context.Context, info *StockInfo) (string, error) {
- tx, err := d.db.BeginTx(ctx, nil)
+func AddStockInfo(ctx context.Context, exec TransactionExecutor, info *StockInfo) (string, error) {
+ tx, err := exec.BeginTx(ctx, nil)
if err != nil {
return "", err
}
@@ -71,10 +72,10 @@ func (d *database) AddStockInfo(ctx context.Context, info *StockInfo) (string, e
}(tx)
// Add raw chart analysis
- row, err := d.queryRow(ctx, tx, "stocks/add_raw_chart_analysis", info.ChartAnalysis)
- if err != nil {
- return "", err
- }
+ row := tx.QueryRowContext(ctx, `
+INSERT INTO chart_analysis (raw_analysis)
+VALUES ($1)
+RETURNING id;`, info.ChartAnalysis)
var chartAnalysisID string
if err = row.Scan(&chartAnalysisID); err != nil {
@@ -82,8 +83,11 @@ func (d *database) AddStockInfo(ctx context.Context, info *StockInfo) (string, e
}
// Add stock info
- row, err = d.queryRow(ctx, tx,
- "stocks/add_rating",
+ row = tx.QueryRowContext(ctx,
+ `
+INSERT INTO ratings (symbol, composite, eps, rel_str, group_rel_str, smr, acc_dis, chart_analysis, price)
+VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
+RETURNING id;`,
info.Symbol,
info.Ratings.Composite,
info.Ratings.EPS,
@@ -94,9 +98,6 @@ func (d *database) AddStockInfo(ctx context.Context, info *StockInfo) (string, e
chartAnalysisID,
info.Price.Display(),
)
- if err != nil {
- return "", err
- }
var ratingsID string
if err = row.Scan(&ratingsID); err != nil {
@@ -106,15 +107,26 @@ func (d *database) AddStockInfo(ctx context.Context, info *StockInfo) (string, e
return ratingsID, tx.Commit()
}
-func (d *database) GetStockInfo(ctx context.Context, id string) (*StockInfo, error) {
- row, err := d.queryRow(ctx, d.db, "stocks/get_stock_info", id)
- if err != nil {
- return nil, err
- }
+func GetStockInfo(ctx context.Context, exec Executor, id string) (*StockInfo, error) {
+ row := exec.QueryRowContext(ctx, `
+SELECT r.symbol,
+ s.name,
+ ca.raw_analysis,
+ r.composite,
+ r.eps,
+ r.rel_str,
+ r.group_rel_str,
+ r.smr,
+ r.acc_dis,
+ r.price
+FROM ratings r
+ INNER JOIN stocks s on r.symbol = s.symbol
+ INNER JOIN chart_analysis ca on r.chart_analysis = ca.id
+WHERE r.id = $1;`, id)
var info StockInfo
var priceStr string
- err = row.Scan(
+ err := row.Scan(
&info.Symbol,
&info.Name,
&info.ChartAnalysis,
@@ -138,8 +150,17 @@ func (d *database) GetStockInfo(ctx context.Context, id string) (*StockInfo, err
return &info, nil
}
-func (d *database) AddAnalysis(ctx context.Context, ratingId string, analysis *analyzer.Analysis) error {
- _, err := d.exec(ctx, d.db, "stocks/add_analysis",
+func AddAnalysis(ctx context.Context, exec Executor, ratingId string, analysis *analyzer.Analysis) error {
+ _, err := exec.ExecContext(ctx, `
+UPDATE chart_analysis ca
+SET processed = true,
+ action = $2,
+ price = $3,
+ reason = $4,
+ confidence = $5
+FROM ratings r
+WHERE r.id = $1
+ AND r.chart_analysis = ca.id;`,
ratingId,
analysis.Action,
analysis.Price.Display(),
diff --git a/backend/internal/database/users.go b/backend/internal/database/users.go
index ff6f674..d023598 100644
--- a/backend/internal/database/users.go
+++ b/backend/internal/database/users.go
@@ -9,35 +9,25 @@ import (
"github.com/ansg191/ibd-trader-backend/internal/keys"
)
-type UserStore interface {
- AddUser(ctx context.Context, subject string) error
- GetUser(ctx context.Context, subject string) (*User, error)
- ListUsers(ctx context.Context, hasIBDCreds bool) ([]User, error)
- AddIBDCreds(ctx context.Context, subject string, username string, password string) error
- GetIBDCreds(ctx context.Context, subject string) (username string, password string, err error)
-}
-
var ErrUserNotFound = fmt.Errorf("user not found")
var ErrIBDCredsNotFound = fmt.Errorf("ibd creds not found")
-func (d *database) AddUser(ctx context.Context, subject string) (err error) {
- _, err = d.exec(
- ctx,
- d.db,
- "users/add_user",
- subject,
- )
+func AddUser(ctx context.Context, exec Executor, subject string) (err error) {
+ _, err = exec.ExecContext(ctx, `
+INSERT INTO users (subject)
+VALUES ($1)
+ON CONFLICT DO NOTHING;`, subject)
return
}
-func (d *database) GetUser(ctx context.Context, subject string) (*User, error) {
- row, err := d.queryRow(ctx, d.db, "users/get_user", subject)
- if err != nil {
- return nil, fmt.Errorf("unable to get user: %w", err)
- }
+func GetUser(ctx context.Context, exec Executor, subject string) (*User, error) {
+ row := exec.QueryRowContext(ctx, `
+SELECT subject, ibd_username, ibd_password, encryption_key
+FROM users
+WHERE subject = $1;`, subject)
user := &User{}
- err = row.Scan(&user.Subject, &user.IBDUsername, &user.EncryptedIBDPassword, &user.EncryptionKeyID)
+ err := row.Scan(&user.Subject, &user.IBDUsername, &user.EncryptedIBDPassword, &user.EncryptionKeyID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrUserNotFound
@@ -48,8 +38,11 @@ func (d *database) GetUser(ctx context.Context, subject string) (*User, error) {
return user, nil
}
-func (d *database) ListUsers(ctx context.Context, hasIBDCreds bool) ([]User, error) {
- rows, err := d.query(ctx, d.db, "users/list_users")
+func ListUsers(ctx context.Context, exec Executor, hasIBDCreds bool) ([]User, error) {
+ rows, err := exec.QueryContext(ctx, `
+SELECT subject, ibd_username, ibd_password, encryption_key
+FROM users;
+`)
if err != nil {
return nil, fmt.Errorf("unable to list users: %w", err)
}
@@ -71,13 +64,18 @@ func (d *database) ListUsers(ctx context.Context, hasIBDCreds bool) ([]User, err
return users, nil
}
-func (d *database) AddIBDCreds(ctx context.Context, subject string, username string, password string) error {
- encryptedPass, encryptedKey, err := keys.Encrypt(ctx, d.kms, d.keyName, []byte(password))
+func AddIBDCreds(
+ ctx context.Context,
+ exec TransactionExecutor,
+ kms keys.KeyManagementService,
+ keyName, subject, username, password string,
+) error {
+ encryptedPass, encryptedKey, err := keys.Encrypt(ctx, kms, keyName, []byte(password))
if err != nil {
return fmt.Errorf("unable to encrypt password: %w", err)
}
- tx, err := d.db.BeginTx(ctx, nil)
+ tx, err := exec.BeginTx(ctx, nil)
if err != nil {
return err
}
@@ -85,18 +83,17 @@ func (d *database) AddIBDCreds(ctx context.Context, subject string, username str
_ = tx.Rollback()
}(tx)
- row, err := d.queryRow(ctx, tx, "keys/add_key", d.keyName, encryptedKey)
+ keyId, err := AddKey(ctx, tx, keyName, encryptedKey)
if err != nil {
return fmt.Errorf("unable to add ibd creds key: %w", err)
}
- var keyId int
- err = row.Scan(&keyId)
- if err != nil {
- return fmt.Errorf("unable to scan key id: %w", err)
- }
-
- _, err = d.exec(ctx, tx, "users/add_ibd_creds", subject, username, encryptedPass, keyId)
+ _, err = exec.ExecContext(ctx, `
+UPDATE users
+SET ibd_username = $2,
+ ibd_password = $3,
+ encryption_key = $4
+WHERE subject = $1;`, subject, username, encryptedPass, keyId)
if err != nil {
return fmt.Errorf("unable to add ibd creds to user: %w", err)
}
@@ -108,11 +105,21 @@ func (d *database) AddIBDCreds(ctx context.Context, subject string, username str
return nil
}
-func (d *database) GetIBDCreds(ctx context.Context, subject string) (username string, password string, err error) {
- row, err := d.queryRow(ctx, d.db, "users/get_ibd_creds", subject)
- if err != nil {
- return "", "", fmt.Errorf("unable to get ibd creds: %w", err)
- }
+func GetIBDCreds(
+ ctx context.Context,
+ exec Executor,
+ kms keys.KeyManagementService,
+ subject string,
+) (
+ username string,
+ password string,
+ err error,
+) {
+ row := exec.QueryRowContext(ctx, `
+SELECT ibd_username, ibd_password, encrypted_key, kms_key_name
+FROM users
+INNER JOIN public.keys k on k.id = users.encryption_key
+WHERE subject = $1;`, subject)
var encryptedPass, encryptedKey []byte
var keyName string
@@ -124,7 +131,7 @@ func (d *database) GetIBDCreds(ctx context.Context, subject string) (username st
return "", "", fmt.Errorf("unable to scan sql row into ibd creds: %w", err)
}
- passwordBytes, err := keys.Decrypt(ctx, d.kms, keyName, encryptedPass, encryptedKey)
+ passwordBytes, err := keys.Decrypt(ctx, kms, keyName, encryptedPass, encryptedKey)
if err != nil {
return "", "", fmt.Errorf("unable to decrypt password: %w", err)
}