diff options
author | 2024-08-11 13:15:50 -0700 | |
---|---|---|
committer | 2024-08-11 13:15:50 -0700 | |
commit | 6a3c21fb0b1c126849f2bbff494403bbe901448e (patch) | |
tree | 5d7805524357c2c8a9819c39d2051a4e3633a1d5 /backend/internal | |
parent | 29c6040a51616e9e4cf6c70ee16391b2a3b238c9 (diff) | |
parent | f34b92ded11b07f78575ac62c260a380c468e5ea (diff) | |
download | ibd-trader-6a3c21fb0b1c126849f2bbff494403bbe901448e.tar.gz ibd-trader-6a3c21fb0b1c126849f2bbff494403bbe901448e.tar.zst ibd-trader-6a3c21fb0b1c126849f2bbff494403bbe901448e.zip |
Merge remote-tracking branch 'backend/main'
Diffstat (limited to 'backend/internal')
53 files changed, 7377 insertions, 0 deletions
diff --git a/backend/internal/analyzer/analyzer.go b/backend/internal/analyzer/analyzer.go new file mode 100644 index 0000000..c055647 --- /dev/null +++ b/backend/internal/analyzer/analyzer.go @@ -0,0 +1,32 @@ +package analyzer + +import ( + "context" + + "github.com/Rhymond/go-money" +) + +type Analyzer interface { + Analyze( + ctx context.Context, + symbol string, + price *money.Money, + rawAnalysis string, + ) (*Analysis, error) +} + +type ChartAction string + +const ( + Buy ChartAction = "buy" + Sell ChartAction = "sell" + Hold ChartAction = "hold" + Unknown ChartAction = "unknown" +) + +type Analysis struct { + Action ChartAction + Price *money.Money + Reason string + Confidence uint8 +} diff --git a/backend/internal/analyzer/openai/openai.go b/backend/internal/analyzer/openai/openai.go new file mode 100644 index 0000000..0419c57 --- /dev/null +++ b/backend/internal/analyzer/openai/openai.go @@ -0,0 +1,126 @@ +package openai + +import ( + "context" + "encoding/json" + "fmt" + "math" + "strconv" + "strings" + "time" + + "github.com/ansg191/ibd-trader-backend/internal/analyzer" + "github.com/ansg191/ibd-trader-backend/internal/utils" + + "github.com/Rhymond/go-money" + "github.com/sashabaranov/go-openai" +) + +type Client interface { + CreateChatCompletion( + ctx context.Context, + request openai.ChatCompletionRequest, + ) (response openai.ChatCompletionResponse, err error) +} + +type Analyzer struct { + client Client + model string + systemMsg string + temperature float32 +} + +func NewAnalyzer(opts ...Option) *Analyzer { + a := &Analyzer{ + client: nil, + model: defaultModel, + systemMsg: defaultSystemMsg, + temperature: defaultTemperature, + } + for _, option := range opts { + option(a) + } + if a.client == nil { + panic("client is required") + } + + return a +} + +func (a *Analyzer) Analyze( + ctx context.Context, + symbol string, + price *money.Money, + rawAnalysis string, +) (*analyzer.Analysis, error) { + usrMsg := fmt.Sprintf( + "%s\n%s\n%s\n%s\n", + time.Now().Format(time.RFC3339), + symbol, + price.Display(), + rawAnalysis, + ) + res, err := a.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{ + Model: a.model, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleSystem, + Content: a.systemMsg, + }, + { + Role: openai.ChatMessageRoleUser, + Content: usrMsg, + }, + }, + MaxTokens: 0, + Temperature: a.temperature, + ResponseFormat: &openai.ChatCompletionResponseFormat{Type: openai.ChatCompletionResponseFormatTypeJSONObject}, + }) + if err != nil { + return nil, err + } + + var resp response + if err = json.Unmarshal([]byte(res.Choices[0].Message.Content), &resp); err != nil { + return nil, fmt.Errorf("failed to unmarshal gpt response: %w", err) + } + + var action analyzer.ChartAction + switch strings.ToLower(resp.Action) { + case "buy": + action = analyzer.Buy + case "sell": + action = analyzer.Sell + case "hold": + action = analyzer.Hold + default: + action = analyzer.Unknown + } + + m, err := utils.ParseMoney(resp.Price) + if err != nil { + return nil, fmt.Errorf("failed to parse price: %w", err) + } + + confidence, err := strconv.ParseFloat(resp.Confidence, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse confidence: %w", err) + } + if confidence < 0 || confidence > 100 { + return nil, fmt.Errorf("confidence must be between 0 and 100, got %f", confidence) + } + + return &analyzer.Analysis{ + Action: action, + Price: m, + Reason: resp.Reason, + Confidence: uint8(math.Floor(confidence)), + }, nil +} + +type response struct { + Action string `json:"action"` + Price string `json:"price"` + Reason string `json:"reason"` + Confidence string `json:"confidence"` +} diff --git a/backend/internal/analyzer/openai/openai_test.go b/backend/internal/analyzer/openai/openai_test.go new file mode 100644 index 0000000..0aac709 --- /dev/null +++ b/backend/internal/analyzer/openai/openai_test.go @@ -0,0 +1 @@ +package openai diff --git a/backend/internal/analyzer/openai/options.go b/backend/internal/analyzer/openai/options.go new file mode 100644 index 0000000..11d691f --- /dev/null +++ b/backend/internal/analyzer/openai/options.go @@ -0,0 +1,45 @@ +package openai + +import ( + _ "embed" + + "github.com/sashabaranov/go-openai" +) + +//go:embed system.txt +var defaultSystemMsg string + +const defaultModel = openai.GPT4o +const defaultTemperature = 0.25 + +type Option func(*Analyzer) + +func WithClientConfig(cfg openai.ClientConfig) Option { + return func(a *Analyzer) { + a.client = openai.NewClientWithConfig(cfg) + } +} + +func WithDefaultConfig(apiKey string) Option { + return func(a *Analyzer) { + a.client = openai.NewClient(apiKey) + } +} + +func WithModel(model string) Option { + return func(a *Analyzer) { + a.model = model + } +} + +func WithSystemMsg(msg string) Option { + return func(a *Analyzer) { + a.systemMsg = msg + } +} + +func WithTemperature(temp float32) Option { + return func(a *Analyzer) { + a.temperature = temp + } +} diff --git a/backend/internal/analyzer/openai/system.txt b/backend/internal/analyzer/openai/system.txt new file mode 100644 index 0000000..82e3b2a --- /dev/null +++ b/backend/internal/analyzer/openai/system.txt @@ -0,0 +1,34 @@ +You're a stock analyzer. +You will be given a stock symbol, its current price, and its chart analysis. +Your job is to determine the best course of action to do with the stock (buy, sell, hold, or unknown), +the price at which the action should be taken, the reason for the action, and the confidence (0-100) +level you have in the action. + +The reason should be a paragraph explaining why you chose the action. + +The date the chart analysis was done may be mentioned in the chart analysis. +Make sure to take that into account when making your decision. +If the chart analysis is older than 1 week, lower your confidence accordingly and mention that in the reason. +If the chart analysis is too outdated, set the action to "unknown". + +The information will be given in the following format: +``` +<current datetime> +<stock symbol> +<current price> +<chart analysis> +``` + +Your response should be in the following JSON format: +``` +{ + "action": "<action>", + "price": "<price>", + "reason": "<reason>", + "confidence": "<confidence>" +} +``` +All fields are required and must be strings. +`action` must be one of the following: "buy", "sell", "hold", or "unknown". +`price` must contain the symbol for the currency and the price (e.g. "$100"). +The system WILL validate your response.
\ No newline at end of file diff --git a/backend/internal/auth/auth.go b/backend/internal/auth/auth.go new file mode 100644 index 0000000..edad914 --- /dev/null +++ b/backend/internal/auth/auth.go @@ -0,0 +1,55 @@ +package auth + +import ( + "context" + "errors" + + "github.com/ansg191/ibd-trader-backend/internal/config" + + "github.com/coreos/go-oidc/v3/oidc" + "golang.org/x/oauth2" +) + +// Authenticator is used to authenticate our users. +type Authenticator struct { + *oidc.Provider + oauth2.Config +} + +// New instantiates the *Authenticator. +func New(cfg *config.Config) (*Authenticator, error) { + provider, err := oidc.NewProvider( + context.Background(), + "https://"+cfg.Auth.Domain+"/", + ) + if err != nil { + return nil, err + } + + conf := oauth2.Config{ + ClientID: cfg.Auth.ClientID, + ClientSecret: cfg.Auth.ClientSecret, + RedirectURL: cfg.Auth.CallbackURL, + Endpoint: provider.Endpoint(), + Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, + } + + return &Authenticator{ + Provider: provider, + Config: conf, + }, nil +} + +// VerifyIDToken verifies that an *oauth2.Token is a valid *oidc.IDToken. +func (a *Authenticator) VerifyIDToken(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) { + rawIDToken, ok := token.Extra("id_token").(string) + if !ok { + return nil, errors.New("no id_token field in oauth2 token") + } + + oidcConfig := &oidc.Config{ + ClientID: a.ClientID, + } + + return a.Verifier(oidcConfig).Verify(ctx, rawIDToken) +} diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go new file mode 100644 index 0000000..c37588b --- /dev/null +++ b/backend/internal/config/config.go @@ -0,0 +1,114 @@ +package config + +import ( + "github.com/ansg191/ibd-trader-backend/internal/keys" + "github.com/ansg191/ibd-trader-backend/internal/leader/manager/ibd" + + "github.com/spf13/viper" +) + +type Config struct { + // Logging configuration + Log struct { + // Log level + Level string + // Add source info to log messages + AddSource bool + // Enable colorized output + Color bool + } + // Database configuration + DB struct { + // Database URL + URL string + } + // Redis configuration + Redis struct { + // Redis address + Addr string + // Redis password + Password string + } + // KMS configuration + KMS struct { + // GCP KMS configuration + GCP *keys.GCPKeyName + } + // Server configuration + Server struct { + // Server port + Port uint16 + } + // OAuth 2.0 configuration + Auth struct { + // OAuth 2.0 domain + Domain string + // OAuth 2.0 client ID + ClientID string + // OAuth 2.0 client secret + ClientSecret string + // OAuth 2.0 callback URL + CallbackURL string + } + // IBD configuration + IBD struct { + // Scraper API Key + APIKey string + // Proxy URL + ProxyURL string + // Scrape schedules. In cron format. + Schedules ibd.Schedules + } + // Analyzer configuration + Analyzer struct { + // Use OpenAI for analysis + OpenAI *struct { + // OpenAI API Key + APIKey string + } + } +} + +func New() (*Config, error) { + v := viper.New() + + v.SetDefault("log.level", "INFO") + v.SetDefault("log.addSource", false) + v.SetDefault("log.color", false) + v.SetDefault("server.port", 8000) + + v.SetConfigName("config") + v.AddConfigPath("/etc/ibd-trader/") + v.AddConfigPath("$HOME/.ibd-trader") + v.AddConfigPath(".") + err := v.ReadInConfig() + if err != nil { + if _, ok := err.(viper.ConfigFileNotFoundError); ok { + // Config file not found; ignore error + } else { + return nil, err + } + } + + v.MustBindEnv("db.url", "DATABASE_URL") + v.MustBindEnv("redis.addr", "REDIS_ADDR") + v.MustBindEnv("redis.password", "REDIS_PASSWORD") + v.MustBindEnv("log.level", "LOG_LEVEL") + v.MustBindEnv("server.port", "SERVER_PORT") + v.MustBindEnv("auth.domain", "AUTH_DOMAIN") + v.MustBindEnv("auth.clientID", "AUTH_CLIENT_ID") + v.MustBindEnv("auth.clientSecret", "AUTH_CLIENT_SECRET") + v.MustBindEnv("auth.callbackURL", "AUTH_CALLBACK_URL") + + config := new(Config) + err = v.Unmarshal(config) + if err != nil { + return nil, err + } + + return config, config.assert() +} + +func (c *Config) assert() error { + return nil +} diff --git a/backend/internal/database/cookies.go b/backend/internal/database/cookies.go new file mode 100644 index 0000000..3ea21d0 --- /dev/null +++ b/backend/internal/database/cookies.go @@ -0,0 +1,189 @@ +package database + +import ( + "context" + "database/sql" + "errors" + "fmt" + "net/http" + "time" + + "github.com/ansg191/ibd-trader-backend/internal/keys" +) + +func GetAnyCookie(ctx context.Context, exec Executor, kms keys.KeyManagementService) (*IBDCookie, error) { + row := exec.QueryRowContext(ctx, ` +SELECT ibd_tokens.id, token, encrypted_key, kms_key_name, expires_at +FROM ibd_tokens + INNER JOIN keys ON encryption_key = keys.id +WHERE expires_at > NOW() + AND degraded = FALSE +ORDER BY random() +LIMIT 1;`) + + var id uint + var encryptedToken, encryptedKey []byte + var keyName string + var expiry time.Time + err := row.Scan(&id, &encryptedToken, &encryptedKey, &keyName, &expiry) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, fmt.Errorf("unable to scan sql row into ibd cookie: %w", err) + } + + // Set the expiry to UTC explicitly. + // For some reason, the expiry time is set to location="". + expiry = expiry.UTC() + + token, err := keys.Decrypt(ctx, kms, keyName, encryptedToken, encryptedKey) + if err != nil { + return nil, fmt.Errorf("unable to decrypt token: %w", err) + } + return &IBDCookie{ + Token: string(token), + Expiry: expiry, + }, nil +} + +func GetCookies( + ctx context.Context, + exec Executor, + kms keys.KeyManagementService, + subject string, + degraded bool, +) ([]IBDCookie, error) { + rows, err := exec.QueryContext(ctx, ` +SELECT ibd_tokens.id, token, encrypted_key, kms_key_name, expires_at +FROM ibd_tokens + INNER JOIN keys ON encryption_key = keys.id +WHERE user_subject = $1 + AND expires_at > NOW() + AND degraded = $2 +ORDER BY expires_at DESC;`, subject, degraded) + if err != nil { + return nil, fmt.Errorf("unable to get ibd cookies: %w", err) + } + + cookies := make([]IBDCookie, 0) + for rows.Next() { + var id uint + var encryptedToken, encryptedKey []byte + var keyName string + var expiry time.Time + err = rows.Scan(&id, &encryptedToken, &encryptedKey, &keyName, &expiry) + if err != nil { + return nil, fmt.Errorf("unable to scan sql row into ibd cookie: %w", err) + } + + // Set the expiry to UTC explicitly. + // For some reason, the expiry time is set to location="". + expiry = expiry.UTC() + + token, err := keys.Decrypt(ctx, kms, keyName, encryptedToken, encryptedKey) + if err != nil { + return nil, fmt.Errorf("unable to decrypt token: %w", err) + } + cookie := IBDCookie{ + ID: id, + Token: string(token), + Expiry: expiry, + } + cookies = append(cookies, cookie) + } + + return cookies, nil +} + +func AddCookie( + ctx context.Context, + exec TransactionExecutor, + kms keys.KeyManagementService, + subject string, + cookie *http.Cookie, +) error { + tx, err := exec.BeginTx(ctx, nil) + if err != nil { + return err + } + + // Get the key ID for the user + user, err := GetUser(ctx, tx, subject) + if err != nil { + return fmt.Errorf("unable to get user: %w", err) + } + if user.EncryptionKeyID == nil { + return errors.New("user does not have an encryption key") + } + + // Get the key + var keyName string + var key []byte + err = tx.QueryRowContext(ctx, ` +SELECT kms_key_name, encrypted_key +FROM keys +WHERE id = $1;`, + *user.EncryptionKeyID, + ).Scan(&keyName, &key) + if err != nil { + return fmt.Errorf("unable to get key: %w", err) + } + + // Encrypt the token + encryptedToken, err := keys.EncryptWithKey(ctx, kms, keyName, key, []byte(cookie.Value)) + if err != nil { + return fmt.Errorf("unable to encrypt token: %w", err) + } + + // Add the cookie to the database + _, err = exec.ExecContext(ctx, ` +INSERT INTO ibd_tokens (token, expires_at, user_subject, encryption_key) +VALUES ($1, $2, $3, $4)`, encryptedToken, cookie.Expires, subject, *user.EncryptionKeyID) + if err != nil { + return fmt.Errorf("unable to add cookie: %w", err) + } + + return nil +} + +func ReportCookieFailure(ctx context.Context, exec Executor, id uint) error { + _, err := exec.ExecContext(ctx, ` +UPDATE ibd_tokens +SET degraded = TRUE +WHERE id = $1;`, id) + if err != nil { + return fmt.Errorf("unable to report cookie failure: %w", err) + } + return nil +} + +func RepairCookie(ctx context.Context, exec Executor, id uint) error { + _, err := exec.ExecContext(ctx, ` +UPDATE ibd_tokens +SET degraded = FALSE +WHERE id = $1;`, id) + if err != nil { + return fmt.Errorf("unable to report cookie failure: %w", err) + } + return nil +} + +type IBDCookie struct { + ID uint + Token string + Expiry time.Time +} + +func (c *IBDCookie) ToHTTPCookie() *http.Cookie { + return &http.Cookie{ + Name: ".ASPXAUTH", + Value: c.Token, + Path: "/", + Domain: "investors.com", + Expires: c.Expiry, + Secure: true, + HttpOnly: false, + SameSite: http.SameSiteLaxMode, + } +} diff --git a/backend/internal/database/database.go b/backend/internal/database/database.go new file mode 100644 index 0000000..409dd3c --- /dev/null +++ b/backend/internal/database/database.go @@ -0,0 +1,166 @@ +package database + +import ( + "context" + "database/sql" + "database/sql/driver" + "errors" + "io" + "log/slog" + "sync" + "time" + + "github.com/ansg191/ibd-trader-backend/db" + "github.com/ansg191/ibd-trader-backend/internal/keys" + + "github.com/golang-migrate/migrate/v4" + _ "github.com/golang-migrate/migrate/v4/database/postgres" + "github.com/golang-migrate/migrate/v4/source/iofs" + _ "github.com/lib/pq" +) + +type Database interface { + io.Closer + TransactionExecutor + driver.Pinger + + Migrate(ctx context.Context) error + Maintenance(ctx context.Context) +} + +type database struct { + logger *slog.Logger + + db *sql.DB + url string + + kms keys.KeyManagementService + keyName string +} + +func New(ctx context.Context, logger *slog.Logger, url string, kms keys.KeyManagementService, keyName string) (Database, error) { + sqlDB, err := sql.Open("postgres", url) + if err != nil { + return nil, err + } + + err = sqlDB.PingContext(ctx) + if err != nil { + // Ping failed. Don't error, but give a warning. + logger.WarnContext(ctx, "Unable to ping database", "error", err) + } + + return &database{ + logger: logger, + db: sqlDB, + url: url, + kms: kms, + keyName: keyName, + }, nil +} + +func (d *database) Close() error { + return d.db.Close() +} + +func (d *database) Migrate(ctx context.Context) error { + return Migrate(ctx, d.url) +} + +func (d *database) Maintenance(ctx context.Context) { + ticker := time.NewTicker(15 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + func() { + var wg sync.WaitGroup + wg.Add(1) + + _, cancel := context.WithTimeout(context.Background(), 10*time.Minute) + defer cancel() + + wg.Wait() + }() + case <-ctx.Done(): + return + } + } +} + +func Migrate(ctx context.Context, url string) error { + fs, err := iofs.New(db.Migrations, "migrations") + if err != nil { + return err + } + + m, err := migrate.NewWithSourceInstance("iofs", fs, url) + if err != nil { + return err + } + + slog.InfoContext(ctx, "Running DB migration") + err = m.Up() + if err != nil && !errors.Is(err, migrate.ErrNoChange) { + slog.ErrorContext(ctx, "DB migration failed", "error", err) + return err + } + + return nil +} + +func (d *database) Ping(ctx context.Context) error { + return d.db.PingContext(ctx) +} + +type Executor interface { + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) + QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row +} + +type TransactionExecutor interface { + Executor + BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) +} + +func (d *database) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { + d.logger.DebugContext(ctx, "Executing query", "query", query) + + now := time.Now() + ret, err := d.db.ExecContext(ctx, query, args...) + if err != nil { + return nil, err + } + + d.logger.DebugContext(ctx, "Query executed successfully", "duration", time.Since(now)) + return ret, nil +} + +func (d *database) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { + d.logger.DebugContext(ctx, "Executing query", "query", query) + + now := time.Now() + ret, err := d.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + + d.logger.DebugContext(ctx, "Query executed successfully", "duration", time.Since(now)) + return ret, nil +} + +func (d *database) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row { + d.logger.DebugContext(ctx, "Executing query", "query", query) + + now := time.Now() + ret := d.db.QueryRowContext(ctx, query, args...) + + d.logger.DebugContext(ctx, "Query executed successfully", "duration", time.Since(now)) + return ret +} + +func (d *database) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { + return d.db.BeginTx(ctx, opts) +} diff --git a/backend/internal/database/database_test.go b/backend/internal/database/database_test.go new file mode 100644 index 0000000..407a09a --- /dev/null +++ b/backend/internal/database/database_test.go @@ -0,0 +1,79 @@ +package database + +import ( + "context" + "database/sql" + "fmt" + "log" + "testing" + "time" + + "github.com/ory/dockertest/v3" + "github.com/ory/dockertest/v3/docker" +) + +var exec *sql.DB + +func TestMain(m *testing.M) { + pool, err := dockertest.NewPool("") + if err != nil { + log.Fatalf("Could not create pool: %s", err) + } + + err = pool.Client.Ping() + if err != nil { + log.Fatalf("Could not connect to Docker: %s", err) + } + + resource, err := pool.RunWithOptions(&dockertest.RunOptions{ + Repository: "postgres", + Tag: "16", + Env: []string{ + "POSTGRES_PASSWORD=secret", + "POSTGRES_USER=ibd-client-test", + "POSTGRES_DB=ibd-client-test", + "listen_addresses='*'", + }, + Cmd: []string{ + "postgres", + "-c", + "log_statement=all", + }, + }, func(config *docker.HostConfig) { + config.AutoRemove = true + config.RestartPolicy = docker.RestartPolicy{Name: "no"} + }) + if err != nil { + log.Fatalf("Could not start resource: %s", err) + } + + hostAndPort := resource.GetHostPort("5432/tcp") + databaseUrl := fmt.Sprintf("postgres://ibd-client-test:secret@%s/ibd-client-test?sslmode=disable", hostAndPort) + + // Kill container after 120 seconds + _ = resource.Expire(120) + + pool.MaxWait = 120 * time.Second + if err = pool.Retry(func() error { + exec, err = sql.Open("postgres", databaseUrl) + if err != nil { + return err + } + return exec.Ping() + }); err != nil { + log.Fatalf("Could not connect to database: %s", err) + } + + err = Migrate(context.Background(), databaseUrl) + if err != nil { + log.Fatalf("Could not migrate database: %s", err) + } + + defer func() { + if err := pool.Purge(resource); err != nil { + log.Fatalf("Could not purge resource: %s", err) + } + }() + + m.Run() +} diff --git a/backend/internal/database/stocks.go b/backend/internal/database/stocks.go new file mode 100644 index 0000000..24f5fe7 --- /dev/null +++ b/backend/internal/database/stocks.go @@ -0,0 +1,293 @@ +package database + +import ( + "context" + "database/sql" + "database/sql/driver" + "errors" + + pb "github.com/ansg191/ibd-trader-backend/api/gen/idb/stock/v1" + "github.com/ansg191/ibd-trader-backend/internal/analyzer" + "github.com/ansg191/ibd-trader-backend/internal/utils" + + "github.com/Rhymond/go-money" +) + +var ErrStockNotFound = errors.New("stock not found") + +func GetStock(ctx context.Context, exec Executor, symbol string) (Stock, error) { + row := exec.QueryRowContext(ctx, ` +SELECT symbol, name, ibd_url +FROM stocks +WHERE symbol = $1; +`, symbol) + + var stock Stock + if err := row.Scan(&stock.Symbol, &stock.Name, &stock.IBDUrl); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return Stock{}, ErrStockNotFound + } + return Stock{}, err + } + + return stock, nil +} + +func AddStock(ctx context.Context, exec Executor, stock Stock) error { + _, err := exec.ExecContext(ctx, ` +INSERT INTO stocks (symbol, name, ibd_url) +VALUES ($1, $2, $3) +ON CONFLICT (symbol) + DO UPDATE SET name = $2, + ibd_url = $3;`, stock.Symbol, stock.Name, stock.IBDUrl) + return err +} + +func AddRanking(ctx context.Context, exec Executor, symbol string, ibd50, cap20 int) error { + if ibd50 > 0 { + _, err := exec.ExecContext(ctx, ` +INSERT INTO stock_rank (symbol, rank_type, rank) +VALUES ($1, $2, $3)`, symbol, "ibd50", ibd50) + if err != nil { + return err + } + } + if cap20 > 0 { + _, err := exec.ExecContext(ctx, ` +INSERT INTO stock_rank (symbol, rank_type, rank) +VALUES ($1, $2, $3)`, symbol, "cap20", cap20) + if err != nil { + return err + } + } + return nil +} + +func AddStockInfo(ctx context.Context, exec TransactionExecutor, info *StockInfo) (string, error) { + tx, err := exec.BeginTx(ctx, nil) + if err != nil { + return "", err + } + defer func(tx *sql.Tx) { + _ = tx.Rollback() + }(tx) + + // Add raw chart analysis + row := tx.QueryRowContext(ctx, ` +INSERT INTO chart_analysis (raw_analysis) +VALUES ($1) +RETURNING id;`, info.ChartAnalysis) + + var chartAnalysisID string + if err = row.Scan(&chartAnalysisID); err != nil { + return "", err + } + + // Add stock info + row = tx.QueryRowContext(ctx, + ` +INSERT INTO ratings (symbol, composite, eps, rel_str, group_rel_str, smr, acc_dis, chart_analysis, price) +VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) +RETURNING id;`, + info.Symbol, + info.Ratings.Composite, + info.Ratings.EPS, + info.Ratings.RelStr, + info.Ratings.GroupRelStr, + info.Ratings.SMR, + info.Ratings.AccDis, + chartAnalysisID, + info.Price.Display(), + ) + + var ratingsID string + if err = row.Scan(&ratingsID); err != nil { + return "", err + } + + return ratingsID, tx.Commit() +} + +func GetStockInfo(ctx context.Context, exec Executor, id string) (*StockInfo, error) { + row := exec.QueryRowContext(ctx, ` +SELECT r.symbol, + s.name, + ca.raw_analysis, + r.composite, + r.eps, + r.rel_str, + r.group_rel_str, + r.smr, + r.acc_dis, + r.price +FROM ratings r + INNER JOIN stocks s on r.symbol = s.symbol + INNER JOIN chart_analysis ca on r.chart_analysis = ca.id +WHERE r.id = $1;`, id) + + var info StockInfo + var priceStr string + err := row.Scan( + &info.Symbol, + &info.Name, + &info.ChartAnalysis, + &info.Ratings.Composite, + &info.Ratings.EPS, + &info.Ratings.RelStr, + &info.Ratings.GroupRelStr, + &info.Ratings.SMR, + &info.Ratings.AccDis, + &priceStr, + ) + if err != nil { + return nil, err + } + + info.Price, err = utils.ParseMoney(priceStr) + if err != nil { + return nil, err + } + + return &info, nil +} + +func AddAnalysis( + ctx context.Context, + exec Executor, + ratingId string, + analysis *analyzer.Analysis, +) (id string, err error) { + err = exec.QueryRowContext(ctx, ` +UPDATE chart_analysis ca +SET processed = true, + action = $2, + price = $3, + reason = $4, + confidence = $5 +FROM ratings r +WHERE r.id = $1 + AND r.chart_analysis = ca.id +RETURNING ca.id;`, + ratingId, + analysis.Action, + analysis.Price.Display(), + analysis.Reason, + analysis.Confidence, + ).Scan(&id) + return id, err +} + +type Stock struct { + Symbol string + Name string + IBDUrl string +} + +type StockInfo struct { + Symbol string + Name string + ChartAnalysis string + Ratings Ratings + Price *money.Money +} + +type Ratings struct { + Composite uint8 + EPS uint8 + RelStr uint8 + GroupRelStr LetterRating + SMR LetterRating + AccDis LetterRating +} + +type LetterRating pb.LetterGrade + +func (r LetterRating) String() string { + switch pb.LetterGrade(r) { + case pb.LetterGrade_LETTER_GRADE_E: + return "E" + case pb.LetterGrade_LETTER_GRADE_E_PLUS: + return "E+" + case pb.LetterGrade_LETTER_GRADE_D_MINUS: + return "D-" + case pb.LetterGrade_LETTER_GRADE_D: + return "D" + case pb.LetterGrade_LETTER_GRADE_D_PLUS: + return "D+" + case pb.LetterGrade_LETTER_GRADE_C_MINUS: + return "C-" + case pb.LetterGrade_LETTER_GRADE_C: + return "C" + case pb.LetterGrade_LETTER_GRADE_C_PLUS: + return "C+" + case pb.LetterGrade_LETTER_GRADE_B_MINUS: + return "B-" + case pb.LetterGrade_LETTER_GRADE_B: + return "B" + case pb.LetterGrade_LETTER_GRADE_B_PLUS: + return "B+" + case pb.LetterGrade_LETTER_GRADE_A_MINUS: + return "A-" + case pb.LetterGrade_LETTER_GRADE_A: + return "A" + case pb.LetterGrade_LETTER_GRADE_A_PLUS: + return "A+" + default: + return "NA" + } +} + +func LetterRatingFromString(str string) LetterRating { + switch str { + case "E": + return LetterRating(pb.LetterGrade_LETTER_GRADE_E) + case "E+": + return LetterRating(pb.LetterGrade_LETTER_GRADE_E_PLUS) + case "D-": + return LetterRating(pb.LetterGrade_LETTER_GRADE_D_MINUS) + case "D": + return LetterRating(pb.LetterGrade_LETTER_GRADE_D) + case "D+": + return LetterRating(pb.LetterGrade_LETTER_GRADE_D_PLUS) + case "C-": + return LetterRating(pb.LetterGrade_LETTER_GRADE_C_MINUS) + case "C": + return LetterRating(pb.LetterGrade_LETTER_GRADE_C) + case "C+": + return LetterRating(pb.LetterGrade_LETTER_GRADE_C_PLUS) + case "B-": + return LetterRating(pb.LetterGrade_LETTER_GRADE_B_MINUS) + case "B": + return LetterRating(pb.LetterGrade_LETTER_GRADE_B) + case "B+": + return LetterRating(pb.LetterGrade_LETTER_GRADE_B_PLUS) + case "A-": + return LetterRating(pb.LetterGrade_LETTER_GRADE_A_MINUS) + case "A": + return LetterRating(pb.LetterGrade_LETTER_GRADE_A) + case "A+": + return LetterRating(pb.LetterGrade_LETTER_GRADE_A_PLUS) + case "NA": + fallthrough + default: + return LetterRating(pb.LetterGrade_LETTER_GRADE_UNSPECIFIED) + } +} + +func (r LetterRating) Value() (driver.Value, error) { + return r.String(), nil +} + +func (r *LetterRating) Scan(src any) error { + var source string + switch v := src.(type) { + case string: + source = v + case []byte: + source = string(v) + default: + return errors.New("incompatible type for LetterRating") + } + *r = LetterRatingFromString(source) + return nil +} diff --git a/backend/internal/database/users.go b/backend/internal/database/users.go new file mode 100644 index 0000000..f7998fb --- /dev/null +++ b/backend/internal/database/users.go @@ -0,0 +1,151 @@ +package database + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/ansg191/ibd-trader-backend/internal/keys" +) + +var ErrUserNotFound = fmt.Errorf("user not found") +var ErrIBDCredsNotFound = fmt.Errorf("ibd creds not found") + +func AddUser(ctx context.Context, exec Executor, subject string) (err error) { + _, err = exec.ExecContext(ctx, ` +INSERT INTO users (subject) +VALUES ($1) +ON CONFLICT DO NOTHING;`, subject) + return +} + +func GetUser(ctx context.Context, exec Executor, subject string) (*User, error) { + row := exec.QueryRowContext(ctx, ` +SELECT subject, ibd_username, ibd_password, encryption_key +FROM users +WHERE subject = $1;`, subject) + + user := &User{} + err := row.Scan(&user.Subject, &user.IBDUsername, &user.EncryptedIBDPassword, &user.EncryptionKeyID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrUserNotFound + } + return nil, fmt.Errorf("unable to scan sql row into user: %w", err) + } + + return user, nil +} + +func ListUsers(ctx context.Context, exec Executor, hasIBDCreds bool) ([]User, error) { + rows, err := exec.QueryContext(ctx, ` +SELECT subject, ibd_username, ibd_password, encryption_key +FROM users; +`) + if err != nil { + return nil, fmt.Errorf("unable to list users: %w", err) + } + + users := make([]User, 0) + for rows.Next() { + user := User{} + err = rows.Scan(&user.Subject, &user.IBDUsername, &user.EncryptedIBDPassword, &user.EncryptionKeyID) + if err != nil { + return nil, fmt.Errorf("unable to scan sql row into user: %w", err) + } + + if hasIBDCreds && user.IBDUsername == nil { + continue + } + users = append(users, user) + } + + return users, nil +} + +func AddIBDCreds( + ctx context.Context, + exec TransactionExecutor, + kms keys.KeyManagementService, + keyName, subject, username, password string, +) error { + encryptedPass, encryptedKey, err := keys.Encrypt(ctx, kms, keyName, []byte(password)) + if err != nil { + return fmt.Errorf("unable to encrypt password: %w", err) + } + + tx, err := exec.BeginTx(ctx, nil) + if err != nil { + return err + } + defer func(tx *sql.Tx) { + _ = tx.Rollback() + }(tx) + + var keyId int + err = tx.QueryRowContext(ctx, ` +INSERT INTO keys (kms_key_name, encrypted_key) +VALUES ($1, $2) +RETURNING id;`, keyName, encryptedKey).Scan(&keyId) + if err != nil { + return fmt.Errorf("unable to add ibd creds key: %w", err) + } + + _, err = exec.ExecContext(ctx, ` +UPDATE users +SET ibd_username = $2, + ibd_password = $3, + encryption_key = $4 +WHERE subject = $1;`, subject, username, encryptedPass, keyId) + if err != nil { + return fmt.Errorf("unable to add ibd creds to user: %w", err) + } + + if err = tx.Commit(); err != nil { + return fmt.Errorf("unable to commit transaction: %w", err) + } + + return nil +} + +func GetIBDCreds( + ctx context.Context, + exec Executor, + kms keys.KeyManagementService, + subject string, +) ( + username string, + password string, + err error, +) { + row := exec.QueryRowContext(ctx, ` +SELECT ibd_username, ibd_password, encrypted_key, kms_key_name +FROM users +INNER JOIN public.keys k on k.id = users.encryption_key +WHERE subject = $1;`, subject) + + var encryptedPass, encryptedKey []byte + var keyName string + err = row.Scan(&username, &encryptedPass, &encryptedKey, &keyName) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return "", "", ErrIBDCredsNotFound + } + return "", "", fmt.Errorf("unable to scan sql row into ibd creds: %w", err) + } + + passwordBytes, err := keys.Decrypt(ctx, kms, keyName, encryptedPass, encryptedKey) + if err != nil { + return "", "", fmt.Errorf("unable to decrypt password: %w", err) + } + + return username, string(passwordBytes), nil +} + +type User struct { + Subject string + IBDUsername *string + EncryptedIBDPassword *string + EncryptionKeyID *int +} diff --git a/backend/internal/ibd/auth.go b/backend/internal/ibd/auth.go new file mode 100644 index 0000000..7b82057 --- /dev/null +++ b/backend/internal/ibd/auth.go @@ -0,0 +1,333 @@ +package ibd + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + + "github.com/ansg191/ibd-trader-backend/internal/ibd/transport" + "golang.org/x/net/html" +) + +const ( + signInUrl = "https://myibd.investors.com/secure/signin.aspx?eurl=https%3A%2F%2Fwww.investors.com" + authenticateUrl = "https://sso.accounts.dowjones.com/authenticate" + postAuthUrl = "https://sso.accounts.dowjones.com/postauth/handler" + cookieName = ".ASPXAUTH" +) + +var ErrAuthCookieNotFound = errors.New("cookie not found") +var ErrBadCredentials = errors.New("bad credentials") + +func (c *Client) Authenticate( + ctx context.Context, + username, + password string, +) (*http.Cookie, error) { + cfg, err := c.getLoginPage(ctx) + if err != nil { + return nil, err + } + + token, params, err := c.sendAuthRequest(ctx, cfg, username, password) + if err != nil { + return nil, err + } + + return c.sendPostAuth(ctx, token, params) +} + +func (c *Client) getLoginPage(ctx context.Context) (*authConfig, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, signInUrl, nil) + if err != nil { + return nil, err + } + + resp, err := c.Do(req, withRequiredProps(transport.PropertiesReliable)) + if err != nil { + return nil, err + } + defer func(Body io.ReadCloser) { + _ = Body.Close() + }(resp.Body) + + if resp.StatusCode != http.StatusOK { + content, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return nil, fmt.Errorf( + "unexpected status code %d: %s", + resp.StatusCode, + string(content), + ) + } + + node, err := html.Parse(resp.Body) + if err != nil { + return nil, err + } + + cfg, err := extractAuthConfig(node) + if err != nil { + return nil, fmt.Errorf("failed to extract auth config: %w", err) + } + + return cfg, nil +} + +func (c *Client) sendAuthRequest(ctx context.Context, cfg *authConfig, username, password string) (string, string, error) { + body := authRequestBody{ + ClientId: cfg.ClientID, + RedirectUri: cfg.CallbackURL, + Tenant: "sso", + ResponseType: cfg.ExtraParams.ResponseType, + Username: username, + Password: password, + Scope: cfg.ExtraParams.Scope, + State: cfg.ExtraParams.State, + Headers: struct { + XRemoteUser string `json:"x-_remote-_user"` + }(struct{ XRemoteUser string }{ + XRemoteUser: username, + }), + XOidcProvider: "localop", + Protocol: cfg.ExtraParams.Protocol, + Nonce: cfg.ExtraParams.Nonce, + UiLocales: cfg.ExtraParams.UiLocales, + Csrf: cfg.ExtraParams.Csrf, + Intstate: cfg.ExtraParams.Intstate, + Connection: "DJldap", + } + bodyJson, err := json.Marshal(body) + if err != nil { + return "", "", fmt.Errorf("failed to marshal auth request body: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, authenticateUrl, bytes.NewReader(bodyJson)) + if err != nil { + return "", "", err + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Auth0-Client", "eyJuYW1lIjoiYXV0aDAuanMtdWxwIiwidmVyc2lvbiI6IjkuMjQuMSJ9") + + resp, err := c.Do(req, + withRequiredProps(transport.PropertiesReliable), + withExpectedStatuses(http.StatusOK, http.StatusUnauthorized)) + if err != nil { + return "", "", err + } + defer func(Body io.ReadCloser) { + _ = Body.Close() + }(resp.Body) + + if resp.StatusCode == http.StatusUnauthorized { + return "", "", ErrBadCredentials + } else if resp.StatusCode != http.StatusOK { + content, err := io.ReadAll(resp.Body) + if err != nil { + return "", "", fmt.Errorf("failed to read response body: %w", err) + } + return "", "", fmt.Errorf( + "unexpected status code %d: %s", + resp.StatusCode, + string(content), + ) + } + + node, err := html.Parse(resp.Body) + if err != nil { + return "", "", err + } + + return extractTokenParams(node) +} + +func (c *Client) sendPostAuth(ctx context.Context, token, params string) (*http.Cookie, error) { + body := fmt.Sprintf("token=%s¶ms=%s", token, params) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, postAuthUrl, strings.NewReader(body)) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := c.Do(req, withRequiredProps(transport.PropertiesReliable)) + if err != nil { + return nil, err + } + defer func(Body io.ReadCloser) { + _ = Body.Close() + }(resp.Body) + + if resp.StatusCode != http.StatusOK { + content, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return nil, fmt.Errorf( + "unexpected status code %d: %s", + resp.StatusCode, + string(content), + ) + } + + // Extract cookie + for _, cookie := range resp.Cookies() { + if cookie.Name == cookieName { + return cookie, nil + } + } + + return nil, ErrAuthCookieNotFound +} + +func extractAuthConfig(node *html.Node) (*authConfig, error) { + // Find `root` element + root := findId(node, "root") + if root == nil { + return nil, fmt.Errorf("root element not found") + } + + // Get adjacent script element + var script *html.Node + for s := root.NextSibling; s != nil; s = s.NextSibling { + if s.Type == html.ElementNode && s.Data == "script" { + script = s + break + } + } + + if script == nil { + return nil, fmt.Errorf("script element not found") + } + + // Get script content + content := extractText(script) + + // Find `AUTH_CONFIG` variable + const authConfigVar = "const AUTH_CONFIG = '" + i := strings.Index(content, authConfigVar) + if i == -1 { + return nil, fmt.Errorf("AUTH_CONFIG not found") + } + + // Find end of `AUTH_CONFIG` variable + j := strings.Index(content[i+len(authConfigVar):], "'") + + // Extract `AUTH_CONFIG` value + authConfigJSONB64 := content[i+len(authConfigVar) : i+len(authConfigVar)+j] + + // Decode `AUTH_CONFIG` value + authConfigJSON, err := base64.StdEncoding.DecodeString(authConfigJSONB64) + if err != nil { + return nil, fmt.Errorf("failed to decode AUTH_CONFIG: %w", err) + } + + // Unmarshal `AUTH_CONFIG` value + var cfg authConfig + if err = json.Unmarshal(authConfigJSON, &cfg); err != nil { + return nil, fmt.Errorf("failed to unmarshal AUTH_CONFIG: %w", err) + } + + return &cfg, nil +} + +type authConfig struct { + Auth0Domain string `json:"auth0Domain"` + CallbackURL string `json:"callbackURL"` + ClientID string `json:"clientID"` + ExtraParams struct { + Protocol string `json:"protocol"` + Scope string `json:"scope"` + ResponseType string `json:"response_type"` + Nonce string `json:"nonce"` + UiLocales string `json:"ui_locales"` + Csrf string `json:"_csrf"` + Intstate string `json:"_intstate"` + State string `json:"state"` + } `json:"extraParams"` + InternalOptions struct { + ResponseType string `json:"response_type"` + ClientId string `json:"client_id"` + Scope string `json:"scope"` + RedirectUri string `json:"redirect_uri"` + UiLocales string `json:"ui_locales"` + Eurl string `json:"eurl"` + Nonce string `json:"nonce"` + State string `json:"state"` + Resource string `json:"resource"` + Protocol string `json:"protocol"` + Client string `json:"client"` + } `json:"internalOptions"` + IsThirdPartyClient bool `json:"isThirdPartyClient"` + AuthorizationServer struct { + Url string `json:"url"` + Issuer string `json:"issuer"` + } `json:"authorizationServer"` +} + +func extractTokenParams(node *html.Node) (token string, params string, err error) { + inputs := findChildrenRecursive(node, func(node *html.Node) bool { + return node.Type == html.ElementNode && node.Data == "input" + }) + + var tokenNode, paramsNode *html.Node + for _, input := range inputs { + for _, attr := range input.Attr { + if attr.Key == "name" && attr.Val == "token" { + tokenNode = input + } else if attr.Key == "name" && attr.Val == "params" { + paramsNode = input + } + } + } + + if tokenNode == nil { + return "", "", fmt.Errorf("token input not found") + } + if paramsNode == nil { + return "", "", fmt.Errorf("params input not found") + } + + for _, attr := range tokenNode.Attr { + if attr.Key == "value" { + token = attr.Val + } + } + for _, attr := range paramsNode.Attr { + if attr.Key == "value" { + params = attr.Val + } + } + + return +} + +type authRequestBody struct { + ClientId string `json:"client_id"` + RedirectUri string `json:"redirect_uri"` + Tenant string `json:"tenant"` + ResponseType string `json:"response_type"` + Username string `json:"username"` + Password string `json:"password"` + Scope string `json:"scope"` + State string `json:"state"` + Headers struct { + XRemoteUser string `json:"x-_remote-_user"` + } `json:"headers"` + XOidcProvider string `json:"x-_oidc-_provider"` + Protocol string `json:"protocol"` + Nonce string `json:"nonce"` + UiLocales string `json:"ui_locales"` + Csrf string `json:"_csrf"` + Intstate string `json:"_intstate"` + Connection string `json:"connection"` +} diff --git a/backend/internal/ibd/auth_test.go b/backend/internal/ibd/auth_test.go new file mode 100644 index 0000000..157b507 --- /dev/null +++ b/backend/internal/ibd/auth_test.go @@ -0,0 +1,215 @@ +package ibd + +import ( + "context" + "encoding/json" + "net/http" + "net/url" + "strings" + "testing" + "time" + + "github.com/ansg191/ibd-trader-backend/internal/ibd/transport" + "github.com/jarcoal/httpmock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/net/html" +) + +const extractAuthHtml = ` +<!doctype html> +<html lang="en"> + <head> + <title>Log in · Dow Jones</title> + <meta charset="UTF-8"/> + <meta name="theme-color" content="white"/> + <meta name="viewport" content="width=device-width,initial-scale=1"/> + <meta name="description" content="Dow Jones One Identity Login page"/> + <link rel="apple-touch-icon" sizes="180x180" href="/one_identity_login_pages/login/0ce1520be322adcd762319573804f56d/images/apple-touch-icon.png"/> + <link rel="icon" type="image/png" sizes="32x32" href="/one_identity_login_pages/login/0ce1520be322adcd762319573804f56d/images/favicon-32x32.png"/> + <link rel="icon" type="image/png" sizes="16x16" href="/one_identity_login_pages/login/0ce1520be322adcd762319573804f56d/images/favicon-16x16.png"/> + <link rel="icon" type="image/png" sizes="192x192" href="/one_identity_login_pages/login/0ce1520be322adcd762319573804f56d/images/android-chrome-192x192.png"/> + <link rel="icon" type="image/png" sizes="512x512" href="/one_identity_login_pages/login/0ce1520be322adcd762319573804f56d/images/android-chrome-512x512.png"/> + <link rel="prefetch" href="https://cdn.optimizely.com/js/14856860742.js"/> + <link rel="preconnect" href="//cdn.optimizely.com"/> + <link rel="preconnect" href="//logx.optimizely.com"/> + <script type="module" crossorigin src="/one_identity_login_pages/login/0ce1520be322adcd762319573804f56d/js/index.js"></script> + <link rel="modulepreload" crossorigin href="/one_identity_login_pages/login/0ce1520be322adcd762319573804f56d/js/vendor.js"> + <link rel="modulepreload" crossorigin href="/one_identity_login_pages/login/0ce1520be322adcd762319573804f56d/js/auth.js"> + <link rel="modulepreload" crossorigin href="/one_identity_login_pages/login/0ce1520be322adcd762319573804f56d/js/router.js"> + <link rel="stylesheet" crossorigin href="/one_identity_login_pages/login/0ce1520be322adcd762319573804f56d/assets/styles.css"> + </head> + <body> + <div id="root" vaul-drawer-wrapper="" class="root-container"></div> + <script> + const AUTH_CONFIG = 'eyJhdXRoMERvbWFpbiI6InNzby5hY2NvdW50cy5kb3dqb25lcy5jb20iLCJjYWxsYmFja1VSTCI6Imh0dHBzOi8vbXlpYmQuaW52ZXN0b3JzLmNvbS9vaWRjL2NhbGxiYWNrIiwiY2xpZW50SUQiOiJHU1UxcEcyQnJnZDNQdjJLQm5BWjI0enZ5NXVXU0NRbiIsImV4dHJhUGFyYW1zIjp7InByb3RvY29sIjoib2F1dGgyIiwic2NvcGUiOiJvcGVuaWQgaWRwX2lkIHJvbGVzIGVtYWlsIGdpdmVuX25hbWUgZmFtaWx5X25hbWUgdXVpZCBkalVzZXJuYW1lIGRqU3RhdHVzIHRyYWNraWQgdGFncyBwcnRzIHVwZGF0ZWRfYXQgY3JlYXRlZF9hdCBvZmZsaW5lX2FjY2VzcyBkamlkIiwicmVzcG9uc2VfdHlwZSI6ImNvZGUiLCJub25jZSI6IjY0MDJmYWJiLTFiNzUtNGEyYy1hODRmLTExYWQ2MWFhZGI2YiIsInVpX2xvY2FsZXMiOiJlbi11cy14LWliZC0yMy03IiwiX2NzcmYiOiJOZFVSZ3dPQ3VYRU5URXFDcDhNV25tcGtxd3lva2JjU2E2VV9fLTVib3lWc1NzQVNWTkhLU0EiLCJfaW50c3RhdGUiOiJkZXByZWNhdGVkIiwic3RhdGUiOiJlYXJjN3E2UnE2a3lHS3h5LlltbGlxOU4xRXZvU1V0ejhDVjhuMFZBYzZWc1V4RElSTTRTcmxtSWJXMmsifSwiaW50ZXJuYWxPcHRpb25zIjp7InJlc3BvbnNlX3R5cGUiOiJjb2RlIiwiY2xpZW50X2lkIjoiR1NVMXBHMkJyZ2QzUHYyS0JuQVoyNHp2eTV1V1NDUW4iLCJzY29wZSI6Im9wZW5pZCBpZHBfaWQgcm9sZXMgZW1haWwgZ2l2ZW5fbmFtZSBmYW1pbHlfbmFtZSB1dWlkIGRqVXNlcm5hbWUgZGpTdGF0dXMgdHJhY2tpZCB0YWdzIHBydHMgdXBkYXRlZF9hdCBjcmVhdGVkX2F0IG9mZmxpbmVfYWNjZXNzIGRqaWQiLCJyZWRpcmVjdF91cmkiOiJodHRwczovL215aWJkLmludmVzdG9ycy5jb20vb2lkYy9jYWxsYmFjayIsInVpX2xvY2FsZXMiOiJlbi11cy14LWliZC0yMy03IiwiZXVybCI6Imh0dHBzOi8vd3d3LmludmVzdG9ycy5jb20iLCJub25jZSI6IjY0MDJmYWJiLTFiNzUtNGEyYy1hODRmLTExYWQ2MWFhZGI2YiIsInN0YXRlIjoiZWFyYzdxNlJxNmt5R0t4eS5ZbWxpcTlOMUV2b1NVdHo4Q1Y4bjBWQWM2VnNVeERJUk00U3JsbUliVzJrIiwicmVzb3VyY2UiOiJodHRwcyUzQSUyRiUyRnd3dy5pbnZlc3RvcnMuY29tIiwicHJvdG9jb2wiOiJvYXV0aDIiLCJjbGllbnQiOiJHU1UxcEcyQnJnZDNQdjJLQm5BWjI0enZ5NXVXU0NRbiJ9LCJpc1RoaXJkUGFydHlDbGllbnQiOmZhbHNlLCJhdXRob3JpemF0aW9uU2VydmVyIjp7InVybCI6Imh0dHBzOi8vc3NvLmFjY291bnRzLmRvd2pvbmVzLmNvbSIsImlzc3VlciI6Imh0dHBzOi8vc3NvLmFjY291bnRzLmRvd2pvbmVzLmNvbS8ifX0=' + const ENV_CONFIG = 'production' + + window.sessionStorage.setItem('auth-config', AUTH_CONFIG) + window.sessionStorage.setItem('env-config', ENV_CONFIG) + </script> + <script src="https://cdn.optimizely.com/js/14856860742.js" crossorigin="anonymous"></script> + <script type="text/javascript" src="https://dcdd29eaa743c493e732-7dc0216bc6cc2f4ed239035dfc17235b.ssl.cf3.rackcdn.com/tags/wsj/hokbottom.js"></script> + <script type="text/javascript" src="/R8As7u5b/init.js"></script> + </body> +</html> +` + +func Test_extractAuthConfig(t *testing.T) { + t.Parallel() + expectedJSON := ` +{ + "auth0Domain": "sso.accounts.dowjones.com", + "callbackURL": "https://myibd.investors.com/oidc/callback", + "clientID": "GSU1pG2Brgd3Pv2KBnAZ24zvy5uWSCQn", + "extraParams": { + "protocol": "oauth2", + "scope": "openid idp_id roles email given_name family_name uuid djUsername djStatus trackid tags prts updated_at created_at offline_access djid", + "response_type": "code", + "nonce": "6402fabb-1b75-4a2c-a84f-11ad61aadb6b", + "ui_locales": "en-us-x-ibd-23-7", + "_csrf": "NdURgwOCuXENTEqCp8MWnmpkqwyokbcSa6U__-5boyVsSsASVNHKSA", + "_intstate": "deprecated", + "state": "earc7q6Rq6kyGKxy.Ymliq9N1EvoSUtz8CV8n0VAc6VsUxDIRM4SrlmIbW2k" + }, + "internalOptions": { + "response_type": "code", + "client_id": "GSU1pG2Brgd3Pv2KBnAZ24zvy5uWSCQn", + "scope": "openid idp_id roles email given_name family_name uuid djUsername djStatus trackid tags prts updated_at created_at offline_access djid", + "redirect_uri": "https://myibd.investors.com/oidc/callback", + "ui_locales": "en-us-x-ibd-23-7", + "eurl": "https://www.investors.com", + "nonce": "6402fabb-1b75-4a2c-a84f-11ad61aadb6b", + "state": "earc7q6Rq6kyGKxy.Ymliq9N1EvoSUtz8CV8n0VAc6VsUxDIRM4SrlmIbW2k", + "resource": "https%3A%2F%2Fwww.investors.com", + "protocol": "oauth2", + "client": "GSU1pG2Brgd3Pv2KBnAZ24zvy5uWSCQn" + }, + "isThirdPartyClient": false, + "authorizationServer": { + "url": "https://sso.accounts.dowjones.com", + "issuer": "https://sso.accounts.dowjones.com/" + } +}` + var expectedCfg authConfig + err := json.Unmarshal([]byte(expectedJSON), &expectedCfg) + require.NoError(t, err) + + node, err := html.Parse(strings.NewReader(extractAuthHtml)) + require.NoError(t, err) + + cfg, err := extractAuthConfig(node) + require.NoError(t, err) + require.NotNil(t, cfg) + + assert.Equal(t, expectedCfg, *cfg) +} + +const extractTokenParamsHtml = ` +<form method="post" name="hiddenform" action="https://sso.accounts.dowjones.com/postauth/handler"> + <input type="hidden" name="token" value="eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJkalVzZXJuYW1lIjoiYW5zZzE5MUB5YWhvby5jb20iLCJpZCI6IjAxZWFmNTE5LTA0OWItNGIyOS04ZjZhLWQyNjIyZjNiMWJjNiIsImdpdmVuX25hbWUiOiJBbnNodWwiLCJmYW1pbHlfbmFtZSI6Ikd1cHRhIiwibmFtZSI6IkFuc2h1bCBHdXB0YSIsImVtYWlsIjoiYW5zZzE5MUB5YWhvby5jb20iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYWNjb3VudF9pZCI6Ijk5NzI5Mzc0NDIxMiIsImRqaWQiOiIwMWVhZjUxOS0wNDliLTRiMjktOGY2YS1kMjYyMmYzYjFiYzYiLCJ0cmFja2lkIjoiMWM0NGQyMTRmM2VlYTZiMzcyNDYxNDc3NDc0NDMyODJmMTRmY2ZjYmI4NmE4NmVjYTI0MDc2ZDVlMzU4ZmUzZCIsInVwZGF0ZWRfYXQiOjE3MTI3OTQxNTYsImNyZWF0ZWRfYXQiOjE3MTI3OTQxNTYsInVhdCI6MTcyMjU1MjMzOSwicm9sZXMiOlsiQkFSUk9OUy1DSEFOR0VQQVNTV09SRCIsIkZSRUVSRUctQkFTRSIsIkZSRUVSRUctSU5ESVZJRFVBTCIsIldTSi1DSEFOR0VQQVNTV09SRCIsIldTSi1BUkNISVZFIiwiV1NKLVNFTEZTRVJWIiwiSUJELUlORElWSURVQUwiLCJJQkQtSUNBIiwiSUJELUFFSSJdLCJkalN0YXR1cyI6WyJJQkRfVVNFUlMiXSwicHJ0cyI6IjIwMjQwNDEwMTcwOTE2LTA0MDAiLCJjcmVhdGVUaW1lc3RhbXAiOiIyMDI0MDQxMTAwMDkxNloiLCJzdXVpZCI6Ik1ERmxZV1kxTVRrdE1EUTVZaTAwWWpJNUxUaG1ObUV0WkRJMk1qSm1NMkl4WW1NMi50S09fM014VkVReks3dE5qTkdxUXNZMlBNbXp5cUxGRkxySnBrZGhrcDZrIiwic3ViIjoiMDFlYWY1MTktMDQ5Yi00YjI5LThmNmEtZDI2MjJmM2IxYmM2IiwiYXVkIjoiR1NVMXBHMkJyZ2QzUHYyS0JuQVoyNHp2eTV1V1NDUW4iLCJpc3MiOiJodHRwczovL3Nzby5hY2NvdW50cy5kb3dqb25lcy5jb20vIiwiaWF0IjoxNzIyNTUyMzM5MTI0LCJleHAiOjE3MjI1NTI3NzExMjR9.HVn33IFttQrG1JKEV2oElIy3mm8TJ-3GpV_jqZE81_cY22z4IMWPz7zUGz0WgOoUuQGyrYXiaNrfxD6GaoimRL6wxrH0Fy5iYC3dOEdlGfldswfgEOwSiZkBJRc2wWTVQLm93EeJ5ZZyKIXGY_ZkwcYfhrwaTAz8McBBnRmZkm0eiNJQ5YK-QZL-yFa3DxMdPPW91jLA2rjOIVnJ-I_0nMwaJ4ZwXHG2Sw4aAXxtbFqIqarKwIdOUSpRFOCSYpeWcxmbliurKlP1djrKrYgYSZxsKOHZhnbikZDtoDCAlPRlfbKOO4u36KXooDYGJ6p__s2kGCLOLLkP_QLHMNU8Jg"> + <input type="hidden" name="params" value="%7B%22response_type%22%3A%22code%22%2C%22client_id%22%3A%22GSU1pG2Brgd3Pv2KBnAZ24zvy5uWSCQn%22%2C%22redirect_uri%22%3A%22https%3A%2F%2Fmyibd.investors.com%2Foidc%2Fcallback%22%2C%22state%22%3A%22J-ihUYZIYzey682D.aOLszineC9qjPkM6Y6wWgFC61ABYBiuK9u48AHTFS5I%22%2C%22scope%22%3A%22openid%20idp_id%20roles%20email%20given_name%20family_name%20uuid%20djUsername%20djStatus%20trackid%20tags%20prts%20updated_at%20created_at%20offline_access%20djid%22%2C%22nonce%22%3A%22457bb517-f490-43b6-a55f-d93f90d698ad%22%7D"> + <noscript> + <p>Script is disabled. Click Submit to continue.</p> + <input type="submit" value="Submit"> + </noscript> +</form> +` +const extractTokenExpectedToken = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJkalVzZXJuYW1lIjoiYW5zZzE5MUB5YWhvby5jb20iLCJpZCI6IjAxZWFmNTE5LTA0OWItNGIyOS04ZjZhLWQyNjIyZjNiMWJjNiIsImdpdmVuX25hbWUiOiJBbnNodWwiLCJmYW1pbHlfbmFtZSI6Ikd1cHRhIiwibmFtZSI6IkFuc2h1bCBHdXB0YSIsImVtYWlsIjoiYW5zZzE5MUB5YWhvby5jb20iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYWNjb3VudF9pZCI6Ijk5NzI5Mzc0NDIxMiIsImRqaWQiOiIwMWVhZjUxOS0wNDliLTRiMjktOGY2YS1kMjYyMmYzYjFiYzYiLCJ0cmFja2lkIjoiMWM0NGQyMTRmM2VlYTZiMzcyNDYxNDc3NDc0NDMyODJmMTRmY2ZjYmI4NmE4NmVjYTI0MDc2ZDVlMzU4ZmUzZCIsInVwZGF0ZWRfYXQiOjE3MTI3OTQxNTYsImNyZWF0ZWRfYXQiOjE3MTI3OTQxNTYsInVhdCI6MTcyMjU1MjMzOSwicm9sZXMiOlsiQkFSUk9OUy1DSEFOR0VQQVNTV09SRCIsIkZSRUVSRUctQkFTRSIsIkZSRUVSRUctSU5ESVZJRFVBTCIsIldTSi1DSEFOR0VQQVNTV09SRCIsIldTSi1BUkNISVZFIiwiV1NKLVNFTEZTRVJWIiwiSUJELUlORElWSURVQUwiLCJJQkQtSUNBIiwiSUJELUFFSSJdLCJkalN0YXR1cyI6WyJJQkRfVVNFUlMiXSwicHJ0cyI6IjIwMjQwNDEwMTcwOTE2LTA0MDAiLCJjcmVhdGVUaW1lc3RhbXAiOiIyMDI0MDQxMTAwMDkxNloiLCJzdXVpZCI6Ik1ERmxZV1kxTVRrdE1EUTVZaTAwWWpJNUxUaG1ObUV0WkRJMk1qSm1NMkl4WW1NMi50S09fM014VkVReks3dE5qTkdxUXNZMlBNbXp5cUxGRkxySnBrZGhrcDZrIiwic3ViIjoiMDFlYWY1MTktMDQ5Yi00YjI5LThmNmEtZDI2MjJmM2IxYmM2IiwiYXVkIjoiR1NVMXBHMkJyZ2QzUHYyS0JuQVoyNHp2eTV1V1NDUW4iLCJpc3MiOiJodHRwczovL3Nzby5hY2NvdW50cy5kb3dqb25lcy5jb20vIiwiaWF0IjoxNzIyNTUyMzM5MTI0LCJleHAiOjE3MjI1NTI3NzExMjR9.HVn33IFttQrG1JKEV2oElIy3mm8TJ-3GpV_jqZE81_cY22z4IMWPz7zUGz0WgOoUuQGyrYXiaNrfxD6GaoimRL6wxrH0Fy5iYC3dOEdlGfldswfgEOwSiZkBJRc2wWTVQLm93EeJ5ZZyKIXGY_ZkwcYfhrwaTAz8McBBnRmZkm0eiNJQ5YK-QZL-yFa3DxMdPPW91jLA2rjOIVnJ-I_0nMwaJ4ZwXHG2Sw4aAXxtbFqIqarKwIdOUSpRFOCSYpeWcxmbliurKlP1djrKrYgYSZxsKOHZhnbikZDtoDCAlPRlfbKOO4u36KXooDYGJ6p__s2kGCLOLLkP_QLHMNU8Jg" +const extractTokenExpectedParams = "%7B%22response_type%22%3A%22code%22%2C%22client_id%22%3A%22GSU1pG2Brgd3Pv2KBnAZ24zvy5uWSCQn%22%2C%22redirect_uri%22%3A%22https%3A%2F%2Fmyibd.investors.com%2Foidc%2Fcallback%22%2C%22state%22%3A%22J-ihUYZIYzey682D.aOLszineC9qjPkM6Y6wWgFC61ABYBiuK9u48AHTFS5I%22%2C%22scope%22%3A%22openid%20idp_id%20roles%20email%20given_name%20family_name%20uuid%20djUsername%20djStatus%20trackid%20tags%20prts%20updated_at%20created_at%20offline_access%20djid%22%2C%22nonce%22%3A%22457bb517-f490-43b6-a55f-d93f90d698ad%22%7D" + +func Test_extractTokenParams(t *testing.T) { + t.Parallel() + + node, err := html.Parse(strings.NewReader(extractTokenParamsHtml)) + require.NoError(t, err) + + token, params, err := extractTokenParams(node) + require.NoError(t, err) + assert.Equal(t, extractTokenExpectedToken, token) + assert.Equal(t, extractTokenExpectedParams, params) +} + +func TestClient_Authenticate(t *testing.T) { + t.Parallel() + + expectedVal := "test-cookie" + expectedExp := time.Now().Add(time.Hour).Round(time.Second).In(time.UTC) + + tp := httpmock.NewMockTransport() + tp.RegisterResponder("GET", signInUrl, + httpmock.NewStringResponder(http.StatusOK, extractAuthHtml)) + tp.RegisterResponder("POST", authenticateUrl, + func(request *http.Request) (*http.Response, error) { + var body authRequestBody + require.NoError(t, json.NewDecoder(request.Body).Decode(&body)) + assert.Equal(t, "abc", body.Username) + assert.Equal(t, "xyz", body.Password) + + return httpmock.NewStringResponse(http.StatusOK, extractTokenParamsHtml), nil + }) + tp.RegisterResponder("POST", postAuthUrl, + func(request *http.Request) (*http.Response, error) { + require.NoError(t, request.ParseForm()) + assert.Equal(t, extractTokenExpectedToken, request.Form.Get("token")) + + params, err := url.QueryUnescape(extractTokenExpectedParams) + require.NoError(t, err) + assert.Equal(t, params, request.Form.Get("params")) + + resp := httpmock.NewStringResponse(http.StatusOK, "OK") + cookie := &http.Cookie{Name: cookieName, Value: expectedVal, Expires: expectedExp} + resp.Header.Set("Set-Cookie", cookie.String()) + return resp, nil + }) + + client := NewClient(nil, nil, newTransport(tp)) + + cookie, err := client.Authenticate(context.Background(), "abc", "xyz") + require.NoError(t, err) + require.NotNil(t, cookie) + + assert.Equal(t, expectedVal, cookie.Value) + assert.Equal(t, expectedExp, cookie.Expires) +} + +func TestClient_Authenticate_401(t *testing.T) { + t.Parallel() + + tp := httpmock.NewMockTransport() + tp.RegisterResponder("GET", signInUrl, + httpmock.NewStringResponder(http.StatusOK, extractAuthHtml)) + tp.RegisterResponder("POST", authenticateUrl, + func(request *http.Request) (*http.Response, error) { + var body authRequestBody + require.NoError(t, json.NewDecoder(request.Body).Decode(&body)) + assert.Equal(t, "abc", body.Username) + assert.Equal(t, "xyz", body.Password) + + return httpmock.NewStringResponse(http.StatusUnauthorized, `{"name":"ValidationError","code":"ERR016","message":"Wrong username or password","description":"Wrong username or password"}`), nil + }) + + client := NewClient(nil, nil, newTransport(tp)) + + cookie, err := client.Authenticate(context.Background(), "abc", "xyz") + assert.Nil(t, cookie) + assert.ErrorIs(t, err, ErrBadCredentials) +} + +type testReliableTransport http.Client + +func newTransport(tp *httpmock.MockTransport) *testReliableTransport { + return (*testReliableTransport)(&http.Client{Transport: tp}) +} + +func (t *testReliableTransport) String() string { + return "testReliableTransport" +} + +func (t *testReliableTransport) Do(req *http.Request) (*http.Response, error) { + return (*http.Client)(t).Do(req) +} + +func (t *testReliableTransport) Properties() transport.Properties { + return transport.PropertiesFree | transport.PropertiesReliable +} diff --git a/backend/internal/ibd/check_ibd_username.go b/backend/internal/ibd/check_ibd_username.go new file mode 100644 index 0000000..b026151 --- /dev/null +++ b/backend/internal/ibd/check_ibd_username.go @@ -0,0 +1,68 @@ +package ibd + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" +) + +const ( + checkUsernameUrl = "https://sso.accounts.dowjones.com/getuser" +) + +func (c *Client) CheckIBDUsername(ctx context.Context, username string) (bool, error) { + cfg, err := c.getLoginPage(ctx) + if err != nil { + return false, err + } + + return c.checkIBDUsername(ctx, cfg, username) +} + +func (c *Client) checkIBDUsername(ctx context.Context, cfg *authConfig, username string) (bool, error) { + body := map[string]string{ + "username": username, + "csrf": cfg.ExtraParams.Csrf, + } + bodyJson, err := json.Marshal(body) + if err != nil { + return false, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, checkUsernameUrl, bytes.NewReader(bodyJson)) + if err != nil { + return false, err + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-REMOTE-USER", username) + req.Header.Set("X-REQUEST-EDITIONID", "IBD-EN_US") + req.Header.Set("X-REQUEST-SCHEME", "https") + + resp, err := c.Do(req, withExpectedStatuses(http.StatusOK, http.StatusUnauthorized)) + if err != nil { + return false, err + } + defer func(Body io.ReadCloser) { + _ = Body.Close() + }(resp.Body) + + if resp.StatusCode == http.StatusUnauthorized { + return false, nil + } else if resp.StatusCode != http.StatusOK { + contentBytes, err := io.ReadAll(resp.Body) + if err != nil { + return false, fmt.Errorf("failed to read response body: %w", err) + } + content := string(contentBytes) + return false, fmt.Errorf( + "unexpected status code %d: %s", + resp.StatusCode, + content, + ) + } + return true, nil +} diff --git a/backend/internal/ibd/client.go b/backend/internal/ibd/client.go new file mode 100644 index 0000000..c8575e3 --- /dev/null +++ b/backend/internal/ibd/client.go @@ -0,0 +1,97 @@ +package ibd + +import ( + "context" + "errors" + "log/slog" + "net/http" + "slices" + + "github.com/ansg191/ibd-trader-backend/internal/database" + "github.com/ansg191/ibd-trader-backend/internal/ibd/transport" + "github.com/ansg191/ibd-trader-backend/internal/keys" +) + +var ErrNoAvailableCookies = errors.New("no available cookies") +var ErrNoAvailableTransports = errors.New("no available transports") + +type Client struct { + transports []transport.Transport + db database.Executor + kms keys.KeyManagementService +} + +func NewClient( + db database.Executor, + kms keys.KeyManagementService, + transports ...transport.Transport, +) *Client { + return &Client{transports, db, kms} +} + +func (c *Client) getCookie(ctx context.Context, subject *string) (uint, *http.Cookie, error) { + if subject == nil { + // No subject requirement, get any cookie + cookie, err := database.GetAnyCookie(ctx, c.db, c.kms) + if err != nil { + return 0, nil, err + } + if cookie == nil { + return 0, nil, ErrNoAvailableCookies + } + + return cookie.ID, cookie.ToHTTPCookie(), nil + } + + // Get cookie by subject + cookies, err := database.GetCookies(ctx, c.db, c.kms, *subject, false) + if err != nil { + return 0, nil, err + } + + if len(cookies) == 0 { + return 0, nil, ErrNoAvailableCookies + } + + cookie := cookies[0] + + return cookie.ID, cookie.ToHTTPCookie(), nil +} + +func (c *Client) Do(req *http.Request, opts ...optionFunc) (*http.Response, error) { + o := defaultOptions + for _, opt := range opts { + opt(&o) + } + + // Sort and filter transports by properties + transports := transport.FilterTransports(c.transports, o.requiredProps) + transport.SortTransports(transports) + + for _, tp := range transports { + resp, err := tp.Do(req) + if errors.Is(err, transport.ErrUnsupportedRequest) { + // Skip unsupported transport + continue + } + if err != nil { + slog.ErrorContext(req.Context(), "transport error", + "transport", tp.String(), + "error", err, + ) + continue + } + if slices.Contains(o.expectedStatuses, resp.StatusCode) { + return resp, nil + } else { + slog.ErrorContext(req.Context(), "unexpected status code", + "transport", tp.String(), + "expected", o.expectedStatuses, + "actual", resp.StatusCode, + ) + continue + } + } + + return nil, ErrNoAvailableTransports +} diff --git a/backend/internal/ibd/client_test.go b/backend/internal/ibd/client_test.go new file mode 100644 index 0000000..2368a31 --- /dev/null +++ b/backend/internal/ibd/client_test.go @@ -0,0 +1,201 @@ +package ibd + +import ( + "context" + "database/sql" + "fmt" + "log" + "math/rand/v2" + "testing" + "time" + + "github.com/ansg191/ibd-trader-backend/internal/database" + "github.com/ansg191/ibd-trader-backend/internal/keys" + _ "github.com/lib/pq" + "github.com/ory/dockertest/v3" + "github.com/ory/dockertest/v3/docker" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var ( + db *sql.DB + maxTime = time.Date(2100, 1, 1, 0, 0, 0, 0, time.UTC) + letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") +) + +func TestMain(m *testing.M) { + pool, err := dockertest.NewPool("") + if err != nil { + log.Fatalf("Could not create pool: %s", err) + } + + err = pool.Client.Ping() + if err != nil { + log.Fatalf("Could not connect to Docker: %s", err) + } + + resource, err := pool.RunWithOptions(&dockertest.RunOptions{ + Repository: "postgres", + Tag: "16", + Env: []string{ + "POSTGRES_PASSWORD=secret", + "POSTGRES_USER=ibd-client-test", + "POSTGRES_DB=ibd-client-test", + "listen_addresses='*'", + }, + Cmd: []string{ + "postgres", + "-c", + "log_statement=all", + }, + }, func(config *docker.HostConfig) { + config.AutoRemove = true + config.RestartPolicy = docker.RestartPolicy{Name: "no"} + }) + if err != nil { + log.Fatalf("Could not start resource: %s", err) + } + + hostAndPort := resource.GetHostPort("5432/tcp") + databaseUrl := fmt.Sprintf("postgres://ibd-client-test:secret@%s/ibd-client-test?sslmode=disable", hostAndPort) + + // Kill container after 120 seconds + _ = resource.Expire(120) + + pool.MaxWait = 120 * time.Second + if err = pool.Retry(func() error { + db, err = sql.Open("postgres", databaseUrl) + if err != nil { + return err + } + return db.Ping() + }); err != nil { + log.Fatalf("Could not connect to database: %s", err) + } + + err = database.Migrate(context.Background(), databaseUrl) + if err != nil { + log.Fatalf("Could not migrate database: %s", err) + } + + defer func() { + if err := pool.Purge(resource); err != nil { + log.Fatalf("Could not purge resource: %s", err) + } + }() + + m.Run() +} + +func randStringRunes(n int) string { + b := make([]rune, n) + for i := range b { + b[i] = letterRunes[rand.IntN(len(letterRunes))] + } + return string(b) +} + +func addCookie(t *testing.T) (user, token string) { + t.Helper() + + // Randomly generate a user and token + user = randStringRunes(8) + token = randStringRunes(16) + + ciphertext, key, err := keys.Encrypt(context.Background(), new(kmsStub), "", []byte(token)) + require.NoError(t, err) + + tx, err := db.Begin() + require.NoError(t, err) + + var keyID uint + err = tx.QueryRow(` +INSERT INTO keys (kms_key_name, encrypted_key) + VALUES ('', $1) + RETURNING id; +`, key).Scan(&keyID) + require.NoError(t, err) + + _, err = tx.Exec(` +INSERT +INTO users (subject, encryption_key) +VALUES ($1, $2); +`, user, keyID) + require.NoError(t, err) + + _, err = tx.Exec(` +INSERT +INTO ibd_tokens (user_subject, token, encryption_key, expires_at) +VALUES ($1, $2, $3, $4);`, + user, + ciphertext, + keyID, + maxTime, + ) + require.NoError(t, err) + + err = tx.Commit() + require.NoError(t, err) + + return user, token +} + +func TestClient_getCookie(t *testing.T) { + t.Run("no cookies", func(t *testing.T) { + client := NewClient(db, new(kmsStub)) + + _, _, err := client.getCookie(context.Background(), nil) + assert.ErrorIs(t, err, ErrNoAvailableCookies) + }) + + t.Run("no cookies by subject", func(t *testing.T) { + client := NewClient(db, new(kmsStub)) + + subject := "test" + _, _, err := client.getCookie(context.Background(), &subject) + assert.ErrorIs(t, err, ErrNoAvailableCookies) + }) + + t.Run("get any cookie", func(t *testing.T) { + _, token := addCookie(t) + + client := NewClient(db, new(kmsStub)) + + _, cookie, err := client.getCookie(context.Background(), nil) + require.NoError(t, err) + assert.Equal(t, cookieName, cookie.Name) + assert.Equal(t, token, cookie.Value) + assert.Equal(t, "/", cookie.Path) + assert.Equal(t, maxTime, cookie.Expires) + assert.Equal(t, "investors.com", cookie.Domain) + }) + + t.Run("get cookie by subject", func(t *testing.T) { + subject, token := addCookie(t) + + client := NewClient(db, new(kmsStub)) + + _, cookie, err := client.getCookie(context.Background(), &subject) + require.NoError(t, err) + assert.Equal(t, cookieName, cookie.Name) + assert.Equal(t, token, cookie.Value) + assert.Equal(t, "/", cookie.Path) + assert.Equal(t, maxTime, cookie.Expires) + assert.Equal(t, "investors.com", cookie.Domain) + }) +} + +type kmsStub struct{} + +func (k *kmsStub) Close() error { + return nil +} + +func (k *kmsStub) Encrypt(_ context.Context, _ string, plaintext []byte) ([]byte, error) { + return plaintext, nil +} + +func (k *kmsStub) Decrypt(_ context.Context, _ string, ciphertext []byte) ([]byte, error) { + return ciphertext, nil +} diff --git a/backend/internal/ibd/html_helpers.go b/backend/internal/ibd/html_helpers.go new file mode 100644 index 0000000..0176bc5 --- /dev/null +++ b/backend/internal/ibd/html_helpers.go @@ -0,0 +1,99 @@ +package ibd + +import ( + "strings" + + "golang.org/x/net/html" +) + +func findChildren(node *html.Node, f func(node *html.Node) bool) (found []*html.Node) { + for c := node.FirstChild; c != nil; c = c.NextSibling { + if f(c) { + found = append(found, c) + } + } + return +} + +func findChildrenRecursive(node *html.Node, f func(node *html.Node) bool) (found []*html.Node) { + if f(node) { + found = append(found, node) + } + + for c := node.FirstChild; c != nil; c = c.NextSibling { + found = append(found, findChildrenRecursive(c, f)...) + } + + return +} + +func findClass(node *html.Node, className string) (found *html.Node) { + if isClass(node, className) { + return node + } + + for c := node.FirstChild; c != nil; c = c.NextSibling { + if found = findClass(c, className); found != nil { + return + } + } + + return +} + +func isClass(node *html.Node, className string) bool { + if node.Type == html.ElementNode { + for _, attr := range node.Attr { + if attr.Key != "class" { + continue + } + classes := strings.Fields(attr.Val) + for _, class := range classes { + if class == className { + return true + } + } + } + } + return false +} + +func extractText(node *html.Node) string { + var result strings.Builder + extractTextInner(node, &result) + return result.String() +} + +func extractTextInner(node *html.Node, result *strings.Builder) { + if node.Type == html.TextNode { + result.WriteString(node.Data) + } + for c := node.FirstChild; c != nil; c = c.NextSibling { + extractTextInner(c, result) + } +} + +func findId(node *html.Node, id string) (found *html.Node) { + if isId(node, id) { + return node + } + + for c := node.FirstChild; c != nil; c = c.NextSibling { + if found = findId(c, id); found != nil { + return + } + } + + return +} + +func isId(node *html.Node, id string) bool { + if node.Type == html.ElementNode { + for _, attr := range node.Attr { + if attr.Key == "id" && attr.Val == id { + return true + } + } + } + return false +} diff --git a/backend/internal/ibd/html_helpers_test.go b/backend/internal/ibd/html_helpers_test.go new file mode 100644 index 0000000..d251c39 --- /dev/null +++ b/backend/internal/ibd/html_helpers_test.go @@ -0,0 +1,79 @@ +package ibd + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/net/html" +) + +func Test_findClass(t *testing.T) { + t.Parallel() + tests := []struct { + name string + html string + className string + found bool + expData string + }{ + { + name: "class exists", + html: `<div class="foo"></div>`, + className: "foo", + found: true, + expData: "div", + }, + { + name: "class exists nested", + html: `<div class="foo"><a class="abc"></a></div>`, + className: "abc", + found: true, + expData: "a", + }, + { + name: "class exists multiple", + html: `<div class="foo"><a class="foo"></a></div>`, + className: "foo", + found: true, + expData: "div", + }, + { + name: "class missing", + html: `<div class="abc"><a class="xyz"></a></div>`, + className: "foo", + found: false, + expData: "", + }, + { + name: "class missing", + html: `<div id="foo"><a abc="xyz"></a></div>`, + className: "foo", + found: false, + expData: "", + }, + { + name: "class exists multiple save div", + html: `<div class="foo bar"></div>`, + className: "bar", + found: true, + expData: "div", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + node, err := html.Parse(strings.NewReader(tt.html)) + require.NoError(t, err) + + got := findClass(node, tt.className) + if !tt.found { + require.Nil(t, got) + return + } + require.NotNil(t, got) + assert.Equal(t, tt.expData, got.Data) + }) + } +} diff --git a/backend/internal/ibd/ibd50.go b/backend/internal/ibd/ibd50.go new file mode 100644 index 0000000..52e28aa --- /dev/null +++ b/backend/internal/ibd/ibd50.go @@ -0,0 +1,182 @@ +package ibd + +import ( + "context" + "encoding/json" + "errors" + "io" + "log/slog" + "net/http" + "net/url" + "strconv" + + "github.com/ansg191/ibd-trader-backend/internal/database" +) + +const ibd50Url = "https://research.investors.com/Services/SiteAjaxService.asmx/GetIBD50?sortcolumn1=%22ibd100rank%22&sortOrder1=%22asc%22&sortcolumn2=%22%22&sortOrder2=%22ASC%22" + +// GetIBD50 returns the IBD50 list. +func (c *Client) GetIBD50(ctx context.Context) ([]*Stock, error) { + // We cannot use the scraper here because scrapfly does not support + // Content-Type in GET requests. + req, err := http.NewRequestWithContext(ctx, http.MethodGet, ibd50Url, nil) + if err != nil { + return nil, err + } + + cookieId, cookie, err := c.getCookie(ctx, nil) + if err != nil { + return nil, err + } + req.AddCookie(cookie) + + req.Header.Add("content-type", "application/json; charset=utf-8") + // Add browser-emulating headers + req.Header.Add("accept", "*/*") + req.Header.Add("accept-language", "en-US,en;q=0.9") + req.Header.Add("newrelic", "eyJ2IjpbMCwxXSwiZCI6eyJ0eSI6IkJyb3dzZXIiLCJhYyI6IjMzOTYxMDYiLCJhcCI6IjEzODU5ODMwMDEiLCJpZCI6IjM1Zjk5NmM2MzNjYTViMWYiLCJ0ciI6IjM3ZmRhZmJlOGY2YjhmYTMwYWMzOTkzOGNlMmM0OWMxIiwidGkiOjE3MjIyNzg0NTk3MjUsInRrIjoiMTAyMjY4MSJ9fQ==") + req.Header.Add("priority", "u=1, i") + req.Header.Add("referer", "https://research.investors.com/stock-lists/ibd-50/") + req.Header.Add("sec-ch-ua", "\"Not/A)Brand\";v=\"8\", \"Chromium\";v=\"126\", \"Google Chrome\";v=\"126\"") + req.Header.Add("sec-ch-ua-mobile", "?0") + req.Header.Add("sec-ch-ua-platform", "\"macOS\"") + req.Header.Add("sec-fetch-dest", "empty") + req.Header.Add("sec-fetch-mode", "cors") + req.Header.Add("sec-fetch-site", "same-origin") + req.Header.Add("traceparent", "00-37fdafbe8f6b8fa30ac39938ce2c49c1-35f996c633ca5b1f-01") + req.Header.Add("tracestate", "1022681@nr=0-1-3396106-1385983001-35f996c633ca5b1f----1722278459725") + req.Header.Add("user-agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/126.0.0.0 Safari/537.36") + req.Header.Add("x-newrelic-id", "VwUOV1dTDhABV1FRBgQOVVUF") + req.Header.Add("x-requested-with", "XMLHttpRequest") + + resp, err := c.Do(req) + if err != nil { + return nil, err + } + defer func(Body io.ReadCloser) { + _ = Body.Close() + }(resp.Body) + + var ibd50Resp getIBD50Response + if err = json.NewDecoder(resp.Body).Decode(&ibd50Resp); err != nil { + return nil, err + } + + // If there are less than 10 stocks in the IBD50 list, it's likely that authentication failed. + if len(ibd50Resp.D.ETablesDataList) < 10 { + // Report cookie failure to DB + if err = database.ReportCookieFailure(ctx, c.db, cookieId); err != nil { + slog.Error("Failed to report cookie failure", "error", err) + } + return nil, errors.New("failed to get IBD50 list") + } + + return ibd50Resp.ToStockList(), nil +} + +type Stock struct { + Rank int64 + Symbol string + Name string + + QuoteURL *url.URL +} + +type getIBD50Response struct { + D struct { + Type *string `json:"__type"` + ETablesDataList []struct { + Rank string `json:"Rank"` + Symbol string `json:"Symbol"` + CompanyName string `json:"CompanyName"` + CompRating *string `json:"CompRating"` + EPSRank *string `json:"EPSRank"` + RelSt *string `json:"RelSt"` + GrpStr *string `json:"GrpStr"` + Smr *string `json:"Smr"` + AccDis *string `json:"AccDis"` + SponRating *string `json:"SponRating"` + Price *string `json:"Price"` + PriceClose *string `json:"PriceClose"` + PriceChange *string `json:"PriceChange"` + PricePerChange *string `json:"PricePerChange"` + VolPerChange *string `json:"VolPerChange"` + DailyVol *string `json:"DailyVol"` + WeekHigh52 *string `json:"WeekHigh52"` + PerOffHigh *string `json:"PerOffHigh"` + PERatio *string `json:"PERatio"` + DivYield *string `json:"DivYield"` + LastQtrSalesPerChg *string `json:"LastQtrSalesPerChg"` + LastQtrEpsPerChg *string `json:"LastQtrEpsPerChg"` + ConsecQtrEpsGrt15 *string `json:"ConsecQtrEpsGrt15"` + CurQtrEpsEstPerChg *string `json:"CurQtrEpsEstPerChg"` + CurYrEpsEstPerChg *string `json:"CurYrEpsEstPerChg"` + PretaxMargin *string `json:"PretaxMargin"` + ROE *string `json:"ROE"` + MgmtOwnsPer *string `json:"MgmtOwnsPer"` + QuoteUrl *string `json:"QuoteUrl"` + StockCheckupUrl *string `json:"StockCheckupUrl"` + MarketsmithUrl *string `json:"MarketsmithUrl"` + LeaderboardUrl *string `json:"LeaderboardUrl"` + ChartAnalysisUrl *string `json:"ChartAnalysisUrl"` + Ibd100NewEntryFlag *string `json:"Ibd100NewEntryFlag"` + Ibd100UpInRankFlag *string `json:"Ibd100UpInRankFlag"` + IbdBigCap20NewEntryFlag *string `json:"IbdBigCap20NewEntryFlag"` + CompDesc *string `json:"CompDesc"` + NumberFunds *string `json:"NumberFunds"` + GlobalRank *string `json:"GlobalRank"` + EPSPriorQtr *string `json:"EPSPriorQtr"` + QtrsFundIncrease *string `json:"QtrsFundIncrease"` + } `json:"ETablesDataList"` + IBD50PdfUrl *string `json:"IBD50PdfUrl"` + CAP20PdfUrl *string `json:"CAP20PdfUrl"` + IBD50Date *string `json:"IBD50Date"` + CAP20Date *string `json:"CAP20Date"` + UpdatedDate *string `json:"UpdatedDate"` + GetAllFlags *string `json:"getAllFlags"` + Flag *int `json:"flag"` + Message *string `json:"Message"` + PaywallDesktopMarkup *string `json:"PaywallDesktopMarkup"` + PaywallMobileMarkup *string `json:"PaywallMobileMarkup"` + } `json:"d"` +} + +func (r getIBD50Response) ToStockList() (ibd []*Stock) { + ibd = make([]*Stock, 0, len(r.D.ETablesDataList)) + for _, data := range r.D.ETablesDataList { + rank, err := strconv.ParseInt(data.Rank, 10, 64) + if err != nil { + slog.Error( + "Failed to parse Rank", + "error", err, + "rank", data.Rank, + "symbol", data.Symbol, + "name", data.CompanyName, + ) + continue + } + + var quoteUrl *url.URL + if data.QuoteUrl != nil { + quoteUrl, err = url.Parse(*data.QuoteUrl) + if err != nil { + slog.Error( + "Failed to parse QuoteUrl", + "error", err, + "quoteUrl", *data.QuoteUrl, + "rank", data.Rank, + "symbol", data.Symbol, + "name", data.CompanyName, + ) + } + } + + ibd = append(ibd, &Stock{ + Rank: rank, + Symbol: data.Symbol, + Name: data.CompanyName, + QuoteURL: quoteUrl, + }) + } + return +} diff --git a/backend/internal/ibd/options.go b/backend/internal/ibd/options.go new file mode 100644 index 0000000..5c378d5 --- /dev/null +++ b/backend/internal/ibd/options.go @@ -0,0 +1,26 @@ +package ibd + +import "github.com/ansg191/ibd-trader-backend/internal/ibd/transport" + +type optionFunc func(*options) + +var defaultOptions = options{ + expectedStatuses: []int{200}, +} + +type options struct { + expectedStatuses []int + requiredProps transport.Properties +} + +func withExpectedStatuses(statuses ...int) optionFunc { + return func(o *options) { + o.expectedStatuses = append(o.expectedStatuses, statuses...) + } +} + +func withRequiredProps(props transport.Properties) optionFunc { + return func(o *options) { + o.requiredProps = props + } +} diff --git a/backend/internal/ibd/search.go b/backend/internal/ibd/search.go new file mode 100644 index 0000000..341b14b --- /dev/null +++ b/backend/internal/ibd/search.go @@ -0,0 +1,111 @@ +package ibd + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "time" + + "github.com/ansg191/ibd-trader-backend/internal/database" +) + +const ( + searchUrl = "https://ibdservices.investors.com/im/api/search" +) + +var ErrSymbolNotFound = fmt.Errorf("symbol not found") + +func (c *Client) Search(ctx context.Context, symbol string) (database.Stock, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, searchUrl, nil) + if err != nil { + return database.Stock{}, err + } + + _, cookie, err := c.getCookie(ctx, nil) + if err != nil { + return database.Stock{}, err + } + req.AddCookie(cookie) + + params := url.Values{} + params.Set("key", symbol) + req.URL.RawQuery = params.Encode() + + resp, err := c.Do(req) + if err != nil { + return database.Stock{}, err + } + defer func(Body io.ReadCloser) { + _ = Body.Close() + }(resp.Body) + + if resp.StatusCode != http.StatusOK { + content, err := io.ReadAll(resp.Body) + if err != nil { + return database.Stock{}, fmt.Errorf("failed to read response body: %w", err) + } + return database.Stock{}, fmt.Errorf( + "unexpected status code %d: %s", + resp.StatusCode, + string(content), + ) + } + + var sr searchResponse + if err = json.NewDecoder(resp.Body).Decode(&sr); err != nil { + return database.Stock{}, err + } + + for _, stock := range sr.StockData { + if stock.Symbol == symbol { + return database.Stock{ + Symbol: stock.Symbol, + Name: stock.Company, + IBDUrl: stock.QuoteUrl, + }, nil + } + } + + return database.Stock{}, ErrSymbolNotFound +} + +type searchResponse struct { + Status int `json:"_status"` + Timestamp string `json:"_timestamp"` + StockData []struct { + Id int `json:"id"` + Symbol string `json:"symbol"` + Company string `json:"company"` + PriceDate string `json:"priceDate"` + Price float64 `json:"price"` + PreviousPrice float64 `json:"previousPrice"` + PriceChange float64 `json:"priceChange"` + PricePctChange float64 `json:"pricePctChange"` + Volume int `json:"volume"` + VolumeChange int `json:"volumeChange"` + VolumePctChange int `json:"volumePctChange"` + QuoteUrl string `json:"quoteUrl"` + } `json:"stockData"` + News []struct { + Title string `json:"title"` + Category string `json:"category"` + Body string `json:"body"` + ImageAlt string `json:"imageAlt"` + ImageUrl string `json:"imageUrl"` + NewsUrl string `json:"newsUrl"` + CategoryUrl string `json:"categoryUrl"` + PublishDate time.Time `json:"publishDate"` + PublishDateUnixts int `json:"publishDateUnixts"` + Stocks []struct { + Id int `json:"id"` + Index int `json:"index"` + Symbol string `json:"symbol"` + PricePctChange string `json:"pricePctChange"` + } `json:"stocks"` + VideoFormat bool `json:"videoFormat"` + } `json:"news"` + FullUrl string `json:"fullUrl"` +} diff --git a/backend/internal/ibd/search_test.go b/backend/internal/ibd/search_test.go new file mode 100644 index 0000000..05e93dc --- /dev/null +++ b/backend/internal/ibd/search_test.go @@ -0,0 +1,205 @@ +package ibd + +import ( + "context" + "net/http" + "testing" + + "github.com/ansg191/ibd-trader-backend/internal/ibd/transport" + "github.com/jarcoal/httpmock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const searchResponseJSON = ` +{ + "_status": 200, + "_timestamp": "1722879439.724106", + "stockData": [ + { + "id": 13717, + "symbol": "AAPL", + "company": "Apple", + "priceDate": "2024-08-05T09:18:00", + "price": 212.33, + "previousPrice": 219.86, + "priceChange": -7.53, + "pricePctChange": -3.42, + "volume": 643433, + "volumeChange": -2138, + "volumePctChange": 124, + "quoteUrl": "https://research.investors.com/stock-quotes/nasdaq-apple-aapl.htm" + }, + { + "id": 79964, + "symbol": "AAPU", + "company": "Direxion AAPL Bull 2X", + "priceDate": "2024-08-05T09:18:00", + "price": 32.48, + "previousPrice": 34.9, + "priceChange": -2.42, + "pricePctChange": -6.92, + "volume": 15265, + "volumeChange": -35, + "volumePctChange": 212, + "quoteUrl": "https://research.investors.com/stock-quotes/nasdaq-direxion-aapl-bull-2x-aapu.htm" + }, + { + "id": 80423, + "symbol": "APLY", + "company": "YieldMax AAPL Option Incm", + "priceDate": "2024-08-05T09:11:00", + "price": 17.52, + "previousPrice": 18.15, + "priceChange": -0.63, + "pricePctChange": -3.47, + "volume": 617, + "volumeChange": -2, + "volumePctChange": 97, + "quoteUrl": "https://research.investors.com/stock-quotes/nyse-yieldmax-aapl-option-incm-aply.htm" + }, + { + "id": 79962, + "symbol": "AAPD", + "company": "Direxion Dly AAPL Br 1X", + "priceDate": "2024-08-05T09:18:00", + "price": 18.11, + "previousPrice": 17.53, + "priceChange": 0.58, + "pricePctChange": 3.31, + "volume": 14572, + "volumeChange": -7, + "volumePctChange": 885, + "quoteUrl": "https://research.investors.com/stock-quotes/nasdaq-direxion-dly-aapl-br-1x-aapd.htm" + }, + { + "id": 79968, + "symbol": "AAPB", + "company": "GraniteSh 2x Lg AAPL", + "priceDate": "2024-08-05T09:16:00", + "price": 25.22, + "previousPrice": 27.25, + "priceChange": -2.03, + "pricePctChange": -7.45, + "volume": 2505, + "volumeChange": -7, + "volumePctChange": 151, + "quoteUrl": "https://research.investors.com/stock-quotes/nasdaq-granitesh-2x-lg-aapl-aapb.htm" + } + ], + "news": [ + { + "title": "Warren Buffett Dumped Berkshire Hathaway's Favorite Stocks — Right Before They Plunged", + "category": "News", + "body": "Berkshire Hathaway earnings rose solidly in Q2. Warren Buffett sold nearly half his Apple stock stake. Berkshire stock fell...", + "imageAlt": "", + "imageUrl": "https://www.investors.com/wp-content/uploads/2024/06/Stock-WarrenBuffettwave-01-shutt-640x360.jpg", + "newsUrl": "https://investors.com/news/berkshire-hathaway-earnings-q2-2024-warren-buffett-apple/", + "categoryUrl": "https://investors.com/category/news/", + "publishDate": "2024-08-05T15:51:57+00:00", + "publishDateUnixts": 1722858717, + "stocks": [ + { + "id": 13717, + "index": 0, + "symbol": "AAPL", + "pricePctChange": "-3.42" + } + ], + "videoFormat": false + }, + { + "title": "Nvidia Plunges On Report Of AI Chip Flaw; Is It A Buy Now?", + "category": "Research", + "body": "Nvidia will roll out its Blackwell chip at least three months later than planned.", + "imageAlt": "", + "imageUrl": "https://www.investors.com/wp-content/uploads/2024/01/Stock-Nvidia-studio-01-company-640x360.jpg", + "newsUrl": "https://investors.com/research/nvda-stock-is-nvidia-a-buy-2/", + "categoryUrl": "https://investors.com/category/research/", + "publishDate": "2024-08-05T14:59:22+00:00", + "publishDateUnixts": 1722855562, + "stocks": [ + { + "id": 38607, + "index": 0, + "symbol": "NVDA", + "pricePctChange": "-5.18" + } + ], + "videoFormat": false + }, + { + "title": "Magnificent Seven Stocks Roiled: Nvidia Plunges On AI Chip Delay; Apple, Tesla Dive", + "category": "Research", + "body": "Nvidia stock dived Monday, while Apple and Tesla also fell sharply.", + "imageAlt": "", + "imageUrl": "https://www.investors.com/wp-content/uploads/2022/08/Stock-Nvidia-RTXa5500-comp-640x360.jpg", + "newsUrl": "https://investors.com/research/magnificent-seven-stocks-to-buy-and-and-watch/", + "categoryUrl": "https://investors.com/category/research/", + "publishDate": "2024-08-05T14:51:42+00:00", + "publishDateUnixts": 1722855102, + "stocks": [ + { + "id": 13717, + "index": 0, + "symbol": "AAPL", + "pricePctChange": "-3.42" + } + ], + "videoFormat": false + } + ], + "fullUrl": "https://www.investors.com/search-results/?query=AAPL" +}` + +const emptySearchResponseJSON = ` +{ + "_status": 200, + "_timestamp": "1722879662.804395", + "stockData": [], + "news": [], + "fullUrl": "https://www.investors.com/search-results/?query=abcdefg" +}` + +func TestClient_Search(t *testing.T) { + tests := []struct { + name string + response string + f func(t *testing.T, client *Client) + }{ + { + name: "found", + response: searchResponseJSON, + f: func(t *testing.T, client *Client) { + u, err := client.Search(context.Background(), "AAPL") + require.NoError(t, err) + assert.Equal(t, "AAPL", u.Symbol) + assert.Equal(t, "Apple", u.Name) + assert.Equal(t, "https://research.investors.com/stock-quotes/nasdaq-apple-aapl.htm", u.IBDUrl) + }, + }, + { + name: "not found", + response: emptySearchResponseJSON, + f: func(t *testing.T, client *Client) { + _, err := client.Search(context.Background(), "abcdefg") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrSymbolNotFound) + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tp := httpmock.NewMockTransport() + tp.RegisterResponder("GET", searchUrl, httpmock.NewStringResponder(200, tt.response)) + + client := NewClient( + db, + new(kmsStub), + transport.NewStandardTransport(&http.Client{Transport: tp}), + ) + + tt.f(t, client) + }) + } +} diff --git a/backend/internal/ibd/stockinfo.go b/backend/internal/ibd/stockinfo.go new file mode 100644 index 0000000..1e3b96f --- /dev/null +++ b/backend/internal/ibd/stockinfo.go @@ -0,0 +1,233 @@ +package ibd + +import ( + "context" + "fmt" + "io" + "net/http" + "net/url" + "strconv" + "strings" + + "github.com/ansg191/ibd-trader-backend/internal/database" + "github.com/ansg191/ibd-trader-backend/internal/utils" + + "github.com/Rhymond/go-money" + "golang.org/x/net/html" +) + +func (c *Client) StockInfo(ctx context.Context, uri string) (*database.StockInfo, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, uri, nil) + if err != nil { + return nil, err + } + + _, cookie, err := c.getCookie(ctx, nil) + if err != nil { + return nil, err + } + req.AddCookie(cookie) + + // Set required query parameters + params := url.Values{} + params.Set("list", "ibd50") + params.Set("type", "weekly") + req.URL.RawQuery = params.Encode() + + resp, err := c.Do(req) + if err != nil { + return nil, err + } + defer func(Body io.ReadCloser) { + _ = Body.Close() + }(resp.Body) + + if resp.StatusCode != http.StatusOK { + content, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return nil, fmt.Errorf( + "unexpected status code %d: %s", + resp.StatusCode, + string(content), + ) + } + + node, err := html.Parse(resp.Body) + if err != nil { + return nil, err + } + + name, symbol, err := extractNameAndSymbol(node) + if err != nil { + return nil, fmt.Errorf("failed to extract name and symbol: %w", err) + } + chartAnalysis, err := extractChartAnalysis(node) + if err != nil { + return nil, fmt.Errorf("failed to extract chart analysis: %w", err) + } + ratings, err := extractRatings(node) + if err != nil { + return nil, fmt.Errorf("failed to extract ratings: %w", err) + } + price, err := extractPrice(node) + if err != nil { + return nil, fmt.Errorf("failed to extract price: %w", err) + } + + return &database.StockInfo{ + Symbol: symbol, + Name: name, + ChartAnalysis: chartAnalysis, + Ratings: ratings, + Price: price, + }, nil +} + +func extractNameAndSymbol(node *html.Node) (name string, symbol string, err error) { + // Find span with ID "quote-symbol" + quoteSymbolNode := findId(node, "quote-symbol") + if quoteSymbolNode == nil { + return "", "", fmt.Errorf("could not find `quote-symbol` span") + } + + // Get the text of the quote-symbol span + name = strings.TrimSpace(extractText(quoteSymbolNode)) + + // Find span with ID "qteSymb" + qteSymbNode := findId(node, "qteSymb") + if qteSymbNode == nil { + return "", "", fmt.Errorf("could not find `qteSymb` span") + } + + // Get the text of the qteSymb span + symbol = strings.TrimSpace(extractText(qteSymbNode)) + + // Get index of last closing parenthesis + lastParenIndex := strings.LastIndex(name, ")") + if lastParenIndex == -1 { + return + } + + // Find the last opening parenthesis before the closing parenthesis + lastOpenParenIndex := strings.LastIndex(name[:lastParenIndex], "(") + if lastOpenParenIndex == -1 { + return + } + + // Remove the parenthesis pair + name = strings.TrimSpace(name[:lastOpenParenIndex] + name[lastParenIndex+1:]) + return +} + +func extractPrice(node *html.Node) (*money.Money, error) { + // Find the div with the ID "lstPrice" + lstPriceNode := findId(node, "lstPrice") + if lstPriceNode == nil { + return nil, fmt.Errorf("could not find `lstPrice` div") + } + + // Get the text of the lstPrice div + priceStr := strings.TrimSpace(extractText(lstPriceNode)) + + // Parse the price + price, err := utils.ParseMoney(priceStr) + if err != nil { + return nil, fmt.Errorf("failed to parse price: %w", err) + } + + return price, nil +} + +func extractRatings(node *html.Node) (ratings database.Ratings, err error) { + // Find the div with class "smartContent" + smartSelectNode := findClass(node, "smartContent") + if smartSelectNode == nil { + return ratings, fmt.Errorf("could not find `smartContent` div") + } + + // Iterate over children, looking for "smartRating" divs + for c := smartSelectNode.FirstChild; c != nil; c = c.NextSibling { + if !isClass(c, "smartRating") { + continue + } + + err = processSmartRating(c, &ratings) + if err != nil { + return + } + } + return +} + +// processSmartRating extracts the rating from a "smartRating" div and updates the ratings struct. +// +// The node should look like this: +// +// <ul class="smartRating"> +// <li><a><span>Composite Rating</span></a></li> +// <li>94</li> +// ... +// </ul> +func processSmartRating(node *html.Node, ratings *database.Ratings) error { + // Check that the node is a ul + if node.Type != html.ElementNode || node.Data != "ul" { + return fmt.Errorf("expected ul node, got %s", node.Data) + } + + // Get all `li` children + children := findChildren(node, func(node *html.Node) bool { + return node.Type == html.ElementNode && node.Data == "li" + }) + + // Extract the rating name + ratingName := strings.TrimSpace(extractText(children[0])) + + // Extract the rating value + ratingValueStr := strings.TrimSpace(extractText(children[1])) + + switch ratingName { + case "Composite Rating": + ratingValue, err := strconv.ParseUint(ratingValueStr, 10, 8) + if err != nil { + return fmt.Errorf("failed to parse Composite Rating: %w", err) + } + ratings.Composite = uint8(ratingValue) + case "EPS Rating": + ratingValue, err := strconv.ParseUint(ratingValueStr, 10, 8) + if err != nil { + return fmt.Errorf("failed to parse EPS Rating: %w", err) + } + ratings.EPS = uint8(ratingValue) + case "RS Rating": + ratingValue, err := strconv.ParseUint(ratingValueStr, 10, 8) + if err != nil { + return fmt.Errorf("failed to parse RS Rating: %w", err) + } + ratings.RelStr = uint8(ratingValue) + case "Group RS Rating": + ratings.GroupRelStr = database.LetterRatingFromString(ratingValueStr) + case "SMR Rating": + ratings.SMR = database.LetterRatingFromString(ratingValueStr) + case "Acc/Dis Rating": + ratings.AccDis = database.LetterRatingFromString(ratingValueStr) + default: + return fmt.Errorf("unknown rating name: %s", ratingName) + } + + return nil +} + +func extractChartAnalysis(node *html.Node) (string, error) { + // Find the div with class "chartAnalysis" + chartAnalysisNode := findClass(node, "chartAnalysis") + if chartAnalysisNode == nil { + return "", fmt.Errorf("could not find `chartAnalysis` div") + } + + // Get the text of the chart analysis div + chartAnalysis := strings.TrimSpace(extractText(chartAnalysisNode)) + + return chartAnalysis, nil +} diff --git a/backend/internal/ibd/transport/scrapfly/options.go b/backend/internal/ibd/transport/scrapfly/options.go new file mode 100644 index 0000000..f16a4b0 --- /dev/null +++ b/backend/internal/ibd/transport/scrapfly/options.go @@ -0,0 +1,84 @@ +package scrapfly + +const BaseURL = "https://api.scrapfly.io/scrape" + +var defaultScrapeOptions = ScrapeOptions{ + baseURL: BaseURL, + country: nil, + asp: true, + proxyPool: ProxyPoolDatacenter, + renderJS: false, + cache: false, +} + +type ScrapeOption func(*ScrapeOptions) + +type ScrapeOptions struct { + baseURL string + country *string + asp bool + proxyPool ProxyPool + renderJS bool + cache bool + debug bool +} + +type ProxyPool uint8 + +const ( + ProxyPoolDatacenter ProxyPool = iota + ProxyPoolResidential +) + +func (p ProxyPool) String() string { + switch p { + case ProxyPoolDatacenter: + return "public_datacenter_pool" + case ProxyPoolResidential: + return "public_residential_pool" + default: + panic("invalid proxy pool") + } +} + +func WithCountry(country string) ScrapeOption { + return func(o *ScrapeOptions) { + o.country = &country + } +} + +func WithASP(asp bool) ScrapeOption { + return func(o *ScrapeOptions) { + o.asp = asp + } +} + +func WithProxyPool(proxyPool ProxyPool) ScrapeOption { + return func(o *ScrapeOptions) { + o.proxyPool = proxyPool + } +} + +func WithRenderJS(jsRender bool) ScrapeOption { + return func(o *ScrapeOptions) { + o.renderJS = jsRender + } +} + +func WithCache(cache bool) ScrapeOption { + return func(o *ScrapeOptions) { + o.cache = cache + } +} + +func WithDebug(debug bool) ScrapeOption { + return func(o *ScrapeOptions) { + o.debug = debug + } +} + +func WithBaseURL(baseURL string) ScrapeOption { + return func(o *ScrapeOptions) { + o.baseURL = baseURL + } +} diff --git a/backend/internal/ibd/transport/scrapfly/scraper_types.go b/backend/internal/ibd/transport/scrapfly/scraper_types.go new file mode 100644 index 0000000..f3cf651 --- /dev/null +++ b/backend/internal/ibd/transport/scrapfly/scraper_types.go @@ -0,0 +1,253 @@ +package scrapfly + +import ( + "fmt" + "io" + "net/http" + "strings" + "time" +) + +type ScraperResponse struct { + Config struct { + Asp bool `json:"asp"` + AutoScroll bool `json:"auto_scroll"` + Body interface{} `json:"body"` + Cache bool `json:"cache"` + CacheClear bool `json:"cache_clear"` + CacheTtl int `json:"cache_ttl"` + CorrelationId interface{} `json:"correlation_id"` + CostBudget interface{} `json:"cost_budget"` + Country interface{} `json:"country"` + Debug bool `json:"debug"` + Dns bool `json:"dns"` + Env string `json:"env"` + Extract interface{} `json:"extract"` + ExtractionModel interface{} `json:"extraction_model"` + ExtractionModelCustomSchema interface{} `json:"extraction_model_custom_schema"` + ExtractionPrompt interface{} `json:"extraction_prompt"` + ExtractionTemplate interface{} `json:"extraction_template"` + Format string `json:"format"` + Geolocation interface{} `json:"geolocation"` + Headers struct { + Cookie []string `json:"Cookie"` + } `json:"headers"` + JobUuid interface{} `json:"job_uuid"` + Js interface{} `json:"js"` + JsScenario interface{} `json:"js_scenario"` + Lang interface{} `json:"lang"` + LogEvictionDate string `json:"log_eviction_date"` + Method string `json:"method"` + Origin string `json:"origin"` + Os interface{} `json:"os"` + Project string `json:"project"` + ProxyPool string `json:"proxy_pool"` + RenderJs bool `json:"render_js"` + RenderingStage string `json:"rendering_stage"` + RenderingWait int `json:"rendering_wait"` + Retry bool `json:"retry"` + ScheduleName interface{} `json:"schedule_name"` + ScreenshotFlags interface{} `json:"screenshot_flags"` + ScreenshotResolution interface{} `json:"screenshot_resolution"` + Screenshots interface{} `json:"screenshots"` + Session interface{} `json:"session"` + SessionStickyProxy bool `json:"session_sticky_proxy"` + Ssl bool `json:"ssl"` + Tags interface{} `json:"tags"` + Timeout int `json:"timeout"` + Url string `json:"url"` + UserUuid string `json:"user_uuid"` + Uuid string `json:"uuid"` + WaitForSelector interface{} `json:"wait_for_selector"` + WebhookName interface{} `json:"webhook_name"` + } `json:"config"` + Context struct { + Asp interface{} `json:"asp"` + BandwidthConsumed int `json:"bandwidth_consumed"` + BandwidthImagesConsumed int `json:"bandwidth_images_consumed"` + Cache struct { + Entry interface{} `json:"entry"` + State string `json:"state"` + } `json:"cache"` + Cookies []struct { + Comment interface{} `json:"comment"` + Domain string `json:"domain"` + Expires *string `json:"expires"` + HttpOnly bool `json:"http_only"` + MaxAge interface{} `json:"max_age"` + Name string `json:"name"` + Path string `json:"path"` + Secure bool `json:"secure"` + Size int `json:"size"` + Value string `json:"value"` + Version interface{} `json:"version"` + } `json:"cookies"` + Cost struct { + Details []struct { + Amount int `json:"amount"` + Code string `json:"code"` + Description string `json:"description"` + } `json:"details"` + Total int `json:"total"` + } `json:"cost"` + CreatedAt string `json:"created_at"` + Debug interface{} `json:"debug"` + Env string `json:"env"` + Fingerprint string `json:"fingerprint"` + Headers struct { + Cookie string `json:"Cookie"` + } `json:"headers"` + IsXmlHttpRequest bool `json:"is_xml_http_request"` + Job interface{} `json:"job"` + Lang []string `json:"lang"` + Os struct { + Distribution string `json:"distribution"` + Name string `json:"name"` + Type string `json:"type"` + Version string `json:"version"` + } `json:"os"` + Project string `json:"project"` + Proxy struct { + Country string `json:"country"` + Identity string `json:"identity"` + Network string `json:"network"` + Pool string `json:"pool"` + } `json:"proxy"` + Redirects []interface{} `json:"redirects"` + Retry int `json:"retry"` + Schedule interface{} `json:"schedule"` + Session interface{} `json:"session"` + Spider interface{} `json:"spider"` + Throttler interface{} `json:"throttler"` + Uri struct { + BaseUrl string `json:"base_url"` + Fragment interface{} `json:"fragment"` + Host string `json:"host"` + Params interface{} `json:"params"` + Port int `json:"port"` + Query string `json:"query"` + RootDomain string `json:"root_domain"` + Scheme string `json:"scheme"` + } `json:"uri"` + Url string `json:"url"` + Webhook interface{} `json:"webhook"` + } `json:"context"` + Insights interface{} `json:"insights"` + Result ScraperResult `json:"result"` + Uuid string `json:"uuid"` +} + +type ScraperResult struct { + BrowserData struct { + JavascriptEvaluationResult interface{} `json:"javascript_evaluation_result"` + JsScenario []interface{} `json:"js_scenario"` + LocalStorageData struct { + } `json:"local_storage_data"` + SessionStorageData struct { + } `json:"session_storage_data"` + Websockets []interface{} `json:"websockets"` + XhrCall interface{} `json:"xhr_call"` + } `json:"browser_data"` + Content string `json:"content"` + ContentEncoding string `json:"content_encoding"` + ContentFormat string `json:"content_format"` + ContentType string `json:"content_type"` + Cookies []ScraperCookie `json:"cookies"` + Data interface{} `json:"data"` + Dns interface{} `json:"dns"` + Duration float64 `json:"duration"` + Error interface{} `json:"error"` + ExtractedData interface{} `json:"extracted_data"` + Format string `json:"format"` + Iframes []interface{} `json:"iframes"` + LogUrl string `json:"log_url"` + Reason string `json:"reason"` + RequestHeaders map[string]string `json:"request_headers"` + ResponseHeaders map[string]string `json:"response_headers"` + Screenshots struct { + } `json:"screenshots"` + Size int `json:"size"` + Ssl interface{} `json:"ssl"` + Status string `json:"status"` + StatusCode int `json:"status_code"` + Success bool `json:"success"` + Url string `json:"url"` +} + +type ScraperCookie struct { + Name string `json:"name"` + Value string `json:"value"` + Expires string `json:"expires"` + Path string `json:"path"` + Comment string `json:"comment"` + Domain string `json:"domain"` + MaxAge int `json:"max_age"` + Secure bool `json:"secure"` + HttpOnly bool `json:"http_only"` + Version string `json:"version"` + Size int `json:"size"` +} + +func (c *ScraperCookie) ToHTTPCookie() (*http.Cookie, error) { + var expires time.Time + if c.Expires != "" { + var err error + expires, err = time.Parse("2006-01-02 15:04:05", c.Expires) + if err != nil { + return nil, fmt.Errorf("failed to parse cookie expiration: %w", err) + } + } + return &http.Cookie{ + Name: c.Name, + Value: c.Value, + Path: c.Path, + Domain: c.Domain, + Expires: expires, + Secure: c.Secure, + HttpOnly: c.HttpOnly, + }, nil +} + +func (c *ScraperCookie) FromHTTPCookie(cookie *http.Cookie) { + var expires string + if !cookie.Expires.IsZero() { + expires = cookie.Expires.Format("2006-01-02 15:04:05") + } + *c = ScraperCookie{ + Comment: "", + Domain: cookie.Domain, + Expires: expires, + HttpOnly: cookie.HttpOnly, + MaxAge: cookie.MaxAge, + Name: cookie.Name, + Path: cookie.Path, + Secure: cookie.Secure, + Size: len(cookie.Value), + Value: cookie.Value, + Version: "", + } +} + +func (r *ScraperResponse) ToHTTPResponse() (*http.Response, error) { + resp := &http.Response{ + StatusCode: r.Result.StatusCode, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(r.Result.Content)), + ContentLength: int64(len(r.Result.Content)), + Close: true, + } + + for k, v := range r.Result.ResponseHeaders { + resp.Header.Set(k, v) + } + + for _, c := range r.Result.Cookies { + cookie, err := c.ToHTTPCookie() + if err != nil { + return nil, err + } + resp.Header.Add("Set-Cookie", cookie.String()) + } + + return resp, nil +} diff --git a/backend/internal/ibd/transport/scrapfly/scrapfly.go b/backend/internal/ibd/transport/scrapfly/scrapfly.go new file mode 100644 index 0000000..3b414de --- /dev/null +++ b/backend/internal/ibd/transport/scrapfly/scrapfly.go @@ -0,0 +1,103 @@ +package scrapfly + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strconv" + + "github.com/ansg191/ibd-trader-backend/internal/ibd/transport" +) + +type ScrapflyTransport struct { + client *http.Client + apiKey string + options ScrapeOptions +} + +func New(client *http.Client, apiKey string, opts ...ScrapeOption) *ScrapflyTransport { + options := defaultScrapeOptions + for _, opt := range opts { + opt(&options) + } + + return &ScrapflyTransport{ + client: client, + apiKey: apiKey, + options: options, + } +} + +func (s *ScrapflyTransport) String() string { + return "scrapfly" +} + +func (s *ScrapflyTransport) Do(req *http.Request) (*http.Response, error) { + // Construct scrape request URL + scrapeUrl, err := url.Parse(s.options.baseURL) + if err != nil { + panic(err) + } + scrapeUrl.RawQuery = s.constructRawQuery(req.URL, req.Header) + + // We can't handle `Content-Type` header on GET requests + // Wierd quirk of the Scrapfly API + if req.Method == http.MethodGet && req.Header.Get("Content-Type") != "" { + return nil, transport.ErrUnsupportedRequest + } + + // Construct scrape request + scrapeReq, err := http.NewRequestWithContext(req.Context(), req.Method, scrapeUrl.String(), req.Body) + if err != nil { + return nil, err + } + + // Send scrape request + resp, err := s.client.Do(scrapeReq) + if err != nil { + return nil, err + } + defer func(Body io.ReadCloser) { + _ = Body.Close() + }(resp.Body) + + // Parse scrape response + scraperResponse := new(ScraperResponse) + err = json.NewDecoder(resp.Body).Decode(scraperResponse) + if err != nil { + return nil, err + } + + // Convert scraper response to http.Response + return scraperResponse.ToHTTPResponse() +} + +func (s *ScrapflyTransport) Properties() transport.Properties { + return transport.PropertiesReliable +} + +func (s *ScrapflyTransport) constructRawQuery(u *url.URL, headers http.Header) string { + params := url.Values{} + params.Set("key", s.apiKey) + params.Set("url", u.String()) + if s.options.country != nil { + params.Set("country", *s.options.country) + } + params.Set("asp", strconv.FormatBool(s.options.asp)) + params.Set("proxy_pool", s.options.proxyPool.String()) + params.Set("render_js", strconv.FormatBool(s.options.renderJS)) + params.Set("cache", strconv.FormatBool(s.options.cache)) + + for k, v := range headers { + for i, vv := range v { + params.Add( + fmt.Sprintf("headers[%s][%d]", k, i), + vv, + ) + } + } + + return params.Encode() +} diff --git a/backend/internal/ibd/transport/standard.go b/backend/internal/ibd/transport/standard.go new file mode 100644 index 0000000..9fa9ff9 --- /dev/null +++ b/backend/internal/ibd/transport/standard.go @@ -0,0 +1,41 @@ +package transport + +import ( + "net/http" + + "github.com/EDDYCJY/fake-useragent" +) + +type StandardTransport http.Client + +func NewStandardTransport(client *http.Client) *StandardTransport { + return (*StandardTransport)(client) +} + +func (t *StandardTransport) Do(req *http.Request) (*http.Response, error) { + addFakeHeaders(req) + return (*http.Client)(t).Do(req) +} + +func (t *StandardTransport) String() string { + return "standard" +} + +func (t *StandardTransport) Properties() Properties { + return PropertiesFree +} + +func addFakeHeaders(req *http.Request) { + req.Header.Set("User-Agent", browser.Linux()) + req.Header.Set("Sec-CH-UA", `"Not)A;Brand";v="99", "Google Chrome";v="127", "Chromium";v="127"`) + req.Header.Set("Sec-CH-UA-Mobile", "?0") + req.Header.Set("Sec-CH-UA-Platform", "Linux") + req.Header.Set("Upgrade-Insecure-Requests", "1") + req.Header.Set("Priority", "u=0, i") + req.Header.Set("Sec-Fetch-Site", "none") + req.Header.Set("Sec-Fetch-Mode", "navigate") + req.Header.Set("Sec-Fetch-Dest", "document") + req.Header.Set("Sec-Fetch-User", "?1") + req.Header.Set("Accept-Language", "en-US,en;q=0.9") + req.Header.Set("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7") +} diff --git a/backend/internal/ibd/transport/transport.go b/backend/internal/ibd/transport/transport.go new file mode 100644 index 0000000..95e9ef3 --- /dev/null +++ b/backend/internal/ibd/transport/transport.go @@ -0,0 +1,66 @@ +package transport + +import ( + "cmp" + "errors" + "fmt" + "net/http" + "slices" +) + +var ErrUnsupportedRequest = errors.New("unsupported request") + +type Properties uint8 + +const ( + // PropertiesFree indicates that the transport is free. + // This means that requests made with this transport don't cost any money. + PropertiesFree Properties = 1 << iota + // PropertiesReliable indicates that the transport is reliable. + // This means that requests made with this transport are guaranteed to be + // successful if the server is reachable. + PropertiesReliable +) + +func (p Properties) IsReliable() bool { + return p&PropertiesReliable != 0 +} + +func (p Properties) IsFree() bool { + return p&PropertiesFree != 0 +} + +type Transport interface { + fmt.Stringer + + Do(req *http.Request) (*http.Response, error) + Properties() Properties +} + +// SortTransports sorts the transports by their properties. +// +// The transports are sorted in the following order: +// 1. Free transports +// 2. Reliable transports +func SortTransports(transports []Transport) { + priorities := map[Properties]int{ + PropertiesFree | PropertiesReliable: 0, + PropertiesFree: 1, + PropertiesReliable: 2, + } + slices.SortStableFunc(transports, func(a, b Transport) int { + iPriority := priorities[a.Properties()] + jPriority := priorities[b.Properties()] + return cmp.Compare(iPriority, jPriority) + }) +} + +func FilterTransports(transport []Transport, props Properties) []Transport { + var filtered []Transport + for _, tp := range transport { + if tp.Properties()&props == props { + filtered = append(filtered, tp) + } + } + return filtered +} diff --git a/backend/internal/ibd/userinfo.go b/backend/internal/ibd/userinfo.go new file mode 100644 index 0000000..ed61497 --- /dev/null +++ b/backend/internal/ibd/userinfo.go @@ -0,0 +1,156 @@ +package ibd + +import ( + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" +) + +const ( + userInfoUrl = "https://myibd.investors.com/services/userprofile.aspx?format=json" +) + +func (c *Client) UserInfo(ctx context.Context, cookie *http.Cookie) (*UserProfile, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, userInfoUrl, nil) + if err != nil { + return nil, err + } + + req.AddCookie(cookie) + + resp, err := c.Do(req) + if err != nil { + return nil, err + } + defer func(Body io.ReadCloser) { + _ = Body.Close() + }(resp.Body) + + content, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf( + "unexpected status code %d: %s", + resp.StatusCode, + string(content), + ) + } + + up := new(UserProfile) + if err = up.UnmarshalJSON(content); err != nil { + return nil, err + } + + return up, nil +} + +type UserStatus string + +const ( + UserStatusUnknown UserStatus = "" + UserStatusVisitor UserStatus = "Visitor" + UserStatusSubscriber UserStatus = "Subscriber" +) + +type UserProfile struct { + DisplayName string + Email string + FirstName string + LastName string + Status UserStatus +} + +func (u *UserProfile) UnmarshalJSON(bytes []byte) error { + var resp userProfileResponse + if err := json.Unmarshal(bytes, &resp); err != nil { + return err + } + + u.DisplayName = resp.UserProfile.UserDisplayName + u.Email = resp.UserProfile.UserEmailAddress + u.FirstName = resp.UserProfile.UserFirstName + u.LastName = resp.UserProfile.UserLastName + + switch resp.UserProfile.UserTrialStatus { + case "Visitor": + u.Status = UserStatusVisitor + case "Subscriber": + u.Status = UserStatusSubscriber + default: + slog.Warn("Unknown user status", "status", resp.UserProfile.UserTrialStatus) + u.Status = UserStatusUnknown + } + + return nil +} + +type userProfileResponse struct { + UserProfile userProfile `json:"userProfile"` +} + +type userProfile struct { + UserSubType string `json:"userSubType"` + UserId string `json:"userId"` + UserDisplayName string `json:"userDisplayName"` + Countrycode string `json:"countrycode"` + IsEUCountry string `json:"isEUCountry"` + Log string `json:"log"` + AgeGroup string `json:"ageGroup"` + Gender string `json:"gender"` + InvestingExperience string `json:"investingExperience"` + NumberOfTrades string `json:"numberOfTrades"` + Occupation string `json:"occupation"` + TypeOfInvestments string `json:"typeOfInvestments"` + UserEmailAddress string `json:"userEmailAddress"` + UserEmailAddressSHA1 string `json:"userEmailAddressSHA1"` + UserEmailAddressSHA256 string `json:"userEmailAddressSHA256"` + UserEmailAddressMD5 string `json:"userEmailAddressMD5"` + UserFirstName string `json:"userFirstName"` + UserLastName string `json:"userLastName"` + UserZip string `json:"userZip"` + UserTrialStatus string `json:"userTrialStatus"` + UserProductsOnTrial string `json:"userProductsOnTrial"` + UserProductsOwned string `json:"userProductsOwned"` + UserAdTrade string `json:"userAdTrade"` + UserAdTime string `json:"userAdTime"` + UserAdHold string `json:"userAdHold"` + UserAdJob string `json:"userAdJob"` + UserAdAge string `json:"userAdAge"` + UserAdOutSell string `json:"userAdOutSell"` + UserVisitCount string `json:"userVisitCount"` + RoleLeaderboard bool `json:"role_leaderboard"` + RoleOws bool `json:"role_ows"` + RoleIbdlive bool `json:"role_ibdlive"` + RoleFounderclub bool `json:"role_founderclub"` + RoleEibd bool `json:"role_eibd"` + RoleIcom bool `json:"role_icom"` + RoleEtables bool `json:"role_etables"` + RoleTru10 bool `json:"role_tru10"` + RoleMarketsurge bool `json:"role_marketsurge"` + RoleSwingtrader bool `json:"role_swingtrader"` + RoleAdfree bool `json:"role_adfree"` + RoleMarketdiem bool `json:"role_marketdiem"` + RoleWsjPlus bool `json:"role_wsj_plus"` + RoleWsj bool `json:"role_wsj"` + RoleBarrons bool `json:"role_barrons"` + RoleMarketwatch bool `json:"role_marketwatch"` + UserAdRoles string `json:"userAdRoles"` + TrialDailyPrintNeg bool `json:"trial_daily_print_neg"` + TrialDailyPrintNon bool `json:"trial_daily_print_non"` + TrialWeeklyPrintNeg bool `json:"trial_weekly_print_neg"` + TrialWeeklyPrintNon bool `json:"trial_weekly_print_non"` + TrialDailyComboNeg bool `json:"trial_daily_combo_neg"` + TrialDailyComboNon bool `json:"trial_daily_combo_non"` + TrialWeeklyComboNeg bool `json:"trial_weekly_combo_neg"` + TrialWeeklyComboNon bool `json:"trial_weekly_combo_non"` + TrialEibdNeg bool `json:"trial_eibd_neg"` + TrialEibdNon bool `json:"trial_eibd_non"` + UserVideoPreference string `json:"userVideoPreference"` + UserProfessionalStatus bool `json:"userProfessionalStatus"` +} diff --git a/backend/internal/keys/gcp.go b/backend/internal/keys/gcp.go new file mode 100644 index 0000000..9d10fc5 --- /dev/null +++ b/backend/internal/keys/gcp.go @@ -0,0 +1,131 @@ +package keys + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "encoding/pem" + "errors" + "fmt" + "hash/crc32" + "sync" + + kms "cloud.google.com/go/kms/apiv1" + "cloud.google.com/go/kms/apiv1/kmspb" + "google.golang.org/protobuf/types/known/wrapperspb" +) + +type GoogleKMS struct { + client *kms.KeyManagementClient + + mx sync.RWMutex + keyCache map[string]*rsa.PublicKey +} + +func NewGoogleKMS(ctx context.Context) (*GoogleKMS, error) { + client, err := kms.NewKeyManagementClient(ctx) + if err != nil { + return nil, err + } + + return &GoogleKMS{ + client: client, + keyCache: make(map[string]*rsa.PublicKey), + }, nil +} + +func (g *GoogleKMS) checkCache(keyName string) *rsa.PublicKey { + g.mx.RLock() + defer g.mx.RUnlock() + + return g.keyCache[keyName] +} + +func (g *GoogleKMS) getPublicKey(ctx context.Context, keyName string) (*rsa.PublicKey, error) { + if key := g.checkCache(keyName); key != nil { + return key, nil + } + + response, err := g.client.GetPublicKey(ctx, &kmspb.GetPublicKeyRequest{Name: keyName}) + if err != nil { + return nil, err + } + + block, _ := pem.Decode([]byte(response.Pem)) + if block == nil || block.Type != "PUBLIC KEY" { + return nil, errors.New("failed to decode PEM public key") + } + publicKey, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse public key: %w", err) + } + rsaKey, ok := publicKey.(*rsa.PublicKey) + if !ok { + return nil, errors.New("public key is not an RSA key") + } + + g.mx.Lock() + defer g.mx.Unlock() + g.keyCache[keyName] = rsaKey + + return rsaKey, nil +} + +func (g *GoogleKMS) Encrypt(ctx context.Context, keyName string, plaintext []byte) ([]byte, error) { + publicKey, err := g.getPublicKey(ctx, keyName) + if err != nil { + return nil, err + } + + cipherText, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, publicKey, plaintext, nil) + if err != nil { + return nil, fmt.Errorf("failed to encrypt plaintext: %w", err) + } + + return cipherText, nil +} + +func (g *GoogleKMS) Decrypt(ctx context.Context, keyName string, ciphertext []byte) ([]byte, error) { + req := &kmspb.AsymmetricDecryptRequest{ + Name: keyName, + Ciphertext: ciphertext, + CiphertextCrc32C: wrapperspb.Int64(int64(calcCRC32(ciphertext))), + } + + result, err := g.client.AsymmetricDecrypt(ctx, req) + if err != nil { + return nil, fmt.Errorf("failed to decrypt ciphertext: %w", err) + } + + if !result.VerifiedCiphertextCrc32C { + return nil, errors.New("AsymmetricDecrypt: request corrupted in-transit") + } + if int64(calcCRC32(result.Plaintext)) != result.PlaintextCrc32C.Value { + return nil, fmt.Errorf("AsymmetricDecrypt: response corrupted in-transit") + } + + return result.Plaintext, nil +} + +func (g *GoogleKMS) Close() error { + return g.client.Close() +} + +func calcCRC32(data []byte) uint32 { + t := crc32.MakeTable(crc32.Castagnoli) + return crc32.Checksum(data, t) +} + +type GCPKeyName struct { + Project string + Location string + KeyRing string + CryptoKey string + CryptoKeyVersion string +} + +func (k GCPKeyName) String() string { + return fmt.Sprintf("projects/%s/locations/%s/keyRings/%s/cryptoKeys/%s/cryptoKeyVersions/%s", k.Project, k.Location, k.KeyRing, k.CryptoKey, k.CryptoKeyVersion) +} diff --git a/backend/internal/keys/keys.go b/backend/internal/keys/keys.go new file mode 100644 index 0000000..ac73173 --- /dev/null +++ b/backend/internal/keys/keys.go @@ -0,0 +1,150 @@ +package keys + +import ( + "context" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "fmt" + "io" +) + +var CSRNG = rand.Reader + +//go:generate mockgen -destination mock_keys_test.go -package keys_test -typed . KeyManagementService +type KeyManagementService interface { + io.Closer + + // Encrypt encrypts the given plaintext using the key with the given key name. + Encrypt(ctx context.Context, keyName string, plaintext []byte) ([]byte, error) + + // Decrypt decrypts the given ciphertext using the key with the given key name. + Decrypt(ctx context.Context, keyName string, ciphertext []byte) ([]byte, error) +} + +// Encrypt encrypts the given plaintext using a hybrid encryption scheme. +// +// It first generates a random AES 256-bit key and encrypts the plaintext with it. +// Then, it encrypts the AES key using the KMS. +// +// It returns the ciphertext, the encrypted AES key, and any errors that occurred. +func Encrypt( + ctx context.Context, + kms KeyManagementService, + keyName string, + plaintext []byte, +) (ciphertext []byte, encryptedKey []byte, err error) { + // Generate a random AES key + aesKey := make([]byte, 32) + if _, err = io.ReadFull(CSRNG, aesKey); err != nil { + return nil, nil, fmt.Errorf("unable to generate AES key: %w", err) + } + + // Encrypt the plaintext using the AES key + ciphertext, err = encrypt(aesKey, plaintext) + if err != nil { + return nil, nil, fmt.Errorf("unable to encrypt plaintext: %w", err) + } + + // Encrypt the AES key using the KMS + encryptedKey, err = kms.Encrypt(ctx, keyName, aesKey) + if err != nil { + return nil, nil, fmt.Errorf("unable to encrypt AES key: %w", err) + } + + return ciphertext, encryptedKey, nil +} + +// EncryptWithKey encrypts the given plaintext using a hybrid encryption scheme. +// +// This works similarly to Encrypt, but instead of generating a new AES key, it uses a given already encrypted AES key. +func EncryptWithKey( + ctx context.Context, + kms KeyManagementService, + keyName string, + encryptedKey []byte, + plaintext []byte, +) ([]byte, error) { + // Decrypt the AES key + aesKey, err := kms.Decrypt(ctx, keyName, encryptedKey) + if err != nil { + return nil, fmt.Errorf("unable to decrypt AES key: %w", err) + } + + // Encrypt the plaintext using the AES key + ciphertext, err := encrypt(aesKey, plaintext) + if err != nil { + return nil, fmt.Errorf("unable to encrypt plaintext: %w", err) + } + + return ciphertext, nil +} + +func encrypt(aesKey []byte, plaintext []byte) ([]byte, error) { + // Create an AES cipher + blockCipher, err := aes.NewCipher(aesKey) + if err != nil { + return nil, fmt.Errorf("unable to create AES cipher: %w", err) + } + + gcm, err := cipher.NewGCM(blockCipher) + if err != nil { + return nil, fmt.Errorf("unable to create GCM: %w", err) + } + + // Generate a random nonce + nonce := make([]byte, gcm.NonceSize()) + if _, err = io.ReadFull(CSRNG, nonce); err != nil { + return nil, fmt.Errorf("unable to generate nonce: %w", err) + } + + // Encrypt the plaintext + ciphertext := gcm.Seal(nonce, nonce, plaintext, nil) + return ciphertext, nil +} + +// Decrypt decrypts the given ciphertext using a hybrid encryption scheme. +// +// It first decrypts the AES key using the KMS. +// Then, it decrypts the ciphertext using the decrypted AES key. +// +// It returns the plaintext and any errors that occurred. +func Decrypt( + ctx context.Context, + kms KeyManagementService, + keyName string, + ciphertext []byte, + encryptedKey []byte, +) ([]byte, error) { + // Decrypt the AES key + aesKey, err := kms.Decrypt(ctx, keyName, encryptedKey) + if err != nil { + return nil, fmt.Errorf("unable to decrypt AES key: %w", err) + } + + // Create an AES cipher + blockCipher, err := aes.NewCipher(aesKey) + if err != nil { + return nil, fmt.Errorf("unable to create AES cipher: %w", err) + } + + gcm, err := cipher.NewGCM(blockCipher) + if err != nil { + return nil, fmt.Errorf("unable to create GCM: %w", err) + } + + // Extract the nonce from the ciphertext + nonceSize := gcm.NonceSize() + if len(ciphertext) < nonceSize { + return nil, fmt.Errorf("ciphertext is too short") + } + nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:] + + // Decrypt the ciphertext + plaintext, err := gcm.Open(nil, nonce, ciphertext, nil) + if err != nil { + return nil, fmt.Errorf("unable to decrypt ciphertext: %w", err) + } + + return plaintext, nil +} diff --git a/backend/internal/keys/keys_test.go b/backend/internal/keys/keys_test.go new file mode 100644 index 0000000..34aa493 --- /dev/null +++ b/backend/internal/keys/keys_test.go @@ -0,0 +1,64 @@ +package keys_test + +import ( + "bytes" + "context" + "encoding/hex" + "testing" + + "github.com/ansg191/ibd-trader-backend/internal/keys" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" +) + +func TestEncrypt(t *testing.T) { + ctrl := gomock.NewController(t) + + // Replace RNG with a deterministic RNG + aesKey := []byte("0123456789abcdef0123456789abcdef") + nonce := []byte("0123456789ab") + keys.CSRNG = bytes.NewReader(append(aesKey, nonce...)) + + // Create a mock KMS + kms := NewMockKeyManagementService(ctrl) + keyName := "keyName" + + ctx := context.Background() + plaintext := []byte("plaintext") + + kms.EXPECT(). + Encrypt(ctx, keyName, aesKey). + Return([]byte("encryptedKey"), nil) + + ciphertext, encryptedKey, err := keys.Encrypt(ctx, kms, keyName, plaintext) + require.NoError(t, err) + + encrypted, err := hex.DecodeString("e9c586532dbefd63812293e1c4baf71edb7042a294c49c2020") + require.NoError(t, err) + assert.Equal(t, append(nonce, encrypted...), ciphertext) + assert.Equal(t, []byte("encryptedKey"), encryptedKey) +} + +func TestDecrypt(t *testing.T) { + ctrl := gomock.NewController(t) + + kms := NewMockKeyManagementService(ctrl) + keyName := "keyName" + + ctx := context.Background() + encryptedKey := []byte("encryptedKey") + ciphertext, err := hex.DecodeString("e9c586532dbefd63812293e1c4baf71edb7042a294c49c2020") + require.NoError(t, err) + ciphertext = append([]byte("0123456789ab"), ciphertext...) + + aesKey := []byte("0123456789abcdef0123456789abcdef") + kms.EXPECT(). + Decrypt(ctx, keyName, encryptedKey). + Return(aesKey, nil) + + plaintext, err := keys.Decrypt(ctx, kms, keyName, ciphertext, encryptedKey) + require.NoError(t, err) + assert.Equal(t, []byte("plaintext"), plaintext) +} diff --git a/backend/internal/keys/mock_keys_test.go b/backend/internal/keys/mock_keys_test.go new file mode 100644 index 0000000..19316e0 --- /dev/null +++ b/backend/internal/keys/mock_keys_test.go @@ -0,0 +1,156 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/ansg191/ibd-trader-backend/internal/keys (interfaces: KeyManagementService) +// +// Generated by this command: +// +// mockgen -destination mock_keys_test.go -package keys_test -typed . KeyManagementService +// + +// Package keys_test is a generated GoMock package. +package keys_test + +import ( + context "context" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockKeyManagementService is a mock of KeyManagementService interface. +type MockKeyManagementService struct { + ctrl *gomock.Controller + recorder *MockKeyManagementServiceMockRecorder +} + +// MockKeyManagementServiceMockRecorder is the mock recorder for MockKeyManagementService. +type MockKeyManagementServiceMockRecorder struct { + mock *MockKeyManagementService +} + +// NewMockKeyManagementService creates a new mock instance. +func NewMockKeyManagementService(ctrl *gomock.Controller) *MockKeyManagementService { + mock := &MockKeyManagementService{ctrl: ctrl} + mock.recorder = &MockKeyManagementServiceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockKeyManagementService) EXPECT() *MockKeyManagementServiceMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockKeyManagementService) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockKeyManagementServiceMockRecorder) Close() *MockKeyManagementServiceCloseCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockKeyManagementService)(nil).Close)) + return &MockKeyManagementServiceCloseCall{Call: call} +} + +// MockKeyManagementServiceCloseCall wrap *gomock.Call +type MockKeyManagementServiceCloseCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockKeyManagementServiceCloseCall) Return(arg0 error) *MockKeyManagementServiceCloseCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockKeyManagementServiceCloseCall) Do(f func() error) *MockKeyManagementServiceCloseCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockKeyManagementServiceCloseCall) DoAndReturn(f func() error) *MockKeyManagementServiceCloseCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Decrypt mocks base method. +func (m *MockKeyManagementService) Decrypt(arg0 context.Context, arg1 string, arg2 []byte) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Decrypt", arg0, arg1, arg2) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Decrypt indicates an expected call of Decrypt. +func (mr *MockKeyManagementServiceMockRecorder) Decrypt(arg0, arg1, arg2 any) *MockKeyManagementServiceDecryptCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Decrypt", reflect.TypeOf((*MockKeyManagementService)(nil).Decrypt), arg0, arg1, arg2) + return &MockKeyManagementServiceDecryptCall{Call: call} +} + +// MockKeyManagementServiceDecryptCall wrap *gomock.Call +type MockKeyManagementServiceDecryptCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockKeyManagementServiceDecryptCall) Return(arg0 []byte, arg1 error) *MockKeyManagementServiceDecryptCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockKeyManagementServiceDecryptCall) Do(f func(context.Context, string, []byte) ([]byte, error)) *MockKeyManagementServiceDecryptCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockKeyManagementServiceDecryptCall) DoAndReturn(f func(context.Context, string, []byte) ([]byte, error)) *MockKeyManagementServiceDecryptCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Encrypt mocks base method. +func (m *MockKeyManagementService) Encrypt(arg0 context.Context, arg1 string, arg2 []byte) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Encrypt", arg0, arg1, arg2) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Encrypt indicates an expected call of Encrypt. +func (mr *MockKeyManagementServiceMockRecorder) Encrypt(arg0, arg1, arg2 any) *MockKeyManagementServiceEncryptCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Encrypt", reflect.TypeOf((*MockKeyManagementService)(nil).Encrypt), arg0, arg1, arg2) + return &MockKeyManagementServiceEncryptCall{Call: call} +} + +// MockKeyManagementServiceEncryptCall wrap *gomock.Call +type MockKeyManagementServiceEncryptCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockKeyManagementServiceEncryptCall) Return(arg0 []byte, arg1 error) *MockKeyManagementServiceEncryptCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockKeyManagementServiceEncryptCall) Do(f func(context.Context, string, []byte) ([]byte, error)) *MockKeyManagementServiceEncryptCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockKeyManagementServiceEncryptCall) DoAndReturn(f func(context.Context, string, []byte) ([]byte, error)) *MockKeyManagementServiceEncryptCall { + c.Call = c.Call.DoAndReturn(f) + return c +} diff --git a/backend/internal/leader/election/election.go b/backend/internal/leader/election/election.go new file mode 100644 index 0000000..6f83298 --- /dev/null +++ b/backend/internal/leader/election/election.go @@ -0,0 +1,128 @@ +package election + +import ( + "context" + "errors" + "log/slog" + "time" + + "github.com/bsm/redislock" +) + +var defaultLeaderElectionOptions = leaderElectionOptions{ + lockKey: "ibd-leader-election", + lockTTL: 10 * time.Second, +} + +func RunOrDie( + ctx context.Context, + client redislock.RedisClient, + onLeader func(context.Context), + opts ...LeaderElectionOption, +) { + o := defaultLeaderElectionOptions + for _, opt := range opts { + opt(&o) + } + + locker := redislock.New(client) + + // Election loop + for { + lock, err := locker.Obtain(ctx, o.lockKey, o.lockTTL, nil) + if errors.Is(err, redislock.ErrNotObtained) { + // Another instance is the leader + } else if err != nil { + slog.ErrorContext(ctx, "failed to obtain lock", "error", err) + } else { + // We are the leader + slog.DebugContext(ctx, "elected leader") + runLeader(ctx, lock, onLeader, o) + } + + // Sleep for a bit before trying again + timer := time.NewTimer(o.lockTTL / 5) + select { + case <-ctx.Done(): + if !timer.Stop() { + <-timer.C + } + return + case <-timer.C: + } + } +} + +func runLeader( + ctx context.Context, + lock *redislock.Lock, + onLeader func(context.Context), + o leaderElectionOptions, +) { + // A context that is canceled when the leader loses the lock + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + // Release the lock when done + defer func() { + // Create new context without cancel if the original context is already canceled + relCtx := ctx + if ctx.Err() != nil { + relCtx = context.WithoutCancel(ctx) + } + + // Add a timeout to the release context + relCtx, cancel := context.WithTimeout(relCtx, o.lockTTL) + defer cancel() + + if err := lock.Release(relCtx); err != nil { + slog.Error("failed to release lock", "error", err) + } + }() + + // Run the leader code + go func(ctx context.Context) { + onLeader(ctx) + + // If the leader code returns, cancel the context to release the lock + cancel() + }(ctx) + + // Refresh the lock periodically + ticker := time.NewTicker(o.lockTTL / 10) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + err := lock.Refresh(ctx, o.lockTTL, nil) + if errors.Is(err, redislock.ErrNotObtained) || errors.Is(err, redislock.ErrLockNotHeld) { + slog.ErrorContext(ctx, "leadership lost", "error", err) + return + } else if err != nil { + slog.ErrorContext(ctx, "failed to refresh lock", "error", err) + } + case <-ctx.Done(): + return + } + } +} + +type leaderElectionOptions struct { + lockKey string + lockTTL time.Duration +} + +type LeaderElectionOption func(*leaderElectionOptions) + +func WithLockKey(key string) LeaderElectionOption { + return func(o *leaderElectionOptions) { + o.lockKey = key + } +} + +func WithLockTTL(ttl time.Duration) LeaderElectionOption { + return func(o *leaderElectionOptions) { + o.lockTTL = ttl + } +} diff --git a/backend/internal/leader/manager/ibd/auth/auth.go b/backend/internal/leader/manager/ibd/auth/auth.go new file mode 100644 index 0000000..9b5502d --- /dev/null +++ b/backend/internal/leader/manager/ibd/auth/auth.go @@ -0,0 +1,111 @@ +package auth + +import ( + "context" + "log/slog" + "time" + + "github.com/ansg191/ibd-trader-backend/internal/database" + "github.com/ansg191/ibd-trader-backend/internal/redis/taskqueue" + + "github.com/redis/go-redis/v9" + "github.com/robfig/cron/v3" +) + +const ( + Queue = "auth-queue" + QueueEncoding = taskqueue.EncodingJSON +) + +// Manager is responsible for sending authentication tasks to the workers. +type Manager struct { + queue taskqueue.TaskQueue[TaskInfo] + db database.Executor + schedule cron.Schedule +} + +func New( + ctx context.Context, + db database.Executor, + rClient *redis.Client, + schedule cron.Schedule, +) (*Manager, error) { + queue, err := taskqueue.New( + ctx, + rClient, + Queue, + "auth-manager", + taskqueue.WithEncoding[TaskInfo](QueueEncoding), + ) + if err != nil { + return nil, err + } + + return &Manager{ + queue: queue, + db: db, + schedule: schedule, + }, nil +} + +func (m *Manager) Run(ctx context.Context) { + for { + now := time.Now() + // Find the next time + nextTime := m.schedule.Next(now) + if nextTime.IsZero() { + // Sleep until the next day + time.Sleep(time.Until(now.AddDate(0, 0, 1))) + continue + } + + timer := time.NewTimer(nextTime.Sub(now)) + slog.DebugContext(ctx, "waiting for next Auth scrape", "next_exec", nextTime) + + select { + case <-timer.C: + nextExec := m.schedule.Next(nextTime) + m.scrapeCookies(ctx, nextExec) + case <-ctx.Done(): + if !timer.Stop() { + <-timer.C + } + return + } + } +} + +// scrapeCookies scrapes the cookies for every user from the IBD website. +// +// This iterates through all users with IBD credentials and checks whether their cookies are still valid. +// If the cookies are invalid or missing, it re-authenticates the user and updates the cookies in the database. +func (m *Manager) scrapeCookies(ctx context.Context, deadline time.Time) { + ctx, cancel := context.WithDeadline(ctx, deadline) + defer cancel() + + // Get all users with IBD credentials + users, err := database.ListUsers(ctx, m.db, true) + if err != nil { + slog.ErrorContext(ctx, "failed to get users", "error", err) + return + } + + // Create a new task for each user + for _, user := range users { + task := TaskInfo{ + UserSubject: user.Subject, + } + + // Enqueue the task + _, err := m.queue.Enqueue(ctx, task) + if err != nil { + slog.ErrorContext(ctx, "failed to enqueue task", "error", err) + } + } + + slog.InfoContext(ctx, "enqueued tasks for all users") +} + +type TaskInfo struct { + UserSubject string `json:"user_subject"` +} diff --git a/backend/internal/leader/manager/ibd/ibd.go b/backend/internal/leader/manager/ibd/ibd.go new file mode 100644 index 0000000..e2d4fc0 --- /dev/null +++ b/backend/internal/leader/manager/ibd/ibd.go @@ -0,0 +1,8 @@ +package ibd + +type Schedules struct { + // Auth schedule + Auth string + // IBD50 schedule + IBD50 string +} diff --git a/backend/internal/leader/manager/ibd/scrape/scrape.go b/backend/internal/leader/manager/ibd/scrape/scrape.go new file mode 100644 index 0000000..870ce5e --- /dev/null +++ b/backend/internal/leader/manager/ibd/scrape/scrape.go @@ -0,0 +1,140 @@ +package scrape + +import ( + "context" + "errors" + "log/slog" + "time" + + "github.com/ansg191/ibd-trader-backend/internal/database" + "github.com/ansg191/ibd-trader-backend/internal/ibd" + "github.com/ansg191/ibd-trader-backend/internal/redis/taskqueue" + + "github.com/redis/go-redis/v9" + "github.com/robfig/cron/v3" +) + +const ( + Queue = "scrape-queue" + QueueEncoding = taskqueue.EncodingJSON + Channel = "scrape-channel" +) + +// Manager is responsible for sending scraping tasks to the workers. +type Manager struct { + client *ibd.Client + db database.Executor + queue taskqueue.TaskQueue[TaskInfo] + schedule cron.Schedule + pubsub *redis.PubSub +} + +func New( + ctx context.Context, + client *ibd.Client, + db database.Executor, + redis *redis.Client, + schedule cron.Schedule, +) (*Manager, error) { + queue, err := taskqueue.New( + ctx, + redis, + Queue, + "ibd-manager", + taskqueue.WithEncoding[TaskInfo](QueueEncoding), + ) + if err != nil { + return nil, err + } + + return &Manager{ + client: client, + db: db, + queue: queue, + schedule: schedule, + pubsub: redis.Subscribe(ctx, Channel), + }, nil +} + +func (m *Manager) Close() error { + return m.pubsub.Close() +} + +func (m *Manager) Run(ctx context.Context) { + ch := m.pubsub.Channel() + for { + now := time.Now() + // Find the next time + nextTime := m.schedule.Next(now) + if nextTime.IsZero() { + // Sleep until the next day + time.Sleep(time.Until(now.AddDate(0, 0, 1))) + continue + } + + timer := time.NewTimer(nextTime.Sub(now)) + slog.DebugContext(ctx, "waiting for next IBD50 scrape", "next_exec", nextTime) + + select { + case <-timer.C: + nextExec := m.schedule.Next(nextTime) + m.scrapeIBD50(ctx, nextExec) + case <-ch: + nextExec := m.schedule.Next(time.Now()) + m.scrapeIBD50(ctx, nextExec) + case <-ctx.Done(): + if !timer.Stop() { + <-timer.C + } + return + } + } +} + +func (m *Manager) scrapeIBD50(ctx context.Context, deadline time.Time) { + ctx, cancel := context.WithDeadline(ctx, deadline) + defer cancel() + + stocks, err := m.client.GetIBD50(ctx) + if err != nil { + if errors.Is(err, ibd.ErrNoAvailableCookies) { + slog.WarnContext(ctx, "no available cookies", "error", err) + return + } + slog.ErrorContext(ctx, "failed to get IBD50", "error", err) + return + } + + for _, stock := range stocks { + // Add stock to DB + err = database.AddStock(ctx, m.db, database.Stock{ + Symbol: stock.Symbol, + Name: stock.Name, + IBDUrl: stock.QuoteURL.String(), + }) + if err != nil { + slog.ErrorContext(ctx, "failed to add stock", "error", err) + continue + } + + // Add ranking to Db + err = database.AddRanking(ctx, m.db, stock.Symbol, int(stock.Rank), 0) + if err != nil { + slog.ErrorContext(ctx, "failed to add ranking", "error", err) + continue + } + + // Add scrape task to queue + task := TaskInfo{Symbol: stock.Symbol} + taskID, err := m.queue.Enqueue(ctx, task) + if err != nil { + slog.ErrorContext(ctx, "failed to enqueue task", "error", err) + } + + slog.DebugContext(ctx, "enqueued scrape task", "task_id", taskID, "symbol", stock.Symbol) + } +} + +type TaskInfo struct { + Symbol string `json:"symbol"` +} diff --git a/backend/internal/leader/manager/manager.go b/backend/internal/leader/manager/manager.go new file mode 100644 index 0000000..61e27e0 --- /dev/null +++ b/backend/internal/leader/manager/manager.go @@ -0,0 +1,90 @@ +package manager + +import ( + "context" + "fmt" + "log/slog" + "sync" + + "github.com/ansg191/ibd-trader-backend/internal/config" + "github.com/ansg191/ibd-trader-backend/internal/database" + "github.com/ansg191/ibd-trader-backend/internal/ibd" + "github.com/ansg191/ibd-trader-backend/internal/leader/manager/ibd/auth" + ibd2 "github.com/ansg191/ibd-trader-backend/internal/leader/manager/ibd/scrape" + + "github.com/redis/go-redis/v9" + "github.com/robfig/cron/v3" +) + +type Manager struct { + db database.Database + Monitor *WorkerMonitor + Scraper *ibd2.Manager + Auth *auth.Manager +} + +func New( + ctx context.Context, + cfg *config.Config, + client *redis.Client, + db database.Database, + ibd *ibd.Client, +) (*Manager, error) { + scraperSchedule, err := cron.ParseStandard(cfg.IBD.Schedules.IBD50) + if err != nil { + return nil, fmt.Errorf("unable to parse IBD50 schedule: %w", err) + } + scraper, err := ibd2.New(ctx, ibd, db, client, scraperSchedule) + if err != nil { + return nil, err + } + + authSchedule, err := cron.ParseStandard(cfg.IBD.Schedules.Auth) + if err != nil { + return nil, fmt.Errorf("unable to parse Auth schedule: %w", err) + } + authManager, err := auth.New(ctx, db, client, authSchedule) + if err != nil { + return nil, err + } + + return &Manager{ + db: db, + Monitor: NewWorkerMonitor(client), + Scraper: scraper, + Auth: authManager, + }, nil +} + +func (m *Manager) Run(ctx context.Context) error { + if err := m.db.Migrate(ctx); err != nil { + slog.ErrorContext(ctx, "Unable to migrate database", "error", err) + return err + } + + var wg sync.WaitGroup + wg.Add(4) + + go func() { + defer wg.Done() + m.db.Maintenance(ctx) + }() + + go func() { + defer wg.Done() + m.Monitor.Start(ctx) + }() + + go func() { + defer wg.Done() + m.Scraper.Run(ctx) + }() + + go func() { + defer wg.Done() + m.Auth.Run(ctx) + }() + + wg.Wait() + return ctx.Err() +} diff --git a/backend/internal/leader/manager/monitor.go b/backend/internal/leader/manager/monitor.go new file mode 100644 index 0000000..3b2e3ec --- /dev/null +++ b/backend/internal/leader/manager/monitor.go @@ -0,0 +1,164 @@ +package manager + +import ( + "context" + "fmt" + "log/slog" + "time" + + "github.com/buraksezer/consistent" + "github.com/cespare/xxhash/v2" + "github.com/redis/go-redis/v9" +) + +const ( + MonitorInterval = 5 * time.Second + ActiveWorkersSet = "active-workers" +) + +// WorkerMonitor is a struct that monitors workers and their heartbeats over redis. +type WorkerMonitor struct { + client *redis.Client + + // ring is a consistent hash ring that distributes partitions over detected workers. + ring *consistent.Consistent + // layoutChangeCh is a channel that others can listen to for layout changes to the ring. + layoutChangeCh chan struct{} +} + +// NewWorkerMonitor creates a new WorkerMonitor. +func NewWorkerMonitor(client *redis.Client) *WorkerMonitor { + var members []consistent.Member + return &WorkerMonitor{ + client: client, + ring: consistent.New(members, consistent.Config{ + Hasher: new(hasher), + PartitionCount: consistent.DefaultPartitionCount, + ReplicationFactor: consistent.DefaultReplicationFactor, + Load: consistent.DefaultLoad, + }), + layoutChangeCh: make(chan struct{}), + } +} + +func (m *WorkerMonitor) Close() error { + close(m.layoutChangeCh) + return nil +} + +func (m *WorkerMonitor) Changes() <-chan struct{} { + return m.layoutChangeCh +} + +func (m *WorkerMonitor) Start(ctx context.Context) { + m.monitorWorkers(ctx) + ticker := time.NewTicker(MonitorInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + m.monitorWorkers(ctx) + } + } +} + +func (m *WorkerMonitor) monitorWorkers(ctx context.Context) { + ctx, cancel := context.WithTimeout(ctx, MonitorInterval) + defer cancel() + + // Get all active workers. + workers, err := m.client.SMembers(ctx, ActiveWorkersSet).Result() + if err != nil { + slog.ErrorContext(ctx, "Unable to get active workers", "error", err) + return + } + + // Get existing workers in the ring. + existingWorkers := m.ring.GetMembers() + ewMap := make(map[string]bool) + for _, worker := range existingWorkers { + ewMap[worker.String()] = false + } + + // Check workers' heartbeats. + for _, worker := range workers { + exists, err := m.client.Exists(ctx, WorkerHeartbeatKey(worker)).Result() + if err != nil { + slog.ErrorContext(ctx, "Unable to check worker heartbeat", "worker", worker, "error", err) + continue + } + + if exists == 0 { + slog.WarnContext(ctx, "Worker heartbeat not found", "worker", worker) + + // Remove worker from active workers set. + if err = m.client.SRem(ctx, ActiveWorkersSet, worker).Err(); err != nil { + slog.ErrorContext(ctx, "Unable to remove worker from active workers set", "worker", worker, "error", err) + } + + // Remove worker from the ring. + m.removeWorker(worker) + } else { + // Add worker to the ring if it doesn't exist. + if _, ok := ewMap[worker]; !ok { + slog.InfoContext(ctx, "New worker detected", "worker", worker) + m.addWorker(worker) + } else { + ewMap[worker] = true + } + } + } + + // Check for workers that are not active anymore. + for worker, exists := range ewMap { + if !exists { + slog.WarnContext(ctx, "Worker is not active anymore", "worker", worker) + m.removeWorker(worker) + } + } +} + +func (m *WorkerMonitor) addWorker(worker string) { + m.ring.Add(member{hostname: worker}) + + // Notify others about the layout change. + select { + case m.layoutChangeCh <- struct{}{}: + // Notify others. + default: + // No one is listening. + } +} + +func (m *WorkerMonitor) removeWorker(worker string) { + m.ring.Remove(worker) + + // Notify others about the layout change. + select { + case m.layoutChangeCh <- struct{}{}: + // Notify others. + default: + // No one is listening. + } +} + +func WorkerHeartbeatKey(hostname string) string { + return fmt.Sprintf("worker:%s:heartbeat", hostname) +} + +type hasher struct{} + +func (h *hasher) Sum64(data []byte) uint64 { + return xxhash.Sum64(data) +} + +type member struct { + hostname string +} + +func (m member) String() string { + return m.hostname +} diff --git a/backend/internal/redis/taskqueue/options.go b/backend/internal/redis/taskqueue/options.go new file mode 100644 index 0000000..2d5a23f --- /dev/null +++ b/backend/internal/redis/taskqueue/options.go @@ -0,0 +1,9 @@ +package taskqueue + +type Option[T any] func(*taskQueue[T]) + +func WithEncoding[T any](encoding Encoding) Option[T] { + return func(o *taskQueue[T]) { + o.encoding = encoding + } +} diff --git a/backend/internal/redis/taskqueue/queue.go b/backend/internal/redis/taskqueue/queue.go new file mode 100644 index 0000000..a4b799e --- /dev/null +++ b/backend/internal/redis/taskqueue/queue.go @@ -0,0 +1,545 @@ +package taskqueue + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/gob" + "encoding/json" + "errors" + "fmt" + "log/slog" + "reflect" + "strconv" + "strings" + "time" + + "github.com/redis/go-redis/v9" +) + +type Encoding uint8 + +const ( + EncodingJSON Encoding = iota + EncodingGob + + ResultKey = "result" + ErrorKey = "error" + NextAttemptKey = "next_attempt" +) + +var MaxAttempts = 3 +var ErrTaskNotFound = errors.New("task not found") + +type TaskQueue[T any] interface { + // Enqueue adds a task to the queue. + // Returns the generated task ID. + Enqueue(ctx context.Context, data T) (TaskInfo[T], error) + + // Dequeue removes a task from the queue and returns it. + // The task data is placed into dataOut. + // + // Dequeue blocks until a task is available, timeout, or the context is canceled. + // The returned task is placed in a pending state for lockTimeout duration. + // The task must be completed with Complete or extended with Extend before the lock expires. + // If the lock expires, the task is returned to the queue, where it may be picked up by another worker. + Dequeue( + ctx context.Context, + lockTimeout, + timeout time.Duration, + ) (*TaskInfo[T], error) + + // Extend extends the lock on a task. + Extend(ctx context.Context, taskID TaskID) error + + // Complete marks a task as complete. Optionally, an error can be provided to store additional information. + Complete(ctx context.Context, taskID TaskID, result string) error + + // Data returns the info of a task. + Data(ctx context.Context, taskID TaskID) (TaskInfo[T], error) + + // Return returns a task to the queue and returns the new task ID. + // Increments the attempt counter. + // Tasks with too many attempts (MaxAttempts) are considered failed and aren't returned to the queue. + Return(ctx context.Context, taskID TaskID, err error) (TaskID, error) + + // List returns a list of task IDs in the queue. + // The list is ordered by the time the task was added to the queue. The most recent task is first. + // The count parameter limits the number of tasks returned. + // The start and end parameters limit the range of tasks returned. + // End is exclusive. + // Start must be before end. + List(ctx context.Context, start, end TaskID, count int64) ([]TaskInfo[T], error) +} + +type TaskInfo[T any] struct { + // ID is the unique identifier of the task. Generated by redis. + ID TaskID + // Data is the task data. Stored in stream. + Data T + // Attempts is the number of times the task has been attempted. Stored in stream. + Attempts uint8 + // Result is the result of the task. Stored in a hash. + Result isTaskResult +} + +type isTaskResult interface { + isTaskResult() +} + +type TaskResultSuccess struct { + Result string +} + +type TaskResultError struct { + Error string + NextAttempt TaskID +} + +func (*TaskResultSuccess) isTaskResult() {} +func (*TaskResultError) isTaskResult() {} + +type TaskID struct { + timestamp time.Time + sequence uint64 +} + +func NewTaskID(timestamp time.Time, sequence uint64) TaskID { + return TaskID{timestamp, sequence} +} + +func ParseTaskID(s string) (TaskID, error) { + tPart, sPart, ok := strings.Cut(s, "-") + if !ok { + return TaskID{}, errors.New("invalid task ID") + } + + timestamp, err := strconv.ParseInt(tPart, 10, 64) + if err != nil { + return TaskID{}, err + } + + sequence, err := strconv.ParseUint(sPart, 10, 64) + if err != nil { + return TaskID{}, err + } + + return NewTaskID(time.UnixMilli(timestamp), sequence), nil +} + +func (t TaskID) Timestamp() time.Time { + return t.timestamp +} + +func (t TaskID) String() string { + tPart := strconv.FormatInt(t.timestamp.UnixMilli(), 10) + sPart := strconv.FormatUint(t.sequence, 10) + return tPart + "-" + sPart +} + +type taskQueue[T any] struct { + rdb *redis.Client + encoding Encoding + + streamKey string + groupName string + + workerName string +} + +func New[T any](ctx context.Context, rdb *redis.Client, name string, workerName string, opts ...Option[T]) (TaskQueue[T], error) { + tq := &taskQueue[T]{ + rdb: rdb, + encoding: EncodingJSON, + streamKey: "taskqueue:" + name, + groupName: "default", + workerName: workerName, + } + + for _, opt := range opts { + opt(tq) + } + + // Create the stream if it doesn't exist + err := rdb.XGroupCreateMkStream(ctx, tq.streamKey, tq.groupName, "0").Err() + if err != nil && err.Error() != "BUSYGROUP Consumer Group name already exists" { + return nil, err + } + + return tq, nil +} + +func (q *taskQueue[T]) Enqueue(ctx context.Context, data T) (TaskInfo[T], error) { + task := TaskInfo[T]{ + Data: data, + Attempts: 0, + } + + values, err := encode[T](task, q.encoding) + if err != nil { + return TaskInfo[T]{}, err + } + + taskID, err := q.rdb.XAdd(ctx, &redis.XAddArgs{ + Stream: q.streamKey, + Values: values, + }).Result() + if err != nil { + return TaskInfo[T]{}, err + } + + id, err := ParseTaskID(taskID) + if err != nil { + return TaskInfo[T]{}, err + } + task.ID = id + return task, nil +} + +func (q *taskQueue[T]) Dequeue(ctx context.Context, lockTimeout, timeout time.Duration) (*TaskInfo[T], error) { + // Try to recover a task + task, err := q.recover(ctx, lockTimeout) + if err != nil { + return nil, err + } + if task != nil { + return task, nil + } + + // Check for new tasks + ids, err := q.rdb.XReadGroup(ctx, &redis.XReadGroupArgs{ + Group: q.groupName, + Consumer: q.workerName, + Streams: []string{q.streamKey, ">"}, + Count: 1, + Block: timeout, + }).Result() + if err != nil && !errors.Is(err, redis.Nil) { + return nil, err + } + + if len(ids) == 0 || len(ids[0].Messages) == 0 || errors.Is(err, redis.Nil) { + return nil, nil + } + + msg := ids[0].Messages[0] + task = new(TaskInfo[T]) + *task, err = decode[T](&msg, q.encoding) + if err != nil { + return nil, err + } + return task, nil +} + +func (q *taskQueue[T]) Extend(ctx context.Context, taskID TaskID) error { + _, err := q.rdb.XClaim(ctx, &redis.XClaimArgs{ + Stream: q.streamKey, + Group: q.groupName, + Consumer: q.workerName, + MinIdle: 0, + Messages: []string{taskID.String()}, + }).Result() + if err != nil && !errors.Is(err, redis.Nil) { + return err + } + return nil +} + +func (q *taskQueue[T]) Data(ctx context.Context, taskID TaskID) (TaskInfo[T], error) { + msg, err := q.rdb.XRange(ctx, q.streamKey, taskID.String(), taskID.String()).Result() + if err != nil { + return TaskInfo[T]{}, err + } + + if len(msg) == 0 { + return TaskInfo[T]{}, ErrTaskNotFound + } + + t, err := decode[T](&msg[0], q.encoding) + if err != nil { + return TaskInfo[T]{}, err + } + + t.Result, err = q.getResult(ctx, taskID) + if err != nil { + return TaskInfo[T]{}, nil + } + return t, nil +} + +func (q *taskQueue[T]) Complete(ctx context.Context, taskID TaskID, result string) error { + return q.ack(ctx, taskID, false, result) +} + +var retScript = redis.NewScript(` +local stream_key = KEYS[1] +local hash_key = KEYS[2] + +-- Re-add the task to the stream +local task_id = redis.call('XADD', stream_key, '*', unpack(ARGV)) + +-- Update the hash key to point to the new task +redis.call('HSET', hash_key, 'next_attempt', task_id) + +return task_id +`) + +func (q *taskQueue[T]) Return(ctx context.Context, taskID TaskID, err1 error) (TaskID, error) { + msgs, err := q.rdb.XRange(ctx, q.streamKey, taskID.String(), taskID.String()).Result() + if err != nil { + return TaskID{}, err + } + if len(msgs) == 0 { + return TaskID{}, ErrTaskNotFound + } + + var ackMsg string + if err1 != nil { + ackMsg = err1.Error() + } + + // Ack the task + err = q.ack(ctx, taskID, true, ackMsg) + if err != nil { + return TaskID{}, err + } + + msg := msgs[0] + task, err := decode[T](&msg, q.encoding) + if err != nil { + return TaskID{}, err + } + + task.Attempts++ + if int(task.Attempts) >= MaxAttempts { + // Task has failed + slog.ErrorContext(ctx, "task failed completely", + "taskID", taskID, + "data", task.Data, + "attempts", task.Attempts, + "maxAttempts", MaxAttempts, + ) + return TaskID{}, nil + } + + valuesMap, err := encode[T](task, q.encoding) + if err != nil { + return TaskID{}, err + } + + values := make([]string, 0, len(valuesMap)*2) + for k, v := range valuesMap { + values = append(values, k, v) + } + + keys := []string{ + q.streamKey, + fmt.Sprintf("%s:%s", q.streamKey, taskID.String()), + } + newTaskId, err := retScript.Run(ctx, q.rdb, keys, values).Result() + if err != nil { + return TaskID{}, err + } + + return ParseTaskID(newTaskId.(string)) +} + +func (q *taskQueue[T]) List(ctx context.Context, start, end TaskID, count int64) ([]TaskInfo[T], error) { + if !start.timestamp.IsZero() && !end.timestamp.IsZero() && start.timestamp.After(end.timestamp) { + return nil, errors.New("start must be before end") + } + + var startStr, endStr string + if !start.timestamp.IsZero() { + startStr = start.String() + } else { + startStr = "-" + } + if !end.timestamp.IsZero() { + endStr = "(" + end.String() + } else { + endStr = "+" + } + + msgs, err := q.rdb.XRevRangeN(ctx, q.streamKey, endStr, startStr, count).Result() + if err != nil { + return nil, err + } + if len(msgs) == 0 { + return []TaskInfo[T]{}, nil + } + + tasks := make([]TaskInfo[T], len(msgs)) + for i := range msgs { + tasks[i], err = decode[T](&msgs[i], q.encoding) + if err != nil { + return nil, err + } + + tasks[i].Result, err = q.getResult(ctx, tasks[i].ID) + if err != nil { + return nil, err + } + } + + return tasks, nil +} + +func (q *taskQueue[T]) getResult(ctx context.Context, taskID TaskID) (isTaskResult, error) { + key := fmt.Sprintf("%s:%s", q.streamKey, taskID.String()) + results, err := q.rdb.HMGet(ctx, key, ResultKey, ErrorKey, NextAttemptKey).Result() + if err != nil { + return nil, err + } + + var ret isTaskResult + if results[0] != nil { + ret = &TaskResultSuccess{Result: results[0].(string)} + } else if results[1] != nil { + ret = &TaskResultError{Error: results[1].(string)} + if results[2] != nil { + nextAttempt, err := ParseTaskID(results[2].(string)) + if err != nil { + return nil, err + } + ret.(*TaskResultError).NextAttempt = nextAttempt + } + } + return ret, nil +} + +func (q *taskQueue[T]) recover(ctx context.Context, idleTimeout time.Duration) (*TaskInfo[T], error) { + msgs, _, err := q.rdb.XAutoClaim(ctx, &redis.XAutoClaimArgs{ + Stream: q.streamKey, + Group: q.groupName, + MinIdle: idleTimeout, + Start: "0", + Count: 1, + Consumer: q.workerName, + }).Result() + if err != nil { + return nil, err + } + + if len(msgs) == 0 { + return nil, nil + } + + msg := msgs[0] + task, err := decode[T](&msg, q.encoding) + if err != nil { + return nil, err + } + return &task, nil +} + +func (q *taskQueue[T]) ack(ctx context.Context, taskID TaskID, errored bool, msg string) error { + key := fmt.Sprintf("%s:%s", q.streamKey, taskID.String()) + _, err := q.rdb.TxPipelined(ctx, func(pipe redis.Pipeliner) error { + pipe.XAck(ctx, q.streamKey, q.groupName, taskID.String()) + if errored { + pipe.HSet(ctx, key, ErrorKey, msg) + } else { + pipe.HSet(ctx, key, ResultKey, msg) + } + return nil + }) + return err +} + +func decode[T any](msg *redis.XMessage, encoding Encoding) (task TaskInfo[T], err error) { + task.ID, err = ParseTaskID(msg.ID) + if err != nil { + return + } + + err = getField(msg, "attempts", &task.Attempts) + if err != nil { + return + } + + var data string + err = getField(msg, "data", &data) + if err != nil { + return + } + + switch encoding { + case EncodingJSON: + err = json.Unmarshal([]byte(data), &task.Data) + case EncodingGob: + var decoded []byte + decoded, err = base64.StdEncoding.DecodeString(data) + if err != nil { + return + } + err = gob.NewDecoder(bytes.NewReader(decoded)).Decode(&task.Data) + default: + err = errors.New("unsupported encoding") + } + return +} + +func getField(msg *redis.XMessage, field string, v any) error { + vVal, ok := msg.Values[field] + if !ok { + return errors.New("missing field") + } + + vStr, ok := vVal.(string) + if !ok { + return errors.New("invalid field type") + } + + value := reflect.ValueOf(v).Elem() + switch value.Kind() { + case reflect.String: + value.SetString(vStr) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + i, err := strconv.ParseInt(vStr, 10, 64) + if err != nil { + return err + } + value.SetInt(i) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + i, err := strconv.ParseUint(vStr, 10, 64) + if err != nil { + return err + } + value.SetUint(i) + case reflect.Bool: + b, err := strconv.ParseBool(vStr) + if err != nil { + return err + } + value.SetBool(b) + default: + return errors.New("unsupported field type") + } + return nil +} + +func encode[T any](task TaskInfo[T], encoding Encoding) (ret map[string]string, err error) { + ret = make(map[string]string) + ret["attempts"] = strconv.FormatUint(uint64(task.Attempts), 10) + + switch encoding { + case EncodingJSON: + var data []byte + data, err = json.Marshal(task.Data) + if err != nil { + return + } + ret["data"] = string(data) + case EncodingGob: + var data bytes.Buffer + err = gob.NewEncoder(&data).Encode(task.Data) + if err != nil { + return + } + ret["data"] = base64.StdEncoding.EncodeToString(data.Bytes()) + default: + err = errors.New("unsupported encoding") + } + return +} diff --git a/backend/internal/redis/taskqueue/queue_test.go b/backend/internal/redis/taskqueue/queue_test.go new file mode 100644 index 0000000..ee95d39 --- /dev/null +++ b/backend/internal/redis/taskqueue/queue_test.go @@ -0,0 +1,467 @@ +package taskqueue + +import ( + "context" + "errors" + "fmt" + "log" + "testing" + "time" + + "github.com/ory/dockertest/v3" + "github.com/ory/dockertest/v3/docker" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var client *redis.Client + +func TestMain(m *testing.M) { + pool, err := dockertest.NewPool("") + if err != nil { + log.Fatalf("Could not create pool: %s", err) + } + + err = pool.Client.Ping() + if err != nil { + log.Fatalf("Could not connect to Docker: %s", err) + } + + //resource, err := pool.Run("redis", "7", nil) + resource, err := pool.RunWithOptions(&dockertest.RunOptions{ + Repository: "redis", + Tag: "7", + }, func(config *docker.HostConfig) { + config.AutoRemove = true + config.RestartPolicy = docker.RestartPolicy{Name: "no"} + }) + if err != nil { + log.Fatalf("Could not start resource: %s", err) + } + + //_ = resource.Expire(60) + + if err = pool.Retry(func() error { + client = redis.NewClient(&redis.Options{ + Addr: fmt.Sprintf("localhost:%s", resource.GetPort("6379/tcp")), + }) + return client.Ping(context.Background()).Err() + }); err != nil { + log.Fatalf("Could not connect to redis: %s", err) + } + + defer func() { + if err = client.Close(); err != nil { + log.Printf("Could not close client: %s", err) + } + if err = pool.Purge(resource); err != nil { + log.Fatalf("Could not purge resource: %s", err) + } + }() + + m.Run() +} + +func TestTaskQueue(t *testing.T) { + if testing.Short() { + t.Skip() + } + + lockTimeout := 100 * time.Millisecond + + tests := []struct { + name string + f func(t *testing.T) + }{ + { + name: "Create queue", + f: func(t *testing.T) { + q, err := New[string](context.Background(), client, "test", "worker") + require.NoError(t, err) + require.NotNil(t, q) + }, + }, + { + name: "enqueue & dequeue", + f: func(t *testing.T) { + q, err := New[string](context.Background(), client, "test", "worker") + require.NoError(t, err) + require.NotNil(t, q) + + taskId, err := q.Enqueue(context.Background(), "hello") + require.NoError(t, err) + require.NotEmpty(t, taskId) + + task, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) + require.NoError(t, err) + require.NotNil(t, task) + require.Equal(t, "hello", task.Data) + }, + }, + { + name: "complex data", + f: func(t *testing.T) { + type foo struct { + A int + B string + } + + q, err := New[foo](context.Background(), client, "test", "worker") + require.NoError(t, err) + require.NotNil(t, q) + + taskId, err := q.Enqueue(context.Background(), foo{A: 42, B: "hello"}) + require.NoError(t, err) + require.NotEmpty(t, taskId) + + task, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) + require.NoError(t, err) + require.NotNil(t, task) + require.Equal(t, foo{A: 42, B: "hello"}, task.Data) + }, + }, + { + name: "different workers", + f: func(t *testing.T) { + q1, err := New[string](context.Background(), client, "test", "worker1") + require.NoError(t, err) + require.NotNil(t, q1) + + q2, err := New[string](context.Background(), client, "test", "worker2") + require.NoError(t, err) + require.NotNil(t, q2) + + taskId, err := q1.Enqueue(context.Background(), "hello") + require.NoError(t, err) + require.NotEmpty(t, taskId) + + task, err := q2.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) + require.NoError(t, err) + require.NotNil(t, task) + assert.Equal(t, "hello", task.Data) + }, + }, + { + name: "complete", + f: func(t *testing.T) { + q, err := New[string](context.Background(), client, "test", "worker") + require.NoError(t, err) + require.NotNil(t, q) + + // Enqueue a task + taskId, err := q.Enqueue(context.Background(), "hello") + require.NoError(t, err) + require.NotEmpty(t, taskId) + + // Dequeue the task + task, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) + require.NoError(t, err) + require.NotNil(t, task) + assert.Equal(t, "hello", task.Data) + + // Complete the task + err = q.Complete(context.Background(), task.ID, "done") + require.NoError(t, err) + + // Try to dequeue the task again + task, err = q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) + require.NoError(t, err) + require.Nil(t, task) + }, + }, + { + name: "timeout", + f: func(t *testing.T) { + q, err := New[string](context.Background(), client, "test", "worker") + require.NoError(t, err) + require.NotNil(t, q) + + // Enqueue a task + taskId, err := q.Enqueue(context.Background(), "hello") + require.NoError(t, err) + require.NotEmpty(t, taskId) + + // Dequeue the task + task, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) + require.NoError(t, err) + require.NotNil(t, task) + assert.Equal(t, "hello", task.Data) + + // Wait for the lock to expire + time.Sleep(lockTimeout + 10*time.Millisecond) + + // Try to dequeue the task again + task, err = q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) + require.NoError(t, err) + require.NotNil(t, task) + assert.Equal(t, "hello", task.Data) + }, + }, + { + name: "extend", + f: func(t *testing.T) { + q, err := New[string](context.Background(), client, "test", "worker") + require.NoError(t, err) + require.NotNil(t, q) + + // Enqueue a task + taskId, err := q.Enqueue(context.Background(), "hello") + require.NoError(t, err) + require.NotEmpty(t, taskId) + + // Dequeue the task + task, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) + require.NoError(t, err) + require.NotNil(t, task) + assert.Equal(t, "hello", task.Data) + + // Wait for the lock to expire + time.Sleep(lockTimeout + 10*time.Millisecond) + + // Extend the lock + err = q.Extend(context.Background(), task.ID) + require.NoError(t, err) + + // Try to dequeue the task again + task, err = q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) + require.NoError(t, err) + require.Nil(t, task) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := client.FlushDB(context.Background()).Err(); err != nil { + t.Fatal(err) + } + + tt.f(t) + }) + } + + _ = client.FlushDB(context.Background()) +} + +func TestTaskQueue_List(t *testing.T) { + if testing.Short() { + t.Skip() + } + + tests := []struct { + name string + f func(t *testing.T) + }{ + { + name: "empty", + f: func(t *testing.T) { + q, err := New[any](context.Background(), client, "test", "worker") + require.NoError(t, err) + require.NotNil(t, q) + + tasks, err := q.List(context.Background(), TaskID{}, TaskID{}, 1) + require.NoError(t, err) + assert.Empty(t, tasks) + }, + }, + { + name: "single", + f: func(t *testing.T) { + q, err := New[string](context.Background(), client, "test", "worker") + require.NoError(t, err) + require.NotNil(t, q) + + taskID, err := q.Enqueue(context.Background(), "hello") + require.NoError(t, err) + + tasks, err := q.List(context.Background(), TaskID{}, TaskID{}, 1) + require.NoError(t, err) + assert.Len(t, tasks, 1) + assert.Equal(t, taskID, tasks[0]) + }, + }, + { + name: "multiple", + f: func(t *testing.T) { + q, err := New[string](context.Background(), client, "test", "worker") + require.NoError(t, err) + require.NotNil(t, q) + + taskID, err := q.Enqueue(context.Background(), "hello") + require.NoError(t, err) + taskID2, err := q.Enqueue(context.Background(), "world") + require.NoError(t, err) + + tasks, err := q.List(context.Background(), TaskID{}, TaskID{}, 2) + require.NoError(t, err) + assert.Len(t, tasks, 2) + assert.Equal(t, taskID, tasks[1]) + assert.Equal(t, taskID2, tasks[0]) + }, + }, + { + name: "multiple limited", + f: func(t *testing.T) { + q, err := New[string](context.Background(), client, "test", "worker") + require.NoError(t, err) + require.NotNil(t, q) + + _, err = q.Enqueue(context.Background(), "hello") + require.NoError(t, err) + taskID2, err := q.Enqueue(context.Background(), "world") + require.NoError(t, err) + + tasks, err := q.List(context.Background(), TaskID{}, TaskID{}, 1) + require.NoError(t, err) + assert.Len(t, tasks, 1) + assert.Equal(t, taskID2, tasks[0]) + }, + }, + { + name: "multiple time range", + f: func(t *testing.T) { + q, err := New[string](context.Background(), client, "test", "worker") + require.NoError(t, err) + require.NotNil(t, q) + + taskID, err := q.Enqueue(context.Background(), "hello") + require.NoError(t, err) + time.Sleep(10 * time.Millisecond) + taskID2, err := q.Enqueue(context.Background(), "world") + require.NoError(t, err) + + tasks, err := q.List(context.Background(), TaskID{}, taskID2.ID, 100) + require.NoError(t, err) + assert.Len(t, tasks, 1) + assert.Equal(t, taskID, tasks[0]) + + tasks, err = q.List(context.Background(), taskID2.ID, TaskID{}, 100) + require.NoError(t, err) + assert.Len(t, tasks, 1) + assert.Equal(t, taskID2, tasks[0]) + }, + }, + { + name: "completed tasks", + f: func(t *testing.T) { + q, err := New[string](context.Background(), client, "test", "worker") + require.NoError(t, err) + require.NotNil(t, q) + + task1, err := q.Enqueue(context.Background(), "hello") + require.NoError(t, err) + task2, err := q.Enqueue(context.Background(), "world") + require.NoError(t, err) + + err = q.Complete(context.Background(), task1.ID, "done") + require.NoError(t, err) + + tasks, err := q.List(context.Background(), TaskID{}, TaskID{}, 100) + require.NoError(t, err) + assert.Len(t, tasks, 2) + assert.Equal(t, task2, tasks[0]) + + assert.Equal(t, "hello", tasks[1].Data) + require.IsType(t, &TaskResultSuccess{}, tasks[1].Result) + assert.Equal(t, "done", tasks[1].Result.(*TaskResultSuccess).Result) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := client.FlushDB(context.Background()).Err(); err != nil { + t.Fatal(err) + } + + tt.f(t) + }) + } + + _ = client.FlushDB(context.Background()) +} + +func TestTaskQueue_Return(t *testing.T) { + if testing.Short() { + t.Skip() + } + + lockTimeout := 100 * time.Millisecond + + tests := []struct { + name string + f func(t *testing.T) + }{ + { + name: "simple", + f: func(t *testing.T) { + q, err := New[string](context.Background(), client, "test", "worker") + require.NoError(t, err) + require.NotNil(t, q) + + task1, err := q.Enqueue(context.Background(), "hello") + require.NoError(t, err) + + id := claimAndFail(t, q, lockTimeout) + + task2, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) + require.NoError(t, err) + require.NotNil(t, task2) + assert.Equal(t, task2.ID, id) + assert.Equal(t, task1.Data, task2.Data) + assert.Equal(t, uint8(1), task2.Attempts) + + task1Data, err := q.Data(context.Background(), task1.ID) + require.NoError(t, err) + assert.Equal(t, task1Data.ID, task1.ID) + assert.Equal(t, task1Data.Data, task1.Data) + assert.IsType(t, &TaskResultError{}, task1Data.Result) + assert.Equal(t, "failed", task1Data.Result.(*TaskResultError).Error) + assert.Equal(t, task2.ID, task1Data.Result.(*TaskResultError).NextAttempt) + }, + }, + { + name: "failure", + f: func(t *testing.T) { + q, err := New[string](context.Background(), client, "test", "worker") + require.NoError(t, err) + require.NotNil(t, q) + + _, err = q.Enqueue(context.Background(), "hello") + require.NoError(t, err) + + claimAndFail(t, q, lockTimeout) + claimAndFail(t, q, lockTimeout) + claimAndFail(t, q, lockTimeout) + + task3, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) + require.NoError(t, err) + assert.Nil(t, task3) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := client.FlushDB(context.Background()).Err(); err != nil { + t.Fatal(err) + } + + tt.f(t) + }) + } + + _ = client.FlushDB(context.Background()) +} + +func claimAndFail[T any](t *testing.T, q TaskQueue[T], lockTimeout time.Duration) TaskID { + task, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) + require.NoError(t, err) + require.NotNil(t, task) + + id, err := q.Return(context.Background(), task.ID, errors.New("failed")) + require.NoError(t, err) + assert.NotEqual(t, task.ID, id) + return id +} diff --git a/backend/internal/server/idb/stock/v1/stock.go b/backend/internal/server/idb/stock/v1/stock.go new file mode 100644 index 0000000..8afc2b1 --- /dev/null +++ b/backend/internal/server/idb/stock/v1/stock.go @@ -0,0 +1,64 @@ +package stock + +import ( + "context" + "fmt" + "log/slog" + + pb "github.com/ansg191/ibd-trader-backend/api/gen/idb/stock/v1" + "github.com/ansg191/ibd-trader-backend/internal/database" + "github.com/ansg191/ibd-trader-backend/internal/leader/manager/ibd/scrape" + "github.com/ansg191/ibd-trader-backend/internal/redis/taskqueue" + + "cloud.google.com/go/longrunning/autogen/longrunningpb" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/timestamppb" +) + +const ScrapeOperationPrefix = "scrape" + +type Server struct { + pb.UnimplementedStockServiceServer + + db database.Executor + queue taskqueue.TaskQueue[scrape.TaskInfo] +} + +func New(db database.Executor, queue taskqueue.TaskQueue[scrape.TaskInfo]) *Server { + return &Server{db: db, queue: queue} +} + +func (s *Server) CreateStock(ctx context.Context, request *pb.CreateStockRequest) (*pb.CreateStockResponse, error) { + task, err := s.queue.Enqueue(ctx, scrape.TaskInfo{Symbol: request.Symbol}) + if err != nil { + slog.ErrorContext(ctx, "failed to enqueue task", "err", err) + return nil, status.New(codes.Internal, "failed to enqueue task").Err() + } + op := &longrunningpb.Operation{ + Name: fmt.Sprintf("%s/%s", ScrapeOperationPrefix, task.ID.String()), + Metadata: new(anypb.Any), + Done: false, + Result: nil, + } + err = op.Metadata.MarshalFrom(&pb.StockScrapeOperationMetadata{ + Symbol: request.Symbol, + StartTime: timestamppb.New(task.ID.Timestamp()), + }) + if err != nil { + slog.ErrorContext(ctx, "failed to marshal metadata", "err", err) + return nil, status.New(codes.Internal, "failed to marshal metadata").Err() + } + return &pb.CreateStockResponse{Operation: op}, nil +} + +func (s *Server) GetStock(ctx context.Context, request *pb.GetStockRequest) (*pb.GetStockResponse, error) { + //TODO implement me + panic("implement me") +} + +func (s *Server) ListStocks(ctx context.Context, request *pb.ListStocksRequest) (*pb.ListStocksResponse, error) { + //TODO implement me + panic("implement me") +} diff --git a/backend/internal/server/idb/user/v1/user.go b/backend/internal/server/idb/user/v1/user.go new file mode 100644 index 0000000..2f32e03 --- /dev/null +++ b/backend/internal/server/idb/user/v1/user.go @@ -0,0 +1,159 @@ +package user + +import ( + "context" + "errors" + + pb "github.com/ansg191/ibd-trader-backend/api/gen/idb/user/v1" + "github.com/ansg191/ibd-trader-backend/internal/database" + "github.com/ansg191/ibd-trader-backend/internal/ibd" + "github.com/ansg191/ibd-trader-backend/internal/keys" + + "github.com/mennanov/fmutils" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" +) + +type Server struct { + pb.UnimplementedUserServiceServer + + db database.TransactionExecutor + kms keys.KeyManagementService + keyName string + client *ibd.Client +} + +func New(db database.TransactionExecutor, kms keys.KeyManagementService, keyName string, client *ibd.Client) *Server { + return &Server{ + db: db, + kms: kms, + keyName: keyName, + client: client, + } +} + +func (u *Server) CreateUser(ctx context.Context, request *pb.CreateUserRequest) (*pb.CreateUserResponse, error) { + err := database.AddUser(ctx, u.db, request.Subject) + if err != nil { + return nil, status.Errorf(codes.Internal, "unable to create user: %v", err) + } + + user, err := database.GetUser(ctx, u.db, request.Subject) + if err != nil { + return nil, status.Errorf(codes.Internal, "unable to get user: %v", err) + } + + return &pb.CreateUserResponse{ + User: &pb.User{ + Subject: user.Subject, + IbdUsername: user.IBDUsername, + IbdPassword: nil, + }, + }, nil +} + +func (u *Server) GetUser(ctx context.Context, request *pb.GetUserRequest) (*pb.GetUserResponse, error) { + user, err := database.GetUser(ctx, u.db, request.Subject) + if errors.Is(err, database.ErrUserNotFound) { + return nil, status.New(codes.NotFound, "user not found").Err() + } + if err != nil { + return nil, status.Errorf(codes.Internal, "unable to get user: %v", err) + } + + return &pb.GetUserResponse{ + User: &pb.User{ + Subject: user.Subject, + IbdUsername: user.IBDUsername, + IbdPassword: nil, + }, + }, nil +} + +func (u *Server) UpdateUser(ctx context.Context, request *pb.UpdateUserRequest) (*pb.UpdateUserResponse, error) { + request.UpdateMask.Normalize() + if !request.UpdateMask.IsValid(request.User) { + return nil, status.Errorf(codes.InvalidArgument, "invalid update mask") + } + + existingUserRes, err := u.GetUser(ctx, &pb.GetUserRequest{Subject: request.User.Subject}) + if err != nil { + return nil, err + } + existingUser := existingUserRes.User + + newUser := proto.Clone(existingUser).(*pb.User) + fmutils.Overwrite(request.User, newUser, request.UpdateMask.Paths) + + // if IDB creds are both set and are different, update them + if (newUser.IbdPassword != nil && newUser.IbdUsername != nil) && + (newUser.IbdPassword != existingUser.IbdPassword || + newUser.IbdUsername != existingUser.IbdUsername) { + // Update IBD creds + err = database.AddIBDCreds(ctx, u.db, u.kms, u.keyName, newUser.Subject, *newUser.IbdUsername, *newUser.IbdPassword) + if err != nil { + return nil, status.Errorf(codes.Internal, "unable to update user: %v", err) + } + } + + newUser.IbdPassword = nil + return &pb.UpdateUserResponse{ + User: newUser, + }, nil +} + +func (u *Server) CheckIBDUsername(ctx context.Context, req *pb.CheckIBDUsernameRequest) (*pb.CheckIBDUsernameResponse, error) { + username := req.IbdUsername + if username == "" { + return nil, status.Errorf(codes.InvalidArgument, "username cannot be empty") + } + + // Check if the username exists + exists, err := u.client.CheckIBDUsername(ctx, username) + if err != nil { + return nil, status.Errorf(codes.Internal, "unable to check username: %v", err) + } + + return &pb.CheckIBDUsernameResponse{ + Exists: exists, + }, nil +} + +func (u *Server) AuthenticateUser(ctx context.Context, req *pb.AuthenticateUserRequest) (*pb.AuthenticateUserResponse, error) { + // Check if user has cookies + cookies, err := database.GetCookies(ctx, u.db, u.kms, req.Subject, false) + if err != nil { + return nil, status.Errorf(codes.Internal, "unable to get cookies: %v", err) + } + if len(cookies) > 0 { + return &pb.AuthenticateUserResponse{ + Authenticated: true, + }, nil + } + + // Authenticate user + // Get IBD creds + username, password, err := database.GetIBDCreds(ctx, u.db, u.kms, req.Subject) + if errors.Is(err, database.ErrIBDCredsNotFound) { + return nil, status.New(codes.NotFound, "User has no IDB creds").Err() + } + if err != nil { + return nil, status.Errorf(codes.Internal, "unable to get IBD creds: %v", err) + } + + // Authenticate user + cookie, err := u.client.Authenticate(ctx, username, password) + if errors.Is(err, ibd.ErrBadCredentials) { + return &pb.AuthenticateUserResponse{ + Authenticated: false, + }, nil + } + if err != nil { + return nil, status.Errorf(codes.Internal, "unable to authenticate user: %v", err) + } + + return &pb.AuthenticateUserResponse{ + Authenticated: cookie != nil, + }, nil +} diff --git a/backend/internal/server/operations.go b/backend/internal/server/operations.go new file mode 100644 index 0000000..2487427 --- /dev/null +++ b/backend/internal/server/operations.go @@ -0,0 +1,142 @@ +package server + +import ( + "context" + "errors" + "fmt" + "log/slog" + "strings" + + "cloud.google.com/go/longrunning/autogen/longrunningpb" + spb "github.com/ansg191/ibd-trader-backend/api/gen/idb/stock/v1" + "github.com/ansg191/ibd-trader-backend/internal/leader/manager/ibd/scrape" + "github.com/ansg191/ibd-trader-backend/internal/redis/taskqueue" + "github.com/ansg191/ibd-trader-backend/internal/server/idb/stock/v1" + epb "google.golang.org/genproto/googleapis/rpc/errdetails" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/timestamppb" +) + +type operationServer struct { + longrunningpb.UnimplementedOperationsServer + + scrape taskqueue.TaskQueue[scrape.TaskInfo] +} + +func newOperationServer(scrapeQueue taskqueue.TaskQueue[scrape.TaskInfo]) *operationServer { + return &operationServer{scrape: scrapeQueue} +} + +func (o *operationServer) ListOperations( + ctx context.Context, + req *longrunningpb.ListOperationsRequest, +) (*longrunningpb.ListOperationsResponse, error) { + var end taskqueue.TaskID + if req.PageToken != "" { + var err error + end, err = taskqueue.ParseTaskID(req.PageToken) + if err != nil { + return nil, status.New(codes.InvalidArgument, err.Error()).Err() + } + } else { + end = taskqueue.TaskID{} + } + + switch req.Name { + case stock.ScrapeOperationPrefix: + tasks, err := o.scrape.List(ctx, taskqueue.TaskID{}, end, int64(req.PageSize)) + if err != nil { + return nil, status.New(codes.Internal, "unable to list IDs").Err() + } + + ops := make([]*longrunningpb.Operation, len(tasks)) + for i, task := range tasks { + ops[i] = &longrunningpb.Operation{ + Name: fmt.Sprintf("%s/%s", stock.ScrapeOperationPrefix, task.ID.String()), + Metadata: new(anypb.Any), + Done: task.Result != nil, + Result: nil, + } + err = ops[i].Metadata.MarshalFrom(&spb.StockScrapeOperationMetadata{ + Symbol: task.Data.Symbol, + StartTime: timestamppb.New(task.ID.Timestamp()), + }) + if err != nil { + return nil, status.New(codes.Internal, "unable to marshal metadata").Err() + } + + switch res := task.Result.(type) { + case *taskqueue.TaskResultSuccess: + return nil, status.New(codes.Unimplemented, "not implemented").Err() + case *taskqueue.TaskResultError: + s := status.New(codes.Unknown, res.Error) + s, err = s.WithDetails( + &epb.ErrorInfo{ + Reason: "", + Domain: "", + Metadata: nil, + }) + if err != nil { + return nil, status.New(codes.Internal, "unable to marshal error details").Err() + } + ops[i].Result = &longrunningpb.Operation_Error{Error: s.Proto()} + } + } + + var nextPageToken string + if len(tasks) == int(req.PageSize) { + nextPageToken = tasks[len(tasks)-1].ID.String() + } else { + nextPageToken = "" + } + + return &longrunningpb.ListOperationsResponse{ + Operations: ops, + NextPageToken: nextPageToken, + }, nil + default: + return nil, status.New(codes.NotFound, "unknown operation type").Err() + } +} + +func (o *operationServer) GetOperation(ctx context.Context, req *longrunningpb.GetOperationRequest) (*longrunningpb.Operation, error) { + prefix, id, ok := strings.Cut(req.Name, "/") + if !ok || prefix == "" || id == "" { + return nil, status.New(codes.InvalidArgument, "invalid operation name").Err() + } + + taskID, err := taskqueue.ParseTaskID(id) + if err != nil { + return nil, status.New(codes.InvalidArgument, err.Error()).Err() + } + + switch prefix { + case stock.ScrapeOperationPrefix: + task, err := o.scrape.Data(ctx, taskID) + if errors.Is(err, taskqueue.ErrTaskNotFound) { + return nil, status.New(codes.NotFound, "operation not found").Err() + } + if err != nil { + slog.ErrorContext(ctx, "unable to get operation", "error", err) + return nil, status.New(codes.Internal, "unable to get operation").Err() + } + op := &longrunningpb.Operation{ + Name: req.Name, + Metadata: new(anypb.Any), + Done: task.Result != nil, + Result: nil, + } + err = op.Metadata.MarshalFrom(&spb.StockScrapeOperationMetadata{ + Symbol: task.Data.Symbol, + StartTime: timestamppb.New(task.ID.Timestamp()), + }) + if err != nil { + return nil, status.New(codes.Internal, "unable to marshal metadata").Err() + } + return op, nil + default: + return nil, status.New(codes.NotFound, "unknown operation type").Err() + } +} diff --git a/backend/internal/server/server.go b/backend/internal/server/server.go new file mode 100644 index 0000000..c525cfd --- /dev/null +++ b/backend/internal/server/server.go @@ -0,0 +1,77 @@ +package server + +import ( + "context" + "fmt" + "log/slog" + "net" + + "cloud.google.com/go/longrunning/autogen/longrunningpb" + spb "github.com/ansg191/ibd-trader-backend/api/gen/idb/stock/v1" + upb "github.com/ansg191/ibd-trader-backend/api/gen/idb/user/v1" + "github.com/ansg191/ibd-trader-backend/internal/database" + "github.com/ansg191/ibd-trader-backend/internal/ibd" + "github.com/ansg191/ibd-trader-backend/internal/keys" + "github.com/ansg191/ibd-trader-backend/internal/leader/manager/ibd/scrape" + "github.com/ansg191/ibd-trader-backend/internal/redis/taskqueue" + "github.com/ansg191/ibd-trader-backend/internal/server/idb/stock/v1" + "github.com/ansg191/ibd-trader-backend/internal/server/idb/user/v1" + "github.com/redis/go-redis/v9" + "google.golang.org/grpc" + "google.golang.org/grpc/reflection" +) + +//go:generate make -C ../../api/ generate + +type Server struct { + s *grpc.Server + port uint16 +} + +func New( + ctx context.Context, + port uint16, + db database.TransactionExecutor, + rClient *redis.Client, + client *ibd.Client, + kms keys.KeyManagementService, + keyName string, +) (*Server, error) { + scrapeQueue, err := taskqueue.New( + ctx, + rClient, + scrape.Queue, + "grpc-server", + taskqueue.WithEncoding[scrape.TaskInfo](scrape.QueueEncoding)) + if err != nil { + return nil, err + } + + s := grpc.NewServer() + upb.RegisterUserServiceServer(s, user.New(db, kms, keyName, client)) + spb.RegisterStockServiceServer(s, stock.New(db, scrapeQueue)) + longrunningpb.RegisterOperationsServer(s, newOperationServer(scrapeQueue)) + reflection.Register(s) + return &Server{s, port}, nil +} + +func (s *Server) Serve(ctx context.Context) error { + lis, err := net.Listen("tcp", fmt.Sprintf(":%d", s.port)) + if err != nil { + return err + } + + // Graceful shutdown + go func() { + <-ctx.Done() + slog.ErrorContext(ctx, + "Shutting down server", + "err", ctx.Err(), + "cause", context.Cause(ctx), + ) + s.s.GracefulStop() + }() + + slog.InfoContext(ctx, "Starting gRPC server", "port", s.port) + return s.s.Serve(lis) +} diff --git a/backend/internal/utils/money.go b/backend/internal/utils/money.go new file mode 100644 index 0000000..2dc2286 --- /dev/null +++ b/backend/internal/utils/money.go @@ -0,0 +1,99 @@ +package utils + +import ( + "fmt" + "strconv" + "strings" + + "github.com/Rhymond/go-money" +) + +// supported currencies +var currencies = money.Currencies{ + "USD": money.GetCurrency(money.USD), + "EUR": money.GetCurrency(money.EUR), + "GBP": money.GetCurrency(money.GBP), + "JPY": money.GetCurrency(money.JPY), + "CNY": money.GetCurrency(money.CNY), +} + +func ParseMoney(s string) (*money.Money, error) { + for _, c := range currencies { + numPart, ok := isCurrency(s, c) + if !ok { + continue + } + + // Parse the number part + num, err := strconv.ParseUint(numPart, 10, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse number: %w", err) + } + + return money.New(int64(num), c.Code), nil + } + return nil, fmt.Errorf("matching currency not found") +} + +func isCurrency(s string, c *money.Currency) (string, bool) { + var numPart string + for _, tp := range c.Template { + switch tp { + case '$': + // There should be a matching grapheme in the s at this position + remaining, ok := strings.CutPrefix(s, c.Grapheme) + if !ok { + return "", false + } + s = remaining + case '1': + // There should be a number, thousands, or decimal separator in the s at this position + // Number of expected decimal places + decimalFound := -1 + // Read from string until a non-number, non-thousands, non-decimal, or EOF is found + for len(s) > 0 && (string(s[0]) == c.Thousand || string(s[0]) == c.Decimal || '0' <= s[0] && s[0] <= '9') { + // If the character is a number + if '0' <= s[0] && s[0] <= '9' { + // If we've hit decimal limit, break + if decimalFound == 0 { + break + } + // add the number to the numPart + numPart += string(s[0]) + // Decrement the decimal count + // If the decimal has been found, `decimalFound` is positive + // If the decimal hasn't been found, `decimalFound` is negative, and decrementing it does nothing + decimalFound-- + } + // If decimal has been found (>= 0) and the character is a thousand separator or decimal separator, + // then the number is invalid + if decimalFound >= 0 && (string(s[0]) == c.Thousand || string(s[0]) == c.Decimal) { + return "", false + } + // If the character is a decimal separator, set `decimalFound` to the number of + // expected decimal places for the currency + if string(s[0]) == c.Decimal { + decimalFound = c.Fraction + } + // Move to the next character + s = s[1:] + } + if decimalFound > 0 { + // If there should be more decimal places, add them + numPart += strings.Repeat("0", decimalFound) + } else if decimalFound < 0 { + // If no decimal was found, add the expected number of decimal places + numPart += strings.Repeat("0", c.Fraction) + } + case ' ': + // There should be a space in the s at this position + if len(s) == 0 || s[0] != ' ' { + return "", false + } + s = s[1:] + default: + panic(fmt.Sprintf("unsupported template character: %c", tp)) + } + } + return numPart, true +} diff --git a/backend/internal/utils/money_test.go b/backend/internal/utils/money_test.go new file mode 100644 index 0000000..27ace06 --- /dev/null +++ b/backend/internal/utils/money_test.go @@ -0,0 +1,106 @@ +package utils + +import ( + "testing" + + "github.com/Rhymond/go-money" + "github.com/stretchr/testify/assert" +) + +func TestParseMoney(t *testing.T) { + tests := []struct { + name string + input string + want *money.Money + }{ + { + name: "en-US int no comma", + input: "$123", + want: money.New(12300, money.USD), + }, + { + name: "en-US int comma", + input: "$1,123", + want: money.New(112300, money.USD), + }, + { + name: "en-US decimal comma", + input: "$1,123.45", + want: money.New(112345, money.USD), + }, + { + name: "zh-CN decimal comma", + input: "1,234.56 \u5143", + want: money.New(123456, money.CNY), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m, err := ParseMoney(tt.input) + assert.NoError(t, err) + assert.Equal(t, tt.want, m) + }) + } +} + +func Test_isCurrency(t *testing.T) { + tests := []struct { + name string + input string + currency string + numPart string + ok bool + }{ + { + name: "en-US int no comma", + input: "$123", + currency: money.USD, + numPart: "12300", + ok: true, + }, + { + name: "en-US int comma", + input: "$1,123", + currency: money.USD, + numPart: "112300", + ok: true, + }, + { + name: "en-US decimal comma", + input: "$1,123.45", + currency: money.USD, + numPart: "112345", + ok: true, + }, + { + name: "en-US 1 decimal comma", + input: "$1,123.5", + currency: money.USD, + numPart: "112350", + ok: true, + }, + { + name: "en-US no grapheme", + input: "1,234.56", + currency: money.USD, + numPart: "", + ok: false, + }, + { + name: "zh-CN decimal comma", + input: "1,234.56 \u5143", + currency: money.CNY, + numPart: "123456", + ok: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := money.GetCurrency(tt.currency) + numPart, ok := isCurrency(tt.input, c) + assert.Equal(t, tt.ok, ok) + assert.Equal(t, tt.numPart, numPart) + }) + } +} diff --git a/backend/internal/worker/analyzer/analyzer.go b/backend/internal/worker/analyzer/analyzer.go new file mode 100644 index 0000000..20621dd --- /dev/null +++ b/backend/internal/worker/analyzer/analyzer.go @@ -0,0 +1,142 @@ +package analyzer + +import ( + "context" + "log/slog" + "time" + + "github.com/ansg191/ibd-trader-backend/internal/analyzer" + "github.com/ansg191/ibd-trader-backend/internal/database" + "github.com/ansg191/ibd-trader-backend/internal/redis/taskqueue" + + "github.com/redis/go-redis/v9" +) + +const ( + Queue = "analyzer" + QueueEncoding = taskqueue.EncodingJSON + + lockTimeout = 1 * time.Minute + dequeueTimeout = 5 * time.Second +) + +func RunAnalyzer( + ctx context.Context, + redis *redis.Client, + analyzer analyzer.Analyzer, + db database.Executor, + name string, +) error { + queue, err := taskqueue.New( + ctx, + redis, + Queue, + name, + taskqueue.WithEncoding[TaskInfo](QueueEncoding), + ) + if err != nil { + return err + } + + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + waitForTask(ctx, queue, analyzer, db) + } + } +} + +func waitForTask( + ctx context.Context, + queue taskqueue.TaskQueue[TaskInfo], + analyzer analyzer.Analyzer, + db database.Executor, +) { + task, err := queue.Dequeue(ctx, lockTimeout, dequeueTimeout) + if err != nil { + slog.ErrorContext(ctx, "Failed to dequeue task", "error", err) + return + } + if task == nil { + // No task available. + return + } + + errCh := make(chan error) + resCh := make(chan string) + defer close(errCh) + defer close(resCh) + go func() { + res, err := analyzeStock(ctx, analyzer, db, task.Data.ID) + if err != nil { + errCh <- err + return + } + resCh <- res + }() + + ticker := time.NewTicker(lockTimeout / 5) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + // Context was canceled. Return early. + return + case <-ticker.C: + // Extend the lock periodically. + func() { + ctx, cancel := context.WithTimeout(ctx, lockTimeout/5) + defer cancel() + + err := queue.Extend(ctx, task.ID) + if err != nil { + slog.ErrorContext(ctx, "Failed to extend lock", "error", err) + } + }() + case err = <-errCh: + // analyzeStock has errored. + slog.ErrorContext(ctx, "Failed to analyze", "error", err) + _, err = queue.Return(ctx, task.ID, err) + if err != nil { + slog.ErrorContext(ctx, "Failed to return task", "error", err) + return + } + return + case res := <-resCh: + // analyzeStock has completed successfully. + slog.DebugContext(ctx, "Analyzed ID", "id", task.Data.ID, "result", res) + err = queue.Complete(ctx, task.ID, res) + if err != nil { + slog.ErrorContext(ctx, "Failed to complete task", "error", err) + return + } + return + } + } +} + +func analyzeStock(ctx context.Context, a analyzer.Analyzer, db database.Executor, id string) (string, error) { + info, err := database.GetStockInfo(ctx, db, id) + if err != nil { + return "", err + } + + analysis, err := a.Analyze( + ctx, + info.Symbol, + info.Price, + info.ChartAnalysis, + ) + if err != nil { + return "", err + } + + return database.AddAnalysis(ctx, db, id, analysis) +} + +type TaskInfo struct { + ID string `json:"id"` +} diff --git a/backend/internal/worker/auth/auth.go b/backend/internal/worker/auth/auth.go new file mode 100644 index 0000000..0daa112 --- /dev/null +++ b/backend/internal/worker/auth/auth.go @@ -0,0 +1,239 @@ +package auth + +import ( + "context" + "database/sql" + "errors" + "fmt" + "log/slog" + "time" + + "github.com/ansg191/ibd-trader-backend/internal/database" + "github.com/ansg191/ibd-trader-backend/internal/ibd" + "github.com/ansg191/ibd-trader-backend/internal/keys" + "github.com/ansg191/ibd-trader-backend/internal/leader/manager/ibd/auth" + "github.com/ansg191/ibd-trader-backend/internal/redis/taskqueue" + + "github.com/redis/go-redis/v9" +) + +const ( + lockTimeout = 1 * time.Minute + dequeueTimeout = 5 * time.Second +) + +func RunAuthScraper( + ctx context.Context, + client *ibd.Client, + redis *redis.Client, + db database.TransactionExecutor, + kms keys.KeyManagementService, + name string, +) error { + queue, err := taskqueue.New( + ctx, + redis, + auth.Queue, + name, + taskqueue.WithEncoding[auth.TaskInfo](auth.QueueEncoding), + ) + if err != nil { + return err + } + + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + waitForTask(ctx, queue, client, db, kms) + } + } +} + +func waitForTask( + ctx context.Context, + queue taskqueue.TaskQueue[auth.TaskInfo], + client *ibd.Client, + db database.TransactionExecutor, + kms keys.KeyManagementService, +) { + task, err := queue.Dequeue(ctx, lockTimeout, dequeueTimeout) + if err != nil { + slog.ErrorContext(ctx, "Failed to dequeue task", "error", err) + return + } + if task == nil { + // No task available. + return + } + slog.DebugContext(ctx, "Picked up auth task", "task", task.ID, "user", task.Data.UserSubject) + + ch := make(chan error) + defer close(ch) + go func() { + ch <- scrapeCookies(ctx, client, db, kms, task.Data.UserSubject) + }() + + ticker := time.NewTicker(lockTimeout / 5) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + // The context was canceled. Return early. + return + case <-ticker.C: + // Extend the lock periodically. + func() { + ctx, cancel := context.WithTimeout(ctx, lockTimeout/5) + defer cancel() + + err := queue.Extend(ctx, task.ID) + if err != nil { + slog.ErrorContext(ctx, "Failed to extend lock", "error", err) + } + }() + case err = <-ch: + // scrapeCookies has completed. + if err != nil { + slog.ErrorContext(ctx, "Failed to scrape cookies", "error", err) + _, err = queue.Return(ctx, task.ID, err) + if err != nil { + slog.ErrorContext(ctx, "Failed to return task", "error", err) + return + } + } else { + err = queue.Complete(ctx, task.ID, "") + if err != nil { + slog.ErrorContext(ctx, "Failed to complete task", "error", err) + return + } + slog.DebugContext(ctx, "Authenticated user", "user", task.Data.UserSubject) + } + return + } + } +} + +func scrapeCookies( + ctx context.Context, + client *ibd.Client, + db database.TransactionExecutor, + kms keys.KeyManagementService, + user string, +) error { + ctx, cancel := context.WithTimeout(ctx, lockTimeout) + defer cancel() + + // Check if the user has valid cookies + done, err := hasValidCookies(ctx, db, user) + if err != nil { + return fmt.Errorf("failed to check cookies: %w", err) + } + if done { + return nil + } + + // Health check degraded cookies + done, err = healthCheckDegradedCookies(ctx, client, db, kms, user) + if err != nil { + return fmt.Errorf("failed to health check cookies: %w", err) + } + if done { + return nil + } + + // No cookies are valid, so scrape new cookies + return scrapeNewCookies(ctx, client, db, kms, user) +} + +func hasValidCookies(ctx context.Context, db database.Executor, user string) (bool, error) { + // Check if the user has non-degraded cookies + row := db.QueryRowContext(ctx, ` +SELECT 1 +FROM ibd_tokens +WHERE user_subject = $1 + AND expires_at > NOW() + AND degraded = FALSE;`, user) + + var exists bool + err := row.Scan(&exists) + if errors.Is(err, sql.ErrNoRows) { + return false, nil + } + if err != nil { + return false, fmt.Errorf("failed to get non-degraded cookies: %w", err) + } + + return true, nil +} + +func healthCheckDegradedCookies( + ctx context.Context, + client *ibd.Client, + db database.Executor, + kms keys.KeyManagementService, + user string, +) (bool, error) { + // Check if the user has degraded cookies + cookies, err := database.GetCookies(ctx, db, kms, user, true) + if err != nil { + return false, fmt.Errorf("failed to get degraded cookies: %w", err) + } + + valid := false + for _, cookie := range cookies { + slog.DebugContext(ctx, "Health checking cookie", "cookie", cookie.ID) + + // Health check the cookie + up, err := client.UserInfo(ctx, cookie.ToHTTPCookie()) + if err != nil { + slog.ErrorContext(ctx, "Failed to health check cookie", "error", err) + continue + } + + if up.Status != ibd.UserStatusSubscriber { + continue + } + + // Cookie is valid + valid = true + + // Update the cookie + err = database.RepairCookie(ctx, db, cookie.ID) + if err != nil { + slog.ErrorContext(ctx, "Failed to repair cookie", "error", err) + } + } + + return valid, nil +} + +func scrapeNewCookies( + ctx context.Context, + client *ibd.Client, + db database.TransactionExecutor, + kms keys.KeyManagementService, + user string, +) error { + // Get the user's credentials + username, password, err := database.GetIBDCreds(ctx, db, kms, user) + if err != nil { + return fmt.Errorf("failed to get IBD credentials: %w", err) + } + + // Scrape the user's cookies + cookie, err := client.Authenticate(ctx, username, password) + if err != nil { + return fmt.Errorf("failed to authenticate user: %w", err) + } + + // Store the cookie + err = database.AddCookie(ctx, db, kms, user, cookie) + if err != nil { + return fmt.Errorf("failed to store cookie: %w", err) + } + + return nil +} diff --git a/backend/internal/worker/scraper/scraper.go b/backend/internal/worker/scraper/scraper.go new file mode 100644 index 0000000..c5c1b6c --- /dev/null +++ b/backend/internal/worker/scraper/scraper.go @@ -0,0 +1,198 @@ +package scraper + +import ( + "context" + "errors" + "fmt" + "log/slog" + "time" + + "github.com/ansg191/ibd-trader-backend/internal/database" + "github.com/ansg191/ibd-trader-backend/internal/ibd" + "github.com/ansg191/ibd-trader-backend/internal/leader/manager/ibd/scrape" + "github.com/ansg191/ibd-trader-backend/internal/redis/taskqueue" + "github.com/ansg191/ibd-trader-backend/internal/worker/analyzer" + + "github.com/redis/go-redis/v9" +) + +const ( + lockTimeout = 1 * time.Minute + dequeueTimeout = 5 * time.Second +) + +func RunScraper( + ctx context.Context, + redis *redis.Client, + client *ibd.Client, + db database.TransactionExecutor, + name string, +) error { + queue, err := taskqueue.New( + ctx, + redis, + scrape.Queue, + name, + taskqueue.WithEncoding[scrape.TaskInfo](scrape.QueueEncoding), + ) + if err != nil { + return err + } + + aQueue, err := taskqueue.New( + ctx, + redis, + analyzer.Queue, + name, + taskqueue.WithEncoding[analyzer.TaskInfo](analyzer.QueueEncoding), + ) + if err != nil { + return err + } + + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + waitForTask(ctx, queue, aQueue, client, db) + } + } +} + +func waitForTask( + ctx context.Context, + queue taskqueue.TaskQueue[scrape.TaskInfo], + aQueue taskqueue.TaskQueue[analyzer.TaskInfo], + client *ibd.Client, + db database.TransactionExecutor, +) { + task, err := queue.Dequeue(ctx, lockTimeout, dequeueTimeout) + if err != nil { + slog.ErrorContext(ctx, "Failed to dequeue task", "error", err) + return + } + if task == nil { + // No task available. + return + } + + errCh := make(chan error) + resCh := make(chan string) + defer close(errCh) + defer close(resCh) + go func() { + res, err := scrapeUrl(ctx, client, db, aQueue, task.Data.Symbol) + if err != nil { + errCh <- err + return + } + resCh <- res + }() + + ticker := time.NewTicker(lockTimeout / 5) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + // Context was canceled. Return early. + return + case <-ticker.C: + // Extend the lock periodically. + func() { + ctx, cancel := context.WithTimeout(ctx, lockTimeout/5) + defer cancel() + + err := queue.Extend(ctx, task.ID) + if err != nil { + slog.ErrorContext(ctx, "Failed to extend lock", "error", err) + } + }() + case err = <-errCh: + // scrapeUrl has errored. + slog.ErrorContext(ctx, "Failed to scrape URL", "error", err) + _, err = queue.Return(ctx, task.ID, err) + if err != nil { + slog.ErrorContext(ctx, "Failed to return task", "error", err) + return + } + return + case res := <-resCh: + // scrapeUrl has completed successfully. + slog.DebugContext(ctx, "Scraped URL", "symbol", task.Data.Symbol) + err = queue.Complete(ctx, task.ID, res) + if err != nil { + slog.ErrorContext(ctx, "Failed to complete task", "error", err) + return + } + return + } + } +} + +func scrapeUrl( + ctx context.Context, + client *ibd.Client, + db database.TransactionExecutor, + aQueue taskqueue.TaskQueue[analyzer.TaskInfo], + symbol string, +) (string, error) { + ctx, cancel := context.WithTimeout(ctx, lockTimeout) + defer cancel() + + stockUrl, err := getStockUrl(ctx, db, client, symbol) + if err != nil { + return "", fmt.Errorf("failed to get stock url: %w", err) + } + + // Scrape the stock info. + info, err := client.StockInfo(ctx, stockUrl) + if err != nil { + return "", fmt.Errorf("failed to get stock info: %w", err) + } + + // Add stock info to the database. + id, err := database.AddStockInfo(ctx, db, info) + if err != nil { + return "", fmt.Errorf("failed to add stock info: %w", err) + } + + // Add the stock to the analyzer queue. + _, err = aQueue.Enqueue(ctx, analyzer.TaskInfo{ID: id}) + if err != nil { + return "", fmt.Errorf("failed to enqueue analysis task: %w", err) + } + + slog.DebugContext(ctx, "Added stock info", "id", id) + + return id, nil +} + +func getStockUrl(ctx context.Context, db database.TransactionExecutor, client *ibd.Client, symbol string) (string, error) { + // Get the stock from the database. + stock, err := database.GetStock(ctx, db, symbol) + if err == nil { + return stock.IBDUrl, nil + } + if !errors.Is(err, database.ErrStockNotFound) { + return "", fmt.Errorf("failed to get stock: %w", err) + } + + // If stock isn't found in the database, get the stock from IBD. + stock, err = client.Search(ctx, symbol) + if errors.Is(err, ibd.ErrSymbolNotFound) { + return "", fmt.Errorf("symbol not found: %w", err) + } + if err != nil { + return "", fmt.Errorf("failed to search for symbol: %w", err) + } + + // Add the stock to the database. + err = database.AddStock(ctx, db, stock) + if err != nil { + return "", fmt.Errorf("failed to add stock: %w", err) + } + + return stock.IBDUrl, nil +} diff --git a/backend/internal/worker/worker.go b/backend/internal/worker/worker.go new file mode 100644 index 0000000..6017fb7 --- /dev/null +++ b/backend/internal/worker/worker.go @@ -0,0 +1,151 @@ +package worker + +import ( + "context" + "crypto/rand" + "encoding/base64" + "io" + "log/slog" + "os" + "time" + + "github.com/ansg191/ibd-trader-backend/internal/analyzer" + "github.com/ansg191/ibd-trader-backend/internal/database" + "github.com/ansg191/ibd-trader-backend/internal/ibd" + "github.com/ansg191/ibd-trader-backend/internal/keys" + "github.com/ansg191/ibd-trader-backend/internal/leader/manager" + analyzer2 "github.com/ansg191/ibd-trader-backend/internal/worker/analyzer" + "github.com/ansg191/ibd-trader-backend/internal/worker/auth" + "github.com/ansg191/ibd-trader-backend/internal/worker/scraper" + + "github.com/redis/go-redis/v9" + "golang.org/x/sync/errgroup" +) + +const ( + HeartbeatInterval = 5 * time.Second + HeartbeatTTL = 30 * time.Second +) + +func StartWorker( + ctx context.Context, + ibdClient *ibd.Client, + client *redis.Client, + db database.TransactionExecutor, + kms keys.KeyManagementService, + a analyzer.Analyzer, +) error { + // Get the worker name. + name, err := workerName() + if err != nil { + return err + } + slog.InfoContext(ctx, "Starting worker", "worker", name) + + g, ctx := errgroup.WithContext(ctx) + + g.Go(func() error { + return workerRegistrationLoop(ctx, client, name) + }) + g.Go(func() error { + return scraper.RunScraper(ctx, client, ibdClient, db, name) + }) + g.Go(func() error { + return auth.RunAuthScraper(ctx, ibdClient, client, db, kms, name) + }) + g.Go(func() error { + return analyzer2.RunAnalyzer(ctx, client, a, db, name) + }) + + return g.Wait() +} + +func workerRegistrationLoop(ctx context.Context, client *redis.Client, name string) error { + sendHeartbeat(ctx, client, name) + + ticker := time.NewTicker(HeartbeatInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + sendHeartbeat(ctx, client, name) + case <-ctx.Done(): + removeWorker(ctx, client, name) + return ctx.Err() + } + } +} + +// sendHeartbeat sends a heartbeat for the worker. +// It ensures that the worker is in the active workers set and its heartbeat exists. +func sendHeartbeat(ctx context.Context, client *redis.Client, name string) { + ctx, cancel := context.WithTimeout(ctx, HeartbeatInterval) + defer cancel() + + // Add the worker to the active workers set. + if err := client.SAdd(ctx, manager.ActiveWorkersSet, name).Err(); err != nil { + slog.ErrorContext(ctx, + "Unable to add worker to active workers set", + "worker", name, + "error", err, + ) + return + } + + // Set the worker's heartbeat. + heartbeatKey := manager.WorkerHeartbeatKey(name) + if err := client.Set(ctx, heartbeatKey, time.Now().Unix(), HeartbeatTTL).Err(); err != nil { + slog.ErrorContext(ctx, + "Unable to set worker heartbeat", + "worker", name, + "error", err, + ) + return + } +} + +// removeWorker removes the worker from the active workers set. +func removeWorker(ctx context.Context, client *redis.Client, name string) { + if ctx.Err() != nil { + // If the context is canceled, create a new uncanceled context. + ctx = context.WithoutCancel(ctx) + } + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + // Remove the worker from the active workers set. + if err := client.SRem(ctx, manager.ActiveWorkersSet, name).Err(); err != nil { + slog.ErrorContext(ctx, + "Unable to remove worker from active workers set", + "worker", name, + "error", err, + ) + return + } + + // Remove the worker's heartbeat. + heartbeatKey := manager.WorkerHeartbeatKey(name) + if err := client.Del(ctx, heartbeatKey).Err(); err != nil { + slog.ErrorContext(ctx, + "Unable to remove worker heartbeat", + "worker", name, + "error", err, + ) + return + } +} + +func workerName() (string, error) { + hostname, err := os.Hostname() + if err != nil { + return "", err + } + + bytes := make([]byte, 12) + if _, err = io.ReadFull(rand.Reader, bytes); err != nil { + return "", err + } + + return hostname + "-" + base64.URLEncoding.EncodeToString(bytes), nil +} |