diff options
Diffstat (limited to 'backend/internal')
57 files changed, 7393 insertions, 0 deletions
diff --git a/backend/internal/analyzer/analyzer.go b/backend/internal/analyzer/analyzer.go new file mode 100644 index 0000000..c055647 --- /dev/null +++ b/backend/internal/analyzer/analyzer.go @@ -0,0 +1,32 @@ +package analyzer + +import ( + "context" + + "github.com/Rhymond/go-money" +) + +type Analyzer interface { + Analyze( + ctx context.Context, + symbol string, + price *money.Money, + rawAnalysis string, + ) (*Analysis, error) +} + +type ChartAction string + +const ( + Buy ChartAction = "buy" + Sell ChartAction = "sell" + Hold ChartAction = "hold" + Unknown ChartAction = "unknown" +) + +type Analysis struct { + Action ChartAction + Price *money.Money + Reason string + Confidence uint8 +} diff --git a/backend/internal/analyzer/openai/openai.go b/backend/internal/analyzer/openai/openai.go new file mode 100644 index 0000000..66dfc05 --- /dev/null +++ b/backend/internal/analyzer/openai/openai.go @@ -0,0 +1,126 @@ +package openai + +import ( + "context" + "encoding/json" + "fmt" + "math" + "strconv" + "strings" + "time" + + "ibd-trader/internal/analyzer" + "ibd-trader/internal/utils" + + "github.com/Rhymond/go-money" + "github.com/sashabaranov/go-openai" +) + +type Client interface { + CreateChatCompletion( + ctx context.Context, + request openai.ChatCompletionRequest, + ) (response openai.ChatCompletionResponse, err error) +} + +type Analyzer struct { + client Client + model string + systemMsg string + temperature float32 +} + +func NewAnalyzer(opts ...Option) *Analyzer { + a := &Analyzer{ + client: nil, + model: defaultModel, + systemMsg: defaultSystemMsg, + temperature: defaultTemperature, + } + for _, option := range opts { + option(a) + } + if a.client == nil { + panic("client is required") + } + + return a +} + +func (a *Analyzer) Analyze( + ctx context.Context, + symbol string, + price *money.Money, + rawAnalysis string, +) (*analyzer.Analysis, error) { + usrMsg := fmt.Sprintf( + "%s\n%s\n%s\n%s\n", + time.Now().Format(time.RFC3339), + symbol, + price.Display(), + rawAnalysis, + ) + res, err := a.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{ + Model: a.model, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleSystem, + Content: a.systemMsg, + }, + { + Role: openai.ChatMessageRoleUser, + Content: usrMsg, + }, + }, + MaxTokens: 0, + Temperature: a.temperature, + ResponseFormat: &openai.ChatCompletionResponseFormat{Type: openai.ChatCompletionResponseFormatTypeJSONObject}, + }) + if err != nil { + return nil, err + } + + var resp response + if err = json.Unmarshal([]byte(res.Choices[0].Message.Content), &resp); err != nil { + return nil, fmt.Errorf("failed to unmarshal gpt response: %w", err) + } + + var action analyzer.ChartAction + switch strings.ToLower(resp.Action) { + case "buy": + action = analyzer.Buy + case "sell": + action = analyzer.Sell + case "hold": + action = analyzer.Hold + default: + action = analyzer.Unknown + } + + m, err := utils.ParseMoney(resp.Price) + if err != nil { + return nil, fmt.Errorf("failed to parse price: %w", err) + } + + confidence, err := strconv.ParseFloat(resp.Confidence, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse confidence: %w", err) + } + if confidence < 0 || confidence > 100 { + return nil, fmt.Errorf("confidence must be between 0 and 100, got %f", confidence) + } + + return &analyzer.Analysis{ + Action: action, + Price: m, + Reason: resp.Reason, + Confidence: uint8(math.Floor(confidence)), + }, nil +} + +type response struct { + Action string `json:"action"` + Price string `json:"price"` + Reason string `json:"reason"` + Confidence string `json:"confidence"` +} diff --git a/backend/internal/analyzer/openai/openai_test.go b/backend/internal/analyzer/openai/openai_test.go new file mode 100644 index 0000000..0aac709 --- /dev/null +++ b/backend/internal/analyzer/openai/openai_test.go @@ -0,0 +1 @@ +package openai diff --git a/backend/internal/analyzer/openai/options.go b/backend/internal/analyzer/openai/options.go new file mode 100644 index 0000000..11d691f --- /dev/null +++ b/backend/internal/analyzer/openai/options.go @@ -0,0 +1,45 @@ +package openai + +import ( + _ "embed" + + "github.com/sashabaranov/go-openai" +) + +//go:embed system.txt +var defaultSystemMsg string + +const defaultModel = openai.GPT4o +const defaultTemperature = 0.25 + +type Option func(*Analyzer) + +func WithClientConfig(cfg openai.ClientConfig) Option { + return func(a *Analyzer) { + a.client = openai.NewClientWithConfig(cfg) + } +} + +func WithDefaultConfig(apiKey string) Option { + return func(a *Analyzer) { + a.client = openai.NewClient(apiKey) + } +} + +func WithModel(model string) Option { + return func(a *Analyzer) { + a.model = model + } +} + +func WithSystemMsg(msg string) Option { + return func(a *Analyzer) { + a.systemMsg = msg + } +} + +func WithTemperature(temp float32) Option { + return func(a *Analyzer) { + a.temperature = temp + } +} diff --git a/backend/internal/analyzer/openai/system.txt b/backend/internal/analyzer/openai/system.txt new file mode 100644 index 0000000..82e3b2a --- /dev/null +++ b/backend/internal/analyzer/openai/system.txt @@ -0,0 +1,34 @@ +You're a stock analyzer. +You will be given a stock symbol, its current price, and its chart analysis. +Your job is to determine the best course of action to do with the stock (buy, sell, hold, or unknown), +the price at which the action should be taken, the reason for the action, and the confidence (0-100) +level you have in the action. + +The reason should be a paragraph explaining why you chose the action. + +The date the chart analysis was done may be mentioned in the chart analysis. +Make sure to take that into account when making your decision. +If the chart analysis is older than 1 week, lower your confidence accordingly and mention that in the reason. +If the chart analysis is too outdated, set the action to "unknown". + +The information will be given in the following format: +``` +<current datetime> +<stock symbol> +<current price> +<chart analysis> +``` + +Your response should be in the following JSON format: +``` +{ + "action": "<action>", + "price": "<price>", + "reason": "<reason>", + "confidence": "<confidence>" +} +``` +All fields are required and must be strings. +`action` must be one of the following: "buy", "sell", "hold", or "unknown". +`price` must contain the symbol for the currency and the price (e.g. "$100"). +The system WILL validate your response.
\ No newline at end of file diff --git a/backend/internal/auth/auth.go b/backend/internal/auth/auth.go new file mode 100644 index 0000000..4361b58 --- /dev/null +++ b/backend/internal/auth/auth.go @@ -0,0 +1,55 @@ +package auth + +import ( + "context" + "errors" + + "ibd-trader/internal/config" + + "github.com/coreos/go-oidc/v3/oidc" + "golang.org/x/oauth2" +) + +// Authenticator is used to authenticate our users. +type Authenticator struct { + *oidc.Provider + oauth2.Config +} + +// New instantiates the *Authenticator. +func New(cfg *config.Config) (*Authenticator, error) { + provider, err := oidc.NewProvider( + context.Background(), + "https://"+cfg.Auth.Domain+"/", + ) + if err != nil { + return nil, err + } + + conf := oauth2.Config{ + ClientID: cfg.Auth.ClientID, + ClientSecret: cfg.Auth.ClientSecret, + RedirectURL: cfg.Auth.CallbackURL, + Endpoint: provider.Endpoint(), + Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, + } + + return &Authenticator{ + Provider: provider, + Config: conf, + }, nil +} + +// VerifyIDToken verifies that an *oauth2.Token is a valid *oidc.IDToken. +func (a *Authenticator) VerifyIDToken(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) { + rawIDToken, ok := token.Extra("id_token").(string) + if !ok { + return nil, errors.New("no id_token field in oauth2 token") + } + + oidcConfig := &oidc.Config{ + ClientID: a.ClientID, + } + + return a.Verifier(oidcConfig).Verify(ctx, rawIDToken) +} diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go new file mode 100644 index 0000000..02434e5 --- /dev/null +++ b/backend/internal/config/config.go @@ -0,0 +1,114 @@ +package config + +import ( + "ibd-trader/internal/keys" + "ibd-trader/internal/leader/manager/ibd" + + "github.com/spf13/viper" +) + +type Config struct { + // Logging configuration + Log struct { + // Log level + Level string + // Add source info to log messages + AddSource bool + // Enable colorized output + Color bool + } + // Database configuration + DB struct { + // Database URL + URL string + } + // Redis configuration + Redis struct { + // Redis address + Addr string + // Redis password + Password string + } + // KMS configuration + KMS struct { + // GCP KMS configuration + GCP *keys.GCPKeyName + } + // Server configuration + Server struct { + // Server port + Port uint16 + } + // OAuth 2.0 configuration + Auth struct { + // OAuth 2.0 domain + Domain string + // OAuth 2.0 client ID + ClientID string + // OAuth 2.0 client secret + ClientSecret string + // OAuth 2.0 callback URL + CallbackURL string + } + // IBD configuration + IBD struct { + // Scraper API Key + APIKey string + // Proxy URL + ProxyURL string + // Scrape schedules. In cron format. + Schedules ibd.Schedules + } + // Analyzer configuration + Analyzer struct { + // Use OpenAI for analysis + OpenAI *struct { + // OpenAI API Key + APIKey string + } + } +} + +func New() (*Config, error) { + v := viper.New() + + v.SetDefault("log.level", "INFO") + v.SetDefault("log.addSource", false) + v.SetDefault("log.color", false) + v.SetDefault("server.port", 8000) + + v.SetConfigName("config") + v.AddConfigPath("/etc/ibd-trader/") + v.AddConfigPath("$HOME/.ibd-trader") + v.AddConfigPath(".") + err := v.ReadInConfig() + if err != nil { + if _, ok := err.(viper.ConfigFileNotFoundError); ok { + // Config file not found; ignore error + } else { + return nil, err + } + } + + v.MustBindEnv("db.url", "DATABASE_URL") + v.MustBindEnv("redis.addr", "REDIS_ADDR") + v.MustBindEnv("redis.password", "REDIS_PASSWORD") + v.MustBindEnv("log.level", "LOG_LEVEL") + v.MustBindEnv("server.port", "SERVER_PORT") + v.MustBindEnv("auth.domain", "AUTH_DOMAIN") + v.MustBindEnv("auth.clientID", "AUTH_CLIENT_ID") + v.MustBindEnv("auth.clientSecret", "AUTH_CLIENT_SECRET") + v.MustBindEnv("auth.callbackURL", "AUTH_CALLBACK_URL") + + config := new(Config) + err = v.Unmarshal(config) + if err != nil { + return nil, err + } + + return config, config.assert() +} + +func (c *Config) assert() error { + return nil +} diff --git a/backend/internal/database/cookies.go b/backend/internal/database/cookies.go new file mode 100644 index 0000000..cb38272 --- /dev/null +++ b/backend/internal/database/cookies.go @@ -0,0 +1,150 @@ +package database + +import ( + "context" + "database/sql" + "errors" + "fmt" + "net/http" + "time" + + "ibd-trader/internal/keys" +) + +type CookieStore interface { + CookieSource + AddCookie(ctx context.Context, subject string, cookie *http.Cookie) error + RepairCookie(ctx context.Context, id uint) error +} + +type CookieSource interface { + GetAnyCookie(ctx context.Context) (*IBDCookie, error) + GetCookies(ctx context.Context, subject string, degraded bool) ([]IBDCookie, error) + ReportCookieFailure(ctx context.Context, id uint) error +} + +func (d *database) GetAnyCookie(ctx context.Context) (*IBDCookie, error) { + row, err := d.queryRow(ctx, d.db, "cookies/get_any_cookie") + if err != nil { + return nil, fmt.Errorf("unable to get any ibd cookie: %w", err) + } + + var id uint + var encryptedToken, encryptedKey []byte + var keyName string + var expiry time.Time + err = row.Scan(&id, &encryptedToken, &encryptedKey, &keyName, &expiry) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, fmt.Errorf("unable to scan sql row into ibd cookie: %w", err) + } + + token, err := keys.Decrypt(ctx, d.kms, keyName, encryptedToken, encryptedKey) + if err != nil { + return nil, fmt.Errorf("unable to decrypt token: %w", err) + } + return &IBDCookie{ + Token: string(token), + Expiry: expiry, + }, nil +} + +func (d *database) GetCookies(ctx context.Context, subject string, degraded bool) ([]IBDCookie, error) { + row, err := d.query(ctx, d.db, "cookies/get_cookies", subject, degraded) + if err != nil { + return nil, fmt.Errorf("unable to get ibd cookies: %w", err) + } + + cookies := make([]IBDCookie, 0) + for row.Next() { + var id uint + var encryptedToken, encryptedKey []byte + var keyName string + var expiry time.Time + err = row.Scan(&id, &encryptedToken, &encryptedKey, &keyName, &expiry) + if err != nil { + return nil, fmt.Errorf("unable to scan sql row into ibd cookie: %w", err) + } + + token, err := keys.Decrypt(ctx, d.kms, keyName, encryptedToken, encryptedKey) + if err != nil { + return nil, fmt.Errorf("unable to decrypt token: %w", err) + } + cookie := IBDCookie{ + ID: id, + Token: string(token), + Expiry: expiry, + } + cookies = append(cookies, cookie) + } + + return cookies, nil +} + +func (d *database) AddCookie(ctx context.Context, subject string, cookie *http.Cookie) error { + // Get the key ID for the user + user, err := d.GetUser(ctx, subject) + if err != nil { + return fmt.Errorf("unable to get user: %w", err) + } + if user.EncryptionKeyID == nil { + return errors.New("user does not have an encryption key") + } + + // Get the key + key, err := d.GetKey(ctx, *user.EncryptionKeyID) + if err != nil { + return fmt.Errorf("unable to get key: %w", err) + } + + // Encrypt the token + encryptedToken, err := keys.EncryptWithKey(ctx, d.kms, key.Name, key.Key, []byte(cookie.Value)) + if err != nil { + return fmt.Errorf("unable to encrypt token: %w", err) + } + + // Add the cookie to the database + _, err = d.exec(ctx, d.db, "cookies/add_cookie", encryptedToken, cookie.Expires, subject, key.Id) + if err != nil { + return fmt.Errorf("unable to add cookie: %w", err) + } + + return nil +} + +func (d *database) ReportCookieFailure(ctx context.Context, id uint) error { + _, err := d.exec(ctx, d.db, "cookies/set_cookie_degraded", true, id) + if err != nil { + return fmt.Errorf("unable to report cookie failure: %w", err) + } + return nil +} + +func (d *database) RepairCookie(ctx context.Context, id uint) error { + _, err := d.exec(ctx, d.db, "cookies/set_cookie_degraded", false, id) + if err != nil { + return fmt.Errorf("unable to report cookie failure: %w", err) + } + return nil +} + +type IBDCookie struct { + ID uint + Token string + Expiry time.Time +} + +func (c *IBDCookie) ToHTTPCookie() *http.Cookie { + return &http.Cookie{ + Name: ".ASPXAUTH", + Value: c.Token, + Path: "/", + Domain: "investors.com", + Expires: c.Expiry, + Secure: true, + HttpOnly: false, + SameSite: http.SameSiteLaxMode, + } +} diff --git a/backend/internal/database/database.go b/backend/internal/database/database.go new file mode 100644 index 0000000..4022dde --- /dev/null +++ b/backend/internal/database/database.go @@ -0,0 +1,178 @@ +package database + +import ( + "context" + "database/sql" + "database/sql/driver" + "errors" + "fmt" + "io" + "log/slog" + "sync" + "time" + + "ibd-trader/db" + "ibd-trader/internal/keys" + + "github.com/golang-migrate/migrate/v4" + _ "github.com/golang-migrate/migrate/v4/database/postgres" + "github.com/golang-migrate/migrate/v4/source/iofs" + _ "github.com/lib/pq" +) + +type Database interface { + io.Closer + UserStore + CookieStore + KeyStore + SessionStore + StockStore + + driver.Pinger + + Migrate(ctx context.Context) error + Maintenance(ctx context.Context) +} + +type database struct { + logger *slog.Logger + + db *sql.DB + url string + + kms keys.KeyManagementService + keyName string +} + +func New(ctx context.Context, logger *slog.Logger, url string, kms keys.KeyManagementService, keyName string) (Database, error) { + sqlDB, err := sql.Open("postgres", url) + if err != nil { + return nil, err + } + + err = sqlDB.PingContext(ctx) + if err != nil { + // Ping failed. Don't error, but give a warning. + logger.WarnContext(ctx, "Unable to ping database", "error", err) + } + + return &database{ + logger: logger, + db: sqlDB, + url: url, + kms: kms, + keyName: keyName, + }, nil +} + +func (d *database) Close() error { + return d.db.Close() +} + +func (d *database) Migrate(ctx context.Context) error { + fs, err := iofs.New(db.Migrations, "migrations") + if err != nil { + return err + } + + m, err := migrate.NewWithSourceInstance("iofs", fs, d.url) + if err != nil { + return err + } + + d.logger.InfoContext(ctx, "Running DB migration") + err = m.Up() + if err != nil && !errors.Is(err, migrate.ErrNoChange) { + d.logger.ErrorContext(ctx, "DB migration failed", "error", err) + return err + } + + return nil +} + +func (d *database) Maintenance(ctx context.Context) { + ticker := time.NewTicker(15 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + func() { + var wg sync.WaitGroup + wg.Add(1) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) + defer cancel() + + go d.cleanupSessions(ctx, &wg) + + wg.Wait() + }() + case <-ctx.Done(): + return + } + } +} + +func (d *database) Ping(ctx context.Context) error { + return d.db.PingContext(ctx) +} + +func (d *database) execInternal(ctx context.Context, queryName string, fn func(string) (any, error)) (any, error) { + query, err := db.GetQuery(queryName) + if err != nil { + return nil, fmt.Errorf("unable to get query: %w", err) + } + d.logger.DebugContext(ctx, "Executing query", "name", queryName, "query", query) + + now := time.Now() + + // Execute the query + result, err := fn(query) + if err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + + d.logger.DebugContext(ctx, "Query executed successfully", "name", queryName, "duration", time.Since(now)) + + return result, nil +} + +func (d *database) exec(ctx context.Context, exec executor, queryName string, args ...any) (sql.Result, error) { + ret, err := d.execInternal(ctx, queryName, func(query string) (any, error) { + return exec.ExecContext(ctx, query, args...) + }) + if err != nil { + return nil, err + } else { + return ret.(sql.Result), nil + } +} + +func (d *database) query(ctx context.Context, exec executor, queryName string, args ...any) (*sql.Rows, error) { + ret, err := d.execInternal(ctx, queryName, func(query string) (any, error) { + return exec.QueryContext(ctx, query, args...) + }) + if err != nil { + return nil, err + } else { + return ret.(*sql.Rows), nil + } +} + +func (d *database) queryRow(ctx context.Context, exec executor, queryName string, args ...any) (*sql.Row, error) { + ret, err := d.execInternal(ctx, queryName, func(query string) (any, error) { + return exec.QueryRowContext(ctx, query, args...), nil + }) + if err != nil { + return nil, err + } else { + return ret.(*sql.Row), nil + } +} + +type executor interface { + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) + QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row +} diff --git a/backend/internal/database/keys.go b/backend/internal/database/keys.go new file mode 100644 index 0000000..0ec4b67 --- /dev/null +++ b/backend/internal/database/keys.go @@ -0,0 +1,49 @@ +package database + +import ( + "context" + "fmt" + "time" +) + +type KeyStore interface { + AddKey(ctx context.Context, keyName string, key []byte) (int, error) + GetKey(ctx context.Context, keyId int) (*Key, error) +} + +func (d *database) AddKey(ctx context.Context, keyName string, key []byte) (int, error) { + row, err := d.queryRow(ctx, d.db, "keys/add_key", keyName, key) + if err != nil { + return 0, fmt.Errorf("unable to add key: %w", err) + } + + var keyId int + err = row.Scan(&keyId) + if err != nil { + return 0, fmt.Errorf("unable to scan key id: %w", err) + } + + return keyId, nil +} + +func (d *database) GetKey(ctx context.Context, keyId int) (*Key, error) { + row, err := d.queryRow(ctx, d.db, "keys/get_key", keyId) + if err != nil { + return nil, fmt.Errorf("unable to get key: %w", err) + } + + key := &Key{} + err = row.Scan(&key.Id, &key.Name, &key.Key, &key.Created) + if err != nil { + return nil, fmt.Errorf("unable to scan key: %w", err) + } + + return key, nil +} + +type Key struct { + Id int + Name string + Key []byte + Created time.Time +} diff --git a/backend/internal/database/session.go b/backend/internal/database/session.go new file mode 100644 index 0000000..36867b3 --- /dev/null +++ b/backend/internal/database/session.go @@ -0,0 +1,122 @@ +package database + +import ( + "context" + "crypto/rand" + "database/sql" + "encoding/base64" + "errors" + "io" + "sync" + + "github.com/coreos/go-oidc/v3/oidc" + "golang.org/x/oauth2" +) + +type SessionStore interface { + CreateState(ctx context.Context) (string, error) + CheckState(ctx context.Context, state string) (bool, error) + CreateSession(ctx context.Context, token *oauth2.Token, idToken *oidc.IDToken) (string, error) + GetSession(ctx context.Context, sessionToken string) (*Session, error) +} + +func (d *database) CreateState(ctx context.Context) (string, error) { + // Generate a random CSRF state token + tokenBytes := make([]byte, 32) + if _, err := io.ReadFull(rand.Reader, tokenBytes); err != nil { + return "", err + } + token := base64.URLEncoding.EncodeToString(tokenBytes) + + // Insert the state into the database + _, err := d.exec(ctx, d.db, "sessions/create_state", token) + if err != nil { + return "", err + } + + return token, nil +} + +func (d *database) CheckState(ctx context.Context, state string) (bool, error) { + var exists bool + row, err := d.queryRow(ctx, d.db, "sessions/check_state", state) + if err != nil { + return false, err + } + err = row.Scan(&exists) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return false, nil + } + return false, err + } + return exists, nil +} + +func (d *database) CreateSession( + ctx context.Context, + token *oauth2.Token, + idToken *oidc.IDToken, +) (sessionToken string, err error) { + // Generate a random session token + tokenBytes := make([]byte, 32) + if _, err = io.ReadFull(rand.Reader, tokenBytes); err != nil { + return + } + sessionToken = base64.URLEncoding.EncodeToString(tokenBytes) + + // Insert the session into the database + _, err = d.exec( + ctx, + d.db, + "sessions/create_session", + sessionToken, + idToken.Subject, + token.AccessToken, + token.Expiry, + ) + return +} + +func (d *database) GetSession(ctx context.Context, sessionToken string) (*Session, error) { + row, err := d.queryRow(ctx, d.db, "sessions/get_session", sessionToken) + if err != nil { + d.logger.ErrorContext(ctx, "Failed to get session", "error", err) + return nil, err + } + + var session Session + err = row.Scan(&session.Token, &session.Subject, &session.OAuthToken.AccessToken, &session.OAuthToken.Expiry) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + d.logger.ErrorContext(ctx, "Failed to scan session", "error", err) + return nil, err + } + + return &session, nil +} + +func (d *database) cleanupSessions(ctx context.Context, wg *sync.WaitGroup) { + defer wg.Done() + + result, err := d.exec(ctx, d.db, "sessions/cleanup_sessions") + if err != nil { + d.logger.Error("Failed to clean up sessions", "error", err) + return + } + + rows, err := result.RowsAffected() + if err != nil { + d.logger.ErrorContext(ctx, "Failed to get rows affected", "error", err) + return + } + d.logger.DebugContext(ctx, "Cleaned up sessions", "rows", rows) +} + +type Session struct { + Token string + Subject string + OAuthToken oauth2.Token +} diff --git a/backend/internal/database/stocks.go b/backend/internal/database/stocks.go new file mode 100644 index 0000000..701201e --- /dev/null +++ b/backend/internal/database/stocks.go @@ -0,0 +1,264 @@ +package database + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "ibd-trader/internal/analyzer" + "ibd-trader/internal/utils" + + "github.com/Rhymond/go-money" +) + +var ErrStockNotFound = errors.New("stock not found") + +type StockStore interface { + GetStock(ctx context.Context, symbol string) (Stock, error) + AddStock(ctx context.Context, stock Stock) error + AddRanking(ctx context.Context, symbol string, ibd50, cap20 int) error + AddStockInfo(ctx context.Context, info *StockInfo) (string, error) + GetStockInfo(ctx context.Context, id string) (*StockInfo, error) + AddAnalysis(ctx context.Context, ratingId string, analysis *analyzer.Analysis) error +} + +func (d *database) GetStock(ctx context.Context, symbol string) (Stock, error) { + row, err := d.queryRow(ctx, d.db, "stocks/get_stock", symbol) + if err != nil { + return Stock{}, err + } + + var stock Stock + if err = row.Scan(&stock.Symbol, &stock.Name, &stock.IBDUrl); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return Stock{}, ErrStockNotFound + } + return Stock{}, err + } + + return stock, nil +} + +func (d *database) AddStock(ctx context.Context, stock Stock) error { + _, err := d.exec(ctx, d.db, "stocks/add_stock", stock.Symbol, stock.Name, stock.IBDUrl) + return err +} + +func (d *database) AddRanking(ctx context.Context, symbol string, ibd50, cap20 int) error { + if ibd50 > 0 { + _, err := d.exec(ctx, d.db, "stocks/add_rank", symbol, "ibd50", ibd50) + if err != nil { + return err + } + } + if cap20 > 0 { + _, err := d.exec(ctx, d.db, "stocks/add_rank", symbol, "cap20", cap20) + if err != nil { + return err + } + } + return nil +} + +func (d *database) AddStockInfo(ctx context.Context, info *StockInfo) (string, error) { + tx, err := d.db.BeginTx(ctx, nil) + if err != nil { + return "", err + } + defer func(tx *sql.Tx) { + _ = tx.Rollback() + }(tx) + + // Add raw chart analysis + row, err := d.queryRow(ctx, tx, "stocks/add_raw_chart_analysis", info.ChartAnalysis) + if err != nil { + return "", err + } + + var chartAnalysisID string + if err = row.Scan(&chartAnalysisID); err != nil { + return "", err + } + + // Add stock info + row, err = d.queryRow(ctx, tx, + "stocks/add_rating", + info.Symbol, + info.Ratings.Composite, + info.Ratings.EPS, + info.Ratings.RelStr, + info.Ratings.GroupRelStr, + info.Ratings.SMR, + info.Ratings.AccDis, + chartAnalysisID, + info.Price.Display(), + ) + if err != nil { + return "", err + } + + var ratingsID string + if err = row.Scan(&ratingsID); err != nil { + return "", err + } + + return ratingsID, tx.Commit() +} + +func (d *database) GetStockInfo(ctx context.Context, id string) (*StockInfo, error) { + row, err := d.queryRow(ctx, d.db, "stocks/get_stock_info", id) + if err != nil { + return nil, err + } + + var info StockInfo + var priceStr string + err = row.Scan( + &info.Symbol, + &info.Name, + &info.ChartAnalysis, + &info.Ratings.Composite, + &info.Ratings.EPS, + &info.Ratings.RelStr, + &info.Ratings.GroupRelStr, + &info.Ratings.SMR, + &info.Ratings.AccDis, + &priceStr, + ) + if err != nil { + return nil, err + } + + info.Price, err = utils.ParseMoney(priceStr) + if err != nil { + return nil, err + } + + return &info, nil +} + +func (d *database) AddAnalysis(ctx context.Context, ratingId string, analysis *analyzer.Analysis) error { + _, err := d.exec(ctx, d.db, "stocks/add_analysis", + ratingId, + analysis.Action, + analysis.Price.Display(), + analysis.Reason, + analysis.Confidence, + ) + return err +} + +type Stock struct { + Symbol string + Name string + IBDUrl string +} + +type StockInfo struct { + Symbol string + Name string + ChartAnalysis string + Ratings Ratings + Price *money.Money +} + +type Ratings struct { + Composite uint8 + EPS uint8 + RelStr uint8 + GroupRelStr LetterRating + SMR LetterRating + AccDis LetterRating +} + +type LetterRating uint8 + +const ( + LetterRatingE LetterRating = iota + LetterRatingEPlus + LetterRatingDMinus + LetterRatingD + LetterRatingDPlus + LetterRatingCMinus + LetterRatingC + LetterRatingCPlus + LetterRatingBMinus + LetterRatingB + LetterRatingBPlus + LetterRatingAMinus + LetterRatingA + LetterRatingAPlus +) + +func (r LetterRating) String() string { + switch r { + case LetterRatingE: + return "E" + case LetterRatingEPlus: + return "E+" + case LetterRatingDMinus: + return "D-" + case LetterRatingD: + return "D" + case LetterRatingDPlus: + return "D+" + case LetterRatingCMinus: + return "C-" + case LetterRatingC: + return "C" + case LetterRatingCPlus: + return "C+" + case LetterRatingBMinus: + return "B-" + case LetterRatingB: + return "B" + case LetterRatingBPlus: + return "B+" + case LetterRatingAMinus: + return "A-" + case LetterRatingA: + return "A" + case LetterRatingAPlus: + return "A+" + default: + return "Unknown" + } +} + +func LetterRatingFromString(str string) (LetterRating, error) { + switch str { + case "N/A": + fallthrough + case "E": + return LetterRatingE, nil + case "E+": + return LetterRatingEPlus, nil + case "D-": + return LetterRatingDMinus, nil + case "D": + return LetterRatingD, nil + case "D+": + return LetterRatingDPlus, nil + case "C-": + return LetterRatingCMinus, nil + case "C": + return LetterRatingC, nil + case "C+": + return LetterRatingCPlus, nil + case "B-": + return LetterRatingBMinus, nil + case "B": + return LetterRatingB, nil + case "B+": + return LetterRatingBPlus, nil + case "A-": + return LetterRatingAMinus, nil + case "A": + return LetterRatingA, nil + case "A+": + return LetterRatingAPlus, nil + default: + return 0, fmt.Errorf("unknown rating: %s", str) + } +} diff --git a/backend/internal/database/users.go b/backend/internal/database/users.go new file mode 100644 index 0000000..1950fcb --- /dev/null +++ b/backend/internal/database/users.go @@ -0,0 +1,140 @@ +package database + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "ibd-trader/internal/keys" +) + +type UserStore interface { + AddUser(ctx context.Context, subject string) error + GetUser(ctx context.Context, subject string) (*User, error) + ListUsers(ctx context.Context, hasIBDCreds bool) ([]User, error) + AddIBDCreds(ctx context.Context, subject string, username string, password string) error + GetIBDCreds(ctx context.Context, subject string) (username string, password string, err error) +} + +var ErrUserNotFound = fmt.Errorf("user not found") +var ErrIBDCredsNotFound = fmt.Errorf("ibd creds not found") + +func (d *database) AddUser(ctx context.Context, subject string) (err error) { + _, err = d.exec( + ctx, + d.db, + "users/add_user", + subject, + ) + return +} + +func (d *database) GetUser(ctx context.Context, subject string) (*User, error) { + row, err := d.queryRow(ctx, d.db, "users/get_user", subject) + if err != nil { + return nil, fmt.Errorf("unable to get user: %w", err) + } + + user := &User{} + err = row.Scan(&user.Subject, &user.IBDUsername, &user.EncryptedIBDPassword, &user.EncryptionKeyID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrUserNotFound + } + return nil, fmt.Errorf("unable to scan sql row into user: %w", err) + } + + return user, nil +} + +func (d *database) ListUsers(ctx context.Context, hasIBDCreds bool) ([]User, error) { + rows, err := d.query(ctx, d.db, "users/list_users") + if err != nil { + return nil, fmt.Errorf("unable to list users: %w", err) + } + + users := make([]User, 0) + for rows.Next() { + user := User{} + err = rows.Scan(&user.Subject, &user.IBDUsername, &user.EncryptedIBDPassword, &user.EncryptionKeyID) + if err != nil { + return nil, fmt.Errorf("unable to scan sql row into user: %w", err) + } + + if hasIBDCreds && user.IBDUsername == nil { + continue + } + users = append(users, user) + } + + return users, nil +} + +func (d *database) AddIBDCreds(ctx context.Context, subject string, username string, password string) error { + encryptedPass, encryptedKey, err := keys.Encrypt(ctx, d.kms, d.keyName, []byte(password)) + if err != nil { + return fmt.Errorf("unable to encrypt password: %w", err) + } + + tx, err := d.db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer func(tx *sql.Tx) { + _ = tx.Rollback() + }(tx) + + row, err := d.queryRow(ctx, tx, "keys/add_key", d.keyName, encryptedKey) + if err != nil { + return fmt.Errorf("unable to add ibd creds key: %w", err) + } + + var keyId int + err = row.Scan(&keyId) + if err != nil { + return fmt.Errorf("unable to scan key id: %w", err) + } + + _, err = d.exec(ctx, tx, "users/add_ibd_creds", subject, username, encryptedPass, keyId) + if err != nil { + return fmt.Errorf("unable to add ibd creds to user: %w", err) + } + + if err = tx.Commit(); err != nil { + return fmt.Errorf("unable to commit transaction: %w", err) + } + + return nil +} + +func (d *database) GetIBDCreds(ctx context.Context, subject string) (username string, password string, err error) { + row, err := d.queryRow(ctx, d.db, "users/get_ibd_creds", subject) + if err != nil { + return "", "", fmt.Errorf("unable to get ibd creds: %w", err) + } + + var encryptedPass, encryptedKey []byte + var keyName string + err = row.Scan(&username, &encryptedPass, &encryptedKey, &keyName) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return "", "", ErrIBDCredsNotFound + } + return "", "", fmt.Errorf("unable to scan sql row into ibd creds: %w", err) + } + + passwordBytes, err := keys.Decrypt(ctx, d.kms, keyName, encryptedPass, encryptedKey) + if err != nil { + return "", "", fmt.Errorf("unable to decrypt password: %w", err) + } + + return username, string(passwordBytes), nil +} + +type User struct { + Subject string + IBDUsername *string + EncryptedIBDPassword *string + EncryptionKeyID *int +} diff --git a/backend/internal/ibd/auth.go b/backend/internal/ibd/auth.go new file mode 100644 index 0000000..7dff3a7 --- /dev/null +++ b/backend/internal/ibd/auth.go @@ -0,0 +1,308 @@ +package ibd + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "net/http" + "strings" + + "golang.org/x/net/html" +) + +const ( + signInUrl = "https://myibd.investors.com/secure/signin.aspx?eurl=https%3A%2F%2Fwww.investors.com" + authenticateUrl = "https://sso.accounts.dowjones.com/authenticate" + postAuthUrl = "https://sso.accounts.dowjones.com/postauth/handler" + cookieName = ".ASPXAUTH" +) + +var ErrAuthCookieNotFound = errors.New("cookie not found") +var ErrBadCredentials = errors.New("bad credentials") + +func (c *Client) Authenticate( + ctx context.Context, + username, + password string, +) (*http.Cookie, error) { + cfg, err := c.getLoginPage(ctx) + if err != nil { + return nil, err + } + + token, params, err := c.sendAuthRequest(ctx, cfg, username, password) + if err != nil { + return nil, err + } + + return c.sendPostAuth(ctx, token, params) +} + +func (c *Client) getLoginPage(ctx context.Context) (*authConfig, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, signInUrl, nil) + if err != nil { + return nil, err + } + + resp, err := c.Do(req) + if err != nil { + return nil, err + } + + if resp.Result.StatusCode != http.StatusOK { + return nil, fmt.Errorf( + "unexpected status code %d: %s", + resp.Result.StatusCode, + resp.Result.Content, + ) + } + + node, err := html.Parse(strings.NewReader(resp.Result.Content)) + if err != nil { + return nil, err + } + + cfg, err := extractAuthConfig(node) + if err != nil { + return nil, fmt.Errorf("failed to extract auth config: %w", err) + } + + return cfg, nil +} + +func (c *Client) sendAuthRequest(ctx context.Context, cfg *authConfig, username, password string) (string, string, error) { + body := authRequestBody{ + ClientId: cfg.ClientID, + RedirectUri: cfg.CallbackURL, + Tenant: "sso", + ResponseType: cfg.ExtraParams.ResponseType, + Username: username, + Password: password, + Scope: cfg.ExtraParams.Scope, + State: cfg.ExtraParams.State, + Headers: struct { + XRemoteUser string `json:"x-_remote-_user"` + }(struct{ XRemoteUser string }{ + XRemoteUser: username, + }), + XOidcProvider: "localop", + Protocol: cfg.ExtraParams.Protocol, + Nonce: cfg.ExtraParams.Nonce, + UiLocales: cfg.ExtraParams.UiLocales, + Csrf: cfg.ExtraParams.Csrf, + Intstate: cfg.ExtraParams.Intstate, + Connection: "DJldap", + } + bodyJson, err := json.Marshal(body) + if err != nil { + return "", "", fmt.Errorf("failed to marshal auth request body: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, authenticateUrl, bytes.NewReader(bodyJson)) + if err != nil { + return "", "", err + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Auth0-Client", "eyJuYW1lIjoiYXV0aDAuanMtdWxwIiwidmVyc2lvbiI6IjkuMjQuMSJ9") + + resp, err := c.Do(req) + if err != nil { + return "", "", err + } + + if resp.Result.StatusCode == http.StatusUnauthorized { + return "", "", ErrBadCredentials + } else if resp.Result.StatusCode != http.StatusOK { + return "", "", fmt.Errorf( + "unexpected status code %d: %s", + resp.Result.StatusCode, + resp.Result.Content, + ) + } + + node, err := html.Parse(strings.NewReader(resp.Result.Content)) + if err != nil { + return "", "", err + } + + return extractTokenParams(node) +} + +func (c *Client) sendPostAuth(ctx context.Context, token, params string) (*http.Cookie, error) { + body := fmt.Sprintf("token=%s¶ms=%s", token, params) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, postAuthUrl, strings.NewReader(body)) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := c.Do(req) + if err != nil { + return nil, err + } + + if resp.Result.StatusCode != http.StatusOK { + return nil, fmt.Errorf( + "unexpected status code %d: %s", + resp.Result.StatusCode, + resp.Result.Content, + ) + } + + // Extract cookie + for _, cookie := range resp.Result.Cookies { + if cookie.Name == cookieName { + return cookie.ToHTTPCookie() + } + } + + return nil, ErrAuthCookieNotFound +} + +func extractAuthConfig(node *html.Node) (*authConfig, error) { + // Find `root` element + root := findId(node, "root") + if root == nil { + return nil, fmt.Errorf("root element not found") + } + + // Get adjacent script element + var script *html.Node + for s := root.NextSibling; s != nil; s = s.NextSibling { + if s.Type == html.ElementNode && s.Data == "script" { + script = s + break + } + } + + if script == nil { + return nil, fmt.Errorf("script element not found") + } + + // Get script content + content := extractText(script) + + // Find `AUTH_CONFIG` variable + const authConfigVar = "const AUTH_CONFIG = '" + i := strings.Index(content, authConfigVar) + if i == -1 { + return nil, fmt.Errorf("AUTH_CONFIG not found") + } + + // Find end of `AUTH_CONFIG` variable + j := strings.Index(content[i+len(authConfigVar):], "'") + + // Extract `AUTH_CONFIG` value + authConfigJSONB64 := content[i+len(authConfigVar) : i+len(authConfigVar)+j] + + // Decode `AUTH_CONFIG` value + authConfigJSON, err := base64.StdEncoding.DecodeString(authConfigJSONB64) + if err != nil { + return nil, fmt.Errorf("failed to decode AUTH_CONFIG: %w", err) + } + + // Unmarshal `AUTH_CONFIG` value + var cfg authConfig + if err = json.Unmarshal(authConfigJSON, &cfg); err != nil { + return nil, fmt.Errorf("failed to unmarshal AUTH_CONFIG: %w", err) + } + + return &cfg, nil +} + +type authConfig struct { + Auth0Domain string `json:"auth0Domain"` + CallbackURL string `json:"callbackURL"` + ClientID string `json:"clientID"` + ExtraParams struct { + Protocol string `json:"protocol"` + Scope string `json:"scope"` + ResponseType string `json:"response_type"` + Nonce string `json:"nonce"` + UiLocales string `json:"ui_locales"` + Csrf string `json:"_csrf"` + Intstate string `json:"_intstate"` + State string `json:"state"` + } `json:"extraParams"` + InternalOptions struct { + ResponseType string `json:"response_type"` + ClientId string `json:"client_id"` + Scope string `json:"scope"` + RedirectUri string `json:"redirect_uri"` + UiLocales string `json:"ui_locales"` + Eurl string `json:"eurl"` + Nonce string `json:"nonce"` + State string `json:"state"` + Resource string `json:"resource"` + Protocol string `json:"protocol"` + Client string `json:"client"` + } `json:"internalOptions"` + IsThirdPartyClient bool `json:"isThirdPartyClient"` + AuthorizationServer struct { + Url string `json:"url"` + Issuer string `json:"issuer"` + } `json:"authorizationServer"` +} + +func extractTokenParams(node *html.Node) (token string, params string, err error) { + inputs := findChildrenRecursive(node, func(node *html.Node) bool { + return node.Type == html.ElementNode && node.Data == "input" + }) + + var tokenNode, paramsNode *html.Node + for _, input := range inputs { + for _, attr := range input.Attr { + if attr.Key == "name" && attr.Val == "token" { + tokenNode = input + } else if attr.Key == "name" && attr.Val == "params" { + paramsNode = input + } + } + } + + if tokenNode == nil { + return "", "", fmt.Errorf("token input not found") + } + if paramsNode == nil { + return "", "", fmt.Errorf("params input not found") + } + + for _, attr := range tokenNode.Attr { + if attr.Key == "value" { + token = attr.Val + } + } + for _, attr := range paramsNode.Attr { + if attr.Key == "value" { + params = attr.Val + } + } + + return +} + +type authRequestBody struct { + ClientId string `json:"client_id"` + RedirectUri string `json:"redirect_uri"` + Tenant string `json:"tenant"` + ResponseType string `json:"response_type"` + Username string `json:"username"` + Password string `json:"password"` + Scope string `json:"scope"` + State string `json:"state"` + Headers struct { + XRemoteUser string `json:"x-_remote-_user"` + } `json:"headers"` + XOidcProvider string `json:"x-_oidc-_provider"` + Protocol string `json:"protocol"` + Nonce string `json:"nonce"` + UiLocales string `json:"ui_locales"` + Csrf string `json:"_csrf"` + Intstate string `json:"_intstate"` + Connection string `json:"connection"` +} diff --git a/backend/internal/ibd/auth_test.go b/backend/internal/ibd/auth_test.go new file mode 100644 index 0000000..d28b33a --- /dev/null +++ b/backend/internal/ibd/auth_test.go @@ -0,0 +1,217 @@ +package ibd + +import ( + "context" + "encoding/json" + "net/http" + "net/url" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/net/html" +) + +const extractAuthHtml = ` +<!doctype html> +<html lang="en"> + <head> + <title>Log in · Dow Jones</title> + <meta charset="UTF-8"/> + <meta name="theme-color" content="white"/> + <meta name="viewport" content="width=device-width,initial-scale=1"/> + <meta name="description" content="Dow Jones One Identity Login page"/> + <link rel="apple-touch-icon" sizes="180x180" href="/one_identity_login_pages/login/0ce1520be322adcd762319573804f56d/images/apple-touch-icon.png"/> + <link rel="icon" type="image/png" sizes="32x32" href="/one_identity_login_pages/login/0ce1520be322adcd762319573804f56d/images/favicon-32x32.png"/> + <link rel="icon" type="image/png" sizes="16x16" href="/one_identity_login_pages/login/0ce1520be322adcd762319573804f56d/images/favicon-16x16.png"/> + <link rel="icon" type="image/png" sizes="192x192" href="/one_identity_login_pages/login/0ce1520be322adcd762319573804f56d/images/android-chrome-192x192.png"/> + <link rel="icon" type="image/png" sizes="512x512" href="/one_identity_login_pages/login/0ce1520be322adcd762319573804f56d/images/android-chrome-512x512.png"/> + <link rel="prefetch" href="https://cdn.optimizely.com/js/14856860742.js"/> + <link rel="preconnect" href="//cdn.optimizely.com"/> + <link rel="preconnect" href="//logx.optimizely.com"/> + <script type="module" crossorigin src="/one_identity_login_pages/login/0ce1520be322adcd762319573804f56d/js/index.js"></script> + <link rel="modulepreload" crossorigin href="/one_identity_login_pages/login/0ce1520be322adcd762319573804f56d/js/vendor.js"> + <link rel="modulepreload" crossorigin href="/one_identity_login_pages/login/0ce1520be322adcd762319573804f56d/js/auth.js"> + <link rel="modulepreload" crossorigin href="/one_identity_login_pages/login/0ce1520be322adcd762319573804f56d/js/router.js"> + <link rel="stylesheet" crossorigin href="/one_identity_login_pages/login/0ce1520be322adcd762319573804f56d/assets/styles.css"> + </head> + <body> + <div id="root" vaul-drawer-wrapper="" class="root-container"></div> + <script> + const AUTH_CONFIG = 'eyJhdXRoMERvbWFpbiI6InNzby5hY2NvdW50cy5kb3dqb25lcy5jb20iLCJjYWxsYmFja1VSTCI6Imh0dHBzOi8vbXlpYmQuaW52ZXN0b3JzLmNvbS9vaWRjL2NhbGxiYWNrIiwiY2xpZW50SUQiOiJHU1UxcEcyQnJnZDNQdjJLQm5BWjI0enZ5NXVXU0NRbiIsImV4dHJhUGFyYW1zIjp7InByb3RvY29sIjoib2F1dGgyIiwic2NvcGUiOiJvcGVuaWQgaWRwX2lkIHJvbGVzIGVtYWlsIGdpdmVuX25hbWUgZmFtaWx5X25hbWUgdXVpZCBkalVzZXJuYW1lIGRqU3RhdHVzIHRyYWNraWQgdGFncyBwcnRzIHVwZGF0ZWRfYXQgY3JlYXRlZF9hdCBvZmZsaW5lX2FjY2VzcyBkamlkIiwicmVzcG9uc2VfdHlwZSI6ImNvZGUiLCJub25jZSI6IjY0MDJmYWJiLTFiNzUtNGEyYy1hODRmLTExYWQ2MWFhZGI2YiIsInVpX2xvY2FsZXMiOiJlbi11cy14LWliZC0yMy03IiwiX2NzcmYiOiJOZFVSZ3dPQ3VYRU5URXFDcDhNV25tcGtxd3lva2JjU2E2VV9fLTVib3lWc1NzQVNWTkhLU0EiLCJfaW50c3RhdGUiOiJkZXByZWNhdGVkIiwic3RhdGUiOiJlYXJjN3E2UnE2a3lHS3h5LlltbGlxOU4xRXZvU1V0ejhDVjhuMFZBYzZWc1V4RElSTTRTcmxtSWJXMmsifSwiaW50ZXJuYWxPcHRpb25zIjp7InJlc3BvbnNlX3R5cGUiOiJjb2RlIiwiY2xpZW50X2lkIjoiR1NVMXBHMkJyZ2QzUHYyS0JuQVoyNHp2eTV1V1NDUW4iLCJzY29wZSI6Im9wZW5pZCBpZHBfaWQgcm9sZXMgZW1haWwgZ2l2ZW5fbmFtZSBmYW1pbHlfbmFtZSB1dWlkIGRqVXNlcm5hbWUgZGpTdGF0dXMgdHJhY2tpZCB0YWdzIHBydHMgdXBkYXRlZF9hdCBjcmVhdGVkX2F0IG9mZmxpbmVfYWNjZXNzIGRqaWQiLCJyZWRpcmVjdF91cmkiOiJodHRwczovL215aWJkLmludmVzdG9ycy5jb20vb2lkYy9jYWxsYmFjayIsInVpX2xvY2FsZXMiOiJlbi11cy14LWliZC0yMy03IiwiZXVybCI6Imh0dHBzOi8vd3d3LmludmVzdG9ycy5jb20iLCJub25jZSI6IjY0MDJmYWJiLTFiNzUtNGEyYy1hODRmLTExYWQ2MWFhZGI2YiIsInN0YXRlIjoiZWFyYzdxNlJxNmt5R0t4eS5ZbWxpcTlOMUV2b1NVdHo4Q1Y4bjBWQWM2VnNVeERJUk00U3JsbUliVzJrIiwicmVzb3VyY2UiOiJodHRwcyUzQSUyRiUyRnd3dy5pbnZlc3RvcnMuY29tIiwicHJvdG9jb2wiOiJvYXV0aDIiLCJjbGllbnQiOiJHU1UxcEcyQnJnZDNQdjJLQm5BWjI0enZ5NXVXU0NRbiJ9LCJpc1RoaXJkUGFydHlDbGllbnQiOmZhbHNlLCJhdXRob3JpemF0aW9uU2VydmVyIjp7InVybCI6Imh0dHBzOi8vc3NvLmFjY291bnRzLmRvd2pvbmVzLmNvbSIsImlzc3VlciI6Imh0dHBzOi8vc3NvLmFjY291bnRzLmRvd2pvbmVzLmNvbS8ifX0=' + const ENV_CONFIG = 'production' + + window.sessionStorage.setItem('auth-config', AUTH_CONFIG) + window.sessionStorage.setItem('env-config', ENV_CONFIG) + </script> + <script src="https://cdn.optimizely.com/js/14856860742.js" crossorigin="anonymous"></script> + <script type="text/javascript" src="https://dcdd29eaa743c493e732-7dc0216bc6cc2f4ed239035dfc17235b.ssl.cf3.rackcdn.com/tags/wsj/hokbottom.js"></script> + <script type="text/javascript" src="/R8As7u5b/init.js"></script> + </body> +</html> +` + +func Test_extractAuthConfig(t *testing.T) { + t.Parallel() + expectedJSON := ` +{ + "auth0Domain": "sso.accounts.dowjones.com", + "callbackURL": "https://myibd.investors.com/oidc/callback", + "clientID": "GSU1pG2Brgd3Pv2KBnAZ24zvy5uWSCQn", + "extraParams": { + "protocol": "oauth2", + "scope": "openid idp_id roles email given_name family_name uuid djUsername djStatus trackid tags prts updated_at created_at offline_access djid", + "response_type": "code", + "nonce": "6402fabb-1b75-4a2c-a84f-11ad61aadb6b", + "ui_locales": "en-us-x-ibd-23-7", + "_csrf": "NdURgwOCuXENTEqCp8MWnmpkqwyokbcSa6U__-5boyVsSsASVNHKSA", + "_intstate": "deprecated", + "state": "earc7q6Rq6kyGKxy.Ymliq9N1EvoSUtz8CV8n0VAc6VsUxDIRM4SrlmIbW2k" + }, + "internalOptions": { + "response_type": "code", + "client_id": "GSU1pG2Brgd3Pv2KBnAZ24zvy5uWSCQn", + "scope": "openid idp_id roles email given_name family_name uuid djUsername djStatus trackid tags prts updated_at created_at offline_access djid", + "redirect_uri": "https://myibd.investors.com/oidc/callback", + "ui_locales": "en-us-x-ibd-23-7", + "eurl": "https://www.investors.com", + "nonce": "6402fabb-1b75-4a2c-a84f-11ad61aadb6b", + "state": "earc7q6Rq6kyGKxy.Ymliq9N1EvoSUtz8CV8n0VAc6VsUxDIRM4SrlmIbW2k", + "resource": "https%3A%2F%2Fwww.investors.com", + "protocol": "oauth2", + "client": "GSU1pG2Brgd3Pv2KBnAZ24zvy5uWSCQn" + }, + "isThirdPartyClient": false, + "authorizationServer": { + "url": "https://sso.accounts.dowjones.com", + "issuer": "https://sso.accounts.dowjones.com/" + } +}` + var expectedCfg authConfig + err := json.Unmarshal([]byte(expectedJSON), &expectedCfg) + require.NoError(t, err) + + node, err := html.Parse(strings.NewReader(extractAuthHtml)) + require.NoError(t, err) + + cfg, err := extractAuthConfig(node) + require.NoError(t, err) + require.NotNil(t, cfg) + + assert.Equal(t, expectedCfg, *cfg) +} + +const extractTokenParamsHtml = ` +<form method="post" name="hiddenform" action="https://sso.accounts.dowjones.com/postauth/handler"> + <input type="hidden" name="token" value="eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJkalVzZXJuYW1lIjoiYW5zZzE5MUB5YWhvby5jb20iLCJpZCI6IjAxZWFmNTE5LTA0OWItNGIyOS04ZjZhLWQyNjIyZjNiMWJjNiIsImdpdmVuX25hbWUiOiJBbnNodWwiLCJmYW1pbHlfbmFtZSI6Ikd1cHRhIiwibmFtZSI6IkFuc2h1bCBHdXB0YSIsImVtYWlsIjoiYW5zZzE5MUB5YWhvby5jb20iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYWNjb3VudF9pZCI6Ijk5NzI5Mzc0NDIxMiIsImRqaWQiOiIwMWVhZjUxOS0wNDliLTRiMjktOGY2YS1kMjYyMmYzYjFiYzYiLCJ0cmFja2lkIjoiMWM0NGQyMTRmM2VlYTZiMzcyNDYxNDc3NDc0NDMyODJmMTRmY2ZjYmI4NmE4NmVjYTI0MDc2ZDVlMzU4ZmUzZCIsInVwZGF0ZWRfYXQiOjE3MTI3OTQxNTYsImNyZWF0ZWRfYXQiOjE3MTI3OTQxNTYsInVhdCI6MTcyMjU1MjMzOSwicm9sZXMiOlsiQkFSUk9OUy1DSEFOR0VQQVNTV09SRCIsIkZSRUVSRUctQkFTRSIsIkZSRUVSRUctSU5ESVZJRFVBTCIsIldTSi1DSEFOR0VQQVNTV09SRCIsIldTSi1BUkNISVZFIiwiV1NKLVNFTEZTRVJWIiwiSUJELUlORElWSURVQUwiLCJJQkQtSUNBIiwiSUJELUFFSSJdLCJkalN0YXR1cyI6WyJJQkRfVVNFUlMiXSwicHJ0cyI6IjIwMjQwNDEwMTcwOTE2LTA0MDAiLCJjcmVhdGVUaW1lc3RhbXAiOiIyMDI0MDQxMTAwMDkxNloiLCJzdXVpZCI6Ik1ERmxZV1kxTVRrdE1EUTVZaTAwWWpJNUxUaG1ObUV0WkRJMk1qSm1NMkl4WW1NMi50S09fM014VkVReks3dE5qTkdxUXNZMlBNbXp5cUxGRkxySnBrZGhrcDZrIiwic3ViIjoiMDFlYWY1MTktMDQ5Yi00YjI5LThmNmEtZDI2MjJmM2IxYmM2IiwiYXVkIjoiR1NVMXBHMkJyZ2QzUHYyS0JuQVoyNHp2eTV1V1NDUW4iLCJpc3MiOiJodHRwczovL3Nzby5hY2NvdW50cy5kb3dqb25lcy5jb20vIiwiaWF0IjoxNzIyNTUyMzM5MTI0LCJleHAiOjE3MjI1NTI3NzExMjR9.HVn33IFttQrG1JKEV2oElIy3mm8TJ-3GpV_jqZE81_cY22z4IMWPz7zUGz0WgOoUuQGyrYXiaNrfxD6GaoimRL6wxrH0Fy5iYC3dOEdlGfldswfgEOwSiZkBJRc2wWTVQLm93EeJ5ZZyKIXGY_ZkwcYfhrwaTAz8McBBnRmZkm0eiNJQ5YK-QZL-yFa3DxMdPPW91jLA2rjOIVnJ-I_0nMwaJ4ZwXHG2Sw4aAXxtbFqIqarKwIdOUSpRFOCSYpeWcxmbliurKlP1djrKrYgYSZxsKOHZhnbikZDtoDCAlPRlfbKOO4u36KXooDYGJ6p__s2kGCLOLLkP_QLHMNU8Jg"> + <input type="hidden" name="params" value="%7B%22response_type%22%3A%22code%22%2C%22client_id%22%3A%22GSU1pG2Brgd3Pv2KBnAZ24zvy5uWSCQn%22%2C%22redirect_uri%22%3A%22https%3A%2F%2Fmyibd.investors.com%2Foidc%2Fcallback%22%2C%22state%22%3A%22J-ihUYZIYzey682D.aOLszineC9qjPkM6Y6wWgFC61ABYBiuK9u48AHTFS5I%22%2C%22scope%22%3A%22openid%20idp_id%20roles%20email%20given_name%20family_name%20uuid%20djUsername%20djStatus%20trackid%20tags%20prts%20updated_at%20created_at%20offline_access%20djid%22%2C%22nonce%22%3A%22457bb517-f490-43b6-a55f-d93f90d698ad%22%7D"> + <noscript> + <p>Script is disabled. Click Submit to continue.</p> + <input type="submit" value="Submit"> + </noscript> +</form> +` +const extractTokenExpectedToken = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJkalVzZXJuYW1lIjoiYW5zZzE5MUB5YWhvby5jb20iLCJpZCI6IjAxZWFmNTE5LTA0OWItNGIyOS04ZjZhLWQyNjIyZjNiMWJjNiIsImdpdmVuX25hbWUiOiJBbnNodWwiLCJmYW1pbHlfbmFtZSI6Ikd1cHRhIiwibmFtZSI6IkFuc2h1bCBHdXB0YSIsImVtYWlsIjoiYW5zZzE5MUB5YWhvby5jb20iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYWNjb3VudF9pZCI6Ijk5NzI5Mzc0NDIxMiIsImRqaWQiOiIwMWVhZjUxOS0wNDliLTRiMjktOGY2YS1kMjYyMmYzYjFiYzYiLCJ0cmFja2lkIjoiMWM0NGQyMTRmM2VlYTZiMzcyNDYxNDc3NDc0NDMyODJmMTRmY2ZjYmI4NmE4NmVjYTI0MDc2ZDVlMzU4ZmUzZCIsInVwZGF0ZWRfYXQiOjE3MTI3OTQxNTYsImNyZWF0ZWRfYXQiOjE3MTI3OTQxNTYsInVhdCI6MTcyMjU1MjMzOSwicm9sZXMiOlsiQkFSUk9OUy1DSEFOR0VQQVNTV09SRCIsIkZSRUVSRUctQkFTRSIsIkZSRUVSRUctSU5ESVZJRFVBTCIsIldTSi1DSEFOR0VQQVNTV09SRCIsIldTSi1BUkNISVZFIiwiV1NKLVNFTEZTRVJWIiwiSUJELUlORElWSURVQUwiLCJJQkQtSUNBIiwiSUJELUFFSSJdLCJkalN0YXR1cyI6WyJJQkRfVVNFUlMiXSwicHJ0cyI6IjIwMjQwNDEwMTcwOTE2LTA0MDAiLCJjcmVhdGVUaW1lc3RhbXAiOiIyMDI0MDQxMTAwMDkxNloiLCJzdXVpZCI6Ik1ERmxZV1kxTVRrdE1EUTVZaTAwWWpJNUxUaG1ObUV0WkRJMk1qSm1NMkl4WW1NMi50S09fM014VkVReks3dE5qTkdxUXNZMlBNbXp5cUxGRkxySnBrZGhrcDZrIiwic3ViIjoiMDFlYWY1MTktMDQ5Yi00YjI5LThmNmEtZDI2MjJmM2IxYmM2IiwiYXVkIjoiR1NVMXBHMkJyZ2QzUHYyS0JuQVoyNHp2eTV1V1NDUW4iLCJpc3MiOiJodHRwczovL3Nzby5hY2NvdW50cy5kb3dqb25lcy5jb20vIiwiaWF0IjoxNzIyNTUyMzM5MTI0LCJleHAiOjE3MjI1NTI3NzExMjR9.HVn33IFttQrG1JKEV2oElIy3mm8TJ-3GpV_jqZE81_cY22z4IMWPz7zUGz0WgOoUuQGyrYXiaNrfxD6GaoimRL6wxrH0Fy5iYC3dOEdlGfldswfgEOwSiZkBJRc2wWTVQLm93EeJ5ZZyKIXGY_ZkwcYfhrwaTAz8McBBnRmZkm0eiNJQ5YK-QZL-yFa3DxMdPPW91jLA2rjOIVnJ-I_0nMwaJ4ZwXHG2Sw4aAXxtbFqIqarKwIdOUSpRFOCSYpeWcxmbliurKlP1djrKrYgYSZxsKOHZhnbikZDtoDCAlPRlfbKOO4u36KXooDYGJ6p__s2kGCLOLLkP_QLHMNU8Jg" +const extractTokenExpectedParams = "%7B%22response_type%22%3A%22code%22%2C%22client_id%22%3A%22GSU1pG2Brgd3Pv2KBnAZ24zvy5uWSCQn%22%2C%22redirect_uri%22%3A%22https%3A%2F%2Fmyibd.investors.com%2Foidc%2Fcallback%22%2C%22state%22%3A%22J-ihUYZIYzey682D.aOLszineC9qjPkM6Y6wWgFC61ABYBiuK9u48AHTFS5I%22%2C%22scope%22%3A%22openid%20idp_id%20roles%20email%20given_name%20family_name%20uuid%20djUsername%20djStatus%20trackid%20tags%20prts%20updated_at%20created_at%20offline_access%20djid%22%2C%22nonce%22%3A%22457bb517-f490-43b6-a55f-d93f90d698ad%22%7D" + +func Test_extractTokenParams(t *testing.T) { + t.Parallel() + + node, err := html.Parse(strings.NewReader(extractTokenParamsHtml)) + require.NoError(t, err) + + token, params, err := extractTokenParams(node) + require.NoError(t, err) + assert.Equal(t, extractTokenExpectedToken, token) + assert.Equal(t, extractTokenExpectedParams, params) +} + +func TestClient_Authenticate(t *testing.T) { + t.Parallel() + + expectedVal := "test-cookie" + expectedExp := time.Now().Add(time.Hour).Round(time.Second).In(time.UTC) + + server := newServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + uri := r.URL.String() + switch uri { + case signInUrl: + w.Header().Set("Content-Type", "text/html") + _, err := w.Write([]byte(extractAuthHtml)) + require.NoError(t, err) + return + case authenticateUrl: + var body authRequestBody + require.NoError(t, json.NewDecoder(r.Body).Decode(&body)) + assert.Equal(t, "abc", body.Username) + assert.Equal(t, "xyz", body.Password) + + w.Header().Set("Content-Type", "text/html") + _, err := w.Write([]byte(extractTokenParamsHtml)) + require.NoError(t, err) + return + case postAuthUrl: + require.NoError(t, r.ParseForm()) + assert.Equal(t, extractTokenExpectedToken, r.Form.Get("token")) + + params, err := url.QueryUnescape(extractTokenExpectedParams) + require.NoError(t, err) + assert.Equal(t, params, r.Form.Get("params")) + + w.Header().Set("Content-Type", "text/html") + http.SetCookie(w, &http.Cookie{Name: cookieName, Value: expectedVal, Expires: expectedExp}) + _, err = w.Write([]byte("OK")) + require.NoError(t, err) + return + default: + t.Fatalf("unexpected URL: %s", uri) + } + })) + + client, err := NewClient(http.DefaultClient, apiKey, nil, "", WithBaseURL(server.URL)) + require.NoError(t, err) + + cookie, err := client.Authenticate(context.Background(), "abc", "xyz") + require.NoError(t, err) + require.NotNil(t, cookie) + + assert.Equal(t, expectedVal, cookie.Value) + assert.Equal(t, expectedExp, cookie.Expires) +} + +func TestClient_Authenticate_401(t *testing.T) { + t.Parallel() + + server := newServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + uri := r.URL.String() + switch uri { + case signInUrl: + w.Header().Set("Content-Type", "text/html") + _, err := w.Write([]byte(extractAuthHtml)) + require.NoError(t, err) + return + case authenticateUrl: + var body authRequestBody + require.NoError(t, json.NewDecoder(r.Body).Decode(&body)) + assert.Equal(t, "abc", body.Username) + assert.Equal(t, "xyz", body.Password) + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + _, err := w.Write([]byte(`{"name":"ValidationError","code":"ERR016","message":"Wrong username or password","description":"Wrong username or password"}`)) + require.NoError(t, err) + return + default: + t.Fatalf("unexpected URL: %s", uri) + } + })) + + client, err := NewClient(http.DefaultClient, apiKey, nil, "", WithBaseURL(server.URL)) + require.NoError(t, err) + + cookie, err := client.Authenticate(context.Background(), "abc", "xyz") + assert.Nil(t, cookie) + assert.ErrorIs(t, err, ErrBadCredentials) +} diff --git a/backend/internal/ibd/client.go b/backend/internal/ibd/client.go new file mode 100644 index 0000000..eb3d27e --- /dev/null +++ b/backend/internal/ibd/client.go @@ -0,0 +1,146 @@ +package ibd + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strconv" + + "ibd-trader/internal/database" +) + +var ErrNoAvailableCookies = errors.New("no available cookies") + +type Client struct { + // HTTP client used to make requests + client *http.Client + // Scrapfly API key + apiKey string + // Client-wide Scrape options + options ScrapeOptions + // Cookie source + cookies database.CookieSource + // Proxy URL for non-scrapfly requests + proxyUrl *url.URL +} + +func NewClient( + client *http.Client, + apiKey string, + cookies database.CookieSource, + proxyUrl string, + opts ...ScrapeOption, +) (*Client, error) { + options := defaultScrapeOptions + for _, opt := range opts { + opt(&options) + } + + pProxyUrl, err := url.Parse(proxyUrl) + if err != nil { + return nil, err + } + + return &Client{ + client: client, + options: options, + apiKey: apiKey, + cookies: cookies, + proxyUrl: pProxyUrl, + }, nil +} + +func (c *Client) getCookie(ctx context.Context, subject *string) (uint, *http.Cookie, error) { + if subject == nil { + // No subject requirement, get any cookie + cookie, err := c.cookies.GetAnyCookie(ctx) + if err != nil { + return 0, nil, err + } + if cookie == nil { + return 0, nil, ErrNoAvailableCookies + } + + return cookie.ID, cookie.ToHTTPCookie(), nil + } + + // Get cookie by subject + cookies, err := c.cookies.GetCookies(ctx, *subject, false) + if err != nil { + return 0, nil, err + } + + if len(cookies) == 0 { + return 0, nil, ErrNoAvailableCookies + } + + cookie := cookies[0] + + return cookie.ID, cookie.ToHTTPCookie(), nil +} + +func (c *Client) Do(req *http.Request, opts ...ScrapeOption) (*ScraperResponse, error) { + options := c.options + for _, opt := range opts { + opt(&options) + } + + // Construct scrape request URL + scrapeUrl, err := url.Parse(options.baseURL) + if err != nil { + panic(err) + } + scrapeUrl.RawQuery = c.constructRawQuery(options, req.URL, req.Header) + + // Construct scrape request + scrapeReq, err := http.NewRequestWithContext(req.Context(), req.Method, scrapeUrl.String(), req.Body) + if err != nil { + return nil, err + } + + // Send scrape request + resp, err := c.client.Do(scrapeReq) + if err != nil { + return nil, err + } + defer func(Body io.ReadCloser) { + _ = Body.Close() + }(resp.Body) + + // Parse scrape response + scraperResponse := new(ScraperResponse) + err = json.NewDecoder(resp.Body).Decode(scraperResponse) + if err != nil { + return nil, err + } + + return scraperResponse, nil +} + +func (c *Client) constructRawQuery(options ScrapeOptions, u *url.URL, headers http.Header) string { + params := url.Values{} + params.Set("key", c.apiKey) + params.Set("url", u.String()) + if options.country != nil { + params.Set("country", *options.country) + } + params.Set("asp", strconv.FormatBool(options.asp)) + params.Set("proxy_pool", options.proxyPool.String()) + params.Set("render_js", strconv.FormatBool(options.renderJS)) + params.Set("cache", strconv.FormatBool(options.cache)) + + for k, v := range headers { + for i, vv := range v { + params.Add( + fmt.Sprintf("headers[%s][%d]", k, i), + vv, + ) + } + } + + return params.Encode() +} diff --git a/backend/internal/ibd/client_test.go b/backend/internal/ibd/client_test.go new file mode 100644 index 0000000..577987d --- /dev/null +++ b/backend/internal/ibd/client_test.go @@ -0,0 +1,241 @@ +package ibd + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "ibd-trader/internal/database" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const apiKey = "test-api-key-123" + +func newServer(t *testing.T, handler http.Handler) *httptest.Server { + t.Helper() + + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.Close = true + defer r.Body.Close() + req := reconstructReq(t, r) + + rw := newResponseWriter() + handler.ServeHTTP(rw, req) + require.NoError(t, rw.Done(w)) + })) +} + +func reconstructReq(t *testing.T, r *http.Request) *http.Request { + t.Helper() + + params := r.URL.Query() + require.Equal(t, apiKey, params.Get("key")) + + // Reconstruct the request from the query params + var key string + var url string + headers := make(http.Header) + for k, v := range params { + switch k { + case "key": + key = v[0] + case "url": + url = v[0] + default: + if strings.HasPrefix(k, "headers") { + var name string + // Get index of first [ + i := strings.Index(k, "[") + if i == -1 { + t.Fatalf("invalid header key: %s", k) + } + // Get index of first ] + j := strings.Index(k, "]") + if j == -1 { + t.Fatalf("invalid header key: %s", k) + } + + // Get the header name + name = k[i+1 : j] + headers.Set(name, v[0]) + } + } + } + require.Equal(t, apiKey, key) + require.NotEmpty(t, url) + + req, err := http.NewRequestWithContext(r.Context(), r.Method, url, r.Body) + require.NoError(t, err) + req.Header = headers + + return req +} + +type responsewriter struct { + ret ScraperResponse + body bytes.Buffer + headers http.Header +} + +func newResponseWriter() *responsewriter { + return &responsewriter{ + headers: make(http.Header), + } +} + +func (w *responsewriter) Header() http.Header { + return w.headers +} + +func (w *responsewriter) Write(bytes []byte) (int, error) { + if w.ret.Result.StatusCode == 0 { + w.ret.Result.StatusCode = http.StatusOK + } + return w.body.Write(bytes) +} + +func (w *responsewriter) WriteHeader(statusCode int) { + w.ret.Result.StatusCode = statusCode +} + +func (w *responsewriter) Done(rw http.ResponseWriter) error { + w.ret.Result.Content = w.body.String() + + w.ret.Result.ResponseHeaders = make(map[string]string) + for k, v := range w.headers { + if k == "Set-Cookie" { + continue + } + w.ret.Result.ResponseHeaders[k] = v[0] + } + + req := http.Response{Header: w.headers} + w.ret.Result.Cookies = make([]ScraperCookie, 0) + for _, c := range req.Cookies() { + var cookie ScraperCookie + cookie.FromHTTPCookie(c) + w.ret.Result.Cookies = append(w.ret.Result.Cookies, cookie) + } + + rw.WriteHeader(http.StatusOK) + return json.NewEncoder(rw).Encode(w.ret) +} + +func TestClient_getCookie(t *testing.T) { + t.Parallel() + + t.Run("no cookies", func(t *testing.T) { + t.Parallel() + + client, err := NewClient( + http.DefaultClient, + apiKey, + new(emptyCookieSourceStub), + "", + ) + require.NoError(t, err) + + _, _, err = client.getCookie(context.Background(), nil) + assert.ErrorIs(t, err, ErrNoAvailableCookies) + }) + + t.Run("no cookies by subject", func(t *testing.T) { + t.Parallel() + + client, err := NewClient( + http.DefaultClient, + apiKey, + new(emptyCookieSourceStub), + "", + ) + require.NoError(t, err) + + subject := "test" + _, _, err = client.getCookie(context.Background(), &subject) + assert.ErrorIs(t, err, ErrNoAvailableCookies) + }) + + t.Run("get any cookie", func(t *testing.T) { + t.Parallel() + + client, err := NewClient( + http.DefaultClient, + apiKey, + new(cookieSourceStub), + "", + ) + require.NoError(t, err) + + id, cookie, err := client.getCookie(context.Background(), nil) + require.NoError(t, err) + assert.Equal(t, uint(42), id) + assert.Equal(t, cookieName, cookie.Name) + assert.Equal(t, "test-token", cookie.Value) + assert.Equal(t, "/", cookie.Path) + assert.Equal(t, time.Unix(0, 0), cookie.Expires) + assert.Equal(t, "investors.com", cookie.Domain) + }) + + t.Run("get cookie by subject", func(t *testing.T) { + t.Parallel() + + client, err := NewClient( + http.DefaultClient, + apiKey, + new(cookieSourceStub), + "", + ) + require.NoError(t, err) + + subject := "test" + id, cookie, err := client.getCookie(context.Background(), &subject) + require.NoError(t, err) + assert.Equal(t, uint(42), id) + assert.Equal(t, cookieName, cookie.Name) + assert.Equal(t, "test-token", cookie.Value) + assert.Equal(t, "/", cookie.Path) + assert.Equal(t, time.Unix(0, 0), cookie.Expires) + assert.Equal(t, "investors.com", cookie.Domain) + }) +} + +type emptyCookieSourceStub struct{} + +func (c *emptyCookieSourceStub) GetAnyCookie(_ context.Context) (*database.IBDCookie, error) { + return nil, nil +} + +func (c *emptyCookieSourceStub) GetCookies(_ context.Context, _ string, _ bool) ([]database.IBDCookie, error) { + return nil, nil +} + +func (c *emptyCookieSourceStub) ReportCookieFailure(_ context.Context, _ uint) error { + return nil +} + +var testCookie = database.IBDCookie{ + ID: 42, + Token: "test-token", + Expiry: time.Unix(0, 0), +} + +type cookieSourceStub struct{} + +func (c *cookieSourceStub) GetAnyCookie(_ context.Context) (*database.IBDCookie, error) { + return &testCookie, nil +} + +func (c *cookieSourceStub) GetCookies(_ context.Context, _ string, _ bool) ([]database.IBDCookie, error) { + return []database.IBDCookie{testCookie}, nil +} + +func (c *cookieSourceStub) ReportCookieFailure(_ context.Context, _ uint) error { + return nil +} diff --git a/backend/internal/ibd/html_helpers.go b/backend/internal/ibd/html_helpers.go new file mode 100644 index 0000000..0176bc5 --- /dev/null +++ b/backend/internal/ibd/html_helpers.go @@ -0,0 +1,99 @@ +package ibd + +import ( + "strings" + + "golang.org/x/net/html" +) + +func findChildren(node *html.Node, f func(node *html.Node) bool) (found []*html.Node) { + for c := node.FirstChild; c != nil; c = c.NextSibling { + if f(c) { + found = append(found, c) + } + } + return +} + +func findChildrenRecursive(node *html.Node, f func(node *html.Node) bool) (found []*html.Node) { + if f(node) { + found = append(found, node) + } + + for c := node.FirstChild; c != nil; c = c.NextSibling { + found = append(found, findChildrenRecursive(c, f)...) + } + + return +} + +func findClass(node *html.Node, className string) (found *html.Node) { + if isClass(node, className) { + return node + } + + for c := node.FirstChild; c != nil; c = c.NextSibling { + if found = findClass(c, className); found != nil { + return + } + } + + return +} + +func isClass(node *html.Node, className string) bool { + if node.Type == html.ElementNode { + for _, attr := range node.Attr { + if attr.Key != "class" { + continue + } + classes := strings.Fields(attr.Val) + for _, class := range classes { + if class == className { + return true + } + } + } + } + return false +} + +func extractText(node *html.Node) string { + var result strings.Builder + extractTextInner(node, &result) + return result.String() +} + +func extractTextInner(node *html.Node, result *strings.Builder) { + if node.Type == html.TextNode { + result.WriteString(node.Data) + } + for c := node.FirstChild; c != nil; c = c.NextSibling { + extractTextInner(c, result) + } +} + +func findId(node *html.Node, id string) (found *html.Node) { + if isId(node, id) { + return node + } + + for c := node.FirstChild; c != nil; c = c.NextSibling { + if found = findId(c, id); found != nil { + return + } + } + + return +} + +func isId(node *html.Node, id string) bool { + if node.Type == html.ElementNode { + for _, attr := range node.Attr { + if attr.Key == "id" && attr.Val == id { + return true + } + } + } + return false +} diff --git a/backend/internal/ibd/html_helpers_test.go b/backend/internal/ibd/html_helpers_test.go new file mode 100644 index 0000000..d251c39 --- /dev/null +++ b/backend/internal/ibd/html_helpers_test.go @@ -0,0 +1,79 @@ +package ibd + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/net/html" +) + +func Test_findClass(t *testing.T) { + t.Parallel() + tests := []struct { + name string + html string + className string + found bool + expData string + }{ + { + name: "class exists", + html: `<div class="foo"></div>`, + className: "foo", + found: true, + expData: "div", + }, + { + name: "class exists nested", + html: `<div class="foo"><a class="abc"></a></div>`, + className: "abc", + found: true, + expData: "a", + }, + { + name: "class exists multiple", + html: `<div class="foo"><a class="foo"></a></div>`, + className: "foo", + found: true, + expData: "div", + }, + { + name: "class missing", + html: `<div class="abc"><a class="xyz"></a></div>`, + className: "foo", + found: false, + expData: "", + }, + { + name: "class missing", + html: `<div id="foo"><a abc="xyz"></a></div>`, + className: "foo", + found: false, + expData: "", + }, + { + name: "class exists multiple save div", + html: `<div class="foo bar"></div>`, + className: "bar", + found: true, + expData: "div", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + node, err := html.Parse(strings.NewReader(tt.html)) + require.NoError(t, err) + + got := findClass(node, tt.className) + if !tt.found { + require.Nil(t, got) + return + } + require.NotNil(t, got) + assert.Equal(t, tt.expData, got.Data) + }) + } +} diff --git a/backend/internal/ibd/ibd50.go b/backend/internal/ibd/ibd50.go new file mode 100644 index 0000000..93aa31d --- /dev/null +++ b/backend/internal/ibd/ibd50.go @@ -0,0 +1,185 @@ +package ibd + +import ( + "context" + "encoding/json" + "errors" + "io" + "log/slog" + "net/http" + "net/url" + "strconv" +) + +const ibd50Url = "https://research.investors.com/Services/SiteAjaxService.asmx/GetIBD50?sortcolumn1=%22ibd100rank%22&sortOrder1=%22asc%22&sortcolumn2=%22%22&sortOrder2=%22ASC%22" + +// GetIBD50 returns the IBD50 list. +func (c *Client) GetIBD50(ctx context.Context) ([]*Stock, error) { + // We cannot use the scraper here because scrapfly does not support + // Content-Type in GET requests. + req, err := http.NewRequestWithContext(ctx, http.MethodGet, ibd50Url, nil) + if err != nil { + return nil, err + } + + cookieId, cookie, err := c.getCookie(ctx, nil) + if err != nil { + return nil, err + } + req.AddCookie(cookie) + + req.Header.Add("content-type", "application/json; charset=utf-8") + // Add browser-emulating headers + req.Header.Add("accept", "*/*") + req.Header.Add("accept-language", "en-US,en;q=0.9") + req.Header.Add("newrelic", "eyJ2IjpbMCwxXSwiZCI6eyJ0eSI6IkJyb3dzZXIiLCJhYyI6IjMzOTYxMDYiLCJhcCI6IjEzODU5ODMwMDEiLCJpZCI6IjM1Zjk5NmM2MzNjYTViMWYiLCJ0ciI6IjM3ZmRhZmJlOGY2YjhmYTMwYWMzOTkzOGNlMmM0OWMxIiwidGkiOjE3MjIyNzg0NTk3MjUsInRrIjoiMTAyMjY4MSJ9fQ==") + req.Header.Add("priority", "u=1, i") + req.Header.Add("referer", "https://research.investors.com/stock-lists/ibd-50/") + req.Header.Add("sec-ch-ua", "\"Not/A)Brand\";v=\"8\", \"Chromium\";v=\"126\", \"Google Chrome\";v=\"126\"") + req.Header.Add("sec-ch-ua-mobile", "?0") + req.Header.Add("sec-ch-ua-platform", "\"macOS\"") + req.Header.Add("sec-fetch-dest", "empty") + req.Header.Add("sec-fetch-mode", "cors") + req.Header.Add("sec-fetch-site", "same-origin") + req.Header.Add("traceparent", "00-37fdafbe8f6b8fa30ac39938ce2c49c1-35f996c633ca5b1f-01") + req.Header.Add("tracestate", "1022681@nr=0-1-3396106-1385983001-35f996c633ca5b1f----1722278459725") + req.Header.Add("user-agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/126.0.0.0 Safari/537.36") + req.Header.Add("x-newrelic-id", "VwUOV1dTDhABV1FRBgQOVVUF") + req.Header.Add("x-requested-with", "XMLHttpRequest") + + // Clone client to add proxy + client := *(c.client) + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.Proxy = http.ProxyURL(c.proxyUrl) + + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer func(Body io.ReadCloser) { + _ = Body.Close() + }(resp.Body) + + var ibd50Resp getIBD50Response + if err = json.NewDecoder(resp.Body).Decode(&ibd50Resp); err != nil { + return nil, err + } + + // If there are less than 10 stocks in the IBD50 list, it's likely that authentication failed. + if len(ibd50Resp.D.ETablesDataList) < 10 { + // Report cookie failure to DB + if err = c.cookies.ReportCookieFailure(ctx, cookieId); err != nil { + slog.Error("Failed to report cookie failure", "error", err) + } + return nil, errors.New("failed to get IBD50 list") + } + + return ibd50Resp.ToStockList(), nil +} + +type Stock struct { + Rank int64 + Symbol string + Name string + + QuoteURL *url.URL +} + +type getIBD50Response struct { + D struct { + Type *string `json:"__type"` + ETablesDataList []struct { + Rank string `json:"Rank"` + Symbol string `json:"Symbol"` + CompanyName string `json:"CompanyName"` + CompRating *string `json:"CompRating"` + EPSRank *string `json:"EPSRank"` + RelSt *string `json:"RelSt"` + GrpStr *string `json:"GrpStr"` + Smr *string `json:"Smr"` + AccDis *string `json:"AccDis"` + SponRating *string `json:"SponRating"` + Price *string `json:"Price"` + PriceClose *string `json:"PriceClose"` + PriceChange *string `json:"PriceChange"` + PricePerChange *string `json:"PricePerChange"` + VolPerChange *string `json:"VolPerChange"` + DailyVol *string `json:"DailyVol"` + WeekHigh52 *string `json:"WeekHigh52"` + PerOffHigh *string `json:"PerOffHigh"` + PERatio *string `json:"PERatio"` + DivYield *string `json:"DivYield"` + LastQtrSalesPerChg *string `json:"LastQtrSalesPerChg"` + LastQtrEpsPerChg *string `json:"LastQtrEpsPerChg"` + ConsecQtrEpsGrt15 *string `json:"ConsecQtrEpsGrt15"` + CurQtrEpsEstPerChg *string `json:"CurQtrEpsEstPerChg"` + CurYrEpsEstPerChg *string `json:"CurYrEpsEstPerChg"` + PretaxMargin *string `json:"PretaxMargin"` + ROE *string `json:"ROE"` + MgmtOwnsPer *string `json:"MgmtOwnsPer"` + QuoteUrl *string `json:"QuoteUrl"` + StockCheckupUrl *string `json:"StockCheckupUrl"` + MarketsmithUrl *string `json:"MarketsmithUrl"` + LeaderboardUrl *string `json:"LeaderboardUrl"` + ChartAnalysisUrl *string `json:"ChartAnalysisUrl"` + Ibd100NewEntryFlag *string `json:"Ibd100NewEntryFlag"` + Ibd100UpInRankFlag *string `json:"Ibd100UpInRankFlag"` + IbdBigCap20NewEntryFlag *string `json:"IbdBigCap20NewEntryFlag"` + CompDesc *string `json:"CompDesc"` + NumberFunds *string `json:"NumberFunds"` + GlobalRank *string `json:"GlobalRank"` + EPSPriorQtr *string `json:"EPSPriorQtr"` + QtrsFundIncrease *string `json:"QtrsFundIncrease"` + } `json:"ETablesDataList"` + IBD50PdfUrl *string `json:"IBD50PdfUrl"` + CAP20PdfUrl *string `json:"CAP20PdfUrl"` + IBD50Date *string `json:"IBD50Date"` + CAP20Date *string `json:"CAP20Date"` + UpdatedDate *string `json:"UpdatedDate"` + GetAllFlags *string `json:"getAllFlags"` + Flag *int `json:"flag"` + Message *string `json:"Message"` + PaywallDesktopMarkup *string `json:"PaywallDesktopMarkup"` + PaywallMobileMarkup *string `json:"PaywallMobileMarkup"` + } `json:"d"` +} + +func (r getIBD50Response) ToStockList() (ibd []*Stock) { + ibd = make([]*Stock, 0, len(r.D.ETablesDataList)) + for _, data := range r.D.ETablesDataList { + rank, err := strconv.ParseInt(data.Rank, 10, 64) + if err != nil { + slog.Error( + "Failed to parse Rank", + "error", err, + "rank", data.Rank, + "symbol", data.Symbol, + "name", data.CompanyName, + ) + continue + } + + var quoteUrl *url.URL + if data.QuoteUrl != nil { + quoteUrl, err = url.Parse(*data.QuoteUrl) + if err != nil { + slog.Error( + "Failed to parse QuoteUrl", + "error", err, + "quoteUrl", *data.QuoteUrl, + "rank", data.Rank, + "symbol", data.Symbol, + "name", data.CompanyName, + ) + } + } + + ibd = append(ibd, &Stock{ + Rank: rank, + Symbol: data.Symbol, + Name: data.CompanyName, + QuoteURL: quoteUrl, + }) + } + return +} diff --git a/backend/internal/ibd/options.go b/backend/internal/ibd/options.go new file mode 100644 index 0000000..a07241e --- /dev/null +++ b/backend/internal/ibd/options.go @@ -0,0 +1,84 @@ +package ibd + +const BaseURL = "https://api.scrapfly.io/scrape" + +var defaultScrapeOptions = ScrapeOptions{ + baseURL: BaseURL, + country: nil, + asp: true, + proxyPool: ProxyPoolDatacenter, + renderJS: false, + cache: false, +} + +type ScrapeOption func(*ScrapeOptions) + +type ScrapeOptions struct { + baseURL string + country *string + asp bool + proxyPool ProxyPool + renderJS bool + cache bool + debug bool +} + +type ProxyPool uint8 + +const ( + ProxyPoolDatacenter ProxyPool = iota + ProxyPoolResidential +) + +func (p ProxyPool) String() string { + switch p { + case ProxyPoolDatacenter: + return "public_datacenter_pool" + case ProxyPoolResidential: + return "public_residential_pool" + default: + panic("invalid proxy pool") + } +} + +func WithCountry(country string) ScrapeOption { + return func(o *ScrapeOptions) { + o.country = &country + } +} + +func WithASP(asp bool) ScrapeOption { + return func(o *ScrapeOptions) { + o.asp = asp + } +} + +func WithProxyPool(proxyPool ProxyPool) ScrapeOption { + return func(o *ScrapeOptions) { + o.proxyPool = proxyPool + } +} + +func WithRenderJS(jsRender bool) ScrapeOption { + return func(o *ScrapeOptions) { + o.renderJS = jsRender + } +} + +func WithCache(cache bool) ScrapeOption { + return func(o *ScrapeOptions) { + o.cache = cache + } +} + +func WithDebug(debug bool) ScrapeOption { + return func(o *ScrapeOptions) { + o.debug = debug + } +} + +func WithBaseURL(baseURL string) ScrapeOption { + return func(o *ScrapeOptions) { + o.baseURL = baseURL + } +} diff --git a/backend/internal/ibd/scraper_types.go b/backend/internal/ibd/scraper_types.go new file mode 100644 index 0000000..c21ed1c --- /dev/null +++ b/backend/internal/ibd/scraper_types.go @@ -0,0 +1,227 @@ +package ibd + +import ( + "fmt" + "net/http" + "time" +) + +type ScraperResponse struct { + Config struct { + Asp bool `json:"asp"` + AutoScroll bool `json:"auto_scroll"` + Body interface{} `json:"body"` + Cache bool `json:"cache"` + CacheClear bool `json:"cache_clear"` + CacheTtl int `json:"cache_ttl"` + CorrelationId interface{} `json:"correlation_id"` + CostBudget interface{} `json:"cost_budget"` + Country interface{} `json:"country"` + Debug bool `json:"debug"` + Dns bool `json:"dns"` + Env string `json:"env"` + Extract interface{} `json:"extract"` + ExtractionModel interface{} `json:"extraction_model"` + ExtractionModelCustomSchema interface{} `json:"extraction_model_custom_schema"` + ExtractionPrompt interface{} `json:"extraction_prompt"` + ExtractionTemplate interface{} `json:"extraction_template"` + Format string `json:"format"` + Geolocation interface{} `json:"geolocation"` + Headers struct { + Cookie []string `json:"Cookie"` + } `json:"headers"` + JobUuid interface{} `json:"job_uuid"` + Js interface{} `json:"js"` + JsScenario interface{} `json:"js_scenario"` + Lang interface{} `json:"lang"` + LogEvictionDate string `json:"log_eviction_date"` + Method string `json:"method"` + Origin string `json:"origin"` + Os interface{} `json:"os"` + Project string `json:"project"` + ProxyPool string `json:"proxy_pool"` + RenderJs bool `json:"render_js"` + RenderingStage string `json:"rendering_stage"` + RenderingWait int `json:"rendering_wait"` + Retry bool `json:"retry"` + ScheduleName interface{} `json:"schedule_name"` + ScreenshotFlags interface{} `json:"screenshot_flags"` + ScreenshotResolution interface{} `json:"screenshot_resolution"` + Screenshots interface{} `json:"screenshots"` + Session interface{} `json:"session"` + SessionStickyProxy bool `json:"session_sticky_proxy"` + Ssl bool `json:"ssl"` + Tags interface{} `json:"tags"` + Timeout int `json:"timeout"` + Url string `json:"url"` + UserUuid string `json:"user_uuid"` + Uuid string `json:"uuid"` + WaitForSelector interface{} `json:"wait_for_selector"` + WebhookName interface{} `json:"webhook_name"` + } `json:"config"` + Context struct { + Asp interface{} `json:"asp"` + BandwidthConsumed int `json:"bandwidth_consumed"` + BandwidthImagesConsumed int `json:"bandwidth_images_consumed"` + Cache struct { + Entry interface{} `json:"entry"` + State string `json:"state"` + } `json:"cache"` + Cookies []struct { + Comment interface{} `json:"comment"` + Domain string `json:"domain"` + Expires *string `json:"expires"` + HttpOnly bool `json:"http_only"` + MaxAge interface{} `json:"max_age"` + Name string `json:"name"` + Path string `json:"path"` + Secure bool `json:"secure"` + Size int `json:"size"` + Value string `json:"value"` + Version interface{} `json:"version"` + } `json:"cookies"` + Cost struct { + Details []struct { + Amount int `json:"amount"` + Code string `json:"code"` + Description string `json:"description"` + } `json:"details"` + Total int `json:"total"` + } `json:"cost"` + CreatedAt string `json:"created_at"` + Debug interface{} `json:"debug"` + Env string `json:"env"` + Fingerprint string `json:"fingerprint"` + Headers struct { + Cookie string `json:"Cookie"` + } `json:"headers"` + IsXmlHttpRequest bool `json:"is_xml_http_request"` + Job interface{} `json:"job"` + Lang []string `json:"lang"` + Os struct { + Distribution string `json:"distribution"` + Name string `json:"name"` + Type string `json:"type"` + Version string `json:"version"` + } `json:"os"` + Project string `json:"project"` + Proxy struct { + Country string `json:"country"` + Identity string `json:"identity"` + Network string `json:"network"` + Pool string `json:"pool"` + } `json:"proxy"` + Redirects []interface{} `json:"redirects"` + Retry int `json:"retry"` + Schedule interface{} `json:"schedule"` + Session interface{} `json:"session"` + Spider interface{} `json:"spider"` + Throttler interface{} `json:"throttler"` + Uri struct { + BaseUrl string `json:"base_url"` + Fragment interface{} `json:"fragment"` + Host string `json:"host"` + Params interface{} `json:"params"` + Port int `json:"port"` + Query string `json:"query"` + RootDomain string `json:"root_domain"` + Scheme string `json:"scheme"` + } `json:"uri"` + Url string `json:"url"` + Webhook interface{} `json:"webhook"` + } `json:"context"` + Insights interface{} `json:"insights"` + Result ScraperResult `json:"result"` + Uuid string `json:"uuid"` +} + +type ScraperResult struct { + BrowserData struct { + JavascriptEvaluationResult interface{} `json:"javascript_evaluation_result"` + JsScenario []interface{} `json:"js_scenario"` + LocalStorageData struct { + } `json:"local_storage_data"` + SessionStorageData struct { + } `json:"session_storage_data"` + Websockets []interface{} `json:"websockets"` + XhrCall interface{} `json:"xhr_call"` + } `json:"browser_data"` + Content string `json:"content"` + ContentEncoding string `json:"content_encoding"` + ContentFormat string `json:"content_format"` + ContentType string `json:"content_type"` + Cookies []ScraperCookie `json:"cookies"` + Data interface{} `json:"data"` + Dns interface{} `json:"dns"` + Duration float64 `json:"duration"` + Error interface{} `json:"error"` + ExtractedData interface{} `json:"extracted_data"` + Format string `json:"format"` + Iframes []interface{} `json:"iframes"` + LogUrl string `json:"log_url"` + Reason string `json:"reason"` + RequestHeaders map[string]string `json:"request_headers"` + ResponseHeaders map[string]string `json:"response_headers"` + Screenshots struct { + } `json:"screenshots"` + Size int `json:"size"` + Ssl interface{} `json:"ssl"` + Status string `json:"status"` + StatusCode int `json:"status_code"` + Success bool `json:"success"` + Url string `json:"url"` +} + +type ScraperCookie struct { + Name string `json:"name"` + Value string `json:"value"` + Expires string `json:"expires"` + Path string `json:"path"` + Comment string `json:"comment"` + Domain string `json:"domain"` + MaxAge int `json:"max_age"` + Secure bool `json:"secure"` + HttpOnly bool `json:"http_only"` + Version string `json:"version"` + Size int `json:"size"` +} + +func (c *ScraperCookie) ToHTTPCookie() (*http.Cookie, error) { + var expires time.Time + if c.Expires != "" { + var err error + expires, err = time.Parse("2006-01-02 15:04:05", c.Expires) + if err != nil { + return nil, fmt.Errorf("failed to parse cookie expiration: %w", err) + } + } + return &http.Cookie{ + Name: c.Name, + Value: c.Value, + Path: c.Path, + Domain: c.Domain, + Expires: expires, + Secure: c.Secure, + HttpOnly: c.HttpOnly, + }, nil +} + +func (c *ScraperCookie) FromHTTPCookie(cookie *http.Cookie) { + var expires string + if !cookie.Expires.IsZero() { + expires = cookie.Expires.Format("2006-01-02 15:04:05") + } + *c = ScraperCookie{ + Comment: "", + Domain: cookie.Domain, + Expires: expires, + HttpOnly: cookie.HttpOnly, + MaxAge: cookie.MaxAge, + Name: cookie.Name, + Path: cookie.Path, + Secure: cookie.Secure, + Size: len(cookie.Value), + Value: cookie.Value, + Version: "", + } +} diff --git a/backend/internal/ibd/search.go b/backend/internal/ibd/search.go new file mode 100644 index 0000000..981bd97 --- /dev/null +++ b/backend/internal/ibd/search.go @@ -0,0 +1,103 @@ +package ibd + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "time" + + "ibd-trader/internal/database" +) + +const ( + searchUrl = "https://ibdservices.investors.com/im/api/search" +) + +var ErrSymbolNotFound = fmt.Errorf("symbol not found") + +func (c *Client) Search(ctx context.Context, symbol string) (database.Stock, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, searchUrl, nil) + if err != nil { + return database.Stock{}, err + } + + _, cookie, err := c.getCookie(ctx, nil) + if err != nil { + return database.Stock{}, err + } + req.AddCookie(cookie) + + params := url.Values{} + params.Set("key", symbol) + req.URL.RawQuery = params.Encode() + + resp, err := c.Do(req) + if err != nil { + return database.Stock{}, err + } + + if resp.Result.StatusCode != http.StatusOK { + return database.Stock{}, fmt.Errorf( + "unexpected status code %d: %s", + resp.Result.StatusCode, + resp.Result.Content, + ) + } + + var sr searchResponse + if err = json.Unmarshal([]byte(resp.Result.Content), &sr); err != nil { + return database.Stock{}, err + } + + for _, stock := range sr.StockData { + if stock.Symbol == symbol { + return database.Stock{ + Symbol: stock.Symbol, + Name: stock.Company, + IBDUrl: stock.QuoteUrl, + }, nil + } + } + + return database.Stock{}, ErrSymbolNotFound +} + +type searchResponse struct { + Status int `json:"_status"` + Timestamp string `json:"_timestamp"` + StockData []struct { + Id int `json:"id"` + Symbol string `json:"symbol"` + Company string `json:"company"` + PriceDate string `json:"priceDate"` + Price float64 `json:"price"` + PreviousPrice float64 `json:"previousPrice"` + PriceChange float64 `json:"priceChange"` + PricePctChange float64 `json:"pricePctChange"` + Volume int `json:"volume"` + VolumeChange int `json:"volumeChange"` + VolumePctChange int `json:"volumePctChange"` + QuoteUrl string `json:"quoteUrl"` + } `json:"stockData"` + News []struct { + Title string `json:"title"` + Category string `json:"category"` + Body string `json:"body"` + ImageAlt string `json:"imageAlt"` + ImageUrl string `json:"imageUrl"` + NewsUrl string `json:"newsUrl"` + CategoryUrl string `json:"categoryUrl"` + PublishDate time.Time `json:"publishDate"` + PublishDateUnixts int `json:"publishDateUnixts"` + Stocks []struct { + Id int `json:"id"` + Index int `json:"index"` + Symbol string `json:"symbol"` + PricePctChange string `json:"pricePctChange"` + } `json:"stocks"` + VideoFormat bool `json:"videoFormat"` + } `json:"news"` + FullUrl string `json:"fullUrl"` +} diff --git a/backend/internal/ibd/search_test.go b/backend/internal/ibd/search_test.go new file mode 100644 index 0000000..ac0f578 --- /dev/null +++ b/backend/internal/ibd/search_test.go @@ -0,0 +1,204 @@ +package ibd + +import ( + "context" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const searchResponseJSON = ` +{ + "_status": 200, + "_timestamp": "1722879439.724106", + "stockData": [ + { + "id": 13717, + "symbol": "AAPL", + "company": "Apple", + "priceDate": "2024-08-05T09:18:00", + "price": 212.33, + "previousPrice": 219.86, + "priceChange": -7.53, + "pricePctChange": -3.42, + "volume": 643433, + "volumeChange": -2138, + "volumePctChange": 124, + "quoteUrl": "https://research.investors.com/stock-quotes/nasdaq-apple-aapl.htm" + }, + { + "id": 79964, + "symbol": "AAPU", + "company": "Direxion AAPL Bull 2X", + "priceDate": "2024-08-05T09:18:00", + "price": 32.48, + "previousPrice": 34.9, + "priceChange": -2.42, + "pricePctChange": -6.92, + "volume": 15265, + "volumeChange": -35, + "volumePctChange": 212, + "quoteUrl": "https://research.investors.com/stock-quotes/nasdaq-direxion-aapl-bull-2x-aapu.htm" + }, + { + "id": 80423, + "symbol": "APLY", + "company": "YieldMax AAPL Option Incm", + "priceDate": "2024-08-05T09:11:00", + "price": 17.52, + "previousPrice": 18.15, + "priceChange": -0.63, + "pricePctChange": -3.47, + "volume": 617, + "volumeChange": -2, + "volumePctChange": 97, + "quoteUrl": "https://research.investors.com/stock-quotes/nyse-yieldmax-aapl-option-incm-aply.htm" + }, + { + "id": 79962, + "symbol": "AAPD", + "company": "Direxion Dly AAPL Br 1X", + "priceDate": "2024-08-05T09:18:00", + "price": 18.11, + "previousPrice": 17.53, + "priceChange": 0.58, + "pricePctChange": 3.31, + "volume": 14572, + "volumeChange": -7, + "volumePctChange": 885, + "quoteUrl": "https://research.investors.com/stock-quotes/nasdaq-direxion-dly-aapl-br-1x-aapd.htm" + }, + { + "id": 79968, + "symbol": "AAPB", + "company": "GraniteSh 2x Lg AAPL", + "priceDate": "2024-08-05T09:16:00", + "price": 25.22, + "previousPrice": 27.25, + "priceChange": -2.03, + "pricePctChange": -7.45, + "volume": 2505, + "volumeChange": -7, + "volumePctChange": 151, + "quoteUrl": "https://research.investors.com/stock-quotes/nasdaq-granitesh-2x-lg-aapl-aapb.htm" + } + ], + "news": [ + { + "title": "Warren Buffett Dumped Berkshire Hathaway's Favorite Stocks — Right Before They Plunged", + "category": "News", + "body": "Berkshire Hathaway earnings rose solidly in Q2. Warren Buffett sold nearly half his Apple stock stake. Berkshire stock fell...", + "imageAlt": "", + "imageUrl": "https://www.investors.com/wp-content/uploads/2024/06/Stock-WarrenBuffettwave-01-shutt-640x360.jpg", + "newsUrl": "https://investors.com/news/berkshire-hathaway-earnings-q2-2024-warren-buffett-apple/", + "categoryUrl": "https://investors.com/category/news/", + "publishDate": "2024-08-05T15:51:57+00:00", + "publishDateUnixts": 1722858717, + "stocks": [ + { + "id": 13717, + "index": 0, + "symbol": "AAPL", + "pricePctChange": "-3.42" + } + ], + "videoFormat": false + }, + { + "title": "Nvidia Plunges On Report Of AI Chip Flaw; Is It A Buy Now?", + "category": "Research", + "body": "Nvidia will roll out its Blackwell chip at least three months later than planned.", + "imageAlt": "", + "imageUrl": "https://www.investors.com/wp-content/uploads/2024/01/Stock-Nvidia-studio-01-company-640x360.jpg", + "newsUrl": "https://investors.com/research/nvda-stock-is-nvidia-a-buy-2/", + "categoryUrl": "https://investors.com/category/research/", + "publishDate": "2024-08-05T14:59:22+00:00", + "publishDateUnixts": 1722855562, + "stocks": [ + { + "id": 38607, + "index": 0, + "symbol": "NVDA", + "pricePctChange": "-5.18" + } + ], + "videoFormat": false + }, + { + "title": "Magnificent Seven Stocks Roiled: Nvidia Plunges On AI Chip Delay; Apple, Tesla Dive", + "category": "Research", + "body": "Nvidia stock dived Monday, while Apple and Tesla also fell sharply.", + "imageAlt": "", + "imageUrl": "https://www.investors.com/wp-content/uploads/2022/08/Stock-Nvidia-RTXa5500-comp-640x360.jpg", + "newsUrl": "https://investors.com/research/magnificent-seven-stocks-to-buy-and-and-watch/", + "categoryUrl": "https://investors.com/category/research/", + "publishDate": "2024-08-05T14:51:42+00:00", + "publishDateUnixts": 1722855102, + "stocks": [ + { + "id": 13717, + "index": 0, + "symbol": "AAPL", + "pricePctChange": "-3.42" + } + ], + "videoFormat": false + } + ], + "fullUrl": "https://www.investors.com/search-results/?query=AAPL" +}` + +const emptySearchResponseJSON = ` +{ + "_status": 200, + "_timestamp": "1722879662.804395", + "stockData": [], + "news": [], + "fullUrl": "https://www.investors.com/search-results/?query=abcdefg" +}` + +func TestClient_Search(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + response string + f func(t *testing.T, client *Client) + }{ + { + name: "found", + response: searchResponseJSON, + f: func(t *testing.T, client *Client) { + u, err := client.Search(context.Background(), "AAPL") + require.NoError(t, err) + assert.Equal(t, "AAPL", u.Symbol) + assert.Equal(t, "Apple", u.Name) + assert.Equal(t, "https://research.investors.com/stock-quotes/nasdaq-apple-aapl.htm", u.IBDUrl) + }, + }, + { + name: "not found", + response: emptySearchResponseJSON, + f: func(t *testing.T, client *Client) { + _, err := client.Search(context.Background(), "abcdefg") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrSymbolNotFound) + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := newServer(t, http.HandlerFunc(func(writer http.ResponseWriter, r *http.Request) { + _, _ = writer.Write([]byte(tt.response)) + })) + defer server.Close() + + client, err := NewClient(http.DefaultClient, apiKey, new(cookieSourceStub), "", WithBaseURL(server.URL)) + require.NoError(t, err) + + tt.f(t, client) + }) + } +} diff --git a/backend/internal/ibd/stockinfo.go b/backend/internal/ibd/stockinfo.go new file mode 100644 index 0000000..33fea3d --- /dev/null +++ b/backend/internal/ibd/stockinfo.go @@ -0,0 +1,237 @@ +package ibd + +import ( + "context" + "fmt" + "net/http" + "net/url" + "strconv" + "strings" + + "ibd-trader/internal/database" + "ibd-trader/internal/utils" + + "github.com/Rhymond/go-money" + "golang.org/x/net/html" +) + +func (c *Client) StockInfo(ctx context.Context, uri string) (*database.StockInfo, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, uri, nil) + if err != nil { + return nil, err + } + + _, cookie, err := c.getCookie(ctx, nil) + if err != nil { + return nil, err + } + req.AddCookie(cookie) + + // Set required query parameters + params := url.Values{} + params.Set("list", "ibd50") + params.Set("type", "weekly") + req.URL.RawQuery = params.Encode() + + resp, err := c.Do(req) + if err != nil { + return nil, err + } + + if resp.Result.StatusCode != http.StatusOK { + return nil, fmt.Errorf( + "unexpected status code %d: %s", + resp.Result.StatusCode, + resp.Result.Content, + ) + } + + node, err := html.Parse(strings.NewReader(resp.Result.Content)) + if err != nil { + return nil, err + } + + name, symbol, err := extractNameAndSymbol(node) + if err != nil { + return nil, fmt.Errorf("failed to extract name and symbol: %w", err) + } + chartAnalysis, err := extractChartAnalysis(node) + if err != nil { + return nil, fmt.Errorf("failed to extract chart analysis: %w", err) + } + ratings, err := extractRatings(node) + if err != nil { + return nil, fmt.Errorf("failed to extract ratings: %w", err) + } + price, err := extractPrice(node) + if err != nil { + return nil, fmt.Errorf("failed to extract price: %w", err) + } + + return &database.StockInfo{ + Symbol: symbol, + Name: name, + ChartAnalysis: chartAnalysis, + Ratings: ratings, + Price: price, + }, nil +} + +func extractNameAndSymbol(node *html.Node) (name string, symbol string, err error) { + // Find span with ID "quote-symbol" + quoteSymbolNode := findId(node, "quote-symbol") + if quoteSymbolNode == nil { + return "", "", fmt.Errorf("could not find `quote-symbol` span") + } + + // Get the text of the quote-symbol span + name = strings.TrimSpace(extractText(quoteSymbolNode)) + + // Find span with ID "qteSymb" + qteSymbNode := findId(node, "qteSymb") + if qteSymbNode == nil { + return "", "", fmt.Errorf("could not find `qteSymb` span") + } + + // Get the text of the qteSymb span + symbol = strings.TrimSpace(extractText(qteSymbNode)) + + // Get index of last closing parenthesis + lastParenIndex := strings.LastIndex(name, ")") + if lastParenIndex == -1 { + return + } + + // Find the last opening parenthesis before the closing parenthesis + lastOpenParenIndex := strings.LastIndex(name[:lastParenIndex], "(") + if lastOpenParenIndex == -1 { + return + } + + // Remove the parenthesis pair + name = strings.TrimSpace(name[:lastOpenParenIndex] + name[lastParenIndex+1:]) + return +} + +func extractPrice(node *html.Node) (*money.Money, error) { + // Find the div with the ID "lstPrice" + lstPriceNode := findId(node, "lstPrice") + if lstPriceNode == nil { + return nil, fmt.Errorf("could not find `lstPrice` div") + } + + // Get the text of the lstPrice div + priceStr := strings.TrimSpace(extractText(lstPriceNode)) + + // Parse the price + price, err := utils.ParseMoney(priceStr) + if err != nil { + return nil, fmt.Errorf("failed to parse price: %w", err) + } + + return price, nil +} + +func extractRatings(node *html.Node) (ratings database.Ratings, err error) { + // Find the div with class "smartContent" + smartSelectNode := findClass(node, "smartContent") + if smartSelectNode == nil { + return ratings, fmt.Errorf("could not find `smartContent` div") + } + + // Iterate over children, looking for "smartRating" divs + for c := smartSelectNode.FirstChild; c != nil; c = c.NextSibling { + if !isClass(c, "smartRating") { + continue + } + + err = processSmartRating(c, &ratings) + if err != nil { + return + } + } + return +} + +// processSmartRating extracts the rating from a "smartRating" div and updates the ratings struct. +// +// The node should look like this: +// +// <ul class="smartRating"> +// <li><a><span>Composite Rating</span></a></li> +// <li>94</li> +// ... +// </ul> +func processSmartRating(node *html.Node, ratings *database.Ratings) error { + // Check that the node is a ul + if node.Type != html.ElementNode || node.Data != "ul" { + return fmt.Errorf("expected ul node, got %s", node.Data) + } + + // Get all `li` children + children := findChildren(node, func(node *html.Node) bool { + return node.Type == html.ElementNode && node.Data == "li" + }) + + // Extract the rating name + ratingName := strings.TrimSpace(extractText(children[0])) + + // Extract the rating value + ratingValueStr := strings.TrimSpace(extractText(children[1])) + + switch ratingName { + case "Composite Rating": + ratingValue, err := strconv.ParseUint(ratingValueStr, 10, 8) + if err != nil { + return fmt.Errorf("failed to parse Composite Rating: %w", err) + } + ratings.Composite = uint8(ratingValue) + case "EPS Rating": + ratingValue, err := strconv.ParseUint(ratingValueStr, 10, 8) + if err != nil { + return fmt.Errorf("failed to parse EPS Rating: %w", err) + } + ratings.EPS = uint8(ratingValue) + case "RS Rating": + ratingValue, err := strconv.ParseUint(ratingValueStr, 10, 8) + if err != nil { + return fmt.Errorf("failed to parse RS Rating: %w", err) + } + ratings.RelStr = uint8(ratingValue) + case "Group RS Rating": + ratingValue, err := database.LetterRatingFromString(ratingValueStr) + if err != nil { + return fmt.Errorf("failed to parse Group RS Rating: %w", err) + } + ratings.GroupRelStr = ratingValue + case "SMR Rating": + ratingValue, err := database.LetterRatingFromString(ratingValueStr) + if err != nil { + return fmt.Errorf("failed to parse SMR Rating: %w", err) + } + ratings.SMR = ratingValue + case "Acc/Dis Rating": + ratingValue, err := database.LetterRatingFromString(ratingValueStr) + if err != nil { + return fmt.Errorf("failed to parse Acc/Dis Rating: %w", err) + } + ratings.AccDis = ratingValue + default: + return fmt.Errorf("unknown rating name: %s", ratingName) + } + + return nil +} + +func extractChartAnalysis(node *html.Node) (string, error) { + // Find the div with class "chartAnalysis" + chartAnalysisNode := findClass(node, "chartAnalysis") + if chartAnalysisNode == nil { + return "", fmt.Errorf("could not find `chartAnalysis` div") + } + + // Get the text of the chart analysis div + chartAnalysis := strings.TrimSpace(extractText(chartAnalysisNode)) + + return chartAnalysis, nil +} diff --git a/backend/internal/ibd/userinfo.go b/backend/internal/ibd/userinfo.go new file mode 100644 index 0000000..ba7a5b5 --- /dev/null +++ b/backend/internal/ibd/userinfo.go @@ -0,0 +1,147 @@ +package ibd + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "net/http" +) + +const ( + userInfoUrl = "https://myibd.investors.com/services/userprofile.aspx?format=json" +) + +func (c *Client) UserInfo(ctx context.Context, cookie *http.Cookie) (*UserProfile, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, userInfoUrl, nil) + if err != nil { + return nil, err + } + + req.AddCookie(cookie) + + resp, err := c.Do(req) + if err != nil { + return nil, err + } + + if resp.Result.StatusCode != http.StatusOK { + return nil, fmt.Errorf( + "unexpected status code %d: %s", + resp.Result.StatusCode, + resp.Result.Content, + ) + } + + up := new(UserProfile) + if err = up.UnmarshalJSON([]byte(resp.Result.Content)); err != nil { + return nil, err + } + + return up, nil +} + +type UserStatus string + +const ( + UserStatusUnknown UserStatus = "" + UserStatusVisitor UserStatus = "Visitor" + UserStatusSubscriber UserStatus = "Subscriber" +) + +type UserProfile struct { + DisplayName string + Email string + FirstName string + LastName string + Status UserStatus +} + +func (u *UserProfile) UnmarshalJSON(bytes []byte) error { + var resp userProfileResponse + if err := json.Unmarshal(bytes, &resp); err != nil { + return err + } + + u.DisplayName = resp.UserProfile.UserDisplayName + u.Email = resp.UserProfile.UserEmailAddress + u.FirstName = resp.UserProfile.UserFirstName + u.LastName = resp.UserProfile.UserLastName + + switch resp.UserProfile.UserTrialStatus { + case "Visitor": + u.Status = UserStatusVisitor + case "Subscriber": + u.Status = UserStatusSubscriber + default: + slog.Warn("Unknown user status", "status", resp.UserProfile.UserTrialStatus) + u.Status = UserStatusUnknown + } + + return nil +} + +type userProfileResponse struct { + UserProfile userProfile `json:"userProfile"` +} + +type userProfile struct { + UserSubType string `json:"userSubType"` + UserId string `json:"userId"` + UserDisplayName string `json:"userDisplayName"` + Countrycode string `json:"countrycode"` + IsEUCountry string `json:"isEUCountry"` + Log string `json:"log"` + AgeGroup string `json:"ageGroup"` + Gender string `json:"gender"` + InvestingExperience string `json:"investingExperience"` + NumberOfTrades string `json:"numberOfTrades"` + Occupation string `json:"occupation"` + TypeOfInvestments string `json:"typeOfInvestments"` + UserEmailAddress string `json:"userEmailAddress"` + UserEmailAddressSHA1 string `json:"userEmailAddressSHA1"` + UserEmailAddressSHA256 string `json:"userEmailAddressSHA256"` + UserEmailAddressMD5 string `json:"userEmailAddressMD5"` + UserFirstName string `json:"userFirstName"` + UserLastName string `json:"userLastName"` + UserZip string `json:"userZip"` + UserTrialStatus string `json:"userTrialStatus"` + UserProductsOnTrial string `json:"userProductsOnTrial"` + UserProductsOwned string `json:"userProductsOwned"` + UserAdTrade string `json:"userAdTrade"` + UserAdTime string `json:"userAdTime"` + UserAdHold string `json:"userAdHold"` + UserAdJob string `json:"userAdJob"` + UserAdAge string `json:"userAdAge"` + UserAdOutSell string `json:"userAdOutSell"` + UserVisitCount string `json:"userVisitCount"` + RoleLeaderboard bool `json:"role_leaderboard"` + RoleOws bool `json:"role_ows"` + RoleIbdlive bool `json:"role_ibdlive"` + RoleFounderclub bool `json:"role_founderclub"` + RoleEibd bool `json:"role_eibd"` + RoleIcom bool `json:"role_icom"` + RoleEtables bool `json:"role_etables"` + RoleTru10 bool `json:"role_tru10"` + RoleMarketsurge bool `json:"role_marketsurge"` + RoleSwingtrader bool `json:"role_swingtrader"` + RoleAdfree bool `json:"role_adfree"` + RoleMarketdiem bool `json:"role_marketdiem"` + RoleWsjPlus bool `json:"role_wsj_plus"` + RoleWsj bool `json:"role_wsj"` + RoleBarrons bool `json:"role_barrons"` + RoleMarketwatch bool `json:"role_marketwatch"` + UserAdRoles string `json:"userAdRoles"` + TrialDailyPrintNeg bool `json:"trial_daily_print_neg"` + TrialDailyPrintNon bool `json:"trial_daily_print_non"` + TrialWeeklyPrintNeg bool `json:"trial_weekly_print_neg"` + TrialWeeklyPrintNon bool `json:"trial_weekly_print_non"` + TrialDailyComboNeg bool `json:"trial_daily_combo_neg"` + TrialDailyComboNon bool `json:"trial_daily_combo_non"` + TrialWeeklyComboNeg bool `json:"trial_weekly_combo_neg"` + TrialWeeklyComboNon bool `json:"trial_weekly_combo_non"` + TrialEibdNeg bool `json:"trial_eibd_neg"` + TrialEibdNon bool `json:"trial_eibd_non"` + UserVideoPreference string `json:"userVideoPreference"` + UserProfessionalStatus bool `json:"userProfessionalStatus"` +} diff --git a/backend/internal/keys/gcp.go b/backend/internal/keys/gcp.go new file mode 100644 index 0000000..9d10fc5 --- /dev/null +++ b/backend/internal/keys/gcp.go @@ -0,0 +1,131 @@ +package keys + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "encoding/pem" + "errors" + "fmt" + "hash/crc32" + "sync" + + kms "cloud.google.com/go/kms/apiv1" + "cloud.google.com/go/kms/apiv1/kmspb" + "google.golang.org/protobuf/types/known/wrapperspb" +) + +type GoogleKMS struct { + client *kms.KeyManagementClient + + mx sync.RWMutex + keyCache map[string]*rsa.PublicKey +} + +func NewGoogleKMS(ctx context.Context) (*GoogleKMS, error) { + client, err := kms.NewKeyManagementClient(ctx) + if err != nil { + return nil, err + } + + return &GoogleKMS{ + client: client, + keyCache: make(map[string]*rsa.PublicKey), + }, nil +} + +func (g *GoogleKMS) checkCache(keyName string) *rsa.PublicKey { + g.mx.RLock() + defer g.mx.RUnlock() + + return g.keyCache[keyName] +} + +func (g *GoogleKMS) getPublicKey(ctx context.Context, keyName string) (*rsa.PublicKey, error) { + if key := g.checkCache(keyName); key != nil { + return key, nil + } + + response, err := g.client.GetPublicKey(ctx, &kmspb.GetPublicKeyRequest{Name: keyName}) + if err != nil { + return nil, err + } + + block, _ := pem.Decode([]byte(response.Pem)) + if block == nil || block.Type != "PUBLIC KEY" { + return nil, errors.New("failed to decode PEM public key") + } + publicKey, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse public key: %w", err) + } + rsaKey, ok := publicKey.(*rsa.PublicKey) + if !ok { + return nil, errors.New("public key is not an RSA key") + } + + g.mx.Lock() + defer g.mx.Unlock() + g.keyCache[keyName] = rsaKey + + return rsaKey, nil +} + +func (g *GoogleKMS) Encrypt(ctx context.Context, keyName string, plaintext []byte) ([]byte, error) { + publicKey, err := g.getPublicKey(ctx, keyName) + if err != nil { + return nil, err + } + + cipherText, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, publicKey, plaintext, nil) + if err != nil { + return nil, fmt.Errorf("failed to encrypt plaintext: %w", err) + } + + return cipherText, nil +} + +func (g *GoogleKMS) Decrypt(ctx context.Context, keyName string, ciphertext []byte) ([]byte, error) { + req := &kmspb.AsymmetricDecryptRequest{ + Name: keyName, + Ciphertext: ciphertext, + CiphertextCrc32C: wrapperspb.Int64(int64(calcCRC32(ciphertext))), + } + + result, err := g.client.AsymmetricDecrypt(ctx, req) + if err != nil { + return nil, fmt.Errorf("failed to decrypt ciphertext: %w", err) + } + + if !result.VerifiedCiphertextCrc32C { + return nil, errors.New("AsymmetricDecrypt: request corrupted in-transit") + } + if int64(calcCRC32(result.Plaintext)) != result.PlaintextCrc32C.Value { + return nil, fmt.Errorf("AsymmetricDecrypt: response corrupted in-transit") + } + + return result.Plaintext, nil +} + +func (g *GoogleKMS) Close() error { + return g.client.Close() +} + +func calcCRC32(data []byte) uint32 { + t := crc32.MakeTable(crc32.Castagnoli) + return crc32.Checksum(data, t) +} + +type GCPKeyName struct { + Project string + Location string + KeyRing string + CryptoKey string + CryptoKeyVersion string +} + +func (k GCPKeyName) String() string { + return fmt.Sprintf("projects/%s/locations/%s/keyRings/%s/cryptoKeys/%s/cryptoKeyVersions/%s", k.Project, k.Location, k.KeyRing, k.CryptoKey, k.CryptoKeyVersion) +} diff --git a/backend/internal/keys/keys.go b/backend/internal/keys/keys.go new file mode 100644 index 0000000..ac73173 --- /dev/null +++ b/backend/internal/keys/keys.go @@ -0,0 +1,150 @@ +package keys + +import ( + "context" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "fmt" + "io" +) + +var CSRNG = rand.Reader + +//go:generate mockgen -destination mock_keys_test.go -package keys_test -typed . KeyManagementService +type KeyManagementService interface { + io.Closer + + // Encrypt encrypts the given plaintext using the key with the given key name. + Encrypt(ctx context.Context, keyName string, plaintext []byte) ([]byte, error) + + // Decrypt decrypts the given ciphertext using the key with the given key name. + Decrypt(ctx context.Context, keyName string, ciphertext []byte) ([]byte, error) +} + +// Encrypt encrypts the given plaintext using a hybrid encryption scheme. +// +// It first generates a random AES 256-bit key and encrypts the plaintext with it. +// Then, it encrypts the AES key using the KMS. +// +// It returns the ciphertext, the encrypted AES key, and any errors that occurred. +func Encrypt( + ctx context.Context, + kms KeyManagementService, + keyName string, + plaintext []byte, +) (ciphertext []byte, encryptedKey []byte, err error) { + // Generate a random AES key + aesKey := make([]byte, 32) + if _, err = io.ReadFull(CSRNG, aesKey); err != nil { + return nil, nil, fmt.Errorf("unable to generate AES key: %w", err) + } + + // Encrypt the plaintext using the AES key + ciphertext, err = encrypt(aesKey, plaintext) + if err != nil { + return nil, nil, fmt.Errorf("unable to encrypt plaintext: %w", err) + } + + // Encrypt the AES key using the KMS + encryptedKey, err = kms.Encrypt(ctx, keyName, aesKey) + if err != nil { + return nil, nil, fmt.Errorf("unable to encrypt AES key: %w", err) + } + + return ciphertext, encryptedKey, nil +} + +// EncryptWithKey encrypts the given plaintext using a hybrid encryption scheme. +// +// This works similarly to Encrypt, but instead of generating a new AES key, it uses a given already encrypted AES key. +func EncryptWithKey( + ctx context.Context, + kms KeyManagementService, + keyName string, + encryptedKey []byte, + plaintext []byte, +) ([]byte, error) { + // Decrypt the AES key + aesKey, err := kms.Decrypt(ctx, keyName, encryptedKey) + if err != nil { + return nil, fmt.Errorf("unable to decrypt AES key: %w", err) + } + + // Encrypt the plaintext using the AES key + ciphertext, err := encrypt(aesKey, plaintext) + if err != nil { + return nil, fmt.Errorf("unable to encrypt plaintext: %w", err) + } + + return ciphertext, nil +} + +func encrypt(aesKey []byte, plaintext []byte) ([]byte, error) { + // Create an AES cipher + blockCipher, err := aes.NewCipher(aesKey) + if err != nil { + return nil, fmt.Errorf("unable to create AES cipher: %w", err) + } + + gcm, err := cipher.NewGCM(blockCipher) + if err != nil { + return nil, fmt.Errorf("unable to create GCM: %w", err) + } + + // Generate a random nonce + nonce := make([]byte, gcm.NonceSize()) + if _, err = io.ReadFull(CSRNG, nonce); err != nil { + return nil, fmt.Errorf("unable to generate nonce: %w", err) + } + + // Encrypt the plaintext + ciphertext := gcm.Seal(nonce, nonce, plaintext, nil) + return ciphertext, nil +} + +// Decrypt decrypts the given ciphertext using a hybrid encryption scheme. +// +// It first decrypts the AES key using the KMS. +// Then, it decrypts the ciphertext using the decrypted AES key. +// +// It returns the plaintext and any errors that occurred. +func Decrypt( + ctx context.Context, + kms KeyManagementService, + keyName string, + ciphertext []byte, + encryptedKey []byte, +) ([]byte, error) { + // Decrypt the AES key + aesKey, err := kms.Decrypt(ctx, keyName, encryptedKey) + if err != nil { + return nil, fmt.Errorf("unable to decrypt AES key: %w", err) + } + + // Create an AES cipher + blockCipher, err := aes.NewCipher(aesKey) + if err != nil { + return nil, fmt.Errorf("unable to create AES cipher: %w", err) + } + + gcm, err := cipher.NewGCM(blockCipher) + if err != nil { + return nil, fmt.Errorf("unable to create GCM: %w", err) + } + + // Extract the nonce from the ciphertext + nonceSize := gcm.NonceSize() + if len(ciphertext) < nonceSize { + return nil, fmt.Errorf("ciphertext is too short") + } + nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:] + + // Decrypt the ciphertext + plaintext, err := gcm.Open(nil, nonce, ciphertext, nil) + if err != nil { + return nil, fmt.Errorf("unable to decrypt ciphertext: %w", err) + } + + return plaintext, nil +} diff --git a/backend/internal/keys/keys_test.go b/backend/internal/keys/keys_test.go new file mode 100644 index 0000000..14bbcc2 --- /dev/null +++ b/backend/internal/keys/keys_test.go @@ -0,0 +1,64 @@ +package keys_test + +import ( + "bytes" + "context" + "encoding/hex" + "testing" + + "ibd-trader/internal/keys" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" +) + +func TestEncrypt(t *testing.T) { + ctrl := gomock.NewController(t) + + // Replace RNG with a deterministic RNG + aesKey := []byte("0123456789abcdef0123456789abcdef") + nonce := []byte("0123456789ab") + keys.CSRNG = bytes.NewReader(append(aesKey, nonce...)) + + // Create a mock KMS + kms := NewMockKeyManagementService(ctrl) + keyName := "keyName" + + ctx := context.Background() + plaintext := []byte("plaintext") + + kms.EXPECT(). + Encrypt(ctx, keyName, aesKey). + Return([]byte("encryptedKey"), nil) + + ciphertext, encryptedKey, err := keys.Encrypt(ctx, kms, keyName, plaintext) + require.NoError(t, err) + + encrypted, err := hex.DecodeString("e9c586532dbefd63812293e1c4baf71edb7042a294c49c2020") + require.NoError(t, err) + assert.Equal(t, append(nonce, encrypted...), ciphertext) + assert.Equal(t, []byte("encryptedKey"), encryptedKey) +} + +func TestDecrypt(t *testing.T) { + ctrl := gomock.NewController(t) + + kms := NewMockKeyManagementService(ctrl) + keyName := "keyName" + + ctx := context.Background() + encryptedKey := []byte("encryptedKey") + ciphertext, err := hex.DecodeString("e9c586532dbefd63812293e1c4baf71edb7042a294c49c2020") + require.NoError(t, err) + ciphertext = append([]byte("0123456789ab"), ciphertext...) + + aesKey := []byte("0123456789abcdef0123456789abcdef") + kms.EXPECT(). + Decrypt(ctx, keyName, encryptedKey). + Return(aesKey, nil) + + plaintext, err := keys.Decrypt(ctx, kms, keyName, ciphertext, encryptedKey) + require.NoError(t, err) + assert.Equal(t, []byte("plaintext"), plaintext) +} diff --git a/backend/internal/keys/mock_keys_test.go b/backend/internal/keys/mock_keys_test.go new file mode 100644 index 0000000..5a435a0 --- /dev/null +++ b/backend/internal/keys/mock_keys_test.go @@ -0,0 +1,156 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ibd-trader/internal/keys (interfaces: KeyManagementService) +// +// Generated by this command: +// +// mockgen -destination mock_keys_test.go -package keys_test -typed . KeyManagementService +// + +// Package keys_test is a generated GoMock package. +package keys_test + +import ( + context "context" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockKeyManagementService is a mock of KeyManagementService interface. +type MockKeyManagementService struct { + ctrl *gomock.Controller + recorder *MockKeyManagementServiceMockRecorder +} + +// MockKeyManagementServiceMockRecorder is the mock recorder for MockKeyManagementService. +type MockKeyManagementServiceMockRecorder struct { + mock *MockKeyManagementService +} + +// NewMockKeyManagementService creates a new mock instance. +func NewMockKeyManagementService(ctrl *gomock.Controller) *MockKeyManagementService { + mock := &MockKeyManagementService{ctrl: ctrl} + mock.recorder = &MockKeyManagementServiceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockKeyManagementService) EXPECT() *MockKeyManagementServiceMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockKeyManagementService) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockKeyManagementServiceMockRecorder) Close() *MockKeyManagementServiceCloseCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockKeyManagementService)(nil).Close)) + return &MockKeyManagementServiceCloseCall{Call: call} +} + +// MockKeyManagementServiceCloseCall wrap *gomock.Call +type MockKeyManagementServiceCloseCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockKeyManagementServiceCloseCall) Return(arg0 error) *MockKeyManagementServiceCloseCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockKeyManagementServiceCloseCall) Do(f func() error) *MockKeyManagementServiceCloseCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockKeyManagementServiceCloseCall) DoAndReturn(f func() error) *MockKeyManagementServiceCloseCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Decrypt mocks base method. +func (m *MockKeyManagementService) Decrypt(arg0 context.Context, arg1 string, arg2 []byte) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Decrypt", arg0, arg1, arg2) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Decrypt indicates an expected call of Decrypt. +func (mr *MockKeyManagementServiceMockRecorder) Decrypt(arg0, arg1, arg2 any) *MockKeyManagementServiceDecryptCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Decrypt", reflect.TypeOf((*MockKeyManagementService)(nil).Decrypt), arg0, arg1, arg2) + return &MockKeyManagementServiceDecryptCall{Call: call} +} + +// MockKeyManagementServiceDecryptCall wrap *gomock.Call +type MockKeyManagementServiceDecryptCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockKeyManagementServiceDecryptCall) Return(arg0 []byte, arg1 error) *MockKeyManagementServiceDecryptCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockKeyManagementServiceDecryptCall) Do(f func(context.Context, string, []byte) ([]byte, error)) *MockKeyManagementServiceDecryptCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockKeyManagementServiceDecryptCall) DoAndReturn(f func(context.Context, string, []byte) ([]byte, error)) *MockKeyManagementServiceDecryptCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Encrypt mocks base method. +func (m *MockKeyManagementService) Encrypt(arg0 context.Context, arg1 string, arg2 []byte) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Encrypt", arg0, arg1, arg2) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Encrypt indicates an expected call of Encrypt. +func (mr *MockKeyManagementServiceMockRecorder) Encrypt(arg0, arg1, arg2 any) *MockKeyManagementServiceEncryptCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Encrypt", reflect.TypeOf((*MockKeyManagementService)(nil).Encrypt), arg0, arg1, arg2) + return &MockKeyManagementServiceEncryptCall{Call: call} +} + +// MockKeyManagementServiceEncryptCall wrap *gomock.Call +type MockKeyManagementServiceEncryptCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockKeyManagementServiceEncryptCall) Return(arg0 []byte, arg1 error) *MockKeyManagementServiceEncryptCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockKeyManagementServiceEncryptCall) Do(f func(context.Context, string, []byte) ([]byte, error)) *MockKeyManagementServiceEncryptCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockKeyManagementServiceEncryptCall) DoAndReturn(f func(context.Context, string, []byte) ([]byte, error)) *MockKeyManagementServiceEncryptCall { + c.Call = c.Call.DoAndReturn(f) + return c +} diff --git a/backend/internal/leader/election/election.go b/backend/internal/leader/election/election.go new file mode 100644 index 0000000..6f83298 --- /dev/null +++ b/backend/internal/leader/election/election.go @@ -0,0 +1,128 @@ +package election + +import ( + "context" + "errors" + "log/slog" + "time" + + "github.com/bsm/redislock" +) + +var defaultLeaderElectionOptions = leaderElectionOptions{ + lockKey: "ibd-leader-election", + lockTTL: 10 * time.Second, +} + +func RunOrDie( + ctx context.Context, + client redislock.RedisClient, + onLeader func(context.Context), + opts ...LeaderElectionOption, +) { + o := defaultLeaderElectionOptions + for _, opt := range opts { + opt(&o) + } + + locker := redislock.New(client) + + // Election loop + for { + lock, err := locker.Obtain(ctx, o.lockKey, o.lockTTL, nil) + if errors.Is(err, redislock.ErrNotObtained) { + // Another instance is the leader + } else if err != nil { + slog.ErrorContext(ctx, "failed to obtain lock", "error", err) + } else { + // We are the leader + slog.DebugContext(ctx, "elected leader") + runLeader(ctx, lock, onLeader, o) + } + + // Sleep for a bit before trying again + timer := time.NewTimer(o.lockTTL / 5) + select { + case <-ctx.Done(): + if !timer.Stop() { + <-timer.C + } + return + case <-timer.C: + } + } +} + +func runLeader( + ctx context.Context, + lock *redislock.Lock, + onLeader func(context.Context), + o leaderElectionOptions, +) { + // A context that is canceled when the leader loses the lock + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + // Release the lock when done + defer func() { + // Create new context without cancel if the original context is already canceled + relCtx := ctx + if ctx.Err() != nil { + relCtx = context.WithoutCancel(ctx) + } + + // Add a timeout to the release context + relCtx, cancel := context.WithTimeout(relCtx, o.lockTTL) + defer cancel() + + if err := lock.Release(relCtx); err != nil { + slog.Error("failed to release lock", "error", err) + } + }() + + // Run the leader code + go func(ctx context.Context) { + onLeader(ctx) + + // If the leader code returns, cancel the context to release the lock + cancel() + }(ctx) + + // Refresh the lock periodically + ticker := time.NewTicker(o.lockTTL / 10) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + err := lock.Refresh(ctx, o.lockTTL, nil) + if errors.Is(err, redislock.ErrNotObtained) || errors.Is(err, redislock.ErrLockNotHeld) { + slog.ErrorContext(ctx, "leadership lost", "error", err) + return + } else if err != nil { + slog.ErrorContext(ctx, "failed to refresh lock", "error", err) + } + case <-ctx.Done(): + return + } + } +} + +type leaderElectionOptions struct { + lockKey string + lockTTL time.Duration +} + +type LeaderElectionOption func(*leaderElectionOptions) + +func WithLockKey(key string) LeaderElectionOption { + return func(o *leaderElectionOptions) { + o.lockKey = key + } +} + +func WithLockTTL(ttl time.Duration) LeaderElectionOption { + return func(o *leaderElectionOptions) { + o.lockTTL = ttl + } +} diff --git a/backend/internal/leader/manager/ibd/auth/auth.go b/backend/internal/leader/manager/ibd/auth/auth.go new file mode 100644 index 0000000..129bc51 --- /dev/null +++ b/backend/internal/leader/manager/ibd/auth/auth.go @@ -0,0 +1,111 @@ +package auth + +import ( + "context" + "log/slog" + "time" + + "ibd-trader/internal/database" + "ibd-trader/internal/redis/taskqueue" + + "github.com/redis/go-redis/v9" + "github.com/robfig/cron/v3" +) + +const ( + Queue = "auth-queue" + QueueEncoding = taskqueue.EncodingJSON +) + +// Manager is responsible for sending authentication tasks to the workers. +type Manager struct { + queue taskqueue.TaskQueue[TaskInfo] + store database.UserStore + schedule cron.Schedule +} + +func New( + ctx context.Context, + store database.UserStore, + rClient *redis.Client, + schedule cron.Schedule, +) (*Manager, error) { + queue, err := taskqueue.New( + ctx, + rClient, + Queue, + "auth-manager", + taskqueue.WithEncoding[TaskInfo](QueueEncoding), + ) + if err != nil { + return nil, err + } + + return &Manager{ + queue: queue, + store: store, + schedule: schedule, + }, nil +} + +func (m *Manager) Run(ctx context.Context) { + for { + now := time.Now() + // Find the next time + nextTime := m.schedule.Next(now) + if nextTime.IsZero() { + // Sleep until the next day + time.Sleep(time.Until(now.AddDate(0, 0, 1))) + continue + } + + timer := time.NewTimer(nextTime.Sub(now)) + slog.DebugContext(ctx, "waiting for next Auth scrape", "next_exec", nextTime) + + select { + case <-timer.C: + nextExec := m.schedule.Next(nextTime) + m.scrapeCookies(ctx, nextExec) + case <-ctx.Done(): + if !timer.Stop() { + <-timer.C + } + return + } + } +} + +// scrapeCookies scrapes the cookies for every user from the IBD website. +// +// This iterates through all users with IBD credentials and checks whether their cookies are still valid. +// If the cookies are invalid or missing, it re-authenticates the user and updates the cookies in the database. +func (m *Manager) scrapeCookies(ctx context.Context, deadline time.Time) { + ctx, cancel := context.WithDeadline(ctx, deadline) + defer cancel() + + // Get all users with IBD credentials + users, err := m.store.ListUsers(ctx, true) + if err != nil { + slog.ErrorContext(ctx, "failed to get users", "error", err) + return + } + + // Create a new task for each user + for _, user := range users { + task := TaskInfo{ + UserSubject: user.Subject, + } + + // Enqueue the task + _, err := m.queue.Enqueue(ctx, task) + if err != nil { + slog.ErrorContext(ctx, "failed to enqueue task", "error", err) + } + } + + slog.InfoContext(ctx, "enqueued tasks for all users") +} + +type TaskInfo struct { + UserSubject string `json:"user_subject"` +} diff --git a/backend/internal/leader/manager/ibd/ibd.go b/backend/internal/leader/manager/ibd/ibd.go new file mode 100644 index 0000000..e2d4fc0 --- /dev/null +++ b/backend/internal/leader/manager/ibd/ibd.go @@ -0,0 +1,8 @@ +package ibd + +type Schedules struct { + // Auth schedule + Auth string + // IBD50 schedule + IBD50 string +} diff --git a/backend/internal/leader/manager/ibd/scrape/scrape.go b/backend/internal/leader/manager/ibd/scrape/scrape.go new file mode 100644 index 0000000..5f2c2a7 --- /dev/null +++ b/backend/internal/leader/manager/ibd/scrape/scrape.go @@ -0,0 +1,140 @@ +package scrape + +import ( + "context" + "errors" + "log/slog" + "time" + + "ibd-trader/internal/database" + "ibd-trader/internal/ibd" + "ibd-trader/internal/redis/taskqueue" + + "github.com/redis/go-redis/v9" + "github.com/robfig/cron/v3" +) + +const ( + Queue = "scrape-queue" + QueueEncoding = taskqueue.EncodingJSON + Channel = "scrape-channel" +) + +// Manager is responsible for sending scraping tasks to the workers. +type Manager struct { + client *ibd.Client + store database.StockStore + queue taskqueue.TaskQueue[TaskInfo] + schedule cron.Schedule + pubsub *redis.PubSub +} + +func New( + ctx context.Context, + client *ibd.Client, + store database.StockStore, + redis *redis.Client, + schedule cron.Schedule, +) (*Manager, error) { + queue, err := taskqueue.New( + ctx, + redis, + Queue, + "ibd-manager", + taskqueue.WithEncoding[TaskInfo](QueueEncoding), + ) + if err != nil { + return nil, err + } + + return &Manager{ + client: client, + store: store, + queue: queue, + schedule: schedule, + pubsub: redis.Subscribe(ctx, Channel), + }, nil +} + +func (m *Manager) Close() error { + return m.pubsub.Close() +} + +func (m *Manager) Run(ctx context.Context) { + ch := m.pubsub.Channel() + for { + now := time.Now() + // Find the next time + nextTime := m.schedule.Next(now) + if nextTime.IsZero() { + // Sleep until the next day + time.Sleep(time.Until(now.AddDate(0, 0, 1))) + continue + } + + timer := time.NewTimer(nextTime.Sub(now)) + slog.DebugContext(ctx, "waiting for next IBD50 scrape", "next_exec", nextTime) + + select { + case <-timer.C: + nextExec := m.schedule.Next(nextTime) + m.scrapeIBD50(ctx, nextExec) + case _ = <-ch: + nextExec := m.schedule.Next(time.Now()) + m.scrapeIBD50(ctx, nextExec) + case <-ctx.Done(): + if !timer.Stop() { + <-timer.C + } + return + } + } +} + +func (m *Manager) scrapeIBD50(ctx context.Context, deadline time.Time) { + ctx, cancel := context.WithDeadline(ctx, deadline) + defer cancel() + + stocks, err := m.client.GetIBD50(ctx) + if err != nil { + if errors.Is(err, ibd.ErrNoAvailableCookies) { + slog.WarnContext(ctx, "no available cookies", "error", err) + return + } + slog.ErrorContext(ctx, "failed to get IBD50", "error", err) + return + } + + for _, stock := range stocks { + // Add stock to DB + err = m.store.AddStock(ctx, database.Stock{ + Symbol: stock.Symbol, + Name: stock.Name, + IBDUrl: stock.QuoteURL.String(), + }) + if err != nil { + slog.ErrorContext(ctx, "failed to add stock", "error", err) + continue + } + + // Add ranking to Db + err = m.store.AddRanking(ctx, stock.Symbol, int(stock.Rank), 0) + if err != nil { + slog.ErrorContext(ctx, "failed to add ranking", "error", err) + continue + } + + // Add scrape task to queue + task := TaskInfo{Symbol: stock.Symbol} + taskID, err := m.queue.Enqueue(ctx, task) + if err != nil { + slog.ErrorContext(ctx, "failed to enqueue task", "error", err) + } + + slog.DebugContext(ctx, "enqueued scrape task", "task_id", taskID, "symbol", stock.Symbol) + } +} + +type TaskInfo struct { + Symbol string `json:"symbol"` +} diff --git a/backend/internal/leader/manager/manager.go b/backend/internal/leader/manager/manager.go new file mode 100644 index 0000000..b2a9ee9 --- /dev/null +++ b/backend/internal/leader/manager/manager.go @@ -0,0 +1,90 @@ +package manager + +import ( + "context" + "fmt" + "log/slog" + "sync" + + "ibd-trader/internal/config" + "ibd-trader/internal/database" + "ibd-trader/internal/ibd" + "ibd-trader/internal/leader/manager/ibd/auth" + ibd2 "ibd-trader/internal/leader/manager/ibd/scrape" + + "github.com/redis/go-redis/v9" + "github.com/robfig/cron/v3" +) + +type Manager struct { + db database.Database + Monitor *WorkerMonitor + Scraper *ibd2.Manager + Auth *auth.Manager +} + +func New( + ctx context.Context, + cfg *config.Config, + client *redis.Client, + db database.Database, + ibd *ibd.Client, +) (*Manager, error) { + scraperSchedule, err := cron.ParseStandard(cfg.IBD.Schedules.IBD50) + if err != nil { + return nil, fmt.Errorf("unable to parse IBD50 schedule: %w", err) + } + scraper, err := ibd2.New(ctx, ibd, db, client, scraperSchedule) + if err != nil { + return nil, err + } + + authSchedule, err := cron.ParseStandard(cfg.IBD.Schedules.Auth) + if err != nil { + return nil, fmt.Errorf("unable to parse Auth schedule: %w", err) + } + authManager, err := auth.New(ctx, db, client, authSchedule) + if err != nil { + return nil, err + } + + return &Manager{ + db: db, + Monitor: NewWorkerMonitor(client), + Scraper: scraper, + Auth: authManager, + }, nil +} + +func (m *Manager) Run(ctx context.Context) error { + if err := m.db.Migrate(ctx); err != nil { + slog.ErrorContext(ctx, "Unable to migrate database", "error", err) + return err + } + + var wg sync.WaitGroup + wg.Add(4) + + go func() { + defer wg.Done() + m.db.Maintenance(ctx) + }() + + go func() { + defer wg.Done() + m.Monitor.Start(ctx) + }() + + go func() { + defer wg.Done() + m.Scraper.Run(ctx) + }() + + go func() { + defer wg.Done() + m.Auth.Run(ctx) + }() + + wg.Wait() + return ctx.Err() +} diff --git a/backend/internal/leader/manager/monitor.go b/backend/internal/leader/manager/monitor.go new file mode 100644 index 0000000..3b2e3ec --- /dev/null +++ b/backend/internal/leader/manager/monitor.go @@ -0,0 +1,164 @@ +package manager + +import ( + "context" + "fmt" + "log/slog" + "time" + + "github.com/buraksezer/consistent" + "github.com/cespare/xxhash/v2" + "github.com/redis/go-redis/v9" +) + +const ( + MonitorInterval = 5 * time.Second + ActiveWorkersSet = "active-workers" +) + +// WorkerMonitor is a struct that monitors workers and their heartbeats over redis. +type WorkerMonitor struct { + client *redis.Client + + // ring is a consistent hash ring that distributes partitions over detected workers. + ring *consistent.Consistent + // layoutChangeCh is a channel that others can listen to for layout changes to the ring. + layoutChangeCh chan struct{} +} + +// NewWorkerMonitor creates a new WorkerMonitor. +func NewWorkerMonitor(client *redis.Client) *WorkerMonitor { + var members []consistent.Member + return &WorkerMonitor{ + client: client, + ring: consistent.New(members, consistent.Config{ + Hasher: new(hasher), + PartitionCount: consistent.DefaultPartitionCount, + ReplicationFactor: consistent.DefaultReplicationFactor, + Load: consistent.DefaultLoad, + }), + layoutChangeCh: make(chan struct{}), + } +} + +func (m *WorkerMonitor) Close() error { + close(m.layoutChangeCh) + return nil +} + +func (m *WorkerMonitor) Changes() <-chan struct{} { + return m.layoutChangeCh +} + +func (m *WorkerMonitor) Start(ctx context.Context) { + m.monitorWorkers(ctx) + ticker := time.NewTicker(MonitorInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + m.monitorWorkers(ctx) + } + } +} + +func (m *WorkerMonitor) monitorWorkers(ctx context.Context) { + ctx, cancel := context.WithTimeout(ctx, MonitorInterval) + defer cancel() + + // Get all active workers. + workers, err := m.client.SMembers(ctx, ActiveWorkersSet).Result() + if err != nil { + slog.ErrorContext(ctx, "Unable to get active workers", "error", err) + return + } + + // Get existing workers in the ring. + existingWorkers := m.ring.GetMembers() + ewMap := make(map[string]bool) + for _, worker := range existingWorkers { + ewMap[worker.String()] = false + } + + // Check workers' heartbeats. + for _, worker := range workers { + exists, err := m.client.Exists(ctx, WorkerHeartbeatKey(worker)).Result() + if err != nil { + slog.ErrorContext(ctx, "Unable to check worker heartbeat", "worker", worker, "error", err) + continue + } + + if exists == 0 { + slog.WarnContext(ctx, "Worker heartbeat not found", "worker", worker) + + // Remove worker from active workers set. + if err = m.client.SRem(ctx, ActiveWorkersSet, worker).Err(); err != nil { + slog.ErrorContext(ctx, "Unable to remove worker from active workers set", "worker", worker, "error", err) + } + + // Remove worker from the ring. + m.removeWorker(worker) + } else { + // Add worker to the ring if it doesn't exist. + if _, ok := ewMap[worker]; !ok { + slog.InfoContext(ctx, "New worker detected", "worker", worker) + m.addWorker(worker) + } else { + ewMap[worker] = true + } + } + } + + // Check for workers that are not active anymore. + for worker, exists := range ewMap { + if !exists { + slog.WarnContext(ctx, "Worker is not active anymore", "worker", worker) + m.removeWorker(worker) + } + } +} + +func (m *WorkerMonitor) addWorker(worker string) { + m.ring.Add(member{hostname: worker}) + + // Notify others about the layout change. + select { + case m.layoutChangeCh <- struct{}{}: + // Notify others. + default: + // No one is listening. + } +} + +func (m *WorkerMonitor) removeWorker(worker string) { + m.ring.Remove(worker) + + // Notify others about the layout change. + select { + case m.layoutChangeCh <- struct{}{}: + // Notify others. + default: + // No one is listening. + } +} + +func WorkerHeartbeatKey(hostname string) string { + return fmt.Sprintf("worker:%s:heartbeat", hostname) +} + +type hasher struct{} + +func (h *hasher) Sum64(data []byte) uint64 { + return xxhash.Sum64(data) +} + +type member struct { + hostname string +} + +func (m member) String() string { + return m.hostname +} diff --git a/backend/internal/redis/taskqueue/options.go b/backend/internal/redis/taskqueue/options.go new file mode 100644 index 0000000..2d5a23f --- /dev/null +++ b/backend/internal/redis/taskqueue/options.go @@ -0,0 +1,9 @@ +package taskqueue + +type Option[T any] func(*taskQueue[T]) + +func WithEncoding[T any](encoding Encoding) Option[T] { + return func(o *taskQueue[T]) { + o.encoding = encoding + } +} diff --git a/backend/internal/redis/taskqueue/queue.go b/backend/internal/redis/taskqueue/queue.go new file mode 100644 index 0000000..1298a76 --- /dev/null +++ b/backend/internal/redis/taskqueue/queue.go @@ -0,0 +1,494 @@ +package taskqueue + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/gob" + "encoding/json" + "errors" + "log/slog" + "reflect" + "strconv" + "strings" + "time" + + "github.com/redis/go-redis/v9" +) + +type Encoding uint8 + +const ( + EncodingJSON Encoding = iota + EncodingGob +) + +var MaxAttempts = 3 +var ErrTaskNotFound = errors.New("task not found") + +type TaskQueue[T any] interface { + // Enqueue adds a task to the queue. + // Returns the generated task ID. + Enqueue(ctx context.Context, data T) (TaskInfo[T], error) + + // Dequeue removes a task from the queue and returns it. + // The task data is placed into dataOut. + // + // Dequeue blocks until a task is available, timeout, or the context is canceled. + // The returned task is placed in a pending state for lockTimeout duration. + // The task must be completed with Complete or extended with Extend before the lock expires. + // If the lock expires, the task is returned to the queue, where it may be picked up by another worker. + Dequeue( + ctx context.Context, + lockTimeout, + timeout time.Duration, + ) (*TaskInfo[T], error) + + // Extend extends the lock on a task. + Extend(ctx context.Context, taskID TaskID) error + + // Complete marks a task as complete. Optionally, an error can be provided to store additional information. + Complete(ctx context.Context, taskID TaskID, err error) error + + // Data returns the info of a task. + Data(ctx context.Context, taskID TaskID) (TaskInfo[T], error) + + // Return returns a task to the queue and returns the new task ID. + // Increments the attempt counter. + // Tasks with too many attempts (MaxAttempts) are considered failed and aren't returned to the queue. + Return(ctx context.Context, taskID TaskID, err error) (TaskID, error) + + // List returns a list of task IDs in the queue. + // The list is ordered by the time the task was added to the queue. The most recent task is first. + // The count parameter limits the number of tasks returned. + // The start and end parameters limit the range of tasks returned. + // End is exclusive. + // Start must be before end. + List(ctx context.Context, start, end TaskID, count int64) ([]TaskInfo[T], error) +} + +type TaskInfo[T any] struct { + // ID is the unique identifier of the task. Generated by redis. + ID TaskID + // Data is the task data. Stored in stream. + Data T + // Attempts is the number of times the task has been attempted. Stored in stream. + Attempts uint8 + // Done is true if the task has been completed. True if ID in completed hash + Done bool + // Error is the error message if the task has failed. Stored in completed hash. + Error string +} + +type TaskID struct { + timestamp time.Time + sequence uint64 +} + +func NewTaskID(timestamp time.Time, sequence uint64) TaskID { + return TaskID{timestamp, sequence} +} + +func ParseTaskID(s string) (TaskID, error) { + tPart, sPart, ok := strings.Cut(s, "-") + if !ok { + return TaskID{}, errors.New("invalid task ID") + } + + timestamp, err := strconv.ParseInt(tPart, 10, 64) + if err != nil { + return TaskID{}, err + } + + sequence, err := strconv.ParseUint(sPart, 10, 64) + if err != nil { + return TaskID{}, err + } + + return NewTaskID(time.UnixMilli(timestamp), sequence), nil +} + +func (t TaskID) Timestamp() time.Time { + return t.timestamp +} + +func (t TaskID) String() string { + tPart := strconv.FormatInt(t.timestamp.UnixMilli(), 10) + sPart := strconv.FormatUint(t.sequence, 10) + return tPart + "-" + sPart +} + +type taskQueue[T any] struct { + rdb *redis.Client + encoding Encoding + + streamKey string + groupName string + + completedSetKey string + + workerName string +} + +func New[T any](ctx context.Context, rdb *redis.Client, name string, workerName string, opts ...Option[T]) (TaskQueue[T], error) { + tq := &taskQueue[T]{ + rdb: rdb, + encoding: EncodingJSON, + streamKey: "taskqueue:" + name, + groupName: "default", + completedSetKey: "taskqueue:" + name + ":completed", + workerName: workerName, + } + + for _, opt := range opts { + opt(tq) + } + + // Create the stream if it doesn't exist + err := rdb.XGroupCreateMkStream(ctx, tq.streamKey, tq.groupName, "0").Err() + if err != nil && err.Error() != "BUSYGROUP Consumer Group name already exists" { + return nil, err + } + + return tq, nil +} + +func (q *taskQueue[T]) Enqueue(ctx context.Context, data T) (TaskInfo[T], error) { + task := TaskInfo[T]{ + Data: data, + Attempts: 0, + } + + values, err := encode[T](task, q.encoding) + if err != nil { + return TaskInfo[T]{}, err + } + + taskID, err := q.rdb.XAdd(ctx, &redis.XAddArgs{ + Stream: q.streamKey, + Values: values, + }).Result() + if err != nil { + return TaskInfo[T]{}, err + } + + id, err := ParseTaskID(taskID) + if err != nil { + return TaskInfo[T]{}, err + } + task.ID = id + return task, nil +} + +func (q *taskQueue[T]) Dequeue(ctx context.Context, lockTimeout, timeout time.Duration) (*TaskInfo[T], error) { + // Try to recover a task + task, err := q.recover(ctx, lockTimeout) + if err != nil { + return nil, err + } + if task != nil { + return task, nil + } + + // Check for new tasks + ids, err := q.rdb.XReadGroup(ctx, &redis.XReadGroupArgs{ + Group: q.groupName, + Consumer: q.workerName, + Streams: []string{q.streamKey, ">"}, + Count: 1, + Block: timeout, + }).Result() + if err != nil && !errors.Is(err, redis.Nil) { + return nil, err + } + + if len(ids) == 0 || len(ids[0].Messages) == 0 || errors.Is(err, redis.Nil) { + return nil, nil + } + + msg := ids[0].Messages[0] + task = new(TaskInfo[T]) + *task, err = decode[T](&msg, q.encoding) + if err != nil { + return nil, err + } + return task, nil +} + +func (q *taskQueue[T]) Extend(ctx context.Context, taskID TaskID) error { + _, err := q.rdb.XClaim(ctx, &redis.XClaimArgs{ + Stream: q.streamKey, + Group: q.groupName, + Consumer: q.workerName, + MinIdle: 0, + Messages: []string{taskID.String()}, + }).Result() + if err != nil && !errors.Is(err, redis.Nil) { + return err + } + return nil +} + +func (q *taskQueue[T]) Data(ctx context.Context, taskID TaskID) (TaskInfo[T], error) { + msg, err := q.rdb.XRange(ctx, q.streamKey, taskID.String(), taskID.String()).Result() + if err != nil { + return TaskInfo[T]{}, err + } + + if len(msg) == 0 { + return TaskInfo[T]{}, ErrTaskNotFound + } + + t, err := decode[T](&msg[0], q.encoding) + if err != nil { + return TaskInfo[T]{}, err + } + + tErr, err := q.rdb.HGet(ctx, q.completedSetKey, taskID.String()).Result() + if err != nil && !errors.Is(err, redis.Nil) { + return TaskInfo[T]{}, err + } + + if errors.Is(err, redis.Nil) { + return t, nil + } + + t.Done = true + t.Error = tErr + return t, nil +} + +func (q *taskQueue[T]) Complete(ctx context.Context, taskID TaskID, err error) error { + _, err = q.rdb.TxPipelined(ctx, func(pipe redis.Pipeliner) error { + pipe.XAck(ctx, q.streamKey, q.groupName, taskID.String()) + //xdel = pipe.XDel(ctx, q.streamKey, taskID.String()) + //pipe.SAdd(ctx, q.completedSetKey, taskID.String()) + if err != nil { + pipe.HSet(ctx, q.completedSetKey, taskID.String(), err.Error()) + } else { + pipe.HSet(ctx, q.completedSetKey, taskID.String(), "") + } + return nil + }) + return err +} + +func (q *taskQueue[T]) Return(ctx context.Context, taskID TaskID, err1 error) (TaskID, error) { + msgs, err := q.rdb.XRange(ctx, q.streamKey, taskID.String(), taskID.String()).Result() + if err != nil { + return TaskID{}, err + } + if len(msgs) == 0 { + return TaskID{}, ErrTaskNotFound + } + + // Complete the task + err = q.Complete(ctx, taskID, err1) + if err != nil { + return TaskID{}, err + } + + msg := msgs[0] + task, err := decode[T](&msg, q.encoding) + if err != nil { + return TaskID{}, err + } + + task.Attempts++ + if int(task.Attempts) >= MaxAttempts { + // Task has failed + slog.ErrorContext(ctx, "task failed completely", + "taskID", taskID, + "data", task.Data, + "attempts", task.Attempts, + "maxAttempts", MaxAttempts, + ) + return TaskID{}, nil + } + + values, err := encode[T](task, q.encoding) + if err != nil { + return TaskID{}, err + } + newTaskId, err := q.rdb.XAdd(ctx, &redis.XAddArgs{ + Stream: q.streamKey, + Values: values, + }).Result() + if err != nil { + return TaskID{}, err + } + return ParseTaskID(newTaskId) +} + +func (q *taskQueue[T]) List(ctx context.Context, start, end TaskID, count int64) ([]TaskInfo[T], error) { + if !start.timestamp.IsZero() && !end.timestamp.IsZero() && start.timestamp.After(end.timestamp) { + return nil, errors.New("start must be before end") + } + + var startStr, endStr string + if !start.timestamp.IsZero() { + startStr = start.String() + } else { + startStr = "-" + } + if !end.timestamp.IsZero() { + endStr = "(" + end.String() + } else { + endStr = "+" + } + + msgs, err := q.rdb.XRevRangeN(ctx, q.streamKey, endStr, startStr, count).Result() + if err != nil { + return nil, err + } + if len(msgs) == 0 { + return []TaskInfo[T]{}, nil + } + + ids := make([]string, len(msgs)) + for i, msg := range msgs { + ids[i] = msg.ID + } + errs, err := q.rdb.HMGet(ctx, q.completedSetKey, ids...).Result() + if err != nil { + return nil, err + } + if len(errs) != len(msgs) { + return nil, errors.New("SMIsMember returned wrong number of results") + } + + tasks := make([]TaskInfo[T], len(msgs)) + for i := range msgs { + tasks[i], err = decode[T](&msgs[i], q.encoding) + if err != nil { + return nil, err + } + tasks[i].Done = errs[i] != nil + if tasks[i].Done { + tasks[i].Error = errs[i].(string) + } + } + return tasks, nil +} + +func (q *taskQueue[T]) recover(ctx context.Context, idleTimeout time.Duration) (*TaskInfo[T], error) { + msgs, _, err := q.rdb.XAutoClaim(ctx, &redis.XAutoClaimArgs{ + Stream: q.streamKey, + Group: q.groupName, + MinIdle: idleTimeout, + Start: "0", + Count: 1, + Consumer: q.workerName, + }).Result() + if err != nil { + return nil, err + } + + if len(msgs) == 0 { + return nil, nil + } + + msg := msgs[0] + task, err := decode[T](&msg, q.encoding) + if err != nil { + return nil, err + } + return &task, nil +} + +func decode[T any](msg *redis.XMessage, encoding Encoding) (task TaskInfo[T], err error) { + task.ID, err = ParseTaskID(msg.ID) + if err != nil { + return + } + + err = getField(msg, "attempts", &task.Attempts) + if err != nil { + return + } + + var data string + err = getField(msg, "data", &data) + if err != nil { + return + } + + switch encoding { + case EncodingJSON: + err = json.Unmarshal([]byte(data), &task.Data) + case EncodingGob: + var decoded []byte + decoded, err = base64.StdEncoding.DecodeString(data) + if err != nil { + return + } + err = gob.NewDecoder(bytes.NewReader(decoded)).Decode(&task.Data) + default: + err = errors.New("unsupported encoding") + } + return +} + +func getField(msg *redis.XMessage, field string, v any) error { + vVal, ok := msg.Values[field] + if !ok { + return errors.New("missing field") + } + + vStr, ok := vVal.(string) + if !ok { + return errors.New("invalid field type") + } + + value := reflect.ValueOf(v).Elem() + switch value.Kind() { + case reflect.String: + value.SetString(vStr) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + i, err := strconv.ParseInt(vStr, 10, 64) + if err != nil { + return err + } + value.SetInt(i) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + i, err := strconv.ParseUint(vStr, 10, 64) + if err != nil { + return err + } + value.SetUint(i) + case reflect.Bool: + b, err := strconv.ParseBool(vStr) + if err != nil { + return err + } + value.SetBool(b) + default: + return errors.New("unsupported field type") + } + return nil +} + +func encode[T any](task TaskInfo[T], encoding Encoding) (ret map[string]string, err error) { + ret = make(map[string]string) + ret["attempts"] = strconv.FormatUint(uint64(task.Attempts), 10) + + switch encoding { + case EncodingJSON: + var data []byte + data, err = json.Marshal(task.Data) + if err != nil { + return + } + ret["data"] = string(data) + case EncodingGob: + var data bytes.Buffer + err = gob.NewEncoder(&data).Encode(task.Data) + if err != nil { + return + } + ret["data"] = base64.StdEncoding.EncodeToString(data.Bytes()) + default: + err = errors.New("unsupported encoding") + } + return +} diff --git a/backend/internal/redis/taskqueue/queue_test.go b/backend/internal/redis/taskqueue/queue_test.go new file mode 100644 index 0000000..b54d22a --- /dev/null +++ b/backend/internal/redis/taskqueue/queue_test.go @@ -0,0 +1,448 @@ +package taskqueue + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTaskQueue(t *testing.T) { + if testing.Short() { + t.Skip() + } + + client := redis.NewClient(new(redis.Options)) + defer func(client *redis.Client) { + _ = client.Close() + }(client) + + lockTimeout := 100 * time.Millisecond + + tests := []struct { + name string + f func(t *testing.T) + }{ + { + name: "Create queue", + f: func(t *testing.T) { + q, err := New[string](context.Background(), client, "test", "worker") + require.NoError(t, err) + require.NotNil(t, q) + }, + }, + { + name: "enqueue & dequeue", + f: func(t *testing.T) { + q, err := New[string](context.Background(), client, "test", "worker") + require.NoError(t, err) + require.NotNil(t, q) + + taskId, err := q.Enqueue(context.Background(), "hello") + require.NoError(t, err) + require.NotEmpty(t, taskId) + + task, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) + require.NoError(t, err) + require.NotNil(t, task) + require.Equal(t, "hello", task.Data) + }, + }, + { + name: "complex data", + f: func(t *testing.T) { + type foo struct { + A int + B string + } + + q, err := New[foo](context.Background(), client, "test", "worker") + require.NoError(t, err) + require.NotNil(t, q) + + taskId, err := q.Enqueue(context.Background(), foo{A: 42, B: "hello"}) + require.NoError(t, err) + require.NotEmpty(t, taskId) + + task, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) + require.NoError(t, err) + require.NotNil(t, task) + require.Equal(t, foo{A: 42, B: "hello"}, task.Data) + }, + }, + { + name: "different workers", + f: func(t *testing.T) { + q1, err := New[string](context.Background(), client, "test", "worker1") + require.NoError(t, err) + require.NotNil(t, q1) + + q2, err := New[string](context.Background(), client, "test", "worker2") + require.NoError(t, err) + require.NotNil(t, q2) + + taskId, err := q1.Enqueue(context.Background(), "hello") + require.NoError(t, err) + require.NotEmpty(t, taskId) + + task, err := q2.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) + require.NoError(t, err) + require.NotNil(t, task) + assert.Equal(t, "hello", task.Data) + }, + }, + { + name: "complete", + f: func(t *testing.T) { + q, err := New[string](context.Background(), client, "test", "worker") + require.NoError(t, err) + require.NotNil(t, q) + + // Enqueue a task + taskId, err := q.Enqueue(context.Background(), "hello") + require.NoError(t, err) + require.NotEmpty(t, taskId) + + // Dequeue the task + task, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) + require.NoError(t, err) + require.NotNil(t, task) + assert.Equal(t, "hello", task.Data) + + // Complete the task + err = q.Complete(context.Background(), task.ID, nil) + require.NoError(t, err) + + // Try to dequeue the task again + task, err = q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) + require.NoError(t, err) + require.Nil(t, task) + }, + }, + { + name: "timeout", + f: func(t *testing.T) { + q, err := New[string](context.Background(), client, "test", "worker") + require.NoError(t, err) + require.NotNil(t, q) + + // Enqueue a task + taskId, err := q.Enqueue(context.Background(), "hello") + require.NoError(t, err) + require.NotEmpty(t, taskId) + + // Dequeue the task + task, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) + require.NoError(t, err) + require.NotNil(t, task) + assert.Equal(t, "hello", task.Data) + + // Wait for the lock to expire + time.Sleep(lockTimeout + 10*time.Millisecond) + + // Try to dequeue the task again + task, err = q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) + require.NoError(t, err) + require.NotNil(t, task) + assert.Equal(t, "hello", task.Data) + }, + }, + { + name: "extend", + f: func(t *testing.T) { + q, err := New[string](context.Background(), client, "test", "worker") + require.NoError(t, err) + require.NotNil(t, q) + + // Enqueue a task + taskId, err := q.Enqueue(context.Background(), "hello") + require.NoError(t, err) + require.NotEmpty(t, taskId) + + // Dequeue the task + task, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) + require.NoError(t, err) + require.NotNil(t, task) + assert.Equal(t, "hello", task.Data) + + // Wait for the lock to expire + time.Sleep(lockTimeout + 10*time.Millisecond) + + // Extend the lock + err = q.Extend(context.Background(), task.ID) + require.NoError(t, err) + + // Try to dequeue the task again + task, err = q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) + require.NoError(t, err) + require.Nil(t, task) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := client.FlushDB(context.Background()).Err(); err != nil { + t.Fatal(err) + } + + tt.f(t) + }) + } + + _ = client.FlushDB(context.Background()) +} + +func TestTaskQueue_List(t *testing.T) { + if testing.Short() { + t.Skip() + } + + client := redis.NewClient(new(redis.Options)) + defer func(client *redis.Client) { + _ = client.Close() + }(client) + + tests := []struct { + name string + f func(t *testing.T) + }{ + { + name: "empty", + f: func(t *testing.T) { + q, err := New[any](context.Background(), client, "test", "worker") + require.NoError(t, err) + require.NotNil(t, q) + + tasks, err := q.List(context.Background(), TaskID{}, TaskID{}, 1) + require.NoError(t, err) + assert.Empty(t, tasks) + }, + }, + { + name: "single", + f: func(t *testing.T) { + q, err := New[string](context.Background(), client, "test", "worker") + require.NoError(t, err) + require.NotNil(t, q) + + taskID, err := q.Enqueue(context.Background(), "hello") + require.NoError(t, err) + + tasks, err := q.List(context.Background(), TaskID{}, TaskID{}, 1) + require.NoError(t, err) + assert.Len(t, tasks, 1) + assert.Equal(t, taskID, tasks[0]) + }, + }, + { + name: "multiple", + f: func(t *testing.T) { + q, err := New[string](context.Background(), client, "test", "worker") + require.NoError(t, err) + require.NotNil(t, q) + + taskID, err := q.Enqueue(context.Background(), "hello") + require.NoError(t, err) + taskID2, err := q.Enqueue(context.Background(), "world") + require.NoError(t, err) + + tasks, err := q.List(context.Background(), TaskID{}, TaskID{}, 2) + require.NoError(t, err) + assert.Len(t, tasks, 2) + assert.Equal(t, taskID, tasks[1]) + assert.Equal(t, taskID2, tasks[0]) + }, + }, + { + name: "multiple limited", + f: func(t *testing.T) { + q, err := New[string](context.Background(), client, "test", "worker") + require.NoError(t, err) + require.NotNil(t, q) + + _, err = q.Enqueue(context.Background(), "hello") + require.NoError(t, err) + taskID2, err := q.Enqueue(context.Background(), "world") + require.NoError(t, err) + + tasks, err := q.List(context.Background(), TaskID{}, TaskID{}, 1) + require.NoError(t, err) + assert.Len(t, tasks, 1) + assert.Equal(t, taskID2, tasks[0]) + }, + }, + { + name: "multiple time range", + f: func(t *testing.T) { + q, err := New[string](context.Background(), client, "test", "worker") + require.NoError(t, err) + require.NotNil(t, q) + + taskID, err := q.Enqueue(context.Background(), "hello") + require.NoError(t, err) + time.Sleep(10 * time.Millisecond) + taskID2, err := q.Enqueue(context.Background(), "world") + require.NoError(t, err) + + tasks, err := q.List(context.Background(), TaskID{}, taskID2.ID, 100) + require.NoError(t, err) + assert.Len(t, tasks, 1) + assert.Equal(t, taskID, tasks[0]) + + tasks, err = q.List(context.Background(), taskID2.ID, TaskID{}, 100) + require.NoError(t, err) + assert.Len(t, tasks, 1) + assert.Equal(t, taskID2, tasks[0]) + }, + }, + { + name: "completed tasks", + f: func(t *testing.T) { + q, err := New[string](context.Background(), client, "test", "worker") + require.NoError(t, err) + require.NotNil(t, q) + + task1, err := q.Enqueue(context.Background(), "hello") + require.NoError(t, err) + task2, err := q.Enqueue(context.Background(), "world") + require.NoError(t, err) + + err = q.Complete(context.Background(), task1.ID, nil) + require.NoError(t, err) + + tasks, err := q.List(context.Background(), TaskID{}, TaskID{}, 100) + require.NoError(t, err) + assert.Len(t, tasks, 2) + assert.Equal(t, task2, tasks[0]) + + assert.Equal(t, "hello", tasks[1].Data) + assert.Equal(t, true, tasks[1].Done) + assert.Equal(t, "", tasks[1].Error) + }, + }, + { + name: "failed tasks", + f: func(t *testing.T) { + q, err := New[string](context.Background(), client, "test", "worker") + require.NoError(t, err) + require.NotNil(t, q) + + task1, err := q.Enqueue(context.Background(), "hello") + require.NoError(t, err) + task2, err := q.Enqueue(context.Background(), "world") + require.NoError(t, err) + + err = q.Complete(context.Background(), task1.ID, errors.New("failed")) + require.NoError(t, err) + + tasks, err := q.List(context.Background(), TaskID{}, TaskID{}, 100) + require.NoError(t, err) + assert.Len(t, tasks, 2) + assert.Equal(t, task2, tasks[0]) + + assert.Equal(t, "hello", tasks[1].Data) + assert.Equal(t, true, tasks[1].Done) + assert.Equal(t, "failed", tasks[1].Error) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := client.FlushDB(context.Background()).Err(); err != nil { + t.Fatal(err) + } + + tt.f(t) + }) + } + + _ = client.FlushDB(context.Background()) +} + +func TestTaskQueue_Return(t *testing.T) { + + if testing.Short() { + t.Skip() + } + + client := redis.NewClient(new(redis.Options)) + defer func(client *redis.Client) { + _ = client.Close() + }(client) + + lockTimeout := 100 * time.Millisecond + + tests := []struct { + name string + f func(t *testing.T) + }{ + { + name: "simple", + f: func(t *testing.T) { + q, err := New[string](context.Background(), client, "test", "worker") + require.NoError(t, err) + require.NotNil(t, q) + + task1, err := q.Enqueue(context.Background(), "hello") + require.NoError(t, err) + + id := claimAndFail(t, q, lockTimeout) + + task3, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) + require.NoError(t, err) + require.NotNil(t, task3) + assert.Equal(t, task3.ID, id) + assert.Equal(t, task1.Data, task3.Data) + assert.Equal(t, uint8(1), task3.Attempts) + }, + }, + { + name: "failure", + f: func(t *testing.T) { + q, err := New[string](context.Background(), client, "test", "worker") + require.NoError(t, err) + require.NotNil(t, q) + + _, err = q.Enqueue(context.Background(), "hello") + require.NoError(t, err) + + claimAndFail(t, q, lockTimeout) + claimAndFail(t, q, lockTimeout) + claimAndFail(t, q, lockTimeout) + + task3, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) + require.NoError(t, err) + assert.Nil(t, task3) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := client.FlushDB(context.Background()).Err(); err != nil { + t.Fatal(err) + } + + tt.f(t) + }) + } + + _ = client.FlushDB(context.Background()) +} + +func claimAndFail[T any](t *testing.T, q TaskQueue[T], lockTimeout time.Duration) TaskID { + task2, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) + require.NoError(t, err) + require.NotNil(t, task2) + + id, err := q.Return(context.Background(), task2.ID, errors.New("failed")) + require.NoError(t, err) + assert.NotEqual(t, task2.ID, id) + return id +} diff --git a/backend/internal/server/api/ibd/creds/creds.go b/backend/internal/server/api/ibd/creds/creds.go new file mode 100644 index 0000000..a8a05ab --- /dev/null +++ b/backend/internal/server/api/ibd/creds/creds.go @@ -0,0 +1,51 @@ +package creds + +import ( + "context" + "encoding/json" + "log/slog" + "net/http" + "time" + + "ibd-trader/internal/database" +) + +func Handler( + logger *slog.Logger, + db database.UserStore, +) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + var b body + err := json.NewDecoder(r.Body).Decode(&b) + if err != nil { + logger.Error("unable to decode request body", "error", err) + http.Error(w, "unable to decode request body", http.StatusBadRequest) + return + } + + ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) + defer cancel() + + // Get session from context + session, ok := ctx.Value("session").(*database.Session) + if !ok { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + // Add IBD creds to user + err = db.AddIBDCreds(ctx, session.Subject, b.Username, b.Password) + if err != nil { + logger.ErrorContext(ctx, "unable to add IBD creds", "error", err) + http.Error(w, "unable to add IBD creds", http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusCreated) + } +} + +type body struct { + Username string `json:"username"` + Password string `json:"password"` +} diff --git a/backend/internal/server/api/ibd/ibd50/ibd50.go b/backend/internal/server/api/ibd/ibd50/ibd50.go new file mode 100644 index 0000000..fc13bdf --- /dev/null +++ b/backend/internal/server/api/ibd/ibd50/ibd50.go @@ -0,0 +1,27 @@ +package ibd50 + +import ( + "encoding/json" + "log/slog" + "net/http" + + "ibd-trader/internal/ibd" +) + +func Handler( + logger *slog.Logger, + client *ibd.Client, +) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + list, err := client.GetIBD50(r.Context()) + if err != nil { + logger.Error("unable to get IBD50", "error", err) + http.Error(w, "unable to get IBD50", http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(list) + } +} diff --git a/backend/internal/server/api/ibd/scrape/scrape.go b/backend/internal/server/api/ibd/scrape/scrape.go new file mode 100644 index 0000000..59ad0a7 --- /dev/null +++ b/backend/internal/server/api/ibd/scrape/scrape.go @@ -0,0 +1,27 @@ +package scrape + +import ( + "log/slog" + "net/http" + + "ibd-trader/internal/leader/manager/ibd/scrape" + + "github.com/redis/go-redis/v9" +) + +func Handler( + logger *slog.Logger, + client *redis.Client, +) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // Publish to the scrape channel to force a scrape. + err := client.Publish(r.Context(), scrape.Channel, "").Err() + if err != nil { + logger.Error("failed to publish to scrape channel", "error", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusCreated) + } +} diff --git a/backend/internal/server/auth/callback/callback.go b/backend/internal/server/auth/callback/callback.go new file mode 100644 index 0000000..f0a3413 --- /dev/null +++ b/backend/internal/server/auth/callback/callback.go @@ -0,0 +1,93 @@ +package callback + +import ( + "context" + "log/slog" + "net/http" + "time" + + "ibd-trader/internal/auth" + "ibd-trader/internal/database" + "ibd-trader/internal/server/middleware" +) + +func Handler( + logger *slog.Logger, + userStore database.UserStore, + sessionStore database.SessionStore, + auth *auth.Authenticator, +) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // Timeout callback operations after 10 seconds + ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) + defer cancel() + + // Check state + state := r.URL.Query().Get("state") + if state == "" { + http.Error(w, "No state provided", http.StatusBadRequest) + return + } + + exists, err := sessionStore.CheckState(ctx, state) + if err != nil { + logger.ErrorContext(ctx, "Failed to check state", "error", err) + http.Error(w, "Failed to check state", http.StatusInternalServerError) + return + } + if !exists { + http.Error(w, "Invalid state", http.StatusBadRequest) + return + } + + // Exchange code for token + token, err := auth.Exchange(ctx, r.URL.Query().Get("code")) + if err != nil { + logger.ErrorContext(ctx, "Failed to exchange code", "error", err) + http.Error(w, "Failed to exchange code", http.StatusUnauthorized) + return + } + + // Verify token + idToken, err := auth.VerifyIDToken(ctx, token) + if err != nil { + logger.ErrorContext(ctx, "Failed to verify ID token", "error", err) + http.Error(w, "Failed to verify ID token", http.StatusInternalServerError) + return + } + + // Add user to database + if err := userStore.AddUser(ctx, idToken.Subject); err != nil { + logger.ErrorContext(ctx, "Failed to add user", "error", err) + http.Error(w, "Failed to add user", http.StatusInternalServerError) + return + } + + // Create session + session, err := sessionStore.CreateSession(ctx, token, idToken) + if err != nil { + logger.ErrorContext(ctx, "Failed to create session", "error", err) + http.Error(w, "Failed to create session", http.StatusInternalServerError) + return + } + + // Set session cookie + http.SetCookie(w, &http.Cookie{ + Name: middleware.SessionCookie, + Value: session, + Path: "/", + Domain: "", + Expires: token.Expiry, + RawExpires: "", + MaxAge: 0, + Secure: true, + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + Raw: "", + Unparsed: nil, + }) + + // Redirect + http.Redirect(w, r, "/", http.StatusTemporaryRedirect) + } +} diff --git a/backend/internal/server/auth/login/login.go b/backend/internal/server/auth/login/login.go new file mode 100644 index 0000000..102e3d4 --- /dev/null +++ b/backend/internal/server/auth/login/login.go @@ -0,0 +1,28 @@ +package login + +import ( + "context" + "log/slog" + "net/http" + "time" + + "ibd-trader/internal/auth" + "ibd-trader/internal/database" +) + +func Handler(logger *slog.Logger, store database.SessionStore, auth *auth.Authenticator) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // Save state in session table w/o user id + ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second) + defer cancel() + state, err := store.CreateState(ctx) + if err != nil { + logger.ErrorContext(ctx, "Failed to create state", "error", err) + http.Error(w, "Failed to create state", http.StatusInternalServerError) + return + } + + // Redirect to oauth provider + http.Redirect(w, r, auth.AuthCodeURL(state), http.StatusTemporaryRedirect) + } +} diff --git a/backend/internal/server/auth/user/user.go b/backend/internal/server/auth/user/user.go new file mode 100644 index 0000000..526329d --- /dev/null +++ b/backend/internal/server/auth/user/user.go @@ -0,0 +1,45 @@ +package user + +import ( + "context" + "encoding/json" + "log/slog" + "net/http" + "time" + + "ibd-trader/internal/auth" + "ibd-trader/internal/database" +) + +func Handler( + logger *slog.Logger, + auth *auth.Authenticator, +) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) + defer cancel() + + // Get session from context + session, ok := ctx.Value("session").(*database.Session) + if !ok { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + // Create token source + ts := auth.TokenSource(ctx, &session.OAuthToken) + + // Get user info + userInfo, err := auth.UserInfo(ctx, ts) + if err != nil { + logger.ErrorContext(ctx, "Failed to get user info", "error", err) + http.Error(w, "Failed to get user info", http.StatusInternalServerError) + return + } + + // Write user info to response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(userInfo) + } +} diff --git a/backend/internal/server/middleware/auth.go b/backend/internal/server/middleware/auth.go new file mode 100644 index 0000000..f01e4b9 --- /dev/null +++ b/backend/internal/server/middleware/auth.go @@ -0,0 +1,46 @@ +package middleware + +import ( + "context" + "net/http" + "time" + + "ibd-trader/internal/database" +) + +const SessionCookie = "_session" + +func Auth(store database.SessionStore) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Get session cookie + cookie, err := r.Cookie(SessionCookie) + if err != nil { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + // Check session + session, err := store.GetSession(r.Context(), cookie.Value) + if err != nil { + http.Error(w, "Error getting session", http.StatusInternalServerError) + return + } + if session == nil { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + // Check session expiry + if session.OAuthToken.Expiry.Before(time.Now()) { + http.Error(w, "Session expired", http.StatusUnauthorized) + return + } + + // Add session to context + ctx := context.WithValue(r.Context(), "session", session) + + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} diff --git a/backend/internal/server/server.go b/backend/internal/server/server.go new file mode 100644 index 0000000..7270b56 --- /dev/null +++ b/backend/internal/server/server.go @@ -0,0 +1,130 @@ +package server + +import ( + "context" + "errors" + "fmt" + "log/slog" + "net/http" + "time" + + "ibd-trader/internal/auth" + "ibd-trader/internal/config" + "ibd-trader/internal/database" + "ibd-trader/internal/ibd" + "ibd-trader/internal/server/api/ibd/creds" + "ibd-trader/internal/server/api/ibd/ibd50" + "ibd-trader/internal/server/api/ibd/scrape" + "ibd-trader/internal/server/auth/callback" + "ibd-trader/internal/server/auth/login" + "ibd-trader/internal/server/auth/user" + middleware2 "ibd-trader/internal/server/middleware" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" + "github.com/redis/go-redis/v9" +) + +func StartServer( + ctx context.Context, + cfg *config.Config, + logger *slog.Logger, + db database.Database, + auth *auth.Authenticator, + client *ibd.Client, + rClient *redis.Client, +) error { + r := chi.NewRouter() + + r.Use(middleware.RealIP) + r.Use(middleware.RequestID) + r.Use(middleware.Recoverer) + r.Use(middleware.Heartbeat("/healthz")) + + _ = NewMainHandler(logger, db, r) + r.Route("/auth", func(r chi.Router) { + r.Get("/login", login.Handler(logger, db, auth)) + r.Get("/callback", callback.Handler(logger, db, db, auth)) + r.Route("/user", func(r chi.Router) { + r.Use(middleware2.Auth(db)) + r.Get("/", user.Handler(logger, auth)) + }) + }) + r.Route("/api", func(r chi.Router) { + r.Use(middleware.NoCache) + r.Use(middleware2.Auth(db)) + r.Route("/ibd", func(r chi.Router) { + r.Put("/creds", creds.Handler(logger, db)) + r.Get("/ibd50", ibd50.Handler(logger, client)) + r.Put("/scrape", scrape.Handler(logger, rClient)) + }) + }) + + logger.Info("Starting server", "port", cfg.Server.Port) + srv := &http.Server{ + Addr: fmt.Sprintf("0.0.0.0:%d", cfg.Server.Port), + Handler: r, + //ReadTimeout: 1 * time.Minute, + //WriteTimeout: 1 * time.Minute, + ErrorLog: slog.NewLogLogger(logger.Handler(), slog.LevelError), + } + + finishedCh := make(chan error) + go func() { + err := srv.ListenAndServe() + if err != nil && !errors.Is(err, http.ErrServerClosed) { + logger.Error("Server failed", "error", err) + } + finishedCh <- err + close(finishedCh) + }() + + select { + case err := <-finishedCh: + // Server failed + return err + case <-ctx.Done(): + logger.Info("Shutting down server") + } + + // Shutdown server + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := srv.Shutdown(ctx); err != nil { + logger.Error("Failed to shutdown server", "error", err) + return err + } + + // Wait for the server to finish + err := <-finishedCh + if errors.Is(err, http.ErrServerClosed) { + return nil + } + return err +} + +type MainHandler struct { + logger *slog.Logger + db database.Database +} + +func NewMainHandler(logger *slog.Logger, db database.Database, r *chi.Mux) *MainHandler { + h := &MainHandler{logger, db} + r.Get("/readyz", h.Ready) + + return h +} + +func (h *MainHandler) Ready(w http.ResponseWriter, r *http.Request) { + // Check we can ping DB + ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second) + defer cancel() + err := h.db.Ping(ctx) + if err != nil { + http.Error(w, "DB not ready", http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("OK")) +} diff --git a/backend/internal/server2/idb/stock/v1/stock.go b/backend/internal/server2/idb/stock/v1/stock.go new file mode 100644 index 0000000..3a94c82 --- /dev/null +++ b/backend/internal/server2/idb/stock/v1/stock.go @@ -0,0 +1,63 @@ +package stock + +import ( + "context" + "fmt" + "log/slog" + + pb "ibd-trader/api/gen/idb/stock/v1" + "ibd-trader/internal/database" + "ibd-trader/internal/leader/manager/ibd/scrape" + "ibd-trader/internal/redis/taskqueue" + + "cloud.google.com/go/longrunning/autogen/longrunningpb" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/timestamppb" +) + +const ScrapeOperationPrefix = "scrape" + +type Server struct { + pb.UnimplementedStockServiceServer + + db database.StockStore + queue taskqueue.TaskQueue[scrape.TaskInfo] +} + +func New(db database.StockStore, queue taskqueue.TaskQueue[scrape.TaskInfo]) *Server { + return &Server{db: db, queue: queue} +} + +func (s *Server) CreateStock(ctx context.Context, request *pb.CreateStockRequest) (*longrunningpb.Operation, error) { + task, err := s.queue.Enqueue(ctx, scrape.TaskInfo{Symbol: request.Symbol}) + if err != nil { + slog.ErrorContext(ctx, "failed to enqueue task", "err", err) + return nil, status.New(codes.Internal, "failed to enqueue task").Err() + } + op := &longrunningpb.Operation{ + Name: fmt.Sprintf("%s/%s", ScrapeOperationPrefix, task.ID.String()), + Metadata: new(anypb.Any), + Done: false, + Result: nil, + } + err = op.Metadata.MarshalFrom(&pb.StockScrapeOperationMetadata{ + Symbol: request.Symbol, + StartTime: timestamppb.New(task.ID.Timestamp()), + }) + if err != nil { + slog.ErrorContext(ctx, "failed to marshal metadata", "err", err) + return nil, status.New(codes.Internal, "failed to marshal metadata").Err() + } + return op, nil +} + +func (s *Server) GetStock(ctx context.Context, request *pb.GetStockRequest) (*pb.GetStockResponse, error) { + +} + +func (s *Server) ListStocks(ctx context.Context, request *pb.ListStocksRequest) (*pb.ListStocksResponse, error) { + //TODO implement me + panic("implement me") +} diff --git a/backend/internal/server2/idb/user/v1/user.go b/backend/internal/server2/idb/user/v1/user.go new file mode 100644 index 0000000..1866944 --- /dev/null +++ b/backend/internal/server2/idb/user/v1/user.go @@ -0,0 +1,94 @@ +package user + +import ( + "context" + "errors" + + pb "ibd-trader/api/gen/idb/user/v1" + "ibd-trader/internal/database" + + "github.com/mennanov/fmutils" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" +) + +type Server struct { + pb.UnimplementedUserServiceServer + + db database.UserStore +} + +func New(db database.UserStore) *Server { + return &Server{db: db} +} + +func (u *Server) CreateUser(ctx context.Context, request *pb.CreateUserRequest) (*pb.CreateUserResponse, error) { + err := u.db.AddUser(ctx, request.Subject) + if err != nil { + return nil, status.Errorf(codes.Internal, "unable to create user: %v", err) + } + + user, err := u.db.GetUser(ctx, request.Subject) + if err != nil { + return nil, status.Errorf(codes.Internal, "unable to get user: %v", err) + } + + return &pb.CreateUserResponse{ + User: &pb.User{ + Subject: user.Subject, + IbdUsername: user.IBDUsername, + IbdPassword: nil, + }, + }, nil +} + +func (u *Server) GetUser(ctx context.Context, request *pb.GetUserRequest) (*pb.GetUserResponse, error) { + user, err := u.db.GetUser(ctx, request.Subject) + if errors.Is(err, database.ErrUserNotFound) { + return nil, status.New(codes.NotFound, "user not found").Err() + } + if err != nil { + return nil, status.Errorf(codes.Internal, "unable to get user: %v", err) + } + + return &pb.GetUserResponse{ + User: &pb.User{ + Subject: user.Subject, + IbdUsername: user.IBDUsername, + IbdPassword: nil, + }, + }, nil +} + +func (u *Server) UpdateUser(ctx context.Context, request *pb.UpdateUserRequest) (*pb.UpdateUserResponse, error) { + request.UpdateMask.Normalize() + if !request.UpdateMask.IsValid(request.User) { + return nil, status.Errorf(codes.InvalidArgument, "invalid update mask") + } + + existingUserRes, err := u.GetUser(ctx, &pb.GetUserRequest{Subject: request.User.Subject}) + if err != nil { + return nil, err + } + existingUser := existingUserRes.User + + newUser := proto.Clone(existingUser).(*pb.User) + fmutils.Overwrite(request.User, newUser, request.UpdateMask.Paths) + + // if IDB creds are both set and are different, update them + if (newUser.IbdPassword != nil && newUser.IbdUsername != nil) && + (newUser.IbdPassword != existingUser.IbdPassword || + newUser.IbdUsername != existingUser.IbdUsername) { + // Update IBD creds + err = u.db.AddIBDCreds(ctx, newUser.Subject, *newUser.IbdUsername, *newUser.IbdPassword) + if err != nil { + return nil, status.Errorf(codes.Internal, "unable to update user: %v", err) + } + } + + newUser.IbdPassword = nil + return &pb.UpdateUserResponse{ + User: newUser, + }, nil +} diff --git a/backend/internal/server2/operations.go b/backend/internal/server2/operations.go new file mode 100644 index 0000000..c632cd1 --- /dev/null +++ b/backend/internal/server2/operations.go @@ -0,0 +1,130 @@ +package server2 + +import ( + "context" + "errors" + "fmt" + "log/slog" + "strings" + + spb "ibd-trader/api/gen/idb/stock/v1" + "ibd-trader/internal/leader/manager/ibd/scrape" + "ibd-trader/internal/redis/taskqueue" + "ibd-trader/internal/server2/idb/stock/v1" + + "cloud.google.com/go/longrunning/autogen/longrunningpb" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/timestamppb" +) + +type operationServer struct { + longrunningpb.UnimplementedOperationsServer + + scrape taskqueue.TaskQueue[scrape.TaskInfo] +} + +func newOperationServer(scrapeQueue taskqueue.TaskQueue[scrape.TaskInfo]) *operationServer { + return &operationServer{scrape: scrapeQueue} +} + +func (o *operationServer) ListOperations( + ctx context.Context, + req *longrunningpb.ListOperationsRequest, +) (*longrunningpb.ListOperationsResponse, error) { + var end taskqueue.TaskID + if req.PageToken != "" { + var err error + end, err = taskqueue.ParseTaskID(req.PageToken) + if err != nil { + return nil, status.New(codes.InvalidArgument, err.Error()).Err() + } + } else { + end = taskqueue.TaskID{} + } + + switch req.Name { + case stock.ScrapeOperationPrefix: + tasks, err := o.scrape.List(ctx, taskqueue.TaskID{}, end, int64(req.PageSize)) + if err != nil { + return nil, status.New(codes.Internal, "unable to list IDs").Err() + } + + ops := make([]*longrunningpb.Operation, len(tasks)) + for i, task := range tasks { + ops[i] = &longrunningpb.Operation{ + Name: fmt.Sprintf("%s/%s", stock.ScrapeOperationPrefix, task.ID.String()), + Metadata: new(anypb.Any), + Done: task.Done, + Result: nil, + } + err = ops[i].Metadata.MarshalFrom(&spb.StockScrapeOperationMetadata{ + Symbol: task.Data.Symbol, + StartTime: timestamppb.New(task.ID.Timestamp()), + }) + if err != nil { + return nil, status.New(codes.Internal, "unable to marshal metadata").Err() + } + + if task.Done && task.Error != "" { + s := status.New(codes.Unknown, task.Error) + ops[i].Result = &longrunningpb.Operation_Error{Error: s.Proto()} + } + } + + var nextPageToken string + if len(tasks) == int(req.PageSize) { + nextPageToken = tasks[len(tasks)-1].ID.String() + } else { + nextPageToken = "" + } + + return &longrunningpb.ListOperationsResponse{ + Operations: ops, + NextPageToken: nextPageToken, + }, nil + default: + return nil, status.New(codes.NotFound, "unknown operation type").Err() + } +} + +func (o *operationServer) GetOperation(ctx context.Context, req *longrunningpb.GetOperationRequest) (*longrunningpb.Operation, error) { + prefix, id, ok := strings.Cut(req.Name, "/") + if !ok || prefix == "" || id == "" { + return nil, status.New(codes.InvalidArgument, "invalid operation name").Err() + } + + taskID, err := taskqueue.ParseTaskID(id) + if err != nil { + return nil, status.New(codes.InvalidArgument, err.Error()).Err() + } + + switch prefix { + case stock.ScrapeOperationPrefix: + task, err := o.scrape.Data(ctx, taskID) + if errors.Is(err, taskqueue.ErrTaskNotFound) { + return nil, status.New(codes.NotFound, "operation not found").Err() + } + if err != nil { + slog.ErrorContext(ctx, "unable to get operation", "error", err) + return nil, status.New(codes.Internal, "unable to get operation").Err() + } + op := &longrunningpb.Operation{ + Name: req.Name, + Metadata: new(anypb.Any), + Done: task.Done, + Result: nil, + } + err = op.Metadata.MarshalFrom(&spb.StockScrapeOperationMetadata{ + Symbol: task.Data.Symbol, + StartTime: timestamppb.New(task.ID.Timestamp()), + }) + if err != nil { + return nil, status.New(codes.Internal, "unable to marshal metadata").Err() + } + return op, nil + default: + return nil, status.New(codes.NotFound, "unknown operation type").Err() + } +} diff --git a/backend/internal/server2/server.go b/backend/internal/server2/server.go new file mode 100644 index 0000000..4731bdd --- /dev/null +++ b/backend/internal/server2/server.go @@ -0,0 +1,71 @@ +package server2 + +import ( + "context" + "fmt" + "log/slog" + "net" + + spb "ibd-trader/api/gen/idb/stock/v1" + upb "ibd-trader/api/gen/idb/user/v1" + "ibd-trader/internal/database" + "ibd-trader/internal/leader/manager/ibd/scrape" + "ibd-trader/internal/redis/taskqueue" + "ibd-trader/internal/server2/idb/stock/v1" + "ibd-trader/internal/server2/idb/user/v1" + + "cloud.google.com/go/longrunning/autogen/longrunningpb" + "github.com/redis/go-redis/v9" + "google.golang.org/grpc" + "google.golang.org/grpc/reflection" +) + +type Server struct { + s *grpc.Server + port uint16 +} + +func New( + ctx context.Context, + port uint16, + db database.Database, + rClient *redis.Client, +) (*Server, error) { + scrapeQueue, err := taskqueue.New( + ctx, + rClient, + scrape.Queue, + "grpc-server", + taskqueue.WithEncoding[scrape.TaskInfo](scrape.QueueEncoding)) + if err != nil { + return nil, err + } + + s := grpc.NewServer() + upb.RegisterUserServiceServer(s, user.New(db)) + spb.RegisterStockServiceServer(s, stock.New(db, scrapeQueue)) + longrunningpb.RegisterOperationsServer(s, newOperationServer(scrapeQueue)) + reflection.Register(s) + return &Server{s, port}, nil +} + +func (s *Server) Serve(ctx context.Context) error { + lis, err := net.Listen("tcp", fmt.Sprintf(":%d", s.port)) + if err != nil { + return err + } + + // Graceful shutdown + go func() { + <-ctx.Done() + slog.ErrorContext(ctx, + "Shutting down server", + "err", ctx.Err(), + "cause", context.Cause(ctx), + ) + s.s.GracefulStop() + }() + + slog.InfoContext(ctx, "Starting gRPC server", "port", s.port) + return s.s.Serve(lis) +} diff --git a/backend/internal/utils/money.go b/backend/internal/utils/money.go new file mode 100644 index 0000000..2dc2286 --- /dev/null +++ b/backend/internal/utils/money.go @@ -0,0 +1,99 @@ +package utils + +import ( + "fmt" + "strconv" + "strings" + + "github.com/Rhymond/go-money" +) + +// supported currencies +var currencies = money.Currencies{ + "USD": money.GetCurrency(money.USD), + "EUR": money.GetCurrency(money.EUR), + "GBP": money.GetCurrency(money.GBP), + "JPY": money.GetCurrency(money.JPY), + "CNY": money.GetCurrency(money.CNY), +} + +func ParseMoney(s string) (*money.Money, error) { + for _, c := range currencies { + numPart, ok := isCurrency(s, c) + if !ok { + continue + } + + // Parse the number part + num, err := strconv.ParseUint(numPart, 10, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse number: %w", err) + } + + return money.New(int64(num), c.Code), nil + } + return nil, fmt.Errorf("matching currency not found") +} + +func isCurrency(s string, c *money.Currency) (string, bool) { + var numPart string + for _, tp := range c.Template { + switch tp { + case '$': + // There should be a matching grapheme in the s at this position + remaining, ok := strings.CutPrefix(s, c.Grapheme) + if !ok { + return "", false + } + s = remaining + case '1': + // There should be a number, thousands, or decimal separator in the s at this position + // Number of expected decimal places + decimalFound := -1 + // Read from string until a non-number, non-thousands, non-decimal, or EOF is found + for len(s) > 0 && (string(s[0]) == c.Thousand || string(s[0]) == c.Decimal || '0' <= s[0] && s[0] <= '9') { + // If the character is a number + if '0' <= s[0] && s[0] <= '9' { + // If we've hit decimal limit, break + if decimalFound == 0 { + break + } + // add the number to the numPart + numPart += string(s[0]) + // Decrement the decimal count + // If the decimal has been found, `decimalFound` is positive + // If the decimal hasn't been found, `decimalFound` is negative, and decrementing it does nothing + decimalFound-- + } + // If decimal has been found (>= 0) and the character is a thousand separator or decimal separator, + // then the number is invalid + if decimalFound >= 0 && (string(s[0]) == c.Thousand || string(s[0]) == c.Decimal) { + return "", false + } + // If the character is a decimal separator, set `decimalFound` to the number of + // expected decimal places for the currency + if string(s[0]) == c.Decimal { + decimalFound = c.Fraction + } + // Move to the next character + s = s[1:] + } + if decimalFound > 0 { + // If there should be more decimal places, add them + numPart += strings.Repeat("0", decimalFound) + } else if decimalFound < 0 { + // If no decimal was found, add the expected number of decimal places + numPart += strings.Repeat("0", c.Fraction) + } + case ' ': + // There should be a space in the s at this position + if len(s) == 0 || s[0] != ' ' { + return "", false + } + s = s[1:] + default: + panic(fmt.Sprintf("unsupported template character: %c", tp)) + } + } + return numPart, true +} diff --git a/backend/internal/utils/money_test.go b/backend/internal/utils/money_test.go new file mode 100644 index 0000000..27ace06 --- /dev/null +++ b/backend/internal/utils/money_test.go @@ -0,0 +1,106 @@ +package utils + +import ( + "testing" + + "github.com/Rhymond/go-money" + "github.com/stretchr/testify/assert" +) + +func TestParseMoney(t *testing.T) { + tests := []struct { + name string + input string + want *money.Money + }{ + { + name: "en-US int no comma", + input: "$123", + want: money.New(12300, money.USD), + }, + { + name: "en-US int comma", + input: "$1,123", + want: money.New(112300, money.USD), + }, + { + name: "en-US decimal comma", + input: "$1,123.45", + want: money.New(112345, money.USD), + }, + { + name: "zh-CN decimal comma", + input: "1,234.56 \u5143", + want: money.New(123456, money.CNY), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m, err := ParseMoney(tt.input) + assert.NoError(t, err) + assert.Equal(t, tt.want, m) + }) + } +} + +func Test_isCurrency(t *testing.T) { + tests := []struct { + name string + input string + currency string + numPart string + ok bool + }{ + { + name: "en-US int no comma", + input: "$123", + currency: money.USD, + numPart: "12300", + ok: true, + }, + { + name: "en-US int comma", + input: "$1,123", + currency: money.USD, + numPart: "112300", + ok: true, + }, + { + name: "en-US decimal comma", + input: "$1,123.45", + currency: money.USD, + numPart: "112345", + ok: true, + }, + { + name: "en-US 1 decimal comma", + input: "$1,123.5", + currency: money.USD, + numPart: "112350", + ok: true, + }, + { + name: "en-US no grapheme", + input: "1,234.56", + currency: money.USD, + numPart: "", + ok: false, + }, + { + name: "zh-CN decimal comma", + input: "1,234.56 \u5143", + currency: money.CNY, + numPart: "123456", + ok: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := money.GetCurrency(tt.currency) + numPart, ok := isCurrency(tt.input, c) + assert.Equal(t, tt.ok, ok) + assert.Equal(t, tt.numPart, numPart) + }) + } +} diff --git a/backend/internal/worker/analyzer/analyzer.go b/backend/internal/worker/analyzer/analyzer.go new file mode 100644 index 0000000..924e571 --- /dev/null +++ b/backend/internal/worker/analyzer/analyzer.go @@ -0,0 +1,135 @@ +package analyzer + +import ( + "context" + "log/slog" + "time" + + "ibd-trader/internal/analyzer" + "ibd-trader/internal/database" + "ibd-trader/internal/redis/taskqueue" + + "github.com/redis/go-redis/v9" +) + +const ( + Queue = "analyzer" + QueueEncoding = taskqueue.EncodingJSON + + lockTimeout = 1 * time.Minute + dequeueTimeout = 5 * time.Second +) + +func RunAnalyzer( + ctx context.Context, + redis *redis.Client, + analyzer analyzer.Analyzer, + db database.StockStore, + name string, +) error { + queue, err := taskqueue.New( + ctx, + redis, + Queue, + name, + taskqueue.WithEncoding[TaskInfo](QueueEncoding), + ) + if err != nil { + return err + } + + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + waitForTask(ctx, queue, analyzer, db) + } + } +} + +func waitForTask( + ctx context.Context, + queue taskqueue.TaskQueue[TaskInfo], + analyzer analyzer.Analyzer, + db database.StockStore, +) { + task, err := queue.Dequeue(ctx, lockTimeout, dequeueTimeout) + if err != nil { + slog.ErrorContext(ctx, "Failed to dequeue task", "error", err) + return + } + if task == nil { + // No task available. + return + } + + ch := make(chan error) + defer close(ch) + go func() { + ch <- analyzeStock(ctx, analyzer, db, task.Data.ID) + }() + + ticker := time.NewTicker(lockTimeout / 5) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + // Context was canceled. Return early. + return + case <-ticker.C: + // Extend the lock periodically. + func() { + ctx, cancel := context.WithTimeout(ctx, lockTimeout/5) + defer cancel() + + err := queue.Extend(ctx, task.ID) + if err != nil { + slog.ErrorContext(ctx, "Failed to extend lock", "error", err) + } + }() + case err = <-ch: + // scrapeUrl has completed. + if err != nil { + slog.ErrorContext(ctx, "Failed to analyze", "error", err) + _, err = queue.Return(ctx, task.ID, err) + if err != nil { + slog.ErrorContext(ctx, "Failed to return task", "error", err) + return + } + } else { + slog.DebugContext(ctx, "Analyzed ID", "id", task.Data.ID) + err = queue.Complete(ctx, task.ID, nil) + if err != nil { + slog.ErrorContext(ctx, "Failed to complete task", "error", err) + return + } + } + return + } + } +} + +func analyzeStock(ctx context.Context, a analyzer.Analyzer, db database.StockStore, id string) error { + info, err := db.GetStockInfo(ctx, id) + if err != nil { + return err + } + + analysis, err := a.Analyze( + ctx, + info.Symbol, + info.Price, + info.ChartAnalysis, + ) + if err != nil { + return err + } + + return db.AddAnalysis(ctx, id, analysis) +} + +type TaskInfo struct { + ID string `json:"id"` +} diff --git a/backend/internal/worker/auth/auth.go b/backend/internal/worker/auth/auth.go new file mode 100644 index 0000000..e1c6661 --- /dev/null +++ b/backend/internal/worker/auth/auth.go @@ -0,0 +1,228 @@ +package auth + +import ( + "context" + "fmt" + "log/slog" + "time" + + "ibd-trader/internal/database" + "ibd-trader/internal/ibd" + "ibd-trader/internal/leader/manager/ibd/auth" + "ibd-trader/internal/redis/taskqueue" + + "github.com/redis/go-redis/v9" +) + +const ( + lockTimeout = 1 * time.Minute + dequeueTimeout = 5 * time.Second +) + +func RunAuthScraper( + ctx context.Context, + client *ibd.Client, + redis *redis.Client, + users database.UserStore, + cookies database.CookieStore, + name string, +) error { + queue, err := taskqueue.New( + ctx, + redis, + auth.Queue, + name, + taskqueue.WithEncoding[auth.TaskInfo](auth.QueueEncoding), + ) + if err != nil { + return err + } + + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + waitForTask(ctx, queue, client, users, cookies) + } + } +} + +func waitForTask( + ctx context.Context, + queue taskqueue.TaskQueue[auth.TaskInfo], + client *ibd.Client, + users database.UserStore, + cookies database.CookieStore, +) { + task, err := queue.Dequeue(ctx, lockTimeout, dequeueTimeout) + if err != nil { + slog.ErrorContext(ctx, "Failed to dequeue task", "error", err) + return + } + if task == nil { + // No task available. + return + } + slog.DebugContext(ctx, "Picked up auth task", "task", task.ID, "user", task.Data.UserSubject) + + ch := make(chan error) + defer close(ch) + go func() { + ch <- scrapeCookies(ctx, client, users, cookies, task.Data.UserSubject) + }() + + ticker := time.NewTicker(lockTimeout / 5) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + // The context was canceled. Return early. + return + case <-ticker.C: + // Extend the lock periodically. + func() { + ctx, cancel := context.WithTimeout(ctx, lockTimeout/5) + defer cancel() + + err := queue.Extend(ctx, task.ID) + if err != nil { + slog.ErrorContext(ctx, "Failed to extend lock", "error", err) + } + }() + case err = <-ch: + // scrapeCookies has completed. + if err != nil { + slog.ErrorContext(ctx, "Failed to scrape cookies", "error", err) + _, err = queue.Return(ctx, task.ID, err) + if err != nil { + slog.ErrorContext(ctx, "Failed to return task", "error", err) + return + } + } else { + err = queue.Complete(ctx, task.ID, nil) + if err != nil { + slog.ErrorContext(ctx, "Failed to complete task", "error", err) + return + } + slog.DebugContext(ctx, "Authenticated user", "user", task.Data.UserSubject) + } + return + } + } +} + +func scrapeCookies( + ctx context.Context, + client *ibd.Client, + users database.UserStore, + store database.CookieStore, + user string, +) error { + ctx, cancel := context.WithTimeout(ctx, lockTimeout) + defer cancel() + + // Check if the user has valid cookies + done, err := hasValidCookies(ctx, store, user) + if err != nil { + return fmt.Errorf("failed to check cookies: %w", err) + } + if done { + return nil + } + + // Health check degraded cookies + done, err = healthCheckDegradedCookies(ctx, client, store, user) + if err != nil { + return fmt.Errorf("failed to health check cookies: %w", err) + } + if done { + return nil + } + + // No cookies are valid, so scrape new cookies + return scrapeNewCookies(ctx, client, users, store, user) +} + +func hasValidCookies(ctx context.Context, store database.CookieStore, user string) (bool, error) { + // Check if the user has non-degraded cookies + cookies, err := store.GetCookies(ctx, user, false) + if err != nil { + return false, fmt.Errorf("failed to get non-degraded cookies: %w", err) + } + + // If the user has non-degraded cookies, return true + if cookies != nil && len(cookies) > 0 { + return true, nil + } + return false, nil +} + +func healthCheckDegradedCookies( + ctx context.Context, + client *ibd.Client, + store database.CookieStore, + user string, +) (bool, error) { + // Check if the user has degraded cookies + cookies, err := store.GetCookies(ctx, user, true) + if err != nil { + return false, fmt.Errorf("failed to get degraded cookies: %w", err) + } + + valid := false + for _, cookie := range cookies { + slog.DebugContext(ctx, "Health checking cookie", "cookie", cookie.ID) + + // Health check the cookie + up, err := client.UserInfo(ctx, cookie.ToHTTPCookie()) + if err != nil { + slog.ErrorContext(ctx, "Failed to health check cookie", "error", err) + continue + } + + if up.Status != ibd.UserStatusSubscriber { + continue + } + + // Cookie is valid + valid = true + + // Update the cookie + err = store.RepairCookie(ctx, cookie.ID) + if err != nil { + slog.ErrorContext(ctx, "Failed to repair cookie", "error", err) + } + } + + return valid, nil +} + +func scrapeNewCookies( + ctx context.Context, + client *ibd.Client, + users database.UserStore, + store database.CookieStore, + user string, +) error { + // Get the user's credentials + username, password, err := users.GetIBDCreds(ctx, user) + if err != nil { + return fmt.Errorf("failed to get IBD credentials: %w", err) + } + + // Scrape the user's cookies + cookie, err := client.Authenticate(ctx, username, password) + if err != nil { + return fmt.Errorf("failed to authenticate user: %w", err) + } + + // Store the cookie + err = store.AddCookie(ctx, user, cookie) + if err != nil { + return fmt.Errorf("failed to store cookie: %w", err) + } + + return nil +} diff --git a/backend/internal/worker/scraper/scraper.go b/backend/internal/worker/scraper/scraper.go new file mode 100644 index 0000000..a83d9ae --- /dev/null +++ b/backend/internal/worker/scraper/scraper.go @@ -0,0 +1,191 @@ +package scraper + +import ( + "context" + "errors" + "fmt" + "log/slog" + "time" + + "ibd-trader/internal/database" + "ibd-trader/internal/ibd" + "ibd-trader/internal/leader/manager/ibd/scrape" + "ibd-trader/internal/redis/taskqueue" + "ibd-trader/internal/worker/analyzer" + + "github.com/redis/go-redis/v9" +) + +const ( + lockTimeout = 1 * time.Minute + dequeueTimeout = 5 * time.Second +) + +func RunScraper( + ctx context.Context, + redis *redis.Client, + client *ibd.Client, + store database.StockStore, + name string, +) error { + queue, err := taskqueue.New( + ctx, + redis, + scrape.Queue, + name, + taskqueue.WithEncoding[scrape.TaskInfo](scrape.QueueEncoding), + ) + if err != nil { + return err + } + + aQueue, err := taskqueue.New( + ctx, + redis, + analyzer.Queue, + name, + taskqueue.WithEncoding[analyzer.TaskInfo](analyzer.QueueEncoding), + ) + if err != nil { + return err + } + + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + waitForTask(ctx, queue, aQueue, client, store) + } + } +} + +func waitForTask( + ctx context.Context, + queue taskqueue.TaskQueue[scrape.TaskInfo], + aQueue taskqueue.TaskQueue[analyzer.TaskInfo], + client *ibd.Client, + store database.StockStore, +) { + task, err := queue.Dequeue(ctx, lockTimeout, dequeueTimeout) + if err != nil { + slog.ErrorContext(ctx, "Failed to dequeue task", "error", err) + return + } + if task == nil { + // No task available. + return + } + + ch := make(chan error) + go func() { + defer close(ch) + ch <- scrapeUrl(ctx, client, store, aQueue, task.Data.Symbol) + }() + + ticker := time.NewTicker(lockTimeout / 5) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + // Context was canceled. Return early. + return + case <-ticker.C: + // Extend the lock periodically. + func() { + ctx, cancel := context.WithTimeout(ctx, lockTimeout/5) + defer cancel() + + err := queue.Extend(ctx, task.ID) + if err != nil { + slog.ErrorContext(ctx, "Failed to extend lock", "error", err) + } + }() + case err = <-ch: + // scrapeUrl has completed. + if err != nil { + slog.ErrorContext(ctx, "Failed to scrape URL", "error", err) + _, err = queue.Return(ctx, task.ID, err) + if err != nil { + slog.ErrorContext(ctx, "Failed to return task", "error", err) + return + } + } else { + slog.DebugContext(ctx, "Scraped URL", "symbol", task.Data.Symbol) + err = queue.Complete(ctx, task.ID, nil) + if err != nil { + slog.ErrorContext(ctx, "Failed to complete task", "error", err) + return + } + } + return + } + } +} + +func scrapeUrl( + ctx context.Context, + client *ibd.Client, + store database.StockStore, + aQueue taskqueue.TaskQueue[analyzer.TaskInfo], + symbol string, +) error { + ctx, cancel := context.WithTimeout(ctx, lockTimeout) + defer cancel() + + stockUrl, err := getStockUrl(ctx, store, client, symbol) + if err != nil { + return fmt.Errorf("failed to get stock url: %w", err) + } + + // Scrape the stock info. + info, err := client.StockInfo(ctx, stockUrl) + if err != nil { + return fmt.Errorf("failed to get stock info: %w", err) + } + + // Add stock info to the database. + id, err := store.AddStockInfo(ctx, info) + if err != nil { + return fmt.Errorf("failed to add stock info: %w", err) + } + + // Add the stock to the analyzer queue. + _, err = aQueue.Enqueue(ctx, analyzer.TaskInfo{ID: id}) + if err != nil { + return fmt.Errorf("failed to enqueue analysis task: %w", err) + } + + slog.DebugContext(ctx, "Added stock info", "id", id) + + return nil +} + +func getStockUrl(ctx context.Context, store database.StockStore, client *ibd.Client, symbol string) (string, error) { + // Get the stock from the database. + stock, err := store.GetStock(ctx, symbol) + if err == nil { + return stock.IBDUrl, nil + } + if !errors.Is(err, database.ErrStockNotFound) { + return "", fmt.Errorf("failed to get stock: %w", err) + } + + // If stock isn't found in the database, get the stock from IBD. + stock, err = client.Search(ctx, symbol) + if errors.Is(err, ibd.ErrSymbolNotFound) { + return "", fmt.Errorf("symbol not found: %w", err) + } + if err != nil { + return "", fmt.Errorf("failed to search for symbol: %w", err) + } + + // Add the stock to the database. + err = store.AddStock(ctx, stock) + if err != nil { + return "", fmt.Errorf("failed to add stock: %w", err) + } + + return stock.IBDUrl, nil +} diff --git a/backend/internal/worker/worker.go b/backend/internal/worker/worker.go new file mode 100644 index 0000000..5858115 --- /dev/null +++ b/backend/internal/worker/worker.go @@ -0,0 +1,149 @@ +package worker + +import ( + "context" + "crypto/rand" + "encoding/base64" + "io" + "log/slog" + "os" + "time" + + "ibd-trader/internal/analyzer" + "ibd-trader/internal/database" + "ibd-trader/internal/ibd" + "ibd-trader/internal/leader/manager" + analyzer2 "ibd-trader/internal/worker/analyzer" + "ibd-trader/internal/worker/auth" + "ibd-trader/internal/worker/scraper" + + "github.com/redis/go-redis/v9" + "golang.org/x/sync/errgroup" +) + +const ( + HeartbeatInterval = 5 * time.Second + HeartbeatTTL = 30 * time.Second +) + +func StartWorker( + ctx context.Context, + ibdClient *ibd.Client, + client *redis.Client, + db database.Database, + a analyzer.Analyzer, +) error { + // Get the worker name. + name, err := workerName() + if err != nil { + return err + } + slog.InfoContext(ctx, "Starting worker", "worker", name) + + g, ctx := errgroup.WithContext(ctx) + + g.Go(func() error { + return workerRegistrationLoop(ctx, client, name) + }) + g.Go(func() error { + return scraper.RunScraper(ctx, client, ibdClient, db, name) + }) + g.Go(func() error { + return auth.RunAuthScraper(ctx, ibdClient, client, db, db, name) + }) + g.Go(func() error { + return analyzer2.RunAnalyzer(ctx, client, a, db, name) + }) + + return g.Wait() +} + +func workerRegistrationLoop(ctx context.Context, client *redis.Client, name string) error { + sendHeartbeat(ctx, client, name) + + ticker := time.NewTicker(HeartbeatInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + sendHeartbeat(ctx, client, name) + case <-ctx.Done(): + removeWorker(ctx, client, name) + return ctx.Err() + } + } +} + +// sendHeartbeat sends a heartbeat for the worker. +// It ensures that the worker is in the active workers set and its heartbeat exists. +func sendHeartbeat(ctx context.Context, client *redis.Client, name string) { + ctx, cancel := context.WithTimeout(ctx, HeartbeatInterval) + defer cancel() + + // Add the worker to the active workers set. + if err := client.SAdd(ctx, manager.ActiveWorkersSet, name).Err(); err != nil { + slog.ErrorContext(ctx, + "Unable to add worker to active workers set", + "worker", name, + "error", err, + ) + return + } + + // Set the worker's heartbeat. + heartbeatKey := manager.WorkerHeartbeatKey(name) + if err := client.Set(ctx, heartbeatKey, time.Now().Unix(), HeartbeatTTL).Err(); err != nil { + slog.ErrorContext(ctx, + "Unable to set worker heartbeat", + "worker", name, + "error", err, + ) + return + } +} + +// removeWorker removes the worker from the active workers set. +func removeWorker(ctx context.Context, client *redis.Client, name string) { + if ctx.Err() != nil { + // If the context is canceled, create a new uncanceled context. + ctx = context.WithoutCancel(ctx) + } + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + // Remove the worker from the active workers set. + if err := client.SRem(ctx, manager.ActiveWorkersSet, name).Err(); err != nil { + slog.ErrorContext(ctx, + "Unable to remove worker from active workers set", + "worker", name, + "error", err, + ) + return + } + + // Remove the worker's heartbeat. + heartbeatKey := manager.WorkerHeartbeatKey(name) + if err := client.Del(ctx, heartbeatKey).Err(); err != nil { + slog.ErrorContext(ctx, + "Unable to remove worker heartbeat", + "worker", name, + "error", err, + ) + return + } +} + +func workerName() (string, error) { + hostname, err := os.Hostname() + if err != nil { + return "", err + } + + bytes := make([]byte, 12) + if _, err = io.ReadFull(rand.Reader, bytes); err != nil { + return "", err + } + + return hostname + "-" + base64.URLEncoding.EncodeToString(bytes), nil +} |