diff options
author | 2024-08-07 17:48:57 -0700 | |
---|---|---|
committer | 2024-08-07 18:48:10 -0700 | |
commit | e9ee45b9d2bd494332dcf8b2073714f92fd0738d (patch) | |
tree | d34af1af84984409d27003981538f13cde4ba218 /backend/internal | |
parent | 3de4ebb7560851ccbefe296c197456fe80c22901 (diff) | |
download | ibd-trader-e9ee45b9d2bd494332dcf8b2073714f92fd0738d.tar.gz ibd-trader-e9ee45b9d2bd494332dcf8b2073714f92fd0738d.tar.zst ibd-trader-e9ee45b9d2bd494332dcf8b2073714f92fd0738d.zip |
Refactor DB to remove restrictive query system
Diffstat (limited to 'backend/internal')
20 files changed, 489 insertions, 445 deletions
diff --git a/backend/internal/database/cookies.go b/backend/internal/database/cookies.go index 8bed854..d652b65 100644 --- a/backend/internal/database/cookies.go +++ b/backend/internal/database/cookies.go @@ -11,29 +11,21 @@ import ( "github.com/ansg191/ibd-trader-backend/internal/keys" ) -type CookieStore interface { - CookieSource - AddCookie(ctx context.Context, subject string, cookie *http.Cookie) error - RepairCookie(ctx context.Context, id uint) error -} - -type CookieSource interface { - GetAnyCookie(ctx context.Context) (*IBDCookie, error) - GetCookies(ctx context.Context, subject string, degraded bool) ([]IBDCookie, error) - ReportCookieFailure(ctx context.Context, id uint) error -} - -func (d *database) GetAnyCookie(ctx context.Context) (*IBDCookie, error) { - row, err := d.queryRow(ctx, d.db, "cookies/get_any_cookie") - if err != nil { - return nil, fmt.Errorf("unable to get any ibd cookie: %w", err) - } +func GetAnyCookie(ctx context.Context, exec Executor, kms keys.KeyManagementService) (*IBDCookie, error) { + row := exec.QueryRowContext(ctx, ` +SELECT ibd_tokens.id, token, encrypted_key, kms_key_name, expires_at +FROM ibd_tokens + INNER JOIN keys ON encryption_key = keys.id +WHERE expires_at > NOW() + AND degraded = FALSE +ORDER BY random() +LIMIT 1;`) var id uint var encryptedToken, encryptedKey []byte var keyName string var expiry time.Time - err = row.Scan(&id, &encryptedToken, &encryptedKey, &keyName, &expiry) + err := row.Scan(&id, &encryptedToken, &encryptedKey, &keyName, &expiry) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, nil @@ -41,7 +33,11 @@ func (d *database) GetAnyCookie(ctx context.Context) (*IBDCookie, error) { return nil, fmt.Errorf("unable to scan sql row into ibd cookie: %w", err) } - token, err := keys.Decrypt(ctx, d.kms, keyName, encryptedToken, encryptedKey) + // Set the expiry to UTC explicitly. + // For some reason, the expiry time is set to location="". + expiry = expiry.UTC() + + token, err := keys.Decrypt(ctx, kms, keyName, encryptedToken, encryptedKey) if err != nil { return nil, fmt.Errorf("unable to decrypt token: %w", err) } @@ -51,24 +47,41 @@ func (d *database) GetAnyCookie(ctx context.Context) (*IBDCookie, error) { }, nil } -func (d *database) GetCookies(ctx context.Context, subject string, degraded bool) ([]IBDCookie, error) { - row, err := d.query(ctx, d.db, "cookies/get_cookies", subject, degraded) +func GetCookies( + ctx context.Context, + exec Executor, + kms keys.KeyManagementService, + subject string, + degraded bool, +) ([]IBDCookie, error) { + rows, err := exec.QueryContext(ctx, ` +SELECT ibd_tokens.id, token, encrypted_key, kms_key_name, expires_at +FROM ibd_tokens + INNER JOIN keys ON encryption_key = keys.id +WHERE user_subject = $1 + AND expires_at > NOW() + AND degraded = $2 +ORDER BY expires_at DESC;`, subject, degraded) if err != nil { return nil, fmt.Errorf("unable to get ibd cookies: %w", err) } cookies := make([]IBDCookie, 0) - for row.Next() { + for rows.Next() { var id uint var encryptedToken, encryptedKey []byte var keyName string var expiry time.Time - err = row.Scan(&id, &encryptedToken, &encryptedKey, &keyName, &expiry) + err = rows.Scan(&id, &encryptedToken, &encryptedKey, &keyName, &expiry) if err != nil { return nil, fmt.Errorf("unable to scan sql row into ibd cookie: %w", err) } - token, err := keys.Decrypt(ctx, d.kms, keyName, encryptedToken, encryptedKey) + // Set the expiry to UTC explicitly. + // For some reason, the expiry time is set to location="". + expiry = expiry.UTC() + + token, err := keys.Decrypt(ctx, kms, keyName, encryptedToken, encryptedKey) if err != nil { return nil, fmt.Errorf("unable to decrypt token: %w", err) } @@ -83,9 +96,15 @@ func (d *database) GetCookies(ctx context.Context, subject string, degraded bool return cookies, nil } -func (d *database) AddCookie(ctx context.Context, subject string, cookie *http.Cookie) error { +func AddCookie( + ctx context.Context, + exec Executor, + kms keys.KeyManagementService, + subject string, + cookie *http.Cookie, +) error { // Get the key ID for the user - user, err := d.GetUser(ctx, subject) + user, err := GetUser(ctx, exec, subject) if err != nil { return fmt.Errorf("unable to get user: %w", err) } @@ -94,19 +113,21 @@ func (d *database) AddCookie(ctx context.Context, subject string, cookie *http.C } // Get the key - key, err := d.GetKey(ctx, *user.EncryptionKeyID) + key, err := GetKey(ctx, exec, *user.EncryptionKeyID) if err != nil { return fmt.Errorf("unable to get key: %w", err) } // Encrypt the token - encryptedToken, err := keys.EncryptWithKey(ctx, d.kms, key.Name, key.Key, []byte(cookie.Value)) + encryptedToken, err := keys.EncryptWithKey(ctx, kms, key.Name, key.Key, []byte(cookie.Value)) if err != nil { return fmt.Errorf("unable to encrypt token: %w", err) } // Add the cookie to the database - _, err = d.exec(ctx, d.db, "cookies/add_cookie", encryptedToken, cookie.Expires, subject, key.Id) + _, err = exec.ExecContext(ctx, ` +INSERT INTO ibd_tokens (token, expires_at, user_subject, encryption_key) +VALUES ($1, $2, $3, $4)`, encryptedToken, cookie.Expires, subject, key.Id) if err != nil { return fmt.Errorf("unable to add cookie: %w", err) } @@ -114,16 +135,22 @@ func (d *database) AddCookie(ctx context.Context, subject string, cookie *http.C return nil } -func (d *database) ReportCookieFailure(ctx context.Context, id uint) error { - _, err := d.exec(ctx, d.db, "cookies/set_cookie_degraded", true, id) +func ReportCookieFailure(ctx context.Context, exec Executor, id uint) error { + _, err := exec.ExecContext(ctx, ` +UPDATE ibd_tokens +SET degraded = TRUE +WHERE id = $1;`, id) if err != nil { return fmt.Errorf("unable to report cookie failure: %w", err) } return nil } -func (d *database) RepairCookie(ctx context.Context, id uint) error { - _, err := d.exec(ctx, d.db, "cookies/set_cookie_degraded", false, id) +func RepairCookie(ctx context.Context, exec Executor, id uint) error { + _, err := exec.ExecContext(ctx, ` +UPDATE ibd_tokens +SET degraded = FALSE +WHERE id = $1;`, id) if err != nil { return fmt.Errorf("unable to report cookie failure: %w", err) } diff --git a/backend/internal/database/database.go b/backend/internal/database/database.go index 3c822bc..409dd3c 100644 --- a/backend/internal/database/database.go +++ b/backend/internal/database/database.go @@ -5,7 +5,6 @@ import ( "database/sql" "database/sql/driver" "errors" - "fmt" "io" "log/slog" "sync" @@ -22,12 +21,7 @@ import ( type Database interface { io.Closer - UserStore - CookieStore - KeyStore - SessionStore - StockStore - + TransactionExecutor driver.Pinger Migrate(ctx context.Context) error @@ -70,24 +64,7 @@ func (d *database) Close() error { } func (d *database) Migrate(ctx context.Context) error { - fs, err := iofs.New(db.Migrations, "migrations") - if err != nil { - return err - } - - m, err := migrate.NewWithSourceInstance("iofs", fs, d.url) - if err != nil { - return err - } - - d.logger.InfoContext(ctx, "Running DB migration") - err = m.Up() - if err != nil && !errors.Is(err, migrate.ErrNoChange) { - d.logger.ErrorContext(ctx, "DB migration failed", "error", err) - return err - } - - return nil + return Migrate(ctx, d.url) } func (d *database) Maintenance(ctx context.Context) { @@ -101,11 +78,9 @@ func (d *database) Maintenance(ctx context.Context) { var wg sync.WaitGroup wg.Add(1) - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) + _, cancel := context.WithTimeout(context.Background(), 10*time.Minute) defer cancel() - go d.cleanupSessions(ctx, &wg) - wg.Wait() }() case <-ctx.Done(): @@ -114,65 +89,78 @@ func (d *database) Maintenance(ctx context.Context) { } } -func (d *database) Ping(ctx context.Context) error { - return d.db.PingContext(ctx) -} - -func (d *database) execInternal(ctx context.Context, queryName string, fn func(string) (any, error)) (any, error) { - query, err := db.GetQuery(queryName) +func Migrate(ctx context.Context, url string) error { + fs, err := iofs.New(db.Migrations, "migrations") if err != nil { - return nil, fmt.Errorf("unable to get query: %w", err) + return err } - d.logger.DebugContext(ctx, "Executing query", "name", queryName, "query", query) - - now := time.Now() - // Execute the query - result, err := fn(query) + m, err := migrate.NewWithSourceInstance("iofs", fs, url) if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) + return err } - d.logger.DebugContext(ctx, "Query executed successfully", "name", queryName, "duration", time.Since(now)) + slog.InfoContext(ctx, "Running DB migration") + err = m.Up() + if err != nil && !errors.Is(err, migrate.ErrNoChange) { + slog.ErrorContext(ctx, "DB migration failed", "error", err) + return err + } - return result, nil + return nil } -func (d *database) exec(ctx context.Context, exec executor, queryName string, args ...any) (sql.Result, error) { - ret, err := d.execInternal(ctx, queryName, func(query string) (any, error) { - return exec.ExecContext(ctx, query, args...) - }) - if err != nil { - return nil, err - } else { - return ret.(sql.Result), nil - } +func (d *database) Ping(ctx context.Context) error { + return d.db.PingContext(ctx) +} + +type Executor interface { + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) + QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row +} + +type TransactionExecutor interface { + Executor + BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) } -func (d *database) query(ctx context.Context, exec executor, queryName string, args ...any) (*sql.Rows, error) { - ret, err := d.execInternal(ctx, queryName, func(query string) (any, error) { - return exec.QueryContext(ctx, query, args...) - }) +func (d *database) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { + d.logger.DebugContext(ctx, "Executing query", "query", query) + + now := time.Now() + ret, err := d.db.ExecContext(ctx, query, args...) if err != nil { return nil, err - } else { - return ret.(*sql.Rows), nil } + + d.logger.DebugContext(ctx, "Query executed successfully", "duration", time.Since(now)) + return ret, nil } -func (d *database) queryRow(ctx context.Context, exec executor, queryName string, args ...any) (*sql.Row, error) { - ret, err := d.execInternal(ctx, queryName, func(query string) (any, error) { - return exec.QueryRowContext(ctx, query, args...), nil - }) +func (d *database) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { + d.logger.DebugContext(ctx, "Executing query", "query", query) + + now := time.Now() + ret, err := d.db.QueryContext(ctx, query, args...) if err != nil { return nil, err - } else { - return ret.(*sql.Row), nil } + + d.logger.DebugContext(ctx, "Query executed successfully", "duration", time.Since(now)) + return ret, nil } -type executor interface { - ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) - QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) - QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row +func (d *database) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row { + d.logger.DebugContext(ctx, "Executing query", "query", query) + + now := time.Now() + ret := d.db.QueryRowContext(ctx, query, args...) + + d.logger.DebugContext(ctx, "Query executed successfully", "duration", time.Since(now)) + return ret +} + +func (d *database) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { + return d.db.BeginTx(ctx, opts) } diff --git a/backend/internal/database/keys.go b/backend/internal/database/keys.go index 0ec4b67..e2e2770 100644 --- a/backend/internal/database/keys.go +++ b/backend/internal/database/keys.go @@ -6,19 +6,14 @@ import ( "time" ) -type KeyStore interface { - AddKey(ctx context.Context, keyName string, key []byte) (int, error) - GetKey(ctx context.Context, keyId int) (*Key, error) -} - -func (d *database) AddKey(ctx context.Context, keyName string, key []byte) (int, error) { - row, err := d.queryRow(ctx, d.db, "keys/add_key", keyName, key) - if err != nil { - return 0, fmt.Errorf("unable to add key: %w", err) - } +func AddKey(ctx context.Context, exec Executor, keyName string, key []byte) (int, error) { + row := exec.QueryRowContext(ctx, ` +INSERT INTO keys (kms_key_name, encrypted_key) +VALUES ($1, $2) +RETURNING id;`, keyName, key) var keyId int - err = row.Scan(&keyId) + err := row.Scan(&keyId) if err != nil { return 0, fmt.Errorf("unable to scan key id: %w", err) } @@ -26,14 +21,14 @@ func (d *database) AddKey(ctx context.Context, keyName string, key []byte) (int, return keyId, nil } -func (d *database) GetKey(ctx context.Context, keyId int) (*Key, error) { - row, err := d.queryRow(ctx, d.db, "keys/get_key", keyId) - if err != nil { - return nil, fmt.Errorf("unable to get key: %w", err) - } +func GetKey(ctx context.Context, exec Executor, keyId int) (*Key, error) { + row := exec.QueryRowContext(ctx, ` +SELECT id, kms_key_name, encrypted_key, created_at +FROM keys +WHERE id = $1;`, keyId) key := &Key{} - err = row.Scan(&key.Id, &key.Name, &key.Key, &key.Created) + err := row.Scan(&key.Id, &key.Name, &key.Key, &key.Created) if err != nil { return nil, fmt.Errorf("unable to scan key: %w", err) } diff --git a/backend/internal/database/session.go b/backend/internal/database/session.go deleted file mode 100644 index 36867b3..0000000 --- a/backend/internal/database/session.go +++ /dev/null @@ -1,122 +0,0 @@ -package database - -import ( - "context" - "crypto/rand" - "database/sql" - "encoding/base64" - "errors" - "io" - "sync" - - "github.com/coreos/go-oidc/v3/oidc" - "golang.org/x/oauth2" -) - -type SessionStore interface { - CreateState(ctx context.Context) (string, error) - CheckState(ctx context.Context, state string) (bool, error) - CreateSession(ctx context.Context, token *oauth2.Token, idToken *oidc.IDToken) (string, error) - GetSession(ctx context.Context, sessionToken string) (*Session, error) -} - -func (d *database) CreateState(ctx context.Context) (string, error) { - // Generate a random CSRF state token - tokenBytes := make([]byte, 32) - if _, err := io.ReadFull(rand.Reader, tokenBytes); err != nil { - return "", err - } - token := base64.URLEncoding.EncodeToString(tokenBytes) - - // Insert the state into the database - _, err := d.exec(ctx, d.db, "sessions/create_state", token) - if err != nil { - return "", err - } - - return token, nil -} - -func (d *database) CheckState(ctx context.Context, state string) (bool, error) { - var exists bool - row, err := d.queryRow(ctx, d.db, "sessions/check_state", state) - if err != nil { - return false, err - } - err = row.Scan(&exists) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return false, nil - } - return false, err - } - return exists, nil -} - -func (d *database) CreateSession( - ctx context.Context, - token *oauth2.Token, - idToken *oidc.IDToken, -) (sessionToken string, err error) { - // Generate a random session token - tokenBytes := make([]byte, 32) - if _, err = io.ReadFull(rand.Reader, tokenBytes); err != nil { - return - } - sessionToken = base64.URLEncoding.EncodeToString(tokenBytes) - - // Insert the session into the database - _, err = d.exec( - ctx, - d.db, - "sessions/create_session", - sessionToken, - idToken.Subject, - token.AccessToken, - token.Expiry, - ) - return -} - -func (d *database) GetSession(ctx context.Context, sessionToken string) (*Session, error) { - row, err := d.queryRow(ctx, d.db, "sessions/get_session", sessionToken) - if err != nil { - d.logger.ErrorContext(ctx, "Failed to get session", "error", err) - return nil, err - } - - var session Session - err = row.Scan(&session.Token, &session.Subject, &session.OAuthToken.AccessToken, &session.OAuthToken.Expiry) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return nil, nil - } - d.logger.ErrorContext(ctx, "Failed to scan session", "error", err) - return nil, err - } - - return &session, nil -} - -func (d *database) cleanupSessions(ctx context.Context, wg *sync.WaitGroup) { - defer wg.Done() - - result, err := d.exec(ctx, d.db, "sessions/cleanup_sessions") - if err != nil { - d.logger.Error("Failed to clean up sessions", "error", err) - return - } - - rows, err := result.RowsAffected() - if err != nil { - d.logger.ErrorContext(ctx, "Failed to get rows affected", "error", err) - return - } - d.logger.DebugContext(ctx, "Cleaned up sessions", "rows", rows) -} - -type Session struct { - Token string - Subject string - OAuthToken oauth2.Token -} diff --git a/backend/internal/database/stocks.go b/backend/internal/database/stocks.go index f74e4e8..865aec4 100644 --- a/backend/internal/database/stocks.go +++ b/backend/internal/database/stocks.go @@ -14,23 +14,15 @@ import ( var ErrStockNotFound = errors.New("stock not found") -type StockStore interface { - GetStock(ctx context.Context, symbol string) (Stock, error) - AddStock(ctx context.Context, stock Stock) error - AddRanking(ctx context.Context, symbol string, ibd50, cap20 int) error - AddStockInfo(ctx context.Context, info *StockInfo) (string, error) - GetStockInfo(ctx context.Context, id string) (*StockInfo, error) - AddAnalysis(ctx context.Context, ratingId string, analysis *analyzer.Analysis) error -} - -func (d *database) GetStock(ctx context.Context, symbol string) (Stock, error) { - row, err := d.queryRow(ctx, d.db, "stocks/get_stock", symbol) - if err != nil { - return Stock{}, err - } +func GetStock(ctx context.Context, exec Executor, symbol string) (Stock, error) { + row := exec.QueryRowContext(ctx, ` +SELECT symbol, name, ibd_url +FROM stocks +WHERE symbol = $1; +`, symbol) var stock Stock - if err = row.Scan(&stock.Symbol, &stock.Name, &stock.IBDUrl); err != nil { + if err := row.Scan(&stock.Symbol, &stock.Name, &stock.IBDUrl); err != nil { if errors.Is(err, sql.ErrNoRows) { return Stock{}, ErrStockNotFound } @@ -40,20 +32,29 @@ func (d *database) GetStock(ctx context.Context, symbol string) (Stock, error) { return stock, nil } -func (d *database) AddStock(ctx context.Context, stock Stock) error { - _, err := d.exec(ctx, d.db, "stocks/add_stock", stock.Symbol, stock.Name, stock.IBDUrl) +func AddStock(ctx context.Context, exec Executor, stock Stock) error { + _, err := exec.ExecContext(ctx, ` +INSERT INTO stocks (symbol, name, ibd_url) +VALUES ($1, $2, $3) +ON CONFLICT (symbol) + DO UPDATE SET name = $2, + ibd_url = $3;`, stock.Symbol, stock.Name, stock.IBDUrl) return err } -func (d *database) AddRanking(ctx context.Context, symbol string, ibd50, cap20 int) error { +func AddRanking(ctx context.Context, exec Executor, symbol string, ibd50, cap20 int) error { if ibd50 > 0 { - _, err := d.exec(ctx, d.db, "stocks/add_rank", symbol, "ibd50", ibd50) + _, err := exec.ExecContext(ctx, ` +INSERT INTO stock_rank (symbol, rank_type, rank) +VALUES ($1, $2, $3)`, symbol, "ibd50", ibd50) if err != nil { return err } } if cap20 > 0 { - _, err := d.exec(ctx, d.db, "stocks/add_rank", symbol, "cap20", cap20) + _, err := exec.ExecContext(ctx, ` +INSERT INTO stock_rank (symbol, rank_type, rank) +VALUES ($1, $2, $3)`, symbol, "cap20", cap20) if err != nil { return err } @@ -61,8 +62,8 @@ func (d *database) AddRanking(ctx context.Context, symbol string, ibd50, cap20 i return nil } -func (d *database) AddStockInfo(ctx context.Context, info *StockInfo) (string, error) { - tx, err := d.db.BeginTx(ctx, nil) +func AddStockInfo(ctx context.Context, exec TransactionExecutor, info *StockInfo) (string, error) { + tx, err := exec.BeginTx(ctx, nil) if err != nil { return "", err } @@ -71,10 +72,10 @@ func (d *database) AddStockInfo(ctx context.Context, info *StockInfo) (string, e }(tx) // Add raw chart analysis - row, err := d.queryRow(ctx, tx, "stocks/add_raw_chart_analysis", info.ChartAnalysis) - if err != nil { - return "", err - } + row := tx.QueryRowContext(ctx, ` +INSERT INTO chart_analysis (raw_analysis) +VALUES ($1) +RETURNING id;`, info.ChartAnalysis) var chartAnalysisID string if err = row.Scan(&chartAnalysisID); err != nil { @@ -82,8 +83,11 @@ func (d *database) AddStockInfo(ctx context.Context, info *StockInfo) (string, e } // Add stock info - row, err = d.queryRow(ctx, tx, - "stocks/add_rating", + row = tx.QueryRowContext(ctx, + ` +INSERT INTO ratings (symbol, composite, eps, rel_str, group_rel_str, smr, acc_dis, chart_analysis, price) +VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) +RETURNING id;`, info.Symbol, info.Ratings.Composite, info.Ratings.EPS, @@ -94,9 +98,6 @@ func (d *database) AddStockInfo(ctx context.Context, info *StockInfo) (string, e chartAnalysisID, info.Price.Display(), ) - if err != nil { - return "", err - } var ratingsID string if err = row.Scan(&ratingsID); err != nil { @@ -106,15 +107,26 @@ func (d *database) AddStockInfo(ctx context.Context, info *StockInfo) (string, e return ratingsID, tx.Commit() } -func (d *database) GetStockInfo(ctx context.Context, id string) (*StockInfo, error) { - row, err := d.queryRow(ctx, d.db, "stocks/get_stock_info", id) - if err != nil { - return nil, err - } +func GetStockInfo(ctx context.Context, exec Executor, id string) (*StockInfo, error) { + row := exec.QueryRowContext(ctx, ` +SELECT r.symbol, + s.name, + ca.raw_analysis, + r.composite, + r.eps, + r.rel_str, + r.group_rel_str, + r.smr, + r.acc_dis, + r.price +FROM ratings r + INNER JOIN stocks s on r.symbol = s.symbol + INNER JOIN chart_analysis ca on r.chart_analysis = ca.id +WHERE r.id = $1;`, id) var info StockInfo var priceStr string - err = row.Scan( + err := row.Scan( &info.Symbol, &info.Name, &info.ChartAnalysis, @@ -138,8 +150,17 @@ func (d *database) GetStockInfo(ctx context.Context, id string) (*StockInfo, err return &info, nil } -func (d *database) AddAnalysis(ctx context.Context, ratingId string, analysis *analyzer.Analysis) error { - _, err := d.exec(ctx, d.db, "stocks/add_analysis", +func AddAnalysis(ctx context.Context, exec Executor, ratingId string, analysis *analyzer.Analysis) error { + _, err := exec.ExecContext(ctx, ` +UPDATE chart_analysis ca +SET processed = true, + action = $2, + price = $3, + reason = $4, + confidence = $5 +FROM ratings r +WHERE r.id = $1 + AND r.chart_analysis = ca.id;`, ratingId, analysis.Action, analysis.Price.Display(), diff --git a/backend/internal/database/users.go b/backend/internal/database/users.go index ff6f674..d023598 100644 --- a/backend/internal/database/users.go +++ b/backend/internal/database/users.go @@ -9,35 +9,25 @@ import ( "github.com/ansg191/ibd-trader-backend/internal/keys" ) -type UserStore interface { - AddUser(ctx context.Context, subject string) error - GetUser(ctx context.Context, subject string) (*User, error) - ListUsers(ctx context.Context, hasIBDCreds bool) ([]User, error) - AddIBDCreds(ctx context.Context, subject string, username string, password string) error - GetIBDCreds(ctx context.Context, subject string) (username string, password string, err error) -} - var ErrUserNotFound = fmt.Errorf("user not found") var ErrIBDCredsNotFound = fmt.Errorf("ibd creds not found") -func (d *database) AddUser(ctx context.Context, subject string) (err error) { - _, err = d.exec( - ctx, - d.db, - "users/add_user", - subject, - ) +func AddUser(ctx context.Context, exec Executor, subject string) (err error) { + _, err = exec.ExecContext(ctx, ` +INSERT INTO users (subject) +VALUES ($1) +ON CONFLICT DO NOTHING;`, subject) return } -func (d *database) GetUser(ctx context.Context, subject string) (*User, error) { - row, err := d.queryRow(ctx, d.db, "users/get_user", subject) - if err != nil { - return nil, fmt.Errorf("unable to get user: %w", err) - } +func GetUser(ctx context.Context, exec Executor, subject string) (*User, error) { + row := exec.QueryRowContext(ctx, ` +SELECT subject, ibd_username, ibd_password, encryption_key +FROM users +WHERE subject = $1;`, subject) user := &User{} - err = row.Scan(&user.Subject, &user.IBDUsername, &user.EncryptedIBDPassword, &user.EncryptionKeyID) + err := row.Scan(&user.Subject, &user.IBDUsername, &user.EncryptedIBDPassword, &user.EncryptionKeyID) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, ErrUserNotFound @@ -48,8 +38,11 @@ func (d *database) GetUser(ctx context.Context, subject string) (*User, error) { return user, nil } -func (d *database) ListUsers(ctx context.Context, hasIBDCreds bool) ([]User, error) { - rows, err := d.query(ctx, d.db, "users/list_users") +func ListUsers(ctx context.Context, exec Executor, hasIBDCreds bool) ([]User, error) { + rows, err := exec.QueryContext(ctx, ` +SELECT subject, ibd_username, ibd_password, encryption_key +FROM users; +`) if err != nil { return nil, fmt.Errorf("unable to list users: %w", err) } @@ -71,13 +64,18 @@ func (d *database) ListUsers(ctx context.Context, hasIBDCreds bool) ([]User, err return users, nil } -func (d *database) AddIBDCreds(ctx context.Context, subject string, username string, password string) error { - encryptedPass, encryptedKey, err := keys.Encrypt(ctx, d.kms, d.keyName, []byte(password)) +func AddIBDCreds( + ctx context.Context, + exec TransactionExecutor, + kms keys.KeyManagementService, + keyName, subject, username, password string, +) error { + encryptedPass, encryptedKey, err := keys.Encrypt(ctx, kms, keyName, []byte(password)) if err != nil { return fmt.Errorf("unable to encrypt password: %w", err) } - tx, err := d.db.BeginTx(ctx, nil) + tx, err := exec.BeginTx(ctx, nil) if err != nil { return err } @@ -85,18 +83,17 @@ func (d *database) AddIBDCreds(ctx context.Context, subject string, username str _ = tx.Rollback() }(tx) - row, err := d.queryRow(ctx, tx, "keys/add_key", d.keyName, encryptedKey) + keyId, err := AddKey(ctx, tx, keyName, encryptedKey) if err != nil { return fmt.Errorf("unable to add ibd creds key: %w", err) } - var keyId int - err = row.Scan(&keyId) - if err != nil { - return fmt.Errorf("unable to scan key id: %w", err) - } - - _, err = d.exec(ctx, tx, "users/add_ibd_creds", subject, username, encryptedPass, keyId) + _, err = exec.ExecContext(ctx, ` +UPDATE users +SET ibd_username = $2, + ibd_password = $3, + encryption_key = $4 +WHERE subject = $1;`, subject, username, encryptedPass, keyId) if err != nil { return fmt.Errorf("unable to add ibd creds to user: %w", err) } @@ -108,11 +105,21 @@ func (d *database) AddIBDCreds(ctx context.Context, subject string, username str return nil } -func (d *database) GetIBDCreds(ctx context.Context, subject string) (username string, password string, err error) { - row, err := d.queryRow(ctx, d.db, "users/get_ibd_creds", subject) - if err != nil { - return "", "", fmt.Errorf("unable to get ibd creds: %w", err) - } +func GetIBDCreds( + ctx context.Context, + exec Executor, + kms keys.KeyManagementService, + subject string, +) ( + username string, + password string, + err error, +) { + row := exec.QueryRowContext(ctx, ` +SELECT ibd_username, ibd_password, encrypted_key, kms_key_name +FROM users +INNER JOIN public.keys k on k.id = users.encryption_key +WHERE subject = $1;`, subject) var encryptedPass, encryptedKey []byte var keyName string @@ -124,7 +131,7 @@ func (d *database) GetIBDCreds(ctx context.Context, subject string) (username st return "", "", fmt.Errorf("unable to scan sql row into ibd creds: %w", err) } - passwordBytes, err := keys.Decrypt(ctx, d.kms, keyName, encryptedPass, encryptedKey) + passwordBytes, err := keys.Decrypt(ctx, kms, keyName, encryptedPass, encryptedKey) if err != nil { return "", "", fmt.Errorf("unable to decrypt password: %w", err) } diff --git a/backend/internal/ibd/auth_test.go b/backend/internal/ibd/auth_test.go index 54ea98a..157b507 100644 --- a/backend/internal/ibd/auth_test.go +++ b/backend/internal/ibd/auth_test.go @@ -163,7 +163,7 @@ func TestClient_Authenticate(t *testing.T) { return resp, nil }) - client := NewClient(nil, newTransport(tp)) + client := NewClient(nil, nil, newTransport(tp)) cookie, err := client.Authenticate(context.Background(), "abc", "xyz") require.NoError(t, err) @@ -189,7 +189,7 @@ func TestClient_Authenticate_401(t *testing.T) { return httpmock.NewStringResponse(http.StatusUnauthorized, `{"name":"ValidationError","code":"ERR016","message":"Wrong username or password","description":"Wrong username or password"}`), nil }) - client := NewClient(nil, newTransport(tp)) + client := NewClient(nil, nil, newTransport(tp)) cookie, err := client.Authenticate(context.Background(), "abc", "xyz") assert.Nil(t, cookie) diff --git a/backend/internal/ibd/client.go b/backend/internal/ibd/client.go index 25c5173..c8575e3 100644 --- a/backend/internal/ibd/client.go +++ b/backend/internal/ibd/client.go @@ -9,6 +9,7 @@ import ( "github.com/ansg191/ibd-trader-backend/internal/database" "github.com/ansg191/ibd-trader-backend/internal/ibd/transport" + "github.com/ansg191/ibd-trader-backend/internal/keys" ) var ErrNoAvailableCookies = errors.New("no available cookies") @@ -16,20 +17,22 @@ var ErrNoAvailableTransports = errors.New("no available transports") type Client struct { transports []transport.Transport - cookies database.CookieSource + db database.Executor + kms keys.KeyManagementService } func NewClient( - cookies database.CookieSource, + db database.Executor, + kms keys.KeyManagementService, transports ...transport.Transport, ) *Client { - return &Client{transports, cookies} + return &Client{transports, db, kms} } func (c *Client) getCookie(ctx context.Context, subject *string) (uint, *http.Cookie, error) { if subject == nil { // No subject requirement, get any cookie - cookie, err := c.cookies.GetAnyCookie(ctx) + cookie, err := database.GetAnyCookie(ctx, c.db, c.kms) if err != nil { return 0, nil, err } @@ -41,7 +44,7 @@ func (c *Client) getCookie(ctx context.Context, subject *string) (uint, *http.Co } // Get cookie by subject - cookies, err := c.cookies.GetCookies(ctx, *subject, false) + cookies, err := database.GetCookies(ctx, c.db, c.kms, *subject, false) if err != nil { return 0, nil, err } diff --git a/backend/internal/ibd/client_test.go b/backend/internal/ibd/client_test.go index d2dc1b2..2368a31 100644 --- a/backend/internal/ibd/client_test.go +++ b/backend/internal/ibd/client_test.go @@ -2,30 +2,155 @@ package ibd import ( "context" + "database/sql" + "fmt" + "log" + "math/rand/v2" "testing" "time" "github.com/ansg191/ibd-trader-backend/internal/database" + "github.com/ansg191/ibd-trader-backend/internal/keys" + _ "github.com/lib/pq" + "github.com/ory/dockertest/v3" + "github.com/ory/dockertest/v3/docker" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestClient_getCookie(t *testing.T) { - t.Parallel() +var ( + db *sql.DB + maxTime = time.Date(2100, 1, 1, 0, 0, 0, 0, time.UTC) + letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") +) - t.Run("no cookies", func(t *testing.T) { - t.Parallel() +func TestMain(m *testing.M) { + pool, err := dockertest.NewPool("") + if err != nil { + log.Fatalf("Could not create pool: %s", err) + } + + err = pool.Client.Ping() + if err != nil { + log.Fatalf("Could not connect to Docker: %s", err) + } + + resource, err := pool.RunWithOptions(&dockertest.RunOptions{ + Repository: "postgres", + Tag: "16", + Env: []string{ + "POSTGRES_PASSWORD=secret", + "POSTGRES_USER=ibd-client-test", + "POSTGRES_DB=ibd-client-test", + "listen_addresses='*'", + }, + Cmd: []string{ + "postgres", + "-c", + "log_statement=all", + }, + }, func(config *docker.HostConfig) { + config.AutoRemove = true + config.RestartPolicy = docker.RestartPolicy{Name: "no"} + }) + if err != nil { + log.Fatalf("Could not start resource: %s", err) + } + + hostAndPort := resource.GetHostPort("5432/tcp") + databaseUrl := fmt.Sprintf("postgres://ibd-client-test:secret@%s/ibd-client-test?sslmode=disable", hostAndPort) + + // Kill container after 120 seconds + _ = resource.Expire(120) + + pool.MaxWait = 120 * time.Second + if err = pool.Retry(func() error { + db, err = sql.Open("postgres", databaseUrl) + if err != nil { + return err + } + return db.Ping() + }); err != nil { + log.Fatalf("Could not connect to database: %s", err) + } + + err = database.Migrate(context.Background(), databaseUrl) + if err != nil { + log.Fatalf("Could not migrate database: %s", err) + } + + defer func() { + if err := pool.Purge(resource); err != nil { + log.Fatalf("Could not purge resource: %s", err) + } + }() + + m.Run() +} + +func randStringRunes(n int) string { + b := make([]rune, n) + for i := range b { + b[i] = letterRunes[rand.IntN(len(letterRunes))] + } + return string(b) +} + +func addCookie(t *testing.T) (user, token string) { + t.Helper() + + // Randomly generate a user and token + user = randStringRunes(8) + token = randStringRunes(16) + + ciphertext, key, err := keys.Encrypt(context.Background(), new(kmsStub), "", []byte(token)) + require.NoError(t, err) + + tx, err := db.Begin() + require.NoError(t, err) + + var keyID uint + err = tx.QueryRow(` +INSERT INTO keys (kms_key_name, encrypted_key) + VALUES ('', $1) + RETURNING id; +`, key).Scan(&keyID) + require.NoError(t, err) + + _, err = tx.Exec(` +INSERT +INTO users (subject, encryption_key) +VALUES ($1, $2); +`, user, keyID) + require.NoError(t, err) + + _, err = tx.Exec(` +INSERT +INTO ibd_tokens (user_subject, token, encryption_key, expires_at) +VALUES ($1, $2, $3, $4);`, + user, + ciphertext, + keyID, + maxTime, + ) + require.NoError(t, err) + + err = tx.Commit() + require.NoError(t, err) + + return user, token +} - client := NewClient(new(emptyCookieSourceStub)) +func TestClient_getCookie(t *testing.T) { + t.Run("no cookies", func(t *testing.T) { + client := NewClient(db, new(kmsStub)) _, _, err := client.getCookie(context.Background(), nil) assert.ErrorIs(t, err, ErrNoAvailableCookies) }) t.Run("no cookies by subject", func(t *testing.T) { - t.Parallel() - - client := NewClient(new(emptyCookieSourceStub)) + client := NewClient(db, new(kmsStub)) subject := "test" _, _, err := client.getCookie(context.Background(), &subject) @@ -33,67 +158,44 @@ func TestClient_getCookie(t *testing.T) { }) t.Run("get any cookie", func(t *testing.T) { - t.Parallel() + _, token := addCookie(t) - client := NewClient(new(cookieSourceStub)) + client := NewClient(db, new(kmsStub)) - id, cookie, err := client.getCookie(context.Background(), nil) + _, cookie, err := client.getCookie(context.Background(), nil) require.NoError(t, err) - assert.Equal(t, uint(42), id) assert.Equal(t, cookieName, cookie.Name) - assert.Equal(t, "test-token", cookie.Value) + assert.Equal(t, token, cookie.Value) assert.Equal(t, "/", cookie.Path) - assert.Equal(t, time.Unix(0, 0), cookie.Expires) + assert.Equal(t, maxTime, cookie.Expires) assert.Equal(t, "investors.com", cookie.Domain) }) t.Run("get cookie by subject", func(t *testing.T) { - t.Parallel() + subject, token := addCookie(t) - client := NewClient(new(cookieSourceStub)) + client := NewClient(db, new(kmsStub)) - subject := "test" - id, cookie, err := client.getCookie(context.Background(), &subject) + _, cookie, err := client.getCookie(context.Background(), &subject) require.NoError(t, err) - assert.Equal(t, uint(42), id) assert.Equal(t, cookieName, cookie.Name) - assert.Equal(t, "test-token", cookie.Value) + assert.Equal(t, token, cookie.Value) assert.Equal(t, "/", cookie.Path) - assert.Equal(t, time.Unix(0, 0), cookie.Expires) + assert.Equal(t, maxTime, cookie.Expires) assert.Equal(t, "investors.com", cookie.Domain) }) } -type emptyCookieSourceStub struct{} - -func (c *emptyCookieSourceStub) GetAnyCookie(_ context.Context) (*database.IBDCookie, error) { - return nil, nil -} - -func (c *emptyCookieSourceStub) GetCookies(_ context.Context, _ string, _ bool) ([]database.IBDCookie, error) { - return nil, nil -} +type kmsStub struct{} -func (c *emptyCookieSourceStub) ReportCookieFailure(_ context.Context, _ uint) error { +func (k *kmsStub) Close() error { return nil } -var testCookie = database.IBDCookie{ - ID: 42, - Token: "test-token", - Expiry: time.Unix(0, 0), +func (k *kmsStub) Encrypt(_ context.Context, _ string, plaintext []byte) ([]byte, error) { + return plaintext, nil } -type cookieSourceStub struct{} - -func (c *cookieSourceStub) GetAnyCookie(_ context.Context) (*database.IBDCookie, error) { - return &testCookie, nil -} - -func (c *cookieSourceStub) GetCookies(_ context.Context, _ string, _ bool) ([]database.IBDCookie, error) { - return []database.IBDCookie{testCookie}, nil -} - -func (c *cookieSourceStub) ReportCookieFailure(_ context.Context, _ uint) error { - return nil +func (k *kmsStub) Decrypt(_ context.Context, _ string, ciphertext []byte) ([]byte, error) { + return ciphertext, nil } diff --git a/backend/internal/ibd/ibd50.go b/backend/internal/ibd/ibd50.go index ea02f82..52e28aa 100644 --- a/backend/internal/ibd/ibd50.go +++ b/backend/internal/ibd/ibd50.go @@ -9,6 +9,8 @@ import ( "net/http" "net/url" "strconv" + + "github.com/ansg191/ibd-trader-backend/internal/database" ) const ibd50Url = "https://research.investors.com/Services/SiteAjaxService.asmx/GetIBD50?sortcolumn1=%22ibd100rank%22&sortOrder1=%22asc%22&sortcolumn2=%22%22&sortOrder2=%22ASC%22" @@ -63,7 +65,7 @@ func (c *Client) GetIBD50(ctx context.Context) ([]*Stock, error) { // If there are less than 10 stocks in the IBD50 list, it's likely that authentication failed. if len(ibd50Resp.D.ETablesDataList) < 10 { // Report cookie failure to DB - if err = c.cookies.ReportCookieFailure(ctx, cookieId); err != nil { + if err = database.ReportCookieFailure(ctx, c.db, cookieId); err != nil { slog.Error("Failed to report cookie failure", "error", err) } return nil, errors.New("failed to get IBD50 list") diff --git a/backend/internal/ibd/search_test.go b/backend/internal/ibd/search_test.go index 99157cf..05e93dc 100644 --- a/backend/internal/ibd/search_test.go +++ b/backend/internal/ibd/search_test.go @@ -162,8 +162,6 @@ const emptySearchResponseJSON = ` }` func TestClient_Search(t *testing.T) { - t.Parallel() - tests := []struct { name string response string @@ -195,7 +193,11 @@ func TestClient_Search(t *testing.T) { tp := httpmock.NewMockTransport() tp.RegisterResponder("GET", searchUrl, httpmock.NewStringResponder(200, tt.response)) - client := NewClient(new(cookieSourceStub), transport.NewStandardTransport(&http.Client{Transport: tp})) + client := NewClient( + db, + new(kmsStub), + transport.NewStandardTransport(&http.Client{Transport: tp}), + ) tt.f(t, client) }) diff --git a/backend/internal/leader/manager/ibd/auth/auth.go b/backend/internal/leader/manager/ibd/auth/auth.go index 3419afd..9b5502d 100644 --- a/backend/internal/leader/manager/ibd/auth/auth.go +++ b/backend/internal/leader/manager/ibd/auth/auth.go @@ -20,13 +20,13 @@ const ( // Manager is responsible for sending authentication tasks to the workers. type Manager struct { queue taskqueue.TaskQueue[TaskInfo] - store database.UserStore + db database.Executor schedule cron.Schedule } func New( ctx context.Context, - store database.UserStore, + db database.Executor, rClient *redis.Client, schedule cron.Schedule, ) (*Manager, error) { @@ -43,7 +43,7 @@ func New( return &Manager{ queue: queue, - store: store, + db: db, schedule: schedule, }, nil } @@ -84,7 +84,7 @@ func (m *Manager) scrapeCookies(ctx context.Context, deadline time.Time) { defer cancel() // Get all users with IBD credentials - users, err := m.store.ListUsers(ctx, true) + users, err := database.ListUsers(ctx, m.db, true) if err != nil { slog.ErrorContext(ctx, "failed to get users", "error", err) return diff --git a/backend/internal/leader/manager/ibd/scrape/scrape.go b/backend/internal/leader/manager/ibd/scrape/scrape.go index e6cf490..870ce5e 100644 --- a/backend/internal/leader/manager/ibd/scrape/scrape.go +++ b/backend/internal/leader/manager/ibd/scrape/scrape.go @@ -23,7 +23,7 @@ const ( // Manager is responsible for sending scraping tasks to the workers. type Manager struct { client *ibd.Client - store database.StockStore + db database.Executor queue taskqueue.TaskQueue[TaskInfo] schedule cron.Schedule pubsub *redis.PubSub @@ -32,7 +32,7 @@ type Manager struct { func New( ctx context.Context, client *ibd.Client, - store database.StockStore, + db database.Executor, redis *redis.Client, schedule cron.Schedule, ) (*Manager, error) { @@ -49,7 +49,7 @@ func New( return &Manager{ client: client, - store: store, + db: db, queue: queue, schedule: schedule, pubsub: redis.Subscribe(ctx, Channel), @@ -107,7 +107,7 @@ func (m *Manager) scrapeIBD50(ctx context.Context, deadline time.Time) { for _, stock := range stocks { // Add stock to DB - err = m.store.AddStock(ctx, database.Stock{ + err = database.AddStock(ctx, m.db, database.Stock{ Symbol: stock.Symbol, Name: stock.Name, IBDUrl: stock.QuoteURL.String(), @@ -118,7 +118,7 @@ func (m *Manager) scrapeIBD50(ctx context.Context, deadline time.Time) { } // Add ranking to Db - err = m.store.AddRanking(ctx, stock.Symbol, int(stock.Rank), 0) + err = database.AddRanking(ctx, m.db, stock.Symbol, int(stock.Rank), 0) if err != nil { slog.ErrorContext(ctx, "failed to add ranking", "error", err) continue diff --git a/backend/internal/server/idb/stock/v1/stock.go b/backend/internal/server/idb/stock/v1/stock.go index d30bde3..8afc2b1 100644 --- a/backend/internal/server/idb/stock/v1/stock.go +++ b/backend/internal/server/idb/stock/v1/stock.go @@ -22,11 +22,11 @@ const ScrapeOperationPrefix = "scrape" type Server struct { pb.UnimplementedStockServiceServer - db database.StockStore + db database.Executor queue taskqueue.TaskQueue[scrape.TaskInfo] } -func New(db database.StockStore, queue taskqueue.TaskQueue[scrape.TaskInfo]) *Server { +func New(db database.Executor, queue taskqueue.TaskQueue[scrape.TaskInfo]) *Server { return &Server{db: db, queue: queue} } diff --git a/backend/internal/server/idb/user/v1/user.go b/backend/internal/server/idb/user/v1/user.go index c100465..2f32e03 100644 --- a/backend/internal/server/idb/user/v1/user.go +++ b/backend/internal/server/idb/user/v1/user.go @@ -7,6 +7,7 @@ import ( pb "github.com/ansg191/ibd-trader-backend/api/gen/idb/user/v1" "github.com/ansg191/ibd-trader-backend/internal/database" "github.com/ansg191/ibd-trader-backend/internal/ibd" + "github.com/ansg191/ibd-trader-backend/internal/keys" "github.com/mennanov/fmutils" "google.golang.org/grpc/codes" @@ -17,26 +18,28 @@ import ( type Server struct { pb.UnimplementedUserServiceServer - user database.UserStore - cookie database.CookieSource - client *ibd.Client + db database.TransactionExecutor + kms keys.KeyManagementService + keyName string + client *ibd.Client } -func New(userStore database.UserStore, cookieStore database.CookieStore, client *ibd.Client) *Server { +func New(db database.TransactionExecutor, kms keys.KeyManagementService, keyName string, client *ibd.Client) *Server { return &Server{ - user: userStore, - cookie: cookieStore, - client: client, + db: db, + kms: kms, + keyName: keyName, + client: client, } } func (u *Server) CreateUser(ctx context.Context, request *pb.CreateUserRequest) (*pb.CreateUserResponse, error) { - err := u.user.AddUser(ctx, request.Subject) + err := database.AddUser(ctx, u.db, request.Subject) if err != nil { return nil, status.Errorf(codes.Internal, "unable to create user: %v", err) } - user, err := u.user.GetUser(ctx, request.Subject) + user, err := database.GetUser(ctx, u.db, request.Subject) if err != nil { return nil, status.Errorf(codes.Internal, "unable to get user: %v", err) } @@ -51,7 +54,7 @@ func (u *Server) CreateUser(ctx context.Context, request *pb.CreateUserRequest) } func (u *Server) GetUser(ctx context.Context, request *pb.GetUserRequest) (*pb.GetUserResponse, error) { - user, err := u.user.GetUser(ctx, request.Subject) + user, err := database.GetUser(ctx, u.db, request.Subject) if errors.Is(err, database.ErrUserNotFound) { return nil, status.New(codes.NotFound, "user not found").Err() } @@ -88,7 +91,7 @@ func (u *Server) UpdateUser(ctx context.Context, request *pb.UpdateUserRequest) (newUser.IbdPassword != existingUser.IbdPassword || newUser.IbdUsername != existingUser.IbdUsername) { // Update IBD creds - err = u.user.AddIBDCreds(ctx, newUser.Subject, *newUser.IbdUsername, *newUser.IbdPassword) + err = database.AddIBDCreds(ctx, u.db, u.kms, u.keyName, newUser.Subject, *newUser.IbdUsername, *newUser.IbdPassword) if err != nil { return nil, status.Errorf(codes.Internal, "unable to update user: %v", err) } @@ -119,7 +122,7 @@ func (u *Server) CheckIBDUsername(ctx context.Context, req *pb.CheckIBDUsernameR func (u *Server) AuthenticateUser(ctx context.Context, req *pb.AuthenticateUserRequest) (*pb.AuthenticateUserResponse, error) { // Check if user has cookies - cookies, err := u.cookie.GetCookies(ctx, req.Subject, false) + cookies, err := database.GetCookies(ctx, u.db, u.kms, req.Subject, false) if err != nil { return nil, status.Errorf(codes.Internal, "unable to get cookies: %v", err) } @@ -131,7 +134,7 @@ func (u *Server) AuthenticateUser(ctx context.Context, req *pb.AuthenticateUserR // Authenticate user // Get IBD creds - username, password, err := u.user.GetIBDCreds(ctx, req.Subject) + username, password, err := database.GetIBDCreds(ctx, u.db, u.kms, req.Subject) if errors.Is(err, database.ErrIBDCredsNotFound) { return nil, status.New(codes.NotFound, "User has no IDB creds").Err() } diff --git a/backend/internal/server/server.go b/backend/internal/server/server.go index 186d581..c525cfd 100644 --- a/backend/internal/server/server.go +++ b/backend/internal/server/server.go @@ -11,6 +11,7 @@ import ( upb "github.com/ansg191/ibd-trader-backend/api/gen/idb/user/v1" "github.com/ansg191/ibd-trader-backend/internal/database" "github.com/ansg191/ibd-trader-backend/internal/ibd" + "github.com/ansg191/ibd-trader-backend/internal/keys" "github.com/ansg191/ibd-trader-backend/internal/leader/manager/ibd/scrape" "github.com/ansg191/ibd-trader-backend/internal/redis/taskqueue" "github.com/ansg191/ibd-trader-backend/internal/server/idb/stock/v1" @@ -30,9 +31,11 @@ type Server struct { func New( ctx context.Context, port uint16, - db database.Database, + db database.TransactionExecutor, rClient *redis.Client, client *ibd.Client, + kms keys.KeyManagementService, + keyName string, ) (*Server, error) { scrapeQueue, err := taskqueue.New( ctx, @@ -45,7 +48,7 @@ func New( } s := grpc.NewServer() - upb.RegisterUserServiceServer(s, user.New(db, db, client)) + upb.RegisterUserServiceServer(s, user.New(db, kms, keyName, client)) spb.RegisterStockServiceServer(s, stock.New(db, scrapeQueue)) longrunningpb.RegisterOperationsServer(s, newOperationServer(scrapeQueue)) reflection.Register(s) diff --git a/backend/internal/worker/analyzer/analyzer.go b/backend/internal/worker/analyzer/analyzer.go index 79a35ee..ea8069e 100644 --- a/backend/internal/worker/analyzer/analyzer.go +++ b/backend/internal/worker/analyzer/analyzer.go @@ -24,7 +24,7 @@ func RunAnalyzer( ctx context.Context, redis *redis.Client, analyzer analyzer.Analyzer, - db database.StockStore, + db database.Executor, name string, ) error { queue, err := taskqueue.New( @@ -52,7 +52,7 @@ func waitForTask( ctx context.Context, queue taskqueue.TaskQueue[TaskInfo], analyzer analyzer.Analyzer, - db database.StockStore, + db database.Executor, ) { task, err := queue.Dequeue(ctx, lockTimeout, dequeueTimeout) if err != nil { @@ -111,8 +111,8 @@ func waitForTask( } } -func analyzeStock(ctx context.Context, a analyzer.Analyzer, db database.StockStore, id string) error { - info, err := db.GetStockInfo(ctx, id) +func analyzeStock(ctx context.Context, a analyzer.Analyzer, db database.Executor, id string) error { + info, err := database.GetStockInfo(ctx, db, id) if err != nil { return err } @@ -127,7 +127,7 @@ func analyzeStock(ctx context.Context, a analyzer.Analyzer, db database.StockSto return err } - return db.AddAnalysis(ctx, id, analysis) + return database.AddAnalysis(ctx, db, id, analysis) } type TaskInfo struct { diff --git a/backend/internal/worker/auth/auth.go b/backend/internal/worker/auth/auth.go index 1f591fe..579a180 100644 --- a/backend/internal/worker/auth/auth.go +++ b/backend/internal/worker/auth/auth.go @@ -2,12 +2,15 @@ package auth import ( "context" + "database/sql" + "errors" "fmt" "log/slog" "time" "github.com/ansg191/ibd-trader-backend/internal/database" "github.com/ansg191/ibd-trader-backend/internal/ibd" + "github.com/ansg191/ibd-trader-backend/internal/keys" "github.com/ansg191/ibd-trader-backend/internal/leader/manager/ibd/auth" "github.com/ansg191/ibd-trader-backend/internal/redis/taskqueue" @@ -23,8 +26,8 @@ func RunAuthScraper( ctx context.Context, client *ibd.Client, redis *redis.Client, - users database.UserStore, - cookies database.CookieStore, + db database.Executor, + kms keys.KeyManagementService, name string, ) error { queue, err := taskqueue.New( @@ -43,7 +46,7 @@ func RunAuthScraper( case <-ctx.Done(): return ctx.Err() default: - waitForTask(ctx, queue, client, users, cookies) + waitForTask(ctx, queue, client, db, kms) } } } @@ -52,8 +55,8 @@ func waitForTask( ctx context.Context, queue taskqueue.TaskQueue[auth.TaskInfo], client *ibd.Client, - users database.UserStore, - cookies database.CookieStore, + db database.Executor, + kms keys.KeyManagementService, ) { task, err := queue.Dequeue(ctx, lockTimeout, dequeueTimeout) if err != nil { @@ -69,7 +72,7 @@ func waitForTask( ch := make(chan error) defer close(ch) go func() { - ch <- scrapeCookies(ctx, client, users, cookies, task.Data.UserSubject) + ch <- scrapeCookies(ctx, client, db, kms, task.Data.UserSubject) }() ticker := time.NewTicker(lockTimeout / 5) @@ -116,15 +119,15 @@ func waitForTask( func scrapeCookies( ctx context.Context, client *ibd.Client, - users database.UserStore, - store database.CookieStore, + db database.Executor, + kms keys.KeyManagementService, user string, ) error { ctx, cancel := context.WithTimeout(ctx, lockTimeout) defer cancel() // Check if the user has valid cookies - done, err := hasValidCookies(ctx, store, user) + done, err := hasValidCookies(ctx, db, user) if err != nil { return fmt.Errorf("failed to check cookies: %w", err) } @@ -133,7 +136,7 @@ func scrapeCookies( } // Health check degraded cookies - done, err = healthCheckDegradedCookies(ctx, client, store, user) + done, err = healthCheckDegradedCookies(ctx, client, db, kms, user) if err != nil { return fmt.Errorf("failed to health check cookies: %w", err) } @@ -142,31 +145,39 @@ func scrapeCookies( } // No cookies are valid, so scrape new cookies - return scrapeNewCookies(ctx, client, users, store, user) + return scrapeNewCookies(ctx, client, db, kms, user) } -func hasValidCookies(ctx context.Context, store database.CookieStore, user string) (bool, error) { +func hasValidCookies(ctx context.Context, db database.Executor, user string) (bool, error) { // Check if the user has non-degraded cookies - cookies, err := store.GetCookies(ctx, user, false) + row := db.QueryRowContext(ctx, ` +SELECT 1 +FROM ibd_tokens +WHERE user_subject = $1 + AND expires_at > NOW() + AND degraded = FALSE;`, user) + + var exists bool + err := row.Scan(&exists) + if errors.Is(err, sql.ErrNoRows) { + return false, nil + } if err != nil { return false, fmt.Errorf("failed to get non-degraded cookies: %w", err) } - // If the user has non-degraded cookies, return true - if len(cookies) > 0 { - return true, nil - } - return false, nil + return true, nil } func healthCheckDegradedCookies( ctx context.Context, client *ibd.Client, - store database.CookieStore, + db database.Executor, + kms keys.KeyManagementService, user string, ) (bool, error) { // Check if the user has degraded cookies - cookies, err := store.GetCookies(ctx, user, true) + cookies, err := database.GetCookies(ctx, db, kms, user, true) if err != nil { return false, fmt.Errorf("failed to get degraded cookies: %w", err) } @@ -190,7 +201,7 @@ func healthCheckDegradedCookies( valid = true // Update the cookie - err = store.RepairCookie(ctx, cookie.ID) + err = database.RepairCookie(ctx, db, cookie.ID) if err != nil { slog.ErrorContext(ctx, "Failed to repair cookie", "error", err) } @@ -202,12 +213,12 @@ func healthCheckDegradedCookies( func scrapeNewCookies( ctx context.Context, client *ibd.Client, - users database.UserStore, - store database.CookieStore, + db database.Executor, + kms keys.KeyManagementService, user string, ) error { // Get the user's credentials - username, password, err := users.GetIBDCreds(ctx, user) + username, password, err := database.GetIBDCreds(ctx, db, kms, user) if err != nil { return fmt.Errorf("failed to get IBD credentials: %w", err) } @@ -219,7 +230,7 @@ func scrapeNewCookies( } // Store the cookie - err = store.AddCookie(ctx, user, cookie) + err = database.AddCookie(ctx, db, kms, user, cookie) if err != nil { return fmt.Errorf("failed to store cookie: %w", err) } diff --git a/backend/internal/worker/scraper/scraper.go b/backend/internal/worker/scraper/scraper.go index ec71d62..4788834 100644 --- a/backend/internal/worker/scraper/scraper.go +++ b/backend/internal/worker/scraper/scraper.go @@ -25,7 +25,7 @@ func RunScraper( ctx context.Context, redis *redis.Client, client *ibd.Client, - store database.StockStore, + db database.TransactionExecutor, name string, ) error { queue, err := taskqueue.New( @@ -55,7 +55,7 @@ func RunScraper( case <-ctx.Done(): return ctx.Err() default: - waitForTask(ctx, queue, aQueue, client, store) + waitForTask(ctx, queue, aQueue, client, db) } } } @@ -65,7 +65,7 @@ func waitForTask( queue taskqueue.TaskQueue[scrape.TaskInfo], aQueue taskqueue.TaskQueue[analyzer.TaskInfo], client *ibd.Client, - store database.StockStore, + db database.TransactionExecutor, ) { task, err := queue.Dequeue(ctx, lockTimeout, dequeueTimeout) if err != nil { @@ -80,7 +80,7 @@ func waitForTask( ch := make(chan error) go func() { defer close(ch) - ch <- scrapeUrl(ctx, client, store, aQueue, task.Data.Symbol) + ch <- scrapeUrl(ctx, client, db, aQueue, task.Data.Symbol) }() ticker := time.NewTicker(lockTimeout / 5) @@ -127,14 +127,14 @@ func waitForTask( func scrapeUrl( ctx context.Context, client *ibd.Client, - store database.StockStore, + db database.TransactionExecutor, aQueue taskqueue.TaskQueue[analyzer.TaskInfo], symbol string, ) error { ctx, cancel := context.WithTimeout(ctx, lockTimeout) defer cancel() - stockUrl, err := getStockUrl(ctx, store, client, symbol) + stockUrl, err := getStockUrl(ctx, db, client, symbol) if err != nil { return fmt.Errorf("failed to get stock url: %w", err) } @@ -146,7 +146,7 @@ func scrapeUrl( } // Add stock info to the database. - id, err := store.AddStockInfo(ctx, info) + id, err := database.AddStockInfo(ctx, db, info) if err != nil { return fmt.Errorf("failed to add stock info: %w", err) } @@ -162,9 +162,9 @@ func scrapeUrl( return nil } -func getStockUrl(ctx context.Context, store database.StockStore, client *ibd.Client, symbol string) (string, error) { +func getStockUrl(ctx context.Context, db database.TransactionExecutor, client *ibd.Client, symbol string) (string, error) { // Get the stock from the database. - stock, err := store.GetStock(ctx, symbol) + stock, err := database.GetStock(ctx, db, symbol) if err == nil { return stock.IBDUrl, nil } @@ -182,7 +182,7 @@ func getStockUrl(ctx context.Context, store database.StockStore, client *ibd.Cli } // Add the stock to the database. - err = store.AddStock(ctx, stock) + err = database.AddStock(ctx, db, stock) if err != nil { return "", fmt.Errorf("failed to add stock: %w", err) } diff --git a/backend/internal/worker/worker.go b/backend/internal/worker/worker.go index 3d7e9c8..6017fb7 100644 --- a/backend/internal/worker/worker.go +++ b/backend/internal/worker/worker.go @@ -12,6 +12,7 @@ import ( "github.com/ansg191/ibd-trader-backend/internal/analyzer" "github.com/ansg191/ibd-trader-backend/internal/database" "github.com/ansg191/ibd-trader-backend/internal/ibd" + "github.com/ansg191/ibd-trader-backend/internal/keys" "github.com/ansg191/ibd-trader-backend/internal/leader/manager" analyzer2 "github.com/ansg191/ibd-trader-backend/internal/worker/analyzer" "github.com/ansg191/ibd-trader-backend/internal/worker/auth" @@ -30,7 +31,8 @@ func StartWorker( ctx context.Context, ibdClient *ibd.Client, client *redis.Client, - db database.Database, + db database.TransactionExecutor, + kms keys.KeyManagementService, a analyzer.Analyzer, ) error { // Get the worker name. @@ -49,7 +51,7 @@ func StartWorker( return scraper.RunScraper(ctx, client, ibdClient, db, name) }) g.Go(func() error { - return auth.RunAuthScraper(ctx, ibdClient, client, db, db, name) + return auth.RunAuthScraper(ctx, ibdClient, client, db, kms, name) }) g.Go(func() error { return analyzer2.RunAnalyzer(ctx, client, a, db, name) |