diff options
Diffstat (limited to 'backend/internal/database')
-rw-r--r-- | backend/internal/database/cookies.go | 93 | ||||
-rw-r--r-- | backend/internal/database/database.go | 124 | ||||
-rw-r--r-- | backend/internal/database/keys.go | 29 | ||||
-rw-r--r-- | backend/internal/database/session.go | 122 | ||||
-rw-r--r-- | backend/internal/database/stocks.go | 99 | ||||
-rw-r--r-- | backend/internal/database/users.go | 87 |
6 files changed, 235 insertions, 319 deletions
diff --git a/backend/internal/database/cookies.go b/backend/internal/database/cookies.go index 8bed854..d652b65 100644 --- a/backend/internal/database/cookies.go +++ b/backend/internal/database/cookies.go @@ -11,29 +11,21 @@ import ( "github.com/ansg191/ibd-trader-backend/internal/keys" ) -type CookieStore interface { - CookieSource - AddCookie(ctx context.Context, subject string, cookie *http.Cookie) error - RepairCookie(ctx context.Context, id uint) error -} - -type CookieSource interface { - GetAnyCookie(ctx context.Context) (*IBDCookie, error) - GetCookies(ctx context.Context, subject string, degraded bool) ([]IBDCookie, error) - ReportCookieFailure(ctx context.Context, id uint) error -} - -func (d *database) GetAnyCookie(ctx context.Context) (*IBDCookie, error) { - row, err := d.queryRow(ctx, d.db, "cookies/get_any_cookie") - if err != nil { - return nil, fmt.Errorf("unable to get any ibd cookie: %w", err) - } +func GetAnyCookie(ctx context.Context, exec Executor, kms keys.KeyManagementService) (*IBDCookie, error) { + row := exec.QueryRowContext(ctx, ` +SELECT ibd_tokens.id, token, encrypted_key, kms_key_name, expires_at +FROM ibd_tokens + INNER JOIN keys ON encryption_key = keys.id +WHERE expires_at > NOW() + AND degraded = FALSE +ORDER BY random() +LIMIT 1;`) var id uint var encryptedToken, encryptedKey []byte var keyName string var expiry time.Time - err = row.Scan(&id, &encryptedToken, &encryptedKey, &keyName, &expiry) + err := row.Scan(&id, &encryptedToken, &encryptedKey, &keyName, &expiry) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, nil @@ -41,7 +33,11 @@ func (d *database) GetAnyCookie(ctx context.Context) (*IBDCookie, error) { return nil, fmt.Errorf("unable to scan sql row into ibd cookie: %w", err) } - token, err := keys.Decrypt(ctx, d.kms, keyName, encryptedToken, encryptedKey) + // Set the expiry to UTC explicitly. + // For some reason, the expiry time is set to location="". + expiry = expiry.UTC() + + token, err := keys.Decrypt(ctx, kms, keyName, encryptedToken, encryptedKey) if err != nil { return nil, fmt.Errorf("unable to decrypt token: %w", err) } @@ -51,24 +47,41 @@ func (d *database) GetAnyCookie(ctx context.Context) (*IBDCookie, error) { }, nil } -func (d *database) GetCookies(ctx context.Context, subject string, degraded bool) ([]IBDCookie, error) { - row, err := d.query(ctx, d.db, "cookies/get_cookies", subject, degraded) +func GetCookies( + ctx context.Context, + exec Executor, + kms keys.KeyManagementService, + subject string, + degraded bool, +) ([]IBDCookie, error) { + rows, err := exec.QueryContext(ctx, ` +SELECT ibd_tokens.id, token, encrypted_key, kms_key_name, expires_at +FROM ibd_tokens + INNER JOIN keys ON encryption_key = keys.id +WHERE user_subject = $1 + AND expires_at > NOW() + AND degraded = $2 +ORDER BY expires_at DESC;`, subject, degraded) if err != nil { return nil, fmt.Errorf("unable to get ibd cookies: %w", err) } cookies := make([]IBDCookie, 0) - for row.Next() { + for rows.Next() { var id uint var encryptedToken, encryptedKey []byte var keyName string var expiry time.Time - err = row.Scan(&id, &encryptedToken, &encryptedKey, &keyName, &expiry) + err = rows.Scan(&id, &encryptedToken, &encryptedKey, &keyName, &expiry) if err != nil { return nil, fmt.Errorf("unable to scan sql row into ibd cookie: %w", err) } - token, err := keys.Decrypt(ctx, d.kms, keyName, encryptedToken, encryptedKey) + // Set the expiry to UTC explicitly. + // For some reason, the expiry time is set to location="". + expiry = expiry.UTC() + + token, err := keys.Decrypt(ctx, kms, keyName, encryptedToken, encryptedKey) if err != nil { return nil, fmt.Errorf("unable to decrypt token: %w", err) } @@ -83,9 +96,15 @@ func (d *database) GetCookies(ctx context.Context, subject string, degraded bool return cookies, nil } -func (d *database) AddCookie(ctx context.Context, subject string, cookie *http.Cookie) error { +func AddCookie( + ctx context.Context, + exec Executor, + kms keys.KeyManagementService, + subject string, + cookie *http.Cookie, +) error { // Get the key ID for the user - user, err := d.GetUser(ctx, subject) + user, err := GetUser(ctx, exec, subject) if err != nil { return fmt.Errorf("unable to get user: %w", err) } @@ -94,19 +113,21 @@ func (d *database) AddCookie(ctx context.Context, subject string, cookie *http.C } // Get the key - key, err := d.GetKey(ctx, *user.EncryptionKeyID) + key, err := GetKey(ctx, exec, *user.EncryptionKeyID) if err != nil { return fmt.Errorf("unable to get key: %w", err) } // Encrypt the token - encryptedToken, err := keys.EncryptWithKey(ctx, d.kms, key.Name, key.Key, []byte(cookie.Value)) + encryptedToken, err := keys.EncryptWithKey(ctx, kms, key.Name, key.Key, []byte(cookie.Value)) if err != nil { return fmt.Errorf("unable to encrypt token: %w", err) } // Add the cookie to the database - _, err = d.exec(ctx, d.db, "cookies/add_cookie", encryptedToken, cookie.Expires, subject, key.Id) + _, err = exec.ExecContext(ctx, ` +INSERT INTO ibd_tokens (token, expires_at, user_subject, encryption_key) +VALUES ($1, $2, $3, $4)`, encryptedToken, cookie.Expires, subject, key.Id) if err != nil { return fmt.Errorf("unable to add cookie: %w", err) } @@ -114,16 +135,22 @@ func (d *database) AddCookie(ctx context.Context, subject string, cookie *http.C return nil } -func (d *database) ReportCookieFailure(ctx context.Context, id uint) error { - _, err := d.exec(ctx, d.db, "cookies/set_cookie_degraded", true, id) +func ReportCookieFailure(ctx context.Context, exec Executor, id uint) error { + _, err := exec.ExecContext(ctx, ` +UPDATE ibd_tokens +SET degraded = TRUE +WHERE id = $1;`, id) if err != nil { return fmt.Errorf("unable to report cookie failure: %w", err) } return nil } -func (d *database) RepairCookie(ctx context.Context, id uint) error { - _, err := d.exec(ctx, d.db, "cookies/set_cookie_degraded", false, id) +func RepairCookie(ctx context.Context, exec Executor, id uint) error { + _, err := exec.ExecContext(ctx, ` +UPDATE ibd_tokens +SET degraded = FALSE +WHERE id = $1;`, id) if err != nil { return fmt.Errorf("unable to report cookie failure: %w", err) } diff --git a/backend/internal/database/database.go b/backend/internal/database/database.go index 3c822bc..409dd3c 100644 --- a/backend/internal/database/database.go +++ b/backend/internal/database/database.go @@ -5,7 +5,6 @@ import ( "database/sql" "database/sql/driver" "errors" - "fmt" "io" "log/slog" "sync" @@ -22,12 +21,7 @@ import ( type Database interface { io.Closer - UserStore - CookieStore - KeyStore - SessionStore - StockStore - + TransactionExecutor driver.Pinger Migrate(ctx context.Context) error @@ -70,24 +64,7 @@ func (d *database) Close() error { } func (d *database) Migrate(ctx context.Context) error { - fs, err := iofs.New(db.Migrations, "migrations") - if err != nil { - return err - } - - m, err := migrate.NewWithSourceInstance("iofs", fs, d.url) - if err != nil { - return err - } - - d.logger.InfoContext(ctx, "Running DB migration") - err = m.Up() - if err != nil && !errors.Is(err, migrate.ErrNoChange) { - d.logger.ErrorContext(ctx, "DB migration failed", "error", err) - return err - } - - return nil + return Migrate(ctx, d.url) } func (d *database) Maintenance(ctx context.Context) { @@ -101,11 +78,9 @@ func (d *database) Maintenance(ctx context.Context) { var wg sync.WaitGroup wg.Add(1) - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) + _, cancel := context.WithTimeout(context.Background(), 10*time.Minute) defer cancel() - go d.cleanupSessions(ctx, &wg) - wg.Wait() }() case <-ctx.Done(): @@ -114,65 +89,78 @@ func (d *database) Maintenance(ctx context.Context) { } } -func (d *database) Ping(ctx context.Context) error { - return d.db.PingContext(ctx) -} - -func (d *database) execInternal(ctx context.Context, queryName string, fn func(string) (any, error)) (any, error) { - query, err := db.GetQuery(queryName) +func Migrate(ctx context.Context, url string) error { + fs, err := iofs.New(db.Migrations, "migrations") if err != nil { - return nil, fmt.Errorf("unable to get query: %w", err) + return err } - d.logger.DebugContext(ctx, "Executing query", "name", queryName, "query", query) - - now := time.Now() - // Execute the query - result, err := fn(query) + m, err := migrate.NewWithSourceInstance("iofs", fs, url) if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) + return err } - d.logger.DebugContext(ctx, "Query executed successfully", "name", queryName, "duration", time.Since(now)) + slog.InfoContext(ctx, "Running DB migration") + err = m.Up() + if err != nil && !errors.Is(err, migrate.ErrNoChange) { + slog.ErrorContext(ctx, "DB migration failed", "error", err) + return err + } - return result, nil + return nil } -func (d *database) exec(ctx context.Context, exec executor, queryName string, args ...any) (sql.Result, error) { - ret, err := d.execInternal(ctx, queryName, func(query string) (any, error) { - return exec.ExecContext(ctx, query, args...) - }) - if err != nil { - return nil, err - } else { - return ret.(sql.Result), nil - } +func (d *database) Ping(ctx context.Context) error { + return d.db.PingContext(ctx) +} + +type Executor interface { + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) + QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row +} + +type TransactionExecutor interface { + Executor + BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) } -func (d *database) query(ctx context.Context, exec executor, queryName string, args ...any) (*sql.Rows, error) { - ret, err := d.execInternal(ctx, queryName, func(query string) (any, error) { - return exec.QueryContext(ctx, query, args...) - }) +func (d *database) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { + d.logger.DebugContext(ctx, "Executing query", "query", query) + + now := time.Now() + ret, err := d.db.ExecContext(ctx, query, args...) if err != nil { return nil, err - } else { - return ret.(*sql.Rows), nil } + + d.logger.DebugContext(ctx, "Query executed successfully", "duration", time.Since(now)) + return ret, nil } -func (d *database) queryRow(ctx context.Context, exec executor, queryName string, args ...any) (*sql.Row, error) { - ret, err := d.execInternal(ctx, queryName, func(query string) (any, error) { - return exec.QueryRowContext(ctx, query, args...), nil - }) +func (d *database) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { + d.logger.DebugContext(ctx, "Executing query", "query", query) + + now := time.Now() + ret, err := d.db.QueryContext(ctx, query, args...) if err != nil { return nil, err - } else { - return ret.(*sql.Row), nil } + + d.logger.DebugContext(ctx, "Query executed successfully", "duration", time.Since(now)) + return ret, nil } -type executor interface { - ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) - QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) - QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row +func (d *database) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row { + d.logger.DebugContext(ctx, "Executing query", "query", query) + + now := time.Now() + ret := d.db.QueryRowContext(ctx, query, args...) + + d.logger.DebugContext(ctx, "Query executed successfully", "duration", time.Since(now)) + return ret +} + +func (d *database) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { + return d.db.BeginTx(ctx, opts) } diff --git a/backend/internal/database/keys.go b/backend/internal/database/keys.go index 0ec4b67..e2e2770 100644 --- a/backend/internal/database/keys.go +++ b/backend/internal/database/keys.go @@ -6,19 +6,14 @@ import ( "time" ) -type KeyStore interface { - AddKey(ctx context.Context, keyName string, key []byte) (int, error) - GetKey(ctx context.Context, keyId int) (*Key, error) -} - -func (d *database) AddKey(ctx context.Context, keyName string, key []byte) (int, error) { - row, err := d.queryRow(ctx, d.db, "keys/add_key", keyName, key) - if err != nil { - return 0, fmt.Errorf("unable to add key: %w", err) - } +func AddKey(ctx context.Context, exec Executor, keyName string, key []byte) (int, error) { + row := exec.QueryRowContext(ctx, ` +INSERT INTO keys (kms_key_name, encrypted_key) +VALUES ($1, $2) +RETURNING id;`, keyName, key) var keyId int - err = row.Scan(&keyId) + err := row.Scan(&keyId) if err != nil { return 0, fmt.Errorf("unable to scan key id: %w", err) } @@ -26,14 +21,14 @@ func (d *database) AddKey(ctx context.Context, keyName string, key []byte) (int, return keyId, nil } -func (d *database) GetKey(ctx context.Context, keyId int) (*Key, error) { - row, err := d.queryRow(ctx, d.db, "keys/get_key", keyId) - if err != nil { - return nil, fmt.Errorf("unable to get key: %w", err) - } +func GetKey(ctx context.Context, exec Executor, keyId int) (*Key, error) { + row := exec.QueryRowContext(ctx, ` +SELECT id, kms_key_name, encrypted_key, created_at +FROM keys +WHERE id = $1;`, keyId) key := &Key{} - err = row.Scan(&key.Id, &key.Name, &key.Key, &key.Created) + err := row.Scan(&key.Id, &key.Name, &key.Key, &key.Created) if err != nil { return nil, fmt.Errorf("unable to scan key: %w", err) } diff --git a/backend/internal/database/session.go b/backend/internal/database/session.go deleted file mode 100644 index 36867b3..0000000 --- a/backend/internal/database/session.go +++ /dev/null @@ -1,122 +0,0 @@ -package database - -import ( - "context" - "crypto/rand" - "database/sql" - "encoding/base64" - "errors" - "io" - "sync" - - "github.com/coreos/go-oidc/v3/oidc" - "golang.org/x/oauth2" -) - -type SessionStore interface { - CreateState(ctx context.Context) (string, error) - CheckState(ctx context.Context, state string) (bool, error) - CreateSession(ctx context.Context, token *oauth2.Token, idToken *oidc.IDToken) (string, error) - GetSession(ctx context.Context, sessionToken string) (*Session, error) -} - -func (d *database) CreateState(ctx context.Context) (string, error) { - // Generate a random CSRF state token - tokenBytes := make([]byte, 32) - if _, err := io.ReadFull(rand.Reader, tokenBytes); err != nil { - return "", err - } - token := base64.URLEncoding.EncodeToString(tokenBytes) - - // Insert the state into the database - _, err := d.exec(ctx, d.db, "sessions/create_state", token) - if err != nil { - return "", err - } - - return token, nil -} - -func (d *database) CheckState(ctx context.Context, state string) (bool, error) { - var exists bool - row, err := d.queryRow(ctx, d.db, "sessions/check_state", state) - if err != nil { - return false, err - } - err = row.Scan(&exists) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return false, nil - } - return false, err - } - return exists, nil -} - -func (d *database) CreateSession( - ctx context.Context, - token *oauth2.Token, - idToken *oidc.IDToken, -) (sessionToken string, err error) { - // Generate a random session token - tokenBytes := make([]byte, 32) - if _, err = io.ReadFull(rand.Reader, tokenBytes); err != nil { - return - } - sessionToken = base64.URLEncoding.EncodeToString(tokenBytes) - - // Insert the session into the database - _, err = d.exec( - ctx, - d.db, - "sessions/create_session", - sessionToken, - idToken.Subject, - token.AccessToken, - token.Expiry, - ) - return -} - -func (d *database) GetSession(ctx context.Context, sessionToken string) (*Session, error) { - row, err := d.queryRow(ctx, d.db, "sessions/get_session", sessionToken) - if err != nil { - d.logger.ErrorContext(ctx, "Failed to get session", "error", err) - return nil, err - } - - var session Session - err = row.Scan(&session.Token, &session.Subject, &session.OAuthToken.AccessToken, &session.OAuthToken.Expiry) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return nil, nil - } - d.logger.ErrorContext(ctx, "Failed to scan session", "error", err) - return nil, err - } - - return &session, nil -} - -func (d *database) cleanupSessions(ctx context.Context, wg *sync.WaitGroup) { - defer wg.Done() - - result, err := d.exec(ctx, d.db, "sessions/cleanup_sessions") - if err != nil { - d.logger.Error("Failed to clean up sessions", "error", err) - return - } - - rows, err := result.RowsAffected() - if err != nil { - d.logger.ErrorContext(ctx, "Failed to get rows affected", "error", err) - return - } - d.logger.DebugContext(ctx, "Cleaned up sessions", "rows", rows) -} - -type Session struct { - Token string - Subject string - OAuthToken oauth2.Token -} diff --git a/backend/internal/database/stocks.go b/backend/internal/database/stocks.go index f74e4e8..865aec4 100644 --- a/backend/internal/database/stocks.go +++ b/backend/internal/database/stocks.go @@ -14,23 +14,15 @@ import ( var ErrStockNotFound = errors.New("stock not found") -type StockStore interface { - GetStock(ctx context.Context, symbol string) (Stock, error) - AddStock(ctx context.Context, stock Stock) error - AddRanking(ctx context.Context, symbol string, ibd50, cap20 int) error - AddStockInfo(ctx context.Context, info *StockInfo) (string, error) - GetStockInfo(ctx context.Context, id string) (*StockInfo, error) - AddAnalysis(ctx context.Context, ratingId string, analysis *analyzer.Analysis) error -} - -func (d *database) GetStock(ctx context.Context, symbol string) (Stock, error) { - row, err := d.queryRow(ctx, d.db, "stocks/get_stock", symbol) - if err != nil { - return Stock{}, err - } +func GetStock(ctx context.Context, exec Executor, symbol string) (Stock, error) { + row := exec.QueryRowContext(ctx, ` +SELECT symbol, name, ibd_url +FROM stocks +WHERE symbol = $1; +`, symbol) var stock Stock - if err = row.Scan(&stock.Symbol, &stock.Name, &stock.IBDUrl); err != nil { + if err := row.Scan(&stock.Symbol, &stock.Name, &stock.IBDUrl); err != nil { if errors.Is(err, sql.ErrNoRows) { return Stock{}, ErrStockNotFound } @@ -40,20 +32,29 @@ func (d *database) GetStock(ctx context.Context, symbol string) (Stock, error) { return stock, nil } -func (d *database) AddStock(ctx context.Context, stock Stock) error { - _, err := d.exec(ctx, d.db, "stocks/add_stock", stock.Symbol, stock.Name, stock.IBDUrl) +func AddStock(ctx context.Context, exec Executor, stock Stock) error { + _, err := exec.ExecContext(ctx, ` +INSERT INTO stocks (symbol, name, ibd_url) +VALUES ($1, $2, $3) +ON CONFLICT (symbol) + DO UPDATE SET name = $2, + ibd_url = $3;`, stock.Symbol, stock.Name, stock.IBDUrl) return err } -func (d *database) AddRanking(ctx context.Context, symbol string, ibd50, cap20 int) error { +func AddRanking(ctx context.Context, exec Executor, symbol string, ibd50, cap20 int) error { if ibd50 > 0 { - _, err := d.exec(ctx, d.db, "stocks/add_rank", symbol, "ibd50", ibd50) + _, err := exec.ExecContext(ctx, ` +INSERT INTO stock_rank (symbol, rank_type, rank) +VALUES ($1, $2, $3)`, symbol, "ibd50", ibd50) if err != nil { return err } } if cap20 > 0 { - _, err := d.exec(ctx, d.db, "stocks/add_rank", symbol, "cap20", cap20) + _, err := exec.ExecContext(ctx, ` +INSERT INTO stock_rank (symbol, rank_type, rank) +VALUES ($1, $2, $3)`, symbol, "cap20", cap20) if err != nil { return err } @@ -61,8 +62,8 @@ func (d *database) AddRanking(ctx context.Context, symbol string, ibd50, cap20 i return nil } -func (d *database) AddStockInfo(ctx context.Context, info *StockInfo) (string, error) { - tx, err := d.db.BeginTx(ctx, nil) +func AddStockInfo(ctx context.Context, exec TransactionExecutor, info *StockInfo) (string, error) { + tx, err := exec.BeginTx(ctx, nil) if err != nil { return "", err } @@ -71,10 +72,10 @@ func (d *database) AddStockInfo(ctx context.Context, info *StockInfo) (string, e }(tx) // Add raw chart analysis - row, err := d.queryRow(ctx, tx, "stocks/add_raw_chart_analysis", info.ChartAnalysis) - if err != nil { - return "", err - } + row := tx.QueryRowContext(ctx, ` +INSERT INTO chart_analysis (raw_analysis) +VALUES ($1) +RETURNING id;`, info.ChartAnalysis) var chartAnalysisID string if err = row.Scan(&chartAnalysisID); err != nil { @@ -82,8 +83,11 @@ func (d *database) AddStockInfo(ctx context.Context, info *StockInfo) (string, e } // Add stock info - row, err = d.queryRow(ctx, tx, - "stocks/add_rating", + row = tx.QueryRowContext(ctx, + ` +INSERT INTO ratings (symbol, composite, eps, rel_str, group_rel_str, smr, acc_dis, chart_analysis, price) +VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) +RETURNING id;`, info.Symbol, info.Ratings.Composite, info.Ratings.EPS, @@ -94,9 +98,6 @@ func (d *database) AddStockInfo(ctx context.Context, info *StockInfo) (string, e chartAnalysisID, info.Price.Display(), ) - if err != nil { - return "", err - } var ratingsID string if err = row.Scan(&ratingsID); err != nil { @@ -106,15 +107,26 @@ func (d *database) AddStockInfo(ctx context.Context, info *StockInfo) (string, e return ratingsID, tx.Commit() } -func (d *database) GetStockInfo(ctx context.Context, id string) (*StockInfo, error) { - row, err := d.queryRow(ctx, d.db, "stocks/get_stock_info", id) - if err != nil { - return nil, err - } +func GetStockInfo(ctx context.Context, exec Executor, id string) (*StockInfo, error) { + row := exec.QueryRowContext(ctx, ` +SELECT r.symbol, + s.name, + ca.raw_analysis, + r.composite, + r.eps, + r.rel_str, + r.group_rel_str, + r.smr, + r.acc_dis, + r.price +FROM ratings r + INNER JOIN stocks s on r.symbol = s.symbol + INNER JOIN chart_analysis ca on r.chart_analysis = ca.id +WHERE r.id = $1;`, id) var info StockInfo var priceStr string - err = row.Scan( + err := row.Scan( &info.Symbol, &info.Name, &info.ChartAnalysis, @@ -138,8 +150,17 @@ func (d *database) GetStockInfo(ctx context.Context, id string) (*StockInfo, err return &info, nil } -func (d *database) AddAnalysis(ctx context.Context, ratingId string, analysis *analyzer.Analysis) error { - _, err := d.exec(ctx, d.db, "stocks/add_analysis", +func AddAnalysis(ctx context.Context, exec Executor, ratingId string, analysis *analyzer.Analysis) error { + _, err := exec.ExecContext(ctx, ` +UPDATE chart_analysis ca +SET processed = true, + action = $2, + price = $3, + reason = $4, + confidence = $5 +FROM ratings r +WHERE r.id = $1 + AND r.chart_analysis = ca.id;`, ratingId, analysis.Action, analysis.Price.Display(), diff --git a/backend/internal/database/users.go b/backend/internal/database/users.go index ff6f674..d023598 100644 --- a/backend/internal/database/users.go +++ b/backend/internal/database/users.go @@ -9,35 +9,25 @@ import ( "github.com/ansg191/ibd-trader-backend/internal/keys" ) -type UserStore interface { - AddUser(ctx context.Context, subject string) error - GetUser(ctx context.Context, subject string) (*User, error) - ListUsers(ctx context.Context, hasIBDCreds bool) ([]User, error) - AddIBDCreds(ctx context.Context, subject string, username string, password string) error - GetIBDCreds(ctx context.Context, subject string) (username string, password string, err error) -} - var ErrUserNotFound = fmt.Errorf("user not found") var ErrIBDCredsNotFound = fmt.Errorf("ibd creds not found") -func (d *database) AddUser(ctx context.Context, subject string) (err error) { - _, err = d.exec( - ctx, - d.db, - "users/add_user", - subject, - ) +func AddUser(ctx context.Context, exec Executor, subject string) (err error) { + _, err = exec.ExecContext(ctx, ` +INSERT INTO users (subject) +VALUES ($1) +ON CONFLICT DO NOTHING;`, subject) return } -func (d *database) GetUser(ctx context.Context, subject string) (*User, error) { - row, err := d.queryRow(ctx, d.db, "users/get_user", subject) - if err != nil { - return nil, fmt.Errorf("unable to get user: %w", err) - } +func GetUser(ctx context.Context, exec Executor, subject string) (*User, error) { + row := exec.QueryRowContext(ctx, ` +SELECT subject, ibd_username, ibd_password, encryption_key +FROM users +WHERE subject = $1;`, subject) user := &User{} - err = row.Scan(&user.Subject, &user.IBDUsername, &user.EncryptedIBDPassword, &user.EncryptionKeyID) + err := row.Scan(&user.Subject, &user.IBDUsername, &user.EncryptedIBDPassword, &user.EncryptionKeyID) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, ErrUserNotFound @@ -48,8 +38,11 @@ func (d *database) GetUser(ctx context.Context, subject string) (*User, error) { return user, nil } -func (d *database) ListUsers(ctx context.Context, hasIBDCreds bool) ([]User, error) { - rows, err := d.query(ctx, d.db, "users/list_users") +func ListUsers(ctx context.Context, exec Executor, hasIBDCreds bool) ([]User, error) { + rows, err := exec.QueryContext(ctx, ` +SELECT subject, ibd_username, ibd_password, encryption_key +FROM users; +`) if err != nil { return nil, fmt.Errorf("unable to list users: %w", err) } @@ -71,13 +64,18 @@ func (d *database) ListUsers(ctx context.Context, hasIBDCreds bool) ([]User, err return users, nil } -func (d *database) AddIBDCreds(ctx context.Context, subject string, username string, password string) error { - encryptedPass, encryptedKey, err := keys.Encrypt(ctx, d.kms, d.keyName, []byte(password)) +func AddIBDCreds( + ctx context.Context, + exec TransactionExecutor, + kms keys.KeyManagementService, + keyName, subject, username, password string, +) error { + encryptedPass, encryptedKey, err := keys.Encrypt(ctx, kms, keyName, []byte(password)) if err != nil { return fmt.Errorf("unable to encrypt password: %w", err) } - tx, err := d.db.BeginTx(ctx, nil) + tx, err := exec.BeginTx(ctx, nil) if err != nil { return err } @@ -85,18 +83,17 @@ func (d *database) AddIBDCreds(ctx context.Context, subject string, username str _ = tx.Rollback() }(tx) - row, err := d.queryRow(ctx, tx, "keys/add_key", d.keyName, encryptedKey) + keyId, err := AddKey(ctx, tx, keyName, encryptedKey) if err != nil { return fmt.Errorf("unable to add ibd creds key: %w", err) } - var keyId int - err = row.Scan(&keyId) - if err != nil { - return fmt.Errorf("unable to scan key id: %w", err) - } - - _, err = d.exec(ctx, tx, "users/add_ibd_creds", subject, username, encryptedPass, keyId) + _, err = exec.ExecContext(ctx, ` +UPDATE users +SET ibd_username = $2, + ibd_password = $3, + encryption_key = $4 +WHERE subject = $1;`, subject, username, encryptedPass, keyId) if err != nil { return fmt.Errorf("unable to add ibd creds to user: %w", err) } @@ -108,11 +105,21 @@ func (d *database) AddIBDCreds(ctx context.Context, subject string, username str return nil } -func (d *database) GetIBDCreds(ctx context.Context, subject string) (username string, password string, err error) { - row, err := d.queryRow(ctx, d.db, "users/get_ibd_creds", subject) - if err != nil { - return "", "", fmt.Errorf("unable to get ibd creds: %w", err) - } +func GetIBDCreds( + ctx context.Context, + exec Executor, + kms keys.KeyManagementService, + subject string, +) ( + username string, + password string, + err error, +) { + row := exec.QueryRowContext(ctx, ` +SELECT ibd_username, ibd_password, encrypted_key, kms_key_name +FROM users +INNER JOIN public.keys k on k.id = users.encryption_key +WHERE subject = $1;`, subject) var encryptedPass, encryptedKey []byte var keyName string @@ -124,7 +131,7 @@ func (d *database) GetIBDCreds(ctx context.Context, subject string) (username st return "", "", fmt.Errorf("unable to scan sql row into ibd creds: %w", err) } - passwordBytes, err := keys.Decrypt(ctx, d.kms, keyName, encryptedPass, encryptedKey) + passwordBytes, err := keys.Decrypt(ctx, kms, keyName, encryptedPass, encryptedKey) if err != nil { return "", "", fmt.Errorf("unable to decrypt password: %w", err) } |