aboutsummaryrefslogtreecommitdiff
path: root/backend/internal
diff options
context:
space:
mode:
Diffstat (limited to 'backend/internal')
-rw-r--r--backend/internal/analyzer/analyzer.go32
-rw-r--r--backend/internal/analyzer/openai/openai.go126
-rw-r--r--backend/internal/analyzer/openai/openai_test.go1
-rw-r--r--backend/internal/analyzer/openai/options.go45
-rw-r--r--backend/internal/analyzer/openai/system.txt34
-rw-r--r--backend/internal/auth/auth.go55
-rw-r--r--backend/internal/config/config.go114
-rw-r--r--backend/internal/database/cookies.go189
-rw-r--r--backend/internal/database/database.go166
-rw-r--r--backend/internal/database/database_test.go79
-rw-r--r--backend/internal/database/stocks.go293
-rw-r--r--backend/internal/database/users.go151
-rw-r--r--backend/internal/ibd/auth.go333
-rw-r--r--backend/internal/ibd/auth_test.go215
-rw-r--r--backend/internal/ibd/check_ibd_username.go68
-rw-r--r--backend/internal/ibd/client.go97
-rw-r--r--backend/internal/ibd/client_test.go201
-rw-r--r--backend/internal/ibd/html_helpers.go99
-rw-r--r--backend/internal/ibd/html_helpers_test.go79
-rw-r--r--backend/internal/ibd/ibd50.go182
-rw-r--r--backend/internal/ibd/options.go26
-rw-r--r--backend/internal/ibd/search.go111
-rw-r--r--backend/internal/ibd/search_test.go205
-rw-r--r--backend/internal/ibd/stockinfo.go233
-rw-r--r--backend/internal/ibd/transport/scrapfly/options.go84
-rw-r--r--backend/internal/ibd/transport/scrapfly/scraper_types.go253
-rw-r--r--backend/internal/ibd/transport/scrapfly/scrapfly.go103
-rw-r--r--backend/internal/ibd/transport/standard.go41
-rw-r--r--backend/internal/ibd/transport/transport.go66
-rw-r--r--backend/internal/ibd/userinfo.go156
-rw-r--r--backend/internal/keys/gcp.go131
-rw-r--r--backend/internal/keys/keys.go150
-rw-r--r--backend/internal/keys/keys_test.go64
-rw-r--r--backend/internal/keys/mock_keys_test.go156
-rw-r--r--backend/internal/leader/election/election.go128
-rw-r--r--backend/internal/leader/manager/ibd/auth/auth.go111
-rw-r--r--backend/internal/leader/manager/ibd/ibd.go8
-rw-r--r--backend/internal/leader/manager/ibd/scrape/scrape.go140
-rw-r--r--backend/internal/leader/manager/manager.go90
-rw-r--r--backend/internal/leader/manager/monitor.go164
-rw-r--r--backend/internal/redis/taskqueue/options.go9
-rw-r--r--backend/internal/redis/taskqueue/queue.go545
-rw-r--r--backend/internal/redis/taskqueue/queue_test.go467
-rw-r--r--backend/internal/server/idb/stock/v1/stock.go64
-rw-r--r--backend/internal/server/idb/user/v1/user.go159
-rw-r--r--backend/internal/server/operations.go142
-rw-r--r--backend/internal/server/server.go77
-rw-r--r--backend/internal/utils/money.go99
-rw-r--r--backend/internal/utils/money_test.go106
-rw-r--r--backend/internal/worker/analyzer/analyzer.go142
-rw-r--r--backend/internal/worker/auth/auth.go239
-rw-r--r--backend/internal/worker/scraper/scraper.go198
-rw-r--r--backend/internal/worker/worker.go151
53 files changed, 7377 insertions, 0 deletions
diff --git a/backend/internal/analyzer/analyzer.go b/backend/internal/analyzer/analyzer.go
new file mode 100644
index 0000000..c055647
--- /dev/null
+++ b/backend/internal/analyzer/analyzer.go
@@ -0,0 +1,32 @@
+package analyzer
+
+import (
+ "context"
+
+ "github.com/Rhymond/go-money"
+)
+
+type Analyzer interface {
+ Analyze(
+ ctx context.Context,
+ symbol string,
+ price *money.Money,
+ rawAnalysis string,
+ ) (*Analysis, error)
+}
+
+type ChartAction string
+
+const (
+ Buy ChartAction = "buy"
+ Sell ChartAction = "sell"
+ Hold ChartAction = "hold"
+ Unknown ChartAction = "unknown"
+)
+
+type Analysis struct {
+ Action ChartAction
+ Price *money.Money
+ Reason string
+ Confidence uint8
+}
diff --git a/backend/internal/analyzer/openai/openai.go b/backend/internal/analyzer/openai/openai.go
new file mode 100644
index 0000000..0419c57
--- /dev/null
+++ b/backend/internal/analyzer/openai/openai.go
@@ -0,0 +1,126 @@
+package openai
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "math"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/ansg191/ibd-trader-backend/internal/analyzer"
+ "github.com/ansg191/ibd-trader-backend/internal/utils"
+
+ "github.com/Rhymond/go-money"
+ "github.com/sashabaranov/go-openai"
+)
+
+type Client interface {
+ CreateChatCompletion(
+ ctx context.Context,
+ request openai.ChatCompletionRequest,
+ ) (response openai.ChatCompletionResponse, err error)
+}
+
+type Analyzer struct {
+ client Client
+ model string
+ systemMsg string
+ temperature float32
+}
+
+func NewAnalyzer(opts ...Option) *Analyzer {
+ a := &Analyzer{
+ client: nil,
+ model: defaultModel,
+ systemMsg: defaultSystemMsg,
+ temperature: defaultTemperature,
+ }
+ for _, option := range opts {
+ option(a)
+ }
+ if a.client == nil {
+ panic("client is required")
+ }
+
+ return a
+}
+
+func (a *Analyzer) Analyze(
+ ctx context.Context,
+ symbol string,
+ price *money.Money,
+ rawAnalysis string,
+) (*analyzer.Analysis, error) {
+ usrMsg := fmt.Sprintf(
+ "%s\n%s\n%s\n%s\n",
+ time.Now().Format(time.RFC3339),
+ symbol,
+ price.Display(),
+ rawAnalysis,
+ )
+ res, err := a.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
+ Model: a.model,
+ Messages: []openai.ChatCompletionMessage{
+ {
+ Role: openai.ChatMessageRoleSystem,
+ Content: a.systemMsg,
+ },
+ {
+ Role: openai.ChatMessageRoleUser,
+ Content: usrMsg,
+ },
+ },
+ MaxTokens: 0,
+ Temperature: a.temperature,
+ ResponseFormat: &openai.ChatCompletionResponseFormat{Type: openai.ChatCompletionResponseFormatTypeJSONObject},
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ var resp response
+ if err = json.Unmarshal([]byte(res.Choices[0].Message.Content), &resp); err != nil {
+ return nil, fmt.Errorf("failed to unmarshal gpt response: %w", err)
+ }
+
+ var action analyzer.ChartAction
+ switch strings.ToLower(resp.Action) {
+ case "buy":
+ action = analyzer.Buy
+ case "sell":
+ action = analyzer.Sell
+ case "hold":
+ action = analyzer.Hold
+ default:
+ action = analyzer.Unknown
+ }
+
+ m, err := utils.ParseMoney(resp.Price)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse price: %w", err)
+ }
+
+ confidence, err := strconv.ParseFloat(resp.Confidence, 64)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse confidence: %w", err)
+ }
+ if confidence < 0 || confidence > 100 {
+ return nil, fmt.Errorf("confidence must be between 0 and 100, got %f", confidence)
+ }
+
+ return &analyzer.Analysis{
+ Action: action,
+ Price: m,
+ Reason: resp.Reason,
+ Confidence: uint8(math.Floor(confidence)),
+ }, nil
+}
+
+type response struct {
+ Action string `json:"action"`
+ Price string `json:"price"`
+ Reason string `json:"reason"`
+ Confidence string `json:"confidence"`
+}
diff --git a/backend/internal/analyzer/openai/openai_test.go b/backend/internal/analyzer/openai/openai_test.go
new file mode 100644
index 0000000..0aac709
--- /dev/null
+++ b/backend/internal/analyzer/openai/openai_test.go
@@ -0,0 +1 @@
+package openai
diff --git a/backend/internal/analyzer/openai/options.go b/backend/internal/analyzer/openai/options.go
new file mode 100644
index 0000000..11d691f
--- /dev/null
+++ b/backend/internal/analyzer/openai/options.go
@@ -0,0 +1,45 @@
+package openai
+
+import (
+ _ "embed"
+
+ "github.com/sashabaranov/go-openai"
+)
+
+//go:embed system.txt
+var defaultSystemMsg string
+
+const defaultModel = openai.GPT4o
+const defaultTemperature = 0.25
+
+type Option func(*Analyzer)
+
+func WithClientConfig(cfg openai.ClientConfig) Option {
+ return func(a *Analyzer) {
+ a.client = openai.NewClientWithConfig(cfg)
+ }
+}
+
+func WithDefaultConfig(apiKey string) Option {
+ return func(a *Analyzer) {
+ a.client = openai.NewClient(apiKey)
+ }
+}
+
+func WithModel(model string) Option {
+ return func(a *Analyzer) {
+ a.model = model
+ }
+}
+
+func WithSystemMsg(msg string) Option {
+ return func(a *Analyzer) {
+ a.systemMsg = msg
+ }
+}
+
+func WithTemperature(temp float32) Option {
+ return func(a *Analyzer) {
+ a.temperature = temp
+ }
+}
diff --git a/backend/internal/analyzer/openai/system.txt b/backend/internal/analyzer/openai/system.txt
new file mode 100644
index 0000000..82e3b2a
--- /dev/null
+++ b/backend/internal/analyzer/openai/system.txt
@@ -0,0 +1,34 @@
+You're a stock analyzer.
+You will be given a stock symbol, its current price, and its chart analysis.
+Your job is to determine the best course of action to do with the stock (buy, sell, hold, or unknown),
+the price at which the action should be taken, the reason for the action, and the confidence (0-100)
+level you have in the action.
+
+The reason should be a paragraph explaining why you chose the action.
+
+The date the chart analysis was done may be mentioned in the chart analysis.
+Make sure to take that into account when making your decision.
+If the chart analysis is older than 1 week, lower your confidence accordingly and mention that in the reason.
+If the chart analysis is too outdated, set the action to "unknown".
+
+The information will be given in the following format:
+```
+<current datetime>
+<stock symbol>
+<current price>
+<chart analysis>
+```
+
+Your response should be in the following JSON format:
+```
+{
+ "action": "<action>",
+ "price": "<price>",
+ "reason": "<reason>",
+ "confidence": "<confidence>"
+}
+```
+All fields are required and must be strings.
+`action` must be one of the following: "buy", "sell", "hold", or "unknown".
+`price` must contain the symbol for the currency and the price (e.g. "$100").
+The system WILL validate your response. \ No newline at end of file
diff --git a/backend/internal/auth/auth.go b/backend/internal/auth/auth.go
new file mode 100644
index 0000000..edad914
--- /dev/null
+++ b/backend/internal/auth/auth.go
@@ -0,0 +1,55 @@
+package auth
+
+import (
+ "context"
+ "errors"
+
+ "github.com/ansg191/ibd-trader-backend/internal/config"
+
+ "github.com/coreos/go-oidc/v3/oidc"
+ "golang.org/x/oauth2"
+)
+
+// Authenticator is used to authenticate our users.
+type Authenticator struct {
+ *oidc.Provider
+ oauth2.Config
+}
+
+// New instantiates the *Authenticator.
+func New(cfg *config.Config) (*Authenticator, error) {
+ provider, err := oidc.NewProvider(
+ context.Background(),
+ "https://"+cfg.Auth.Domain+"/",
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ conf := oauth2.Config{
+ ClientID: cfg.Auth.ClientID,
+ ClientSecret: cfg.Auth.ClientSecret,
+ RedirectURL: cfg.Auth.CallbackURL,
+ Endpoint: provider.Endpoint(),
+ Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
+ }
+
+ return &Authenticator{
+ Provider: provider,
+ Config: conf,
+ }, nil
+}
+
+// VerifyIDToken verifies that an *oauth2.Token is a valid *oidc.IDToken.
+func (a *Authenticator) VerifyIDToken(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) {
+ rawIDToken, ok := token.Extra("id_token").(string)
+ if !ok {
+ return nil, errors.New("no id_token field in oauth2 token")
+ }
+
+ oidcConfig := &oidc.Config{
+ ClientID: a.ClientID,
+ }
+
+ return a.Verifier(oidcConfig).Verify(ctx, rawIDToken)
+}
diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go
new file mode 100644
index 0000000..c37588b
--- /dev/null
+++ b/backend/internal/config/config.go
@@ -0,0 +1,114 @@
+package config
+
+import (
+ "github.com/ansg191/ibd-trader-backend/internal/keys"
+ "github.com/ansg191/ibd-trader-backend/internal/leader/manager/ibd"
+
+ "github.com/spf13/viper"
+)
+
+type Config struct {
+ // Logging configuration
+ Log struct {
+ // Log level
+ Level string
+ // Add source info to log messages
+ AddSource bool
+ // Enable colorized output
+ Color bool
+ }
+ // Database configuration
+ DB struct {
+ // Database URL
+ URL string
+ }
+ // Redis configuration
+ Redis struct {
+ // Redis address
+ Addr string
+ // Redis password
+ Password string
+ }
+ // KMS configuration
+ KMS struct {
+ // GCP KMS configuration
+ GCP *keys.GCPKeyName
+ }
+ // Server configuration
+ Server struct {
+ // Server port
+ Port uint16
+ }
+ // OAuth 2.0 configuration
+ Auth struct {
+ // OAuth 2.0 domain
+ Domain string
+ // OAuth 2.0 client ID
+ ClientID string
+ // OAuth 2.0 client secret
+ ClientSecret string
+ // OAuth 2.0 callback URL
+ CallbackURL string
+ }
+ // IBD configuration
+ IBD struct {
+ // Scraper API Key
+ APIKey string
+ // Proxy URL
+ ProxyURL string
+ // Scrape schedules. In cron format.
+ Schedules ibd.Schedules
+ }
+ // Analyzer configuration
+ Analyzer struct {
+ // Use OpenAI for analysis
+ OpenAI *struct {
+ // OpenAI API Key
+ APIKey string
+ }
+ }
+}
+
+func New() (*Config, error) {
+ v := viper.New()
+
+ v.SetDefault("log.level", "INFO")
+ v.SetDefault("log.addSource", false)
+ v.SetDefault("log.color", false)
+ v.SetDefault("server.port", 8000)
+
+ v.SetConfigName("config")
+ v.AddConfigPath("/etc/ibd-trader/")
+ v.AddConfigPath("$HOME/.ibd-trader")
+ v.AddConfigPath(".")
+ err := v.ReadInConfig()
+ if err != nil {
+ if _, ok := err.(viper.ConfigFileNotFoundError); ok {
+ // Config file not found; ignore error
+ } else {
+ return nil, err
+ }
+ }
+
+ v.MustBindEnv("db.url", "DATABASE_URL")
+ v.MustBindEnv("redis.addr", "REDIS_ADDR")
+ v.MustBindEnv("redis.password", "REDIS_PASSWORD")
+ v.MustBindEnv("log.level", "LOG_LEVEL")
+ v.MustBindEnv("server.port", "SERVER_PORT")
+ v.MustBindEnv("auth.domain", "AUTH_DOMAIN")
+ v.MustBindEnv("auth.clientID", "AUTH_CLIENT_ID")
+ v.MustBindEnv("auth.clientSecret", "AUTH_CLIENT_SECRET")
+ v.MustBindEnv("auth.callbackURL", "AUTH_CALLBACK_URL")
+
+ config := new(Config)
+ err = v.Unmarshal(config)
+ if err != nil {
+ return nil, err
+ }
+
+ return config, config.assert()
+}
+
+func (c *Config) assert() error {
+ return nil
+}
diff --git a/backend/internal/database/cookies.go b/backend/internal/database/cookies.go
new file mode 100644
index 0000000..3ea21d0
--- /dev/null
+++ b/backend/internal/database/cookies.go
@@ -0,0 +1,189 @@
+package database
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "fmt"
+ "net/http"
+ "time"
+
+ "github.com/ansg191/ibd-trader-backend/internal/keys"
+)
+
+func GetAnyCookie(ctx context.Context, exec Executor, kms keys.KeyManagementService) (*IBDCookie, error) {
+ row := exec.QueryRowContext(ctx, `
+SELECT ibd_tokens.id, token, encrypted_key, kms_key_name, expires_at
+FROM ibd_tokens
+ INNER JOIN keys ON encryption_key = keys.id
+WHERE expires_at > NOW()
+ AND degraded = FALSE
+ORDER BY random()
+LIMIT 1;`)
+
+ var id uint
+ var encryptedToken, encryptedKey []byte
+ var keyName string
+ var expiry time.Time
+ err := row.Scan(&id, &encryptedToken, &encryptedKey, &keyName, &expiry)
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, nil
+ }
+ return nil, fmt.Errorf("unable to scan sql row into ibd cookie: %w", err)
+ }
+
+ // Set the expiry to UTC explicitly.
+ // For some reason, the expiry time is set to location="".
+ expiry = expiry.UTC()
+
+ token, err := keys.Decrypt(ctx, kms, keyName, encryptedToken, encryptedKey)
+ if err != nil {
+ return nil, fmt.Errorf("unable to decrypt token: %w", err)
+ }
+ return &IBDCookie{
+ Token: string(token),
+ Expiry: expiry,
+ }, nil
+}
+
+func GetCookies(
+ ctx context.Context,
+ exec Executor,
+ kms keys.KeyManagementService,
+ subject string,
+ degraded bool,
+) ([]IBDCookie, error) {
+ rows, err := exec.QueryContext(ctx, `
+SELECT ibd_tokens.id, token, encrypted_key, kms_key_name, expires_at
+FROM ibd_tokens
+ INNER JOIN keys ON encryption_key = keys.id
+WHERE user_subject = $1
+ AND expires_at > NOW()
+ AND degraded = $2
+ORDER BY expires_at DESC;`, subject, degraded)
+ if err != nil {
+ return nil, fmt.Errorf("unable to get ibd cookies: %w", err)
+ }
+
+ cookies := make([]IBDCookie, 0)
+ for rows.Next() {
+ var id uint
+ var encryptedToken, encryptedKey []byte
+ var keyName string
+ var expiry time.Time
+ err = rows.Scan(&id, &encryptedToken, &encryptedKey, &keyName, &expiry)
+ if err != nil {
+ return nil, fmt.Errorf("unable to scan sql row into ibd cookie: %w", err)
+ }
+
+ // Set the expiry to UTC explicitly.
+ // For some reason, the expiry time is set to location="".
+ expiry = expiry.UTC()
+
+ token, err := keys.Decrypt(ctx, kms, keyName, encryptedToken, encryptedKey)
+ if err != nil {
+ return nil, fmt.Errorf("unable to decrypt token: %w", err)
+ }
+ cookie := IBDCookie{
+ ID: id,
+ Token: string(token),
+ Expiry: expiry,
+ }
+ cookies = append(cookies, cookie)
+ }
+
+ return cookies, nil
+}
+
+func AddCookie(
+ ctx context.Context,
+ exec TransactionExecutor,
+ kms keys.KeyManagementService,
+ subject string,
+ cookie *http.Cookie,
+) error {
+ tx, err := exec.BeginTx(ctx, nil)
+ if err != nil {
+ return err
+ }
+
+ // Get the key ID for the user
+ user, err := GetUser(ctx, tx, subject)
+ if err != nil {
+ return fmt.Errorf("unable to get user: %w", err)
+ }
+ if user.EncryptionKeyID == nil {
+ return errors.New("user does not have an encryption key")
+ }
+
+ // Get the key
+ var keyName string
+ var key []byte
+ err = tx.QueryRowContext(ctx, `
+SELECT kms_key_name, encrypted_key
+FROM keys
+WHERE id = $1;`,
+ *user.EncryptionKeyID,
+ ).Scan(&keyName, &key)
+ if err != nil {
+ return fmt.Errorf("unable to get key: %w", err)
+ }
+
+ // Encrypt the token
+ encryptedToken, err := keys.EncryptWithKey(ctx, kms, keyName, key, []byte(cookie.Value))
+ if err != nil {
+ return fmt.Errorf("unable to encrypt token: %w", err)
+ }
+
+ // Add the cookie to the database
+ _, err = exec.ExecContext(ctx, `
+INSERT INTO ibd_tokens (token, expires_at, user_subject, encryption_key)
+VALUES ($1, $2, $3, $4)`, encryptedToken, cookie.Expires, subject, *user.EncryptionKeyID)
+ if err != nil {
+ return fmt.Errorf("unable to add cookie: %w", err)
+ }
+
+ return nil
+}
+
+func ReportCookieFailure(ctx context.Context, exec Executor, id uint) error {
+ _, err := exec.ExecContext(ctx, `
+UPDATE ibd_tokens
+SET degraded = TRUE
+WHERE id = $1;`, id)
+ if err != nil {
+ return fmt.Errorf("unable to report cookie failure: %w", err)
+ }
+ return nil
+}
+
+func RepairCookie(ctx context.Context, exec Executor, id uint) error {
+ _, err := exec.ExecContext(ctx, `
+UPDATE ibd_tokens
+SET degraded = FALSE
+WHERE id = $1;`, id)
+ if err != nil {
+ return fmt.Errorf("unable to report cookie failure: %w", err)
+ }
+ return nil
+}
+
+type IBDCookie struct {
+ ID uint
+ Token string
+ Expiry time.Time
+}
+
+func (c *IBDCookie) ToHTTPCookie() *http.Cookie {
+ return &http.Cookie{
+ Name: ".ASPXAUTH",
+ Value: c.Token,
+ Path: "/",
+ Domain: "investors.com",
+ Expires: c.Expiry,
+ Secure: true,
+ HttpOnly: false,
+ SameSite: http.SameSiteLaxMode,
+ }
+}
diff --git a/backend/internal/database/database.go b/backend/internal/database/database.go
new file mode 100644
index 0000000..409dd3c
--- /dev/null
+++ b/backend/internal/database/database.go
@@ -0,0 +1,166 @@
+package database
+
+import (
+ "context"
+ "database/sql"
+ "database/sql/driver"
+ "errors"
+ "io"
+ "log/slog"
+ "sync"
+ "time"
+
+ "github.com/ansg191/ibd-trader-backend/db"
+ "github.com/ansg191/ibd-trader-backend/internal/keys"
+
+ "github.com/golang-migrate/migrate/v4"
+ _ "github.com/golang-migrate/migrate/v4/database/postgres"
+ "github.com/golang-migrate/migrate/v4/source/iofs"
+ _ "github.com/lib/pq"
+)
+
+type Database interface {
+ io.Closer
+ TransactionExecutor
+ driver.Pinger
+
+ Migrate(ctx context.Context) error
+ Maintenance(ctx context.Context)
+}
+
+type database struct {
+ logger *slog.Logger
+
+ db *sql.DB
+ url string
+
+ kms keys.KeyManagementService
+ keyName string
+}
+
+func New(ctx context.Context, logger *slog.Logger, url string, kms keys.KeyManagementService, keyName string) (Database, error) {
+ sqlDB, err := sql.Open("postgres", url)
+ if err != nil {
+ return nil, err
+ }
+
+ err = sqlDB.PingContext(ctx)
+ if err != nil {
+ // Ping failed. Don't error, but give a warning.
+ logger.WarnContext(ctx, "Unable to ping database", "error", err)
+ }
+
+ return &database{
+ logger: logger,
+ db: sqlDB,
+ url: url,
+ kms: kms,
+ keyName: keyName,
+ }, nil
+}
+
+func (d *database) Close() error {
+ return d.db.Close()
+}
+
+func (d *database) Migrate(ctx context.Context) error {
+ return Migrate(ctx, d.url)
+}
+
+func (d *database) Maintenance(ctx context.Context) {
+ ticker := time.NewTicker(15 * time.Minute)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ticker.C:
+ func() {
+ var wg sync.WaitGroup
+ wg.Add(1)
+
+ _, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
+ defer cancel()
+
+ wg.Wait()
+ }()
+ case <-ctx.Done():
+ return
+ }
+ }
+}
+
+func Migrate(ctx context.Context, url string) error {
+ fs, err := iofs.New(db.Migrations, "migrations")
+ if err != nil {
+ return err
+ }
+
+ m, err := migrate.NewWithSourceInstance("iofs", fs, url)
+ if err != nil {
+ return err
+ }
+
+ slog.InfoContext(ctx, "Running DB migration")
+ err = m.Up()
+ if err != nil && !errors.Is(err, migrate.ErrNoChange) {
+ slog.ErrorContext(ctx, "DB migration failed", "error", err)
+ return err
+ }
+
+ return nil
+}
+
+func (d *database) Ping(ctx context.Context) error {
+ return d.db.PingContext(ctx)
+}
+
+type Executor interface {
+ ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
+ QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
+ QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
+}
+
+type TransactionExecutor interface {
+ Executor
+ BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
+}
+
+func (d *database) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
+ d.logger.DebugContext(ctx, "Executing query", "query", query)
+
+ now := time.Now()
+ ret, err := d.db.ExecContext(ctx, query, args...)
+ if err != nil {
+ return nil, err
+ }
+
+ d.logger.DebugContext(ctx, "Query executed successfully", "duration", time.Since(now))
+ return ret, nil
+}
+
+func (d *database) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
+ d.logger.DebugContext(ctx, "Executing query", "query", query)
+
+ now := time.Now()
+ ret, err := d.db.QueryContext(ctx, query, args...)
+ if err != nil {
+ return nil, err
+ }
+
+ d.logger.DebugContext(ctx, "Query executed successfully", "duration", time.Since(now))
+ return ret, nil
+}
+
+func (d *database) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
+ d.logger.DebugContext(ctx, "Executing query", "query", query)
+
+ now := time.Now()
+ ret := d.db.QueryRowContext(ctx, query, args...)
+
+ d.logger.DebugContext(ctx, "Query executed successfully", "duration", time.Since(now))
+ return ret
+}
+
+func (d *database) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
+ return d.db.BeginTx(ctx, opts)
+}
diff --git a/backend/internal/database/database_test.go b/backend/internal/database/database_test.go
new file mode 100644
index 0000000..407a09a
--- /dev/null
+++ b/backend/internal/database/database_test.go
@@ -0,0 +1,79 @@
+package database
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "log"
+ "testing"
+ "time"
+
+ "github.com/ory/dockertest/v3"
+ "github.com/ory/dockertest/v3/docker"
+)
+
+var exec *sql.DB
+
+func TestMain(m *testing.M) {
+ pool, err := dockertest.NewPool("")
+ if err != nil {
+ log.Fatalf("Could not create pool: %s", err)
+ }
+
+ err = pool.Client.Ping()
+ if err != nil {
+ log.Fatalf("Could not connect to Docker: %s", err)
+ }
+
+ resource, err := pool.RunWithOptions(&dockertest.RunOptions{
+ Repository: "postgres",
+ Tag: "16",
+ Env: []string{
+ "POSTGRES_PASSWORD=secret",
+ "POSTGRES_USER=ibd-client-test",
+ "POSTGRES_DB=ibd-client-test",
+ "listen_addresses='*'",
+ },
+ Cmd: []string{
+ "postgres",
+ "-c",
+ "log_statement=all",
+ },
+ }, func(config *docker.HostConfig) {
+ config.AutoRemove = true
+ config.RestartPolicy = docker.RestartPolicy{Name: "no"}
+ })
+ if err != nil {
+ log.Fatalf("Could not start resource: %s", err)
+ }
+
+ hostAndPort := resource.GetHostPort("5432/tcp")
+ databaseUrl := fmt.Sprintf("postgres://ibd-client-test:secret@%s/ibd-client-test?sslmode=disable", hostAndPort)
+
+ // Kill container after 120 seconds
+ _ = resource.Expire(120)
+
+ pool.MaxWait = 120 * time.Second
+ if err = pool.Retry(func() error {
+ exec, err = sql.Open("postgres", databaseUrl)
+ if err != nil {
+ return err
+ }
+ return exec.Ping()
+ }); err != nil {
+ log.Fatalf("Could not connect to database: %s", err)
+ }
+
+ err = Migrate(context.Background(), databaseUrl)
+ if err != nil {
+ log.Fatalf("Could not migrate database: %s", err)
+ }
+
+ defer func() {
+ if err := pool.Purge(resource); err != nil {
+ log.Fatalf("Could not purge resource: %s", err)
+ }
+ }()
+
+ m.Run()
+}
diff --git a/backend/internal/database/stocks.go b/backend/internal/database/stocks.go
new file mode 100644
index 0000000..24f5fe7
--- /dev/null
+++ b/backend/internal/database/stocks.go
@@ -0,0 +1,293 @@
+package database
+
+import (
+ "context"
+ "database/sql"
+ "database/sql/driver"
+ "errors"
+
+ pb "github.com/ansg191/ibd-trader-backend/api/gen/idb/stock/v1"
+ "github.com/ansg191/ibd-trader-backend/internal/analyzer"
+ "github.com/ansg191/ibd-trader-backend/internal/utils"
+
+ "github.com/Rhymond/go-money"
+)
+
+var ErrStockNotFound = errors.New("stock not found")
+
+func GetStock(ctx context.Context, exec Executor, symbol string) (Stock, error) {
+ row := exec.QueryRowContext(ctx, `
+SELECT symbol, name, ibd_url
+FROM stocks
+WHERE symbol = $1;
+`, symbol)
+
+ var stock Stock
+ if err := row.Scan(&stock.Symbol, &stock.Name, &stock.IBDUrl); err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return Stock{}, ErrStockNotFound
+ }
+ return Stock{}, err
+ }
+
+ return stock, nil
+}
+
+func AddStock(ctx context.Context, exec Executor, stock Stock) error {
+ _, err := exec.ExecContext(ctx, `
+INSERT INTO stocks (symbol, name, ibd_url)
+VALUES ($1, $2, $3)
+ON CONFLICT (symbol)
+ DO UPDATE SET name = $2,
+ ibd_url = $3;`, stock.Symbol, stock.Name, stock.IBDUrl)
+ return err
+}
+
+func AddRanking(ctx context.Context, exec Executor, symbol string, ibd50, cap20 int) error {
+ if ibd50 > 0 {
+ _, err := exec.ExecContext(ctx, `
+INSERT INTO stock_rank (symbol, rank_type, rank)
+VALUES ($1, $2, $3)`, symbol, "ibd50", ibd50)
+ if err != nil {
+ return err
+ }
+ }
+ if cap20 > 0 {
+ _, err := exec.ExecContext(ctx, `
+INSERT INTO stock_rank (symbol, rank_type, rank)
+VALUES ($1, $2, $3)`, symbol, "cap20", cap20)
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func AddStockInfo(ctx context.Context, exec TransactionExecutor, info *StockInfo) (string, error) {
+ tx, err := exec.BeginTx(ctx, nil)
+ if err != nil {
+ return "", err
+ }
+ defer func(tx *sql.Tx) {
+ _ = tx.Rollback()
+ }(tx)
+
+ // Add raw chart analysis
+ row := tx.QueryRowContext(ctx, `
+INSERT INTO chart_analysis (raw_analysis)
+VALUES ($1)
+RETURNING id;`, info.ChartAnalysis)
+
+ var chartAnalysisID string
+ if err = row.Scan(&chartAnalysisID); err != nil {
+ return "", err
+ }
+
+ // Add stock info
+ row = tx.QueryRowContext(ctx,
+ `
+INSERT INTO ratings (symbol, composite, eps, rel_str, group_rel_str, smr, acc_dis, chart_analysis, price)
+VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
+RETURNING id;`,
+ info.Symbol,
+ info.Ratings.Composite,
+ info.Ratings.EPS,
+ info.Ratings.RelStr,
+ info.Ratings.GroupRelStr,
+ info.Ratings.SMR,
+ info.Ratings.AccDis,
+ chartAnalysisID,
+ info.Price.Display(),
+ )
+
+ var ratingsID string
+ if err = row.Scan(&ratingsID); err != nil {
+ return "", err
+ }
+
+ return ratingsID, tx.Commit()
+}
+
+func GetStockInfo(ctx context.Context, exec Executor, id string) (*StockInfo, error) {
+ row := exec.QueryRowContext(ctx, `
+SELECT r.symbol,
+ s.name,
+ ca.raw_analysis,
+ r.composite,
+ r.eps,
+ r.rel_str,
+ r.group_rel_str,
+ r.smr,
+ r.acc_dis,
+ r.price
+FROM ratings r
+ INNER JOIN stocks s on r.symbol = s.symbol
+ INNER JOIN chart_analysis ca on r.chart_analysis = ca.id
+WHERE r.id = $1;`, id)
+
+ var info StockInfo
+ var priceStr string
+ err := row.Scan(
+ &info.Symbol,
+ &info.Name,
+ &info.ChartAnalysis,
+ &info.Ratings.Composite,
+ &info.Ratings.EPS,
+ &info.Ratings.RelStr,
+ &info.Ratings.GroupRelStr,
+ &info.Ratings.SMR,
+ &info.Ratings.AccDis,
+ &priceStr,
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ info.Price, err = utils.ParseMoney(priceStr)
+ if err != nil {
+ return nil, err
+ }
+
+ return &info, nil
+}
+
+func AddAnalysis(
+ ctx context.Context,
+ exec Executor,
+ ratingId string,
+ analysis *analyzer.Analysis,
+) (id string, err error) {
+ err = exec.QueryRowContext(ctx, `
+UPDATE chart_analysis ca
+SET processed = true,
+ action = $2,
+ price = $3,
+ reason = $4,
+ confidence = $5
+FROM ratings r
+WHERE r.id = $1
+ AND r.chart_analysis = ca.id
+RETURNING ca.id;`,
+ ratingId,
+ analysis.Action,
+ analysis.Price.Display(),
+ analysis.Reason,
+ analysis.Confidence,
+ ).Scan(&id)
+ return id, err
+}
+
+type Stock struct {
+ Symbol string
+ Name string
+ IBDUrl string
+}
+
+type StockInfo struct {
+ Symbol string
+ Name string
+ ChartAnalysis string
+ Ratings Ratings
+ Price *money.Money
+}
+
+type Ratings struct {
+ Composite uint8
+ EPS uint8
+ RelStr uint8
+ GroupRelStr LetterRating
+ SMR LetterRating
+ AccDis LetterRating
+}
+
+type LetterRating pb.LetterGrade
+
+func (r LetterRating) String() string {
+ switch pb.LetterGrade(r) {
+ case pb.LetterGrade_LETTER_GRADE_E:
+ return "E"
+ case pb.LetterGrade_LETTER_GRADE_E_PLUS:
+ return "E+"
+ case pb.LetterGrade_LETTER_GRADE_D_MINUS:
+ return "D-"
+ case pb.LetterGrade_LETTER_GRADE_D:
+ return "D"
+ case pb.LetterGrade_LETTER_GRADE_D_PLUS:
+ return "D+"
+ case pb.LetterGrade_LETTER_GRADE_C_MINUS:
+ return "C-"
+ case pb.LetterGrade_LETTER_GRADE_C:
+ return "C"
+ case pb.LetterGrade_LETTER_GRADE_C_PLUS:
+ return "C+"
+ case pb.LetterGrade_LETTER_GRADE_B_MINUS:
+ return "B-"
+ case pb.LetterGrade_LETTER_GRADE_B:
+ return "B"
+ case pb.LetterGrade_LETTER_GRADE_B_PLUS:
+ return "B+"
+ case pb.LetterGrade_LETTER_GRADE_A_MINUS:
+ return "A-"
+ case pb.LetterGrade_LETTER_GRADE_A:
+ return "A"
+ case pb.LetterGrade_LETTER_GRADE_A_PLUS:
+ return "A+"
+ default:
+ return "NA"
+ }
+}
+
+func LetterRatingFromString(str string) LetterRating {
+ switch str {
+ case "E":
+ return LetterRating(pb.LetterGrade_LETTER_GRADE_E)
+ case "E+":
+ return LetterRating(pb.LetterGrade_LETTER_GRADE_E_PLUS)
+ case "D-":
+ return LetterRating(pb.LetterGrade_LETTER_GRADE_D_MINUS)
+ case "D":
+ return LetterRating(pb.LetterGrade_LETTER_GRADE_D)
+ case "D+":
+ return LetterRating(pb.LetterGrade_LETTER_GRADE_D_PLUS)
+ case "C-":
+ return LetterRating(pb.LetterGrade_LETTER_GRADE_C_MINUS)
+ case "C":
+ return LetterRating(pb.LetterGrade_LETTER_GRADE_C)
+ case "C+":
+ return LetterRating(pb.LetterGrade_LETTER_GRADE_C_PLUS)
+ case "B-":
+ return LetterRating(pb.LetterGrade_LETTER_GRADE_B_MINUS)
+ case "B":
+ return LetterRating(pb.LetterGrade_LETTER_GRADE_B)
+ case "B+":
+ return LetterRating(pb.LetterGrade_LETTER_GRADE_B_PLUS)
+ case "A-":
+ return LetterRating(pb.LetterGrade_LETTER_GRADE_A_MINUS)
+ case "A":
+ return LetterRating(pb.LetterGrade_LETTER_GRADE_A)
+ case "A+":
+ return LetterRating(pb.LetterGrade_LETTER_GRADE_A_PLUS)
+ case "NA":
+ fallthrough
+ default:
+ return LetterRating(pb.LetterGrade_LETTER_GRADE_UNSPECIFIED)
+ }
+}
+
+func (r LetterRating) Value() (driver.Value, error) {
+ return r.String(), nil
+}
+
+func (r *LetterRating) Scan(src any) error {
+ var source string
+ switch v := src.(type) {
+ case string:
+ source = v
+ case []byte:
+ source = string(v)
+ default:
+ return errors.New("incompatible type for LetterRating")
+ }
+ *r = LetterRatingFromString(source)
+ return nil
+}
diff --git a/backend/internal/database/users.go b/backend/internal/database/users.go
new file mode 100644
index 0000000..f7998fb
--- /dev/null
+++ b/backend/internal/database/users.go
@@ -0,0 +1,151 @@
+package database
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "fmt"
+
+ "github.com/ansg191/ibd-trader-backend/internal/keys"
+)
+
+var ErrUserNotFound = fmt.Errorf("user not found")
+var ErrIBDCredsNotFound = fmt.Errorf("ibd creds not found")
+
+func AddUser(ctx context.Context, exec Executor, subject string) (err error) {
+ _, err = exec.ExecContext(ctx, `
+INSERT INTO users (subject)
+VALUES ($1)
+ON CONFLICT DO NOTHING;`, subject)
+ return
+}
+
+func GetUser(ctx context.Context, exec Executor, subject string) (*User, error) {
+ row := exec.QueryRowContext(ctx, `
+SELECT subject, ibd_username, ibd_password, encryption_key
+FROM users
+WHERE subject = $1;`, subject)
+
+ user := &User{}
+ err := row.Scan(&user.Subject, &user.IBDUsername, &user.EncryptedIBDPassword, &user.EncryptionKeyID)
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, ErrUserNotFound
+ }
+ return nil, fmt.Errorf("unable to scan sql row into user: %w", err)
+ }
+
+ return user, nil
+}
+
+func ListUsers(ctx context.Context, exec Executor, hasIBDCreds bool) ([]User, error) {
+ rows, err := exec.QueryContext(ctx, `
+SELECT subject, ibd_username, ibd_password, encryption_key
+FROM users;
+`)
+ if err != nil {
+ return nil, fmt.Errorf("unable to list users: %w", err)
+ }
+
+ users := make([]User, 0)
+ for rows.Next() {
+ user := User{}
+ err = rows.Scan(&user.Subject, &user.IBDUsername, &user.EncryptedIBDPassword, &user.EncryptionKeyID)
+ if err != nil {
+ return nil, fmt.Errorf("unable to scan sql row into user: %w", err)
+ }
+
+ if hasIBDCreds && user.IBDUsername == nil {
+ continue
+ }
+ users = append(users, user)
+ }
+
+ return users, nil
+}
+
+func AddIBDCreds(
+ ctx context.Context,
+ exec TransactionExecutor,
+ kms keys.KeyManagementService,
+ keyName, subject, username, password string,
+) error {
+ encryptedPass, encryptedKey, err := keys.Encrypt(ctx, kms, keyName, []byte(password))
+ if err != nil {
+ return fmt.Errorf("unable to encrypt password: %w", err)
+ }
+
+ tx, err := exec.BeginTx(ctx, nil)
+ if err != nil {
+ return err
+ }
+ defer func(tx *sql.Tx) {
+ _ = tx.Rollback()
+ }(tx)
+
+ var keyId int
+ err = tx.QueryRowContext(ctx, `
+INSERT INTO keys (kms_key_name, encrypted_key)
+VALUES ($1, $2)
+RETURNING id;`, keyName, encryptedKey).Scan(&keyId)
+ if err != nil {
+ return fmt.Errorf("unable to add ibd creds key: %w", err)
+ }
+
+ _, err = exec.ExecContext(ctx, `
+UPDATE users
+SET ibd_username = $2,
+ ibd_password = $3,
+ encryption_key = $4
+WHERE subject = $1;`, subject, username, encryptedPass, keyId)
+ if err != nil {
+ return fmt.Errorf("unable to add ibd creds to user: %w", err)
+ }
+
+ if err = tx.Commit(); err != nil {
+ return fmt.Errorf("unable to commit transaction: %w", err)
+ }
+
+ return nil
+}
+
+func GetIBDCreds(
+ ctx context.Context,
+ exec Executor,
+ kms keys.KeyManagementService,
+ subject string,
+) (
+ username string,
+ password string,
+ err error,
+) {
+ row := exec.QueryRowContext(ctx, `
+SELECT ibd_username, ibd_password, encrypted_key, kms_key_name
+FROM users
+INNER JOIN public.keys k on k.id = users.encryption_key
+WHERE subject = $1;`, subject)
+
+ var encryptedPass, encryptedKey []byte
+ var keyName string
+ err = row.Scan(&username, &encryptedPass, &encryptedKey, &keyName)
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return "", "", ErrIBDCredsNotFound
+ }
+ return "", "", fmt.Errorf("unable to scan sql row into ibd creds: %w", err)
+ }
+
+ passwordBytes, err := keys.Decrypt(ctx, kms, keyName, encryptedPass, encryptedKey)
+ if err != nil {
+ return "", "", fmt.Errorf("unable to decrypt password: %w", err)
+ }
+
+ return username, string(passwordBytes), nil
+}
+
+type User struct {
+ Subject string
+ IBDUsername *string
+ EncryptedIBDPassword *string
+ EncryptionKeyID *int
+}
diff --git a/backend/internal/ibd/auth.go b/backend/internal/ibd/auth.go
new file mode 100644
index 0000000..7b82057
--- /dev/null
+++ b/backend/internal/ibd/auth.go
@@ -0,0 +1,333 @@
+package ibd
+
+import (
+ "bytes"
+ "context"
+ "encoding/base64"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "strings"
+
+ "github.com/ansg191/ibd-trader-backend/internal/ibd/transport"
+ "golang.org/x/net/html"
+)
+
+const (
+ signInUrl = "https://myibd.investors.com/secure/signin.aspx?eurl=https%3A%2F%2Fwww.investors.com"
+ authenticateUrl = "https://sso.accounts.dowjones.com/authenticate"
+ postAuthUrl = "https://sso.accounts.dowjones.com/postauth/handler"
+ cookieName = ".ASPXAUTH"
+)
+
+var ErrAuthCookieNotFound = errors.New("cookie not found")
+var ErrBadCredentials = errors.New("bad credentials")
+
+func (c *Client) Authenticate(
+ ctx context.Context,
+ username,
+ password string,
+) (*http.Cookie, error) {
+ cfg, err := c.getLoginPage(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ token, params, err := c.sendAuthRequest(ctx, cfg, username, password)
+ if err != nil {
+ return nil, err
+ }
+
+ return c.sendPostAuth(ctx, token, params)
+}
+
+func (c *Client) getLoginPage(ctx context.Context) (*authConfig, error) {
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, signInUrl, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ resp, err := c.Do(req, withRequiredProps(transport.PropertiesReliable))
+ if err != nil {
+ return nil, err
+ }
+ defer func(Body io.ReadCloser) {
+ _ = Body.Close()
+ }(resp.Body)
+
+ if resp.StatusCode != http.StatusOK {
+ content, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read response body: %w", err)
+ }
+ return nil, fmt.Errorf(
+ "unexpected status code %d: %s",
+ resp.StatusCode,
+ string(content),
+ )
+ }
+
+ node, err := html.Parse(resp.Body)
+ if err != nil {
+ return nil, err
+ }
+
+ cfg, err := extractAuthConfig(node)
+ if err != nil {
+ return nil, fmt.Errorf("failed to extract auth config: %w", err)
+ }
+
+ return cfg, nil
+}
+
+func (c *Client) sendAuthRequest(ctx context.Context, cfg *authConfig, username, password string) (string, string, error) {
+ body := authRequestBody{
+ ClientId: cfg.ClientID,
+ RedirectUri: cfg.CallbackURL,
+ Tenant: "sso",
+ ResponseType: cfg.ExtraParams.ResponseType,
+ Username: username,
+ Password: password,
+ Scope: cfg.ExtraParams.Scope,
+ State: cfg.ExtraParams.State,
+ Headers: struct {
+ XRemoteUser string `json:"x-_remote-_user"`
+ }(struct{ XRemoteUser string }{
+ XRemoteUser: username,
+ }),
+ XOidcProvider: "localop",
+ Protocol: cfg.ExtraParams.Protocol,
+ Nonce: cfg.ExtraParams.Nonce,
+ UiLocales: cfg.ExtraParams.UiLocales,
+ Csrf: cfg.ExtraParams.Csrf,
+ Intstate: cfg.ExtraParams.Intstate,
+ Connection: "DJldap",
+ }
+ bodyJson, err := json.Marshal(body)
+ if err != nil {
+ return "", "", fmt.Errorf("failed to marshal auth request body: %w", err)
+ }
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, authenticateUrl, bytes.NewReader(bodyJson))
+ if err != nil {
+ return "", "", err
+ }
+
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Auth0-Client", "eyJuYW1lIjoiYXV0aDAuanMtdWxwIiwidmVyc2lvbiI6IjkuMjQuMSJ9")
+
+ resp, err := c.Do(req,
+ withRequiredProps(transport.PropertiesReliable),
+ withExpectedStatuses(http.StatusOK, http.StatusUnauthorized))
+ if err != nil {
+ return "", "", err
+ }
+ defer func(Body io.ReadCloser) {
+ _ = Body.Close()
+ }(resp.Body)
+
+ if resp.StatusCode == http.StatusUnauthorized {
+ return "", "", ErrBadCredentials
+ } else if resp.StatusCode != http.StatusOK {
+ content, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return "", "", fmt.Errorf("failed to read response body: %w", err)
+ }
+ return "", "", fmt.Errorf(
+ "unexpected status code %d: %s",
+ resp.StatusCode,
+ string(content),
+ )
+ }
+
+ node, err := html.Parse(resp.Body)
+ if err != nil {
+ return "", "", err
+ }
+
+ return extractTokenParams(node)
+}
+
+func (c *Client) sendPostAuth(ctx context.Context, token, params string) (*http.Cookie, error) {
+ body := fmt.Sprintf("token=%s&params=%s", token, params)
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, postAuthUrl, strings.NewReader(body))
+ if err != nil {
+ return nil, err
+ }
+
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+
+ resp, err := c.Do(req, withRequiredProps(transport.PropertiesReliable))
+ if err != nil {
+ return nil, err
+ }
+ defer func(Body io.ReadCloser) {
+ _ = Body.Close()
+ }(resp.Body)
+
+ if resp.StatusCode != http.StatusOK {
+ content, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read response body: %w", err)
+ }
+ return nil, fmt.Errorf(
+ "unexpected status code %d: %s",
+ resp.StatusCode,
+ string(content),
+ )
+ }
+
+ // Extract cookie
+ for _, cookie := range resp.Cookies() {
+ if cookie.Name == cookieName {
+ return cookie, nil
+ }
+ }
+
+ return nil, ErrAuthCookieNotFound
+}
+
+func extractAuthConfig(node *html.Node) (*authConfig, error) {
+ // Find `root` element
+ root := findId(node, "root")
+ if root == nil {
+ return nil, fmt.Errorf("root element not found")
+ }
+
+ // Get adjacent script element
+ var script *html.Node
+ for s := root.NextSibling; s != nil; s = s.NextSibling {
+ if s.Type == html.ElementNode && s.Data == "script" {
+ script = s
+ break
+ }
+ }
+
+ if script == nil {
+ return nil, fmt.Errorf("script element not found")
+ }
+
+ // Get script content
+ content := extractText(script)
+
+ // Find `AUTH_CONFIG` variable
+ const authConfigVar = "const AUTH_CONFIG = '"
+ i := strings.Index(content, authConfigVar)
+ if i == -1 {
+ return nil, fmt.Errorf("AUTH_CONFIG not found")
+ }
+
+ // Find end of `AUTH_CONFIG` variable
+ j := strings.Index(content[i+len(authConfigVar):], "'")
+
+ // Extract `AUTH_CONFIG` value
+ authConfigJSONB64 := content[i+len(authConfigVar) : i+len(authConfigVar)+j]
+
+ // Decode `AUTH_CONFIG` value
+ authConfigJSON, err := base64.StdEncoding.DecodeString(authConfigJSONB64)
+ if err != nil {
+ return nil, fmt.Errorf("failed to decode AUTH_CONFIG: %w", err)
+ }
+
+ // Unmarshal `AUTH_CONFIG` value
+ var cfg authConfig
+ if err = json.Unmarshal(authConfigJSON, &cfg); err != nil {
+ return nil, fmt.Errorf("failed to unmarshal AUTH_CONFIG: %w", err)
+ }
+
+ return &cfg, nil
+}
+
+type authConfig struct {
+ Auth0Domain string `json:"auth0Domain"`
+ CallbackURL string `json:"callbackURL"`
+ ClientID string `json:"clientID"`
+ ExtraParams struct {
+ Protocol string `json:"protocol"`
+ Scope string `json:"scope"`
+ ResponseType string `json:"response_type"`
+ Nonce string `json:"nonce"`
+ UiLocales string `json:"ui_locales"`
+ Csrf string `json:"_csrf"`
+ Intstate string `json:"_intstate"`
+ State string `json:"state"`
+ } `json:"extraParams"`
+ InternalOptions struct {
+ ResponseType string `json:"response_type"`
+ ClientId string `json:"client_id"`
+ Scope string `json:"scope"`
+ RedirectUri string `json:"redirect_uri"`
+ UiLocales string `json:"ui_locales"`
+ Eurl string `json:"eurl"`
+ Nonce string `json:"nonce"`
+ State string `json:"state"`
+ Resource string `json:"resource"`
+ Protocol string `json:"protocol"`
+ Client string `json:"client"`
+ } `json:"internalOptions"`
+ IsThirdPartyClient bool `json:"isThirdPartyClient"`
+ AuthorizationServer struct {
+ Url string `json:"url"`
+ Issuer string `json:"issuer"`
+ } `json:"authorizationServer"`
+}
+
+func extractTokenParams(node *html.Node) (token string, params string, err error) {
+ inputs := findChildrenRecursive(node, func(node *html.Node) bool {
+ return node.Type == html.ElementNode && node.Data == "input"
+ })
+
+ var tokenNode, paramsNode *html.Node
+ for _, input := range inputs {
+ for _, attr := range input.Attr {
+ if attr.Key == "name" && attr.Val == "token" {
+ tokenNode = input
+ } else if attr.Key == "name" && attr.Val == "params" {
+ paramsNode = input
+ }
+ }
+ }
+
+ if tokenNode == nil {
+ return "", "", fmt.Errorf("token input not found")
+ }
+ if paramsNode == nil {
+ return "", "", fmt.Errorf("params input not found")
+ }
+
+ for _, attr := range tokenNode.Attr {
+ if attr.Key == "value" {
+ token = attr.Val
+ }
+ }
+ for _, attr := range paramsNode.Attr {
+ if attr.Key == "value" {
+ params = attr.Val
+ }
+ }
+
+ return
+}
+
+type authRequestBody struct {
+ ClientId string `json:"client_id"`
+ RedirectUri string `json:"redirect_uri"`
+ Tenant string `json:"tenant"`
+ ResponseType string `json:"response_type"`
+ Username string `json:"username"`
+ Password string `json:"password"`
+ Scope string `json:"scope"`
+ State string `json:"state"`
+ Headers struct {
+ XRemoteUser string `json:"x-_remote-_user"`
+ } `json:"headers"`
+ XOidcProvider string `json:"x-_oidc-_provider"`
+ Protocol string `json:"protocol"`
+ Nonce string `json:"nonce"`
+ UiLocales string `json:"ui_locales"`
+ Csrf string `json:"_csrf"`
+ Intstate string `json:"_intstate"`
+ Connection string `json:"connection"`
+}
diff --git a/backend/internal/ibd/auth_test.go b/backend/internal/ibd/auth_test.go
new file mode 100644
index 0000000..157b507
--- /dev/null
+++ b/backend/internal/ibd/auth_test.go
@@ -0,0 +1,215 @@
+package ibd
+
+import (
+ "context"
+ "encoding/json"
+ "net/http"
+ "net/url"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/ansg191/ibd-trader-backend/internal/ibd/transport"
+ "github.com/jarcoal/httpmock"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "golang.org/x/net/html"
+)
+
+const extractAuthHtml = `
+<!doctype html>
+<html lang="en">
+ <head>
+ <title>Log in · Dow Jones</title>
+ <meta charset="UTF-8"/>
+ <meta name="theme-color" content="white"/>
+ <meta name="viewport" content="width=device-width,initial-scale=1"/>
+ <meta name="description" content="Dow Jones One Identity Login page"/>
+ <link rel="apple-touch-icon" sizes="180x180" href="/one_identity_login_pages/login/0ce1520be322adcd762319573804f56d/images/apple-touch-icon.png"/>
+ <link rel="icon" type="image/png" sizes="32x32" href="/one_identity_login_pages/login/0ce1520be322adcd762319573804f56d/images/favicon-32x32.png"/>
+ <link rel="icon" type="image/png" sizes="16x16" href="/one_identity_login_pages/login/0ce1520be322adcd762319573804f56d/images/favicon-16x16.png"/>
+ <link rel="icon" type="image/png" sizes="192x192" href="/one_identity_login_pages/login/0ce1520be322adcd762319573804f56d/images/android-chrome-192x192.png"/>
+ <link rel="icon" type="image/png" sizes="512x512" href="/one_identity_login_pages/login/0ce1520be322adcd762319573804f56d/images/android-chrome-512x512.png"/>
+ <link rel="prefetch" href="https://cdn.optimizely.com/js/14856860742.js"/>
+ <link rel="preconnect" href="//cdn.optimizely.com"/>
+ <link rel="preconnect" href="//logx.optimizely.com"/>
+ <script type="module" crossorigin src="/one_identity_login_pages/login/0ce1520be322adcd762319573804f56d/js/index.js"></script>
+ <link rel="modulepreload" crossorigin href="/one_identity_login_pages/login/0ce1520be322adcd762319573804f56d/js/vendor.js">
+ <link rel="modulepreload" crossorigin href="/one_identity_login_pages/login/0ce1520be322adcd762319573804f56d/js/auth.js">
+ <link rel="modulepreload" crossorigin href="/one_identity_login_pages/login/0ce1520be322adcd762319573804f56d/js/router.js">
+ <link rel="stylesheet" crossorigin href="/one_identity_login_pages/login/0ce1520be322adcd762319573804f56d/assets/styles.css">
+ </head>
+ <body>
+ <div id="root" vaul-drawer-wrapper="" class="root-container"></div>
+ <script>
+ const AUTH_CONFIG = 'eyJhdXRoMERvbWFpbiI6InNzby5hY2NvdW50cy5kb3dqb25lcy5jb20iLCJjYWxsYmFja1VSTCI6Imh0dHBzOi8vbXlpYmQuaW52ZXN0b3JzLmNvbS9vaWRjL2NhbGxiYWNrIiwiY2xpZW50SUQiOiJHU1UxcEcyQnJnZDNQdjJLQm5BWjI0enZ5NXVXU0NRbiIsImV4dHJhUGFyYW1zIjp7InByb3RvY29sIjoib2F1dGgyIiwic2NvcGUiOiJvcGVuaWQgaWRwX2lkIHJvbGVzIGVtYWlsIGdpdmVuX25hbWUgZmFtaWx5X25hbWUgdXVpZCBkalVzZXJuYW1lIGRqU3RhdHVzIHRyYWNraWQgdGFncyBwcnRzIHVwZGF0ZWRfYXQgY3JlYXRlZF9hdCBvZmZsaW5lX2FjY2VzcyBkamlkIiwicmVzcG9uc2VfdHlwZSI6ImNvZGUiLCJub25jZSI6IjY0MDJmYWJiLTFiNzUtNGEyYy1hODRmLTExYWQ2MWFhZGI2YiIsInVpX2xvY2FsZXMiOiJlbi11cy14LWliZC0yMy03IiwiX2NzcmYiOiJOZFVSZ3dPQ3VYRU5URXFDcDhNV25tcGtxd3lva2JjU2E2VV9fLTVib3lWc1NzQVNWTkhLU0EiLCJfaW50c3RhdGUiOiJkZXByZWNhdGVkIiwic3RhdGUiOiJlYXJjN3E2UnE2a3lHS3h5LlltbGlxOU4xRXZvU1V0ejhDVjhuMFZBYzZWc1V4RElSTTRTcmxtSWJXMmsifSwiaW50ZXJuYWxPcHRpb25zIjp7InJlc3BvbnNlX3R5cGUiOiJjb2RlIiwiY2xpZW50X2lkIjoiR1NVMXBHMkJyZ2QzUHYyS0JuQVoyNHp2eTV1V1NDUW4iLCJzY29wZSI6Im9wZW5pZCBpZHBfaWQgcm9sZXMgZW1haWwgZ2l2ZW5fbmFtZSBmYW1pbHlfbmFtZSB1dWlkIGRqVXNlcm5hbWUgZGpTdGF0dXMgdHJhY2tpZCB0YWdzIHBydHMgdXBkYXRlZF9hdCBjcmVhdGVkX2F0IG9mZmxpbmVfYWNjZXNzIGRqaWQiLCJyZWRpcmVjdF91cmkiOiJodHRwczovL215aWJkLmludmVzdG9ycy5jb20vb2lkYy9jYWxsYmFjayIsInVpX2xvY2FsZXMiOiJlbi11cy14LWliZC0yMy03IiwiZXVybCI6Imh0dHBzOi8vd3d3LmludmVzdG9ycy5jb20iLCJub25jZSI6IjY0MDJmYWJiLTFiNzUtNGEyYy1hODRmLTExYWQ2MWFhZGI2YiIsInN0YXRlIjoiZWFyYzdxNlJxNmt5R0t4eS5ZbWxpcTlOMUV2b1NVdHo4Q1Y4bjBWQWM2VnNVeERJUk00U3JsbUliVzJrIiwicmVzb3VyY2UiOiJodHRwcyUzQSUyRiUyRnd3dy5pbnZlc3RvcnMuY29tIiwicHJvdG9jb2wiOiJvYXV0aDIiLCJjbGllbnQiOiJHU1UxcEcyQnJnZDNQdjJLQm5BWjI0enZ5NXVXU0NRbiJ9LCJpc1RoaXJkUGFydHlDbGllbnQiOmZhbHNlLCJhdXRob3JpemF0aW9uU2VydmVyIjp7InVybCI6Imh0dHBzOi8vc3NvLmFjY291bnRzLmRvd2pvbmVzLmNvbSIsImlzc3VlciI6Imh0dHBzOi8vc3NvLmFjY291bnRzLmRvd2pvbmVzLmNvbS8ifX0='
+ const ENV_CONFIG = 'production'
+
+ window.sessionStorage.setItem('auth-config', AUTH_CONFIG)
+ window.sessionStorage.setItem('env-config', ENV_CONFIG)
+ </script>
+ <script src="https://cdn.optimizely.com/js/14856860742.js" crossorigin="anonymous"></script>
+ <script type="text/javascript" src="https://dcdd29eaa743c493e732-7dc0216bc6cc2f4ed239035dfc17235b.ssl.cf3.rackcdn.com/tags/wsj/hokbottom.js"></script>
+ <script type="text/javascript" src="/R8As7u5b/init.js"></script>
+ </body>
+</html>
+`
+
+func Test_extractAuthConfig(t *testing.T) {
+ t.Parallel()
+ expectedJSON := `
+{
+ "auth0Domain": "sso.accounts.dowjones.com",
+ "callbackURL": "https://myibd.investors.com/oidc/callback",
+ "clientID": "GSU1pG2Brgd3Pv2KBnAZ24zvy5uWSCQn",
+ "extraParams": {
+ "protocol": "oauth2",
+ "scope": "openid idp_id roles email given_name family_name uuid djUsername djStatus trackid tags prts updated_at created_at offline_access djid",
+ "response_type": "code",
+ "nonce": "6402fabb-1b75-4a2c-a84f-11ad61aadb6b",
+ "ui_locales": "en-us-x-ibd-23-7",
+ "_csrf": "NdURgwOCuXENTEqCp8MWnmpkqwyokbcSa6U__-5boyVsSsASVNHKSA",
+ "_intstate": "deprecated",
+ "state": "earc7q6Rq6kyGKxy.Ymliq9N1EvoSUtz8CV8n0VAc6VsUxDIRM4SrlmIbW2k"
+ },
+ "internalOptions": {
+ "response_type": "code",
+ "client_id": "GSU1pG2Brgd3Pv2KBnAZ24zvy5uWSCQn",
+ "scope": "openid idp_id roles email given_name family_name uuid djUsername djStatus trackid tags prts updated_at created_at offline_access djid",
+ "redirect_uri": "https://myibd.investors.com/oidc/callback",
+ "ui_locales": "en-us-x-ibd-23-7",
+ "eurl": "https://www.investors.com",
+ "nonce": "6402fabb-1b75-4a2c-a84f-11ad61aadb6b",
+ "state": "earc7q6Rq6kyGKxy.Ymliq9N1EvoSUtz8CV8n0VAc6VsUxDIRM4SrlmIbW2k",
+ "resource": "https%3A%2F%2Fwww.investors.com",
+ "protocol": "oauth2",
+ "client": "GSU1pG2Brgd3Pv2KBnAZ24zvy5uWSCQn"
+ },
+ "isThirdPartyClient": false,
+ "authorizationServer": {
+ "url": "https://sso.accounts.dowjones.com",
+ "issuer": "https://sso.accounts.dowjones.com/"
+ }
+}`
+ var expectedCfg authConfig
+ err := json.Unmarshal([]byte(expectedJSON), &expectedCfg)
+ require.NoError(t, err)
+
+ node, err := html.Parse(strings.NewReader(extractAuthHtml))
+ require.NoError(t, err)
+
+ cfg, err := extractAuthConfig(node)
+ require.NoError(t, err)
+ require.NotNil(t, cfg)
+
+ assert.Equal(t, expectedCfg, *cfg)
+}
+
+const extractTokenParamsHtml = `
+<form method="post" name="hiddenform" action="https://sso.accounts.dowjones.com/postauth/handler">
+ <input type="hidden" name="token" value="eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJkalVzZXJuYW1lIjoiYW5zZzE5MUB5YWhvby5jb20iLCJpZCI6IjAxZWFmNTE5LTA0OWItNGIyOS04ZjZhLWQyNjIyZjNiMWJjNiIsImdpdmVuX25hbWUiOiJBbnNodWwiLCJmYW1pbHlfbmFtZSI6Ikd1cHRhIiwibmFtZSI6IkFuc2h1bCBHdXB0YSIsImVtYWlsIjoiYW5zZzE5MUB5YWhvby5jb20iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYWNjb3VudF9pZCI6Ijk5NzI5Mzc0NDIxMiIsImRqaWQiOiIwMWVhZjUxOS0wNDliLTRiMjktOGY2YS1kMjYyMmYzYjFiYzYiLCJ0cmFja2lkIjoiMWM0NGQyMTRmM2VlYTZiMzcyNDYxNDc3NDc0NDMyODJmMTRmY2ZjYmI4NmE4NmVjYTI0MDc2ZDVlMzU4ZmUzZCIsInVwZGF0ZWRfYXQiOjE3MTI3OTQxNTYsImNyZWF0ZWRfYXQiOjE3MTI3OTQxNTYsInVhdCI6MTcyMjU1MjMzOSwicm9sZXMiOlsiQkFSUk9OUy1DSEFOR0VQQVNTV09SRCIsIkZSRUVSRUctQkFTRSIsIkZSRUVSRUctSU5ESVZJRFVBTCIsIldTSi1DSEFOR0VQQVNTV09SRCIsIldTSi1BUkNISVZFIiwiV1NKLVNFTEZTRVJWIiwiSUJELUlORElWSURVQUwiLCJJQkQtSUNBIiwiSUJELUFFSSJdLCJkalN0YXR1cyI6WyJJQkRfVVNFUlMiXSwicHJ0cyI6IjIwMjQwNDEwMTcwOTE2LTA0MDAiLCJjcmVhdGVUaW1lc3RhbXAiOiIyMDI0MDQxMTAwMDkxNloiLCJzdXVpZCI6Ik1ERmxZV1kxTVRrdE1EUTVZaTAwWWpJNUxUaG1ObUV0WkRJMk1qSm1NMkl4WW1NMi50S09fM014VkVReks3dE5qTkdxUXNZMlBNbXp5cUxGRkxySnBrZGhrcDZrIiwic3ViIjoiMDFlYWY1MTktMDQ5Yi00YjI5LThmNmEtZDI2MjJmM2IxYmM2IiwiYXVkIjoiR1NVMXBHMkJyZ2QzUHYyS0JuQVoyNHp2eTV1V1NDUW4iLCJpc3MiOiJodHRwczovL3Nzby5hY2NvdW50cy5kb3dqb25lcy5jb20vIiwiaWF0IjoxNzIyNTUyMzM5MTI0LCJleHAiOjE3MjI1NTI3NzExMjR9.HVn33IFttQrG1JKEV2oElIy3mm8TJ-3GpV_jqZE81_cY22z4IMWPz7zUGz0WgOoUuQGyrYXiaNrfxD6GaoimRL6wxrH0Fy5iYC3dOEdlGfldswfgEOwSiZkBJRc2wWTVQLm93EeJ5ZZyKIXGY_ZkwcYfhrwaTAz8McBBnRmZkm0eiNJQ5YK-QZL-yFa3DxMdPPW91jLA2rjOIVnJ-I_0nMwaJ4ZwXHG2Sw4aAXxtbFqIqarKwIdOUSpRFOCSYpeWcxmbliurKlP1djrKrYgYSZxsKOHZhnbikZDtoDCAlPRlfbKOO4u36KXooDYGJ6p__s2kGCLOLLkP_QLHMNU8Jg">
+ <input type="hidden" name="params" value="%7B%22response_type%22%3A%22code%22%2C%22client_id%22%3A%22GSU1pG2Brgd3Pv2KBnAZ24zvy5uWSCQn%22%2C%22redirect_uri%22%3A%22https%3A%2F%2Fmyibd.investors.com%2Foidc%2Fcallback%22%2C%22state%22%3A%22J-ihUYZIYzey682D.aOLszineC9qjPkM6Y6wWgFC61ABYBiuK9u48AHTFS5I%22%2C%22scope%22%3A%22openid%20idp_id%20roles%20email%20given_name%20family_name%20uuid%20djUsername%20djStatus%20trackid%20tags%20prts%20updated_at%20created_at%20offline_access%20djid%22%2C%22nonce%22%3A%22457bb517-f490-43b6-a55f-d93f90d698ad%22%7D">
+ <noscript>
+ <p>Script is disabled. Click Submit to continue.</p>
+ <input type="submit" value="Submit">
+ </noscript>
+</form>
+`
+const extractTokenExpectedToken = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJkalVzZXJuYW1lIjoiYW5zZzE5MUB5YWhvby5jb20iLCJpZCI6IjAxZWFmNTE5LTA0OWItNGIyOS04ZjZhLWQyNjIyZjNiMWJjNiIsImdpdmVuX25hbWUiOiJBbnNodWwiLCJmYW1pbHlfbmFtZSI6Ikd1cHRhIiwibmFtZSI6IkFuc2h1bCBHdXB0YSIsImVtYWlsIjoiYW5zZzE5MUB5YWhvby5jb20iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYWNjb3VudF9pZCI6Ijk5NzI5Mzc0NDIxMiIsImRqaWQiOiIwMWVhZjUxOS0wNDliLTRiMjktOGY2YS1kMjYyMmYzYjFiYzYiLCJ0cmFja2lkIjoiMWM0NGQyMTRmM2VlYTZiMzcyNDYxNDc3NDc0NDMyODJmMTRmY2ZjYmI4NmE4NmVjYTI0MDc2ZDVlMzU4ZmUzZCIsInVwZGF0ZWRfYXQiOjE3MTI3OTQxNTYsImNyZWF0ZWRfYXQiOjE3MTI3OTQxNTYsInVhdCI6MTcyMjU1MjMzOSwicm9sZXMiOlsiQkFSUk9OUy1DSEFOR0VQQVNTV09SRCIsIkZSRUVSRUctQkFTRSIsIkZSRUVSRUctSU5ESVZJRFVBTCIsIldTSi1DSEFOR0VQQVNTV09SRCIsIldTSi1BUkNISVZFIiwiV1NKLVNFTEZTRVJWIiwiSUJELUlORElWSURVQUwiLCJJQkQtSUNBIiwiSUJELUFFSSJdLCJkalN0YXR1cyI6WyJJQkRfVVNFUlMiXSwicHJ0cyI6IjIwMjQwNDEwMTcwOTE2LTA0MDAiLCJjcmVhdGVUaW1lc3RhbXAiOiIyMDI0MDQxMTAwMDkxNloiLCJzdXVpZCI6Ik1ERmxZV1kxTVRrdE1EUTVZaTAwWWpJNUxUaG1ObUV0WkRJMk1qSm1NMkl4WW1NMi50S09fM014VkVReks3dE5qTkdxUXNZMlBNbXp5cUxGRkxySnBrZGhrcDZrIiwic3ViIjoiMDFlYWY1MTktMDQ5Yi00YjI5LThmNmEtZDI2MjJmM2IxYmM2IiwiYXVkIjoiR1NVMXBHMkJyZ2QzUHYyS0JuQVoyNHp2eTV1V1NDUW4iLCJpc3MiOiJodHRwczovL3Nzby5hY2NvdW50cy5kb3dqb25lcy5jb20vIiwiaWF0IjoxNzIyNTUyMzM5MTI0LCJleHAiOjE3MjI1NTI3NzExMjR9.HVn33IFttQrG1JKEV2oElIy3mm8TJ-3GpV_jqZE81_cY22z4IMWPz7zUGz0WgOoUuQGyrYXiaNrfxD6GaoimRL6wxrH0Fy5iYC3dOEdlGfldswfgEOwSiZkBJRc2wWTVQLm93EeJ5ZZyKIXGY_ZkwcYfhrwaTAz8McBBnRmZkm0eiNJQ5YK-QZL-yFa3DxMdPPW91jLA2rjOIVnJ-I_0nMwaJ4ZwXHG2Sw4aAXxtbFqIqarKwIdOUSpRFOCSYpeWcxmbliurKlP1djrKrYgYSZxsKOHZhnbikZDtoDCAlPRlfbKOO4u36KXooDYGJ6p__s2kGCLOLLkP_QLHMNU8Jg"
+const extractTokenExpectedParams = "%7B%22response_type%22%3A%22code%22%2C%22client_id%22%3A%22GSU1pG2Brgd3Pv2KBnAZ24zvy5uWSCQn%22%2C%22redirect_uri%22%3A%22https%3A%2F%2Fmyibd.investors.com%2Foidc%2Fcallback%22%2C%22state%22%3A%22J-ihUYZIYzey682D.aOLszineC9qjPkM6Y6wWgFC61ABYBiuK9u48AHTFS5I%22%2C%22scope%22%3A%22openid%20idp_id%20roles%20email%20given_name%20family_name%20uuid%20djUsername%20djStatus%20trackid%20tags%20prts%20updated_at%20created_at%20offline_access%20djid%22%2C%22nonce%22%3A%22457bb517-f490-43b6-a55f-d93f90d698ad%22%7D"
+
+func Test_extractTokenParams(t *testing.T) {
+ t.Parallel()
+
+ node, err := html.Parse(strings.NewReader(extractTokenParamsHtml))
+ require.NoError(t, err)
+
+ token, params, err := extractTokenParams(node)
+ require.NoError(t, err)
+ assert.Equal(t, extractTokenExpectedToken, token)
+ assert.Equal(t, extractTokenExpectedParams, params)
+}
+
+func TestClient_Authenticate(t *testing.T) {
+ t.Parallel()
+
+ expectedVal := "test-cookie"
+ expectedExp := time.Now().Add(time.Hour).Round(time.Second).In(time.UTC)
+
+ tp := httpmock.NewMockTransport()
+ tp.RegisterResponder("GET", signInUrl,
+ httpmock.NewStringResponder(http.StatusOK, extractAuthHtml))
+ tp.RegisterResponder("POST", authenticateUrl,
+ func(request *http.Request) (*http.Response, error) {
+ var body authRequestBody
+ require.NoError(t, json.NewDecoder(request.Body).Decode(&body))
+ assert.Equal(t, "abc", body.Username)
+ assert.Equal(t, "xyz", body.Password)
+
+ return httpmock.NewStringResponse(http.StatusOK, extractTokenParamsHtml), nil
+ })
+ tp.RegisterResponder("POST", postAuthUrl,
+ func(request *http.Request) (*http.Response, error) {
+ require.NoError(t, request.ParseForm())
+ assert.Equal(t, extractTokenExpectedToken, request.Form.Get("token"))
+
+ params, err := url.QueryUnescape(extractTokenExpectedParams)
+ require.NoError(t, err)
+ assert.Equal(t, params, request.Form.Get("params"))
+
+ resp := httpmock.NewStringResponse(http.StatusOK, "OK")
+ cookie := &http.Cookie{Name: cookieName, Value: expectedVal, Expires: expectedExp}
+ resp.Header.Set("Set-Cookie", cookie.String())
+ return resp, nil
+ })
+
+ client := NewClient(nil, nil, newTransport(tp))
+
+ cookie, err := client.Authenticate(context.Background(), "abc", "xyz")
+ require.NoError(t, err)
+ require.NotNil(t, cookie)
+
+ assert.Equal(t, expectedVal, cookie.Value)
+ assert.Equal(t, expectedExp, cookie.Expires)
+}
+
+func TestClient_Authenticate_401(t *testing.T) {
+ t.Parallel()
+
+ tp := httpmock.NewMockTransport()
+ tp.RegisterResponder("GET", signInUrl,
+ httpmock.NewStringResponder(http.StatusOK, extractAuthHtml))
+ tp.RegisterResponder("POST", authenticateUrl,
+ func(request *http.Request) (*http.Response, error) {
+ var body authRequestBody
+ require.NoError(t, json.NewDecoder(request.Body).Decode(&body))
+ assert.Equal(t, "abc", body.Username)
+ assert.Equal(t, "xyz", body.Password)
+
+ return httpmock.NewStringResponse(http.StatusUnauthorized, `{"name":"ValidationError","code":"ERR016","message":"Wrong username or password","description":"Wrong username or password"}`), nil
+ })
+
+ client := NewClient(nil, nil, newTransport(tp))
+
+ cookie, err := client.Authenticate(context.Background(), "abc", "xyz")
+ assert.Nil(t, cookie)
+ assert.ErrorIs(t, err, ErrBadCredentials)
+}
+
+type testReliableTransport http.Client
+
+func newTransport(tp *httpmock.MockTransport) *testReliableTransport {
+ return (*testReliableTransport)(&http.Client{Transport: tp})
+}
+
+func (t *testReliableTransport) String() string {
+ return "testReliableTransport"
+}
+
+func (t *testReliableTransport) Do(req *http.Request) (*http.Response, error) {
+ return (*http.Client)(t).Do(req)
+}
+
+func (t *testReliableTransport) Properties() transport.Properties {
+ return transport.PropertiesFree | transport.PropertiesReliable
+}
diff --git a/backend/internal/ibd/check_ibd_username.go b/backend/internal/ibd/check_ibd_username.go
new file mode 100644
index 0000000..b026151
--- /dev/null
+++ b/backend/internal/ibd/check_ibd_username.go
@@ -0,0 +1,68 @@
+package ibd
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+)
+
+const (
+ checkUsernameUrl = "https://sso.accounts.dowjones.com/getuser"
+)
+
+func (c *Client) CheckIBDUsername(ctx context.Context, username string) (bool, error) {
+ cfg, err := c.getLoginPage(ctx)
+ if err != nil {
+ return false, err
+ }
+
+ return c.checkIBDUsername(ctx, cfg, username)
+}
+
+func (c *Client) checkIBDUsername(ctx context.Context, cfg *authConfig, username string) (bool, error) {
+ body := map[string]string{
+ "username": username,
+ "csrf": cfg.ExtraParams.Csrf,
+ }
+ bodyJson, err := json.Marshal(body)
+ if err != nil {
+ return false, err
+ }
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, checkUsernameUrl, bytes.NewReader(bodyJson))
+ if err != nil {
+ return false, err
+ }
+
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("X-REMOTE-USER", username)
+ req.Header.Set("X-REQUEST-EDITIONID", "IBD-EN_US")
+ req.Header.Set("X-REQUEST-SCHEME", "https")
+
+ resp, err := c.Do(req, withExpectedStatuses(http.StatusOK, http.StatusUnauthorized))
+ if err != nil {
+ return false, err
+ }
+ defer func(Body io.ReadCloser) {
+ _ = Body.Close()
+ }(resp.Body)
+
+ if resp.StatusCode == http.StatusUnauthorized {
+ return false, nil
+ } else if resp.StatusCode != http.StatusOK {
+ contentBytes, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return false, fmt.Errorf("failed to read response body: %w", err)
+ }
+ content := string(contentBytes)
+ return false, fmt.Errorf(
+ "unexpected status code %d: %s",
+ resp.StatusCode,
+ content,
+ )
+ }
+ return true, nil
+}
diff --git a/backend/internal/ibd/client.go b/backend/internal/ibd/client.go
new file mode 100644
index 0000000..c8575e3
--- /dev/null
+++ b/backend/internal/ibd/client.go
@@ -0,0 +1,97 @@
+package ibd
+
+import (
+ "context"
+ "errors"
+ "log/slog"
+ "net/http"
+ "slices"
+
+ "github.com/ansg191/ibd-trader-backend/internal/database"
+ "github.com/ansg191/ibd-trader-backend/internal/ibd/transport"
+ "github.com/ansg191/ibd-trader-backend/internal/keys"
+)
+
+var ErrNoAvailableCookies = errors.New("no available cookies")
+var ErrNoAvailableTransports = errors.New("no available transports")
+
+type Client struct {
+ transports []transport.Transport
+ db database.Executor
+ kms keys.KeyManagementService
+}
+
+func NewClient(
+ db database.Executor,
+ kms keys.KeyManagementService,
+ transports ...transport.Transport,
+) *Client {
+ return &Client{transports, db, kms}
+}
+
+func (c *Client) getCookie(ctx context.Context, subject *string) (uint, *http.Cookie, error) {
+ if subject == nil {
+ // No subject requirement, get any cookie
+ cookie, err := database.GetAnyCookie(ctx, c.db, c.kms)
+ if err != nil {
+ return 0, nil, err
+ }
+ if cookie == nil {
+ return 0, nil, ErrNoAvailableCookies
+ }
+
+ return cookie.ID, cookie.ToHTTPCookie(), nil
+ }
+
+ // Get cookie by subject
+ cookies, err := database.GetCookies(ctx, c.db, c.kms, *subject, false)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ if len(cookies) == 0 {
+ return 0, nil, ErrNoAvailableCookies
+ }
+
+ cookie := cookies[0]
+
+ return cookie.ID, cookie.ToHTTPCookie(), nil
+}
+
+func (c *Client) Do(req *http.Request, opts ...optionFunc) (*http.Response, error) {
+ o := defaultOptions
+ for _, opt := range opts {
+ opt(&o)
+ }
+
+ // Sort and filter transports by properties
+ transports := transport.FilterTransports(c.transports, o.requiredProps)
+ transport.SortTransports(transports)
+
+ for _, tp := range transports {
+ resp, err := tp.Do(req)
+ if errors.Is(err, transport.ErrUnsupportedRequest) {
+ // Skip unsupported transport
+ continue
+ }
+ if err != nil {
+ slog.ErrorContext(req.Context(), "transport error",
+ "transport", tp.String(),
+ "error", err,
+ )
+ continue
+ }
+ if slices.Contains(o.expectedStatuses, resp.StatusCode) {
+ return resp, nil
+ } else {
+ slog.ErrorContext(req.Context(), "unexpected status code",
+ "transport", tp.String(),
+ "expected", o.expectedStatuses,
+ "actual", resp.StatusCode,
+ )
+ continue
+ }
+ }
+
+ return nil, ErrNoAvailableTransports
+}
diff --git a/backend/internal/ibd/client_test.go b/backend/internal/ibd/client_test.go
new file mode 100644
index 0000000..2368a31
--- /dev/null
+++ b/backend/internal/ibd/client_test.go
@@ -0,0 +1,201 @@
+package ibd
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "log"
+ "math/rand/v2"
+ "testing"
+ "time"
+
+ "github.com/ansg191/ibd-trader-backend/internal/database"
+ "github.com/ansg191/ibd-trader-backend/internal/keys"
+ _ "github.com/lib/pq"
+ "github.com/ory/dockertest/v3"
+ "github.com/ory/dockertest/v3/docker"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+var (
+ db *sql.DB
+ maxTime = time.Date(2100, 1, 1, 0, 0, 0, 0, time.UTC)
+ letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
+)
+
+func TestMain(m *testing.M) {
+ pool, err := dockertest.NewPool("")
+ if err != nil {
+ log.Fatalf("Could not create pool: %s", err)
+ }
+
+ err = pool.Client.Ping()
+ if err != nil {
+ log.Fatalf("Could not connect to Docker: %s", err)
+ }
+
+ resource, err := pool.RunWithOptions(&dockertest.RunOptions{
+ Repository: "postgres",
+ Tag: "16",
+ Env: []string{
+ "POSTGRES_PASSWORD=secret",
+ "POSTGRES_USER=ibd-client-test",
+ "POSTGRES_DB=ibd-client-test",
+ "listen_addresses='*'",
+ },
+ Cmd: []string{
+ "postgres",
+ "-c",
+ "log_statement=all",
+ },
+ }, func(config *docker.HostConfig) {
+ config.AutoRemove = true
+ config.RestartPolicy = docker.RestartPolicy{Name: "no"}
+ })
+ if err != nil {
+ log.Fatalf("Could not start resource: %s", err)
+ }
+
+ hostAndPort := resource.GetHostPort("5432/tcp")
+ databaseUrl := fmt.Sprintf("postgres://ibd-client-test:secret@%s/ibd-client-test?sslmode=disable", hostAndPort)
+
+ // Kill container after 120 seconds
+ _ = resource.Expire(120)
+
+ pool.MaxWait = 120 * time.Second
+ if err = pool.Retry(func() error {
+ db, err = sql.Open("postgres", databaseUrl)
+ if err != nil {
+ return err
+ }
+ return db.Ping()
+ }); err != nil {
+ log.Fatalf("Could not connect to database: %s", err)
+ }
+
+ err = database.Migrate(context.Background(), databaseUrl)
+ if err != nil {
+ log.Fatalf("Could not migrate database: %s", err)
+ }
+
+ defer func() {
+ if err := pool.Purge(resource); err != nil {
+ log.Fatalf("Could not purge resource: %s", err)
+ }
+ }()
+
+ m.Run()
+}
+
+func randStringRunes(n int) string {
+ b := make([]rune, n)
+ for i := range b {
+ b[i] = letterRunes[rand.IntN(len(letterRunes))]
+ }
+ return string(b)
+}
+
+func addCookie(t *testing.T) (user, token string) {
+ t.Helper()
+
+ // Randomly generate a user and token
+ user = randStringRunes(8)
+ token = randStringRunes(16)
+
+ ciphertext, key, err := keys.Encrypt(context.Background(), new(kmsStub), "", []byte(token))
+ require.NoError(t, err)
+
+ tx, err := db.Begin()
+ require.NoError(t, err)
+
+ var keyID uint
+ err = tx.QueryRow(`
+INSERT INTO keys (kms_key_name, encrypted_key)
+ VALUES ('', $1)
+ RETURNING id;
+`, key).Scan(&keyID)
+ require.NoError(t, err)
+
+ _, err = tx.Exec(`
+INSERT
+INTO users (subject, encryption_key)
+VALUES ($1, $2);
+`, user, keyID)
+ require.NoError(t, err)
+
+ _, err = tx.Exec(`
+INSERT
+INTO ibd_tokens (user_subject, token, encryption_key, expires_at)
+VALUES ($1, $2, $3, $4);`,
+ user,
+ ciphertext,
+ keyID,
+ maxTime,
+ )
+ require.NoError(t, err)
+
+ err = tx.Commit()
+ require.NoError(t, err)
+
+ return user, token
+}
+
+func TestClient_getCookie(t *testing.T) {
+ t.Run("no cookies", func(t *testing.T) {
+ client := NewClient(db, new(kmsStub))
+
+ _, _, err := client.getCookie(context.Background(), nil)
+ assert.ErrorIs(t, err, ErrNoAvailableCookies)
+ })
+
+ t.Run("no cookies by subject", func(t *testing.T) {
+ client := NewClient(db, new(kmsStub))
+
+ subject := "test"
+ _, _, err := client.getCookie(context.Background(), &subject)
+ assert.ErrorIs(t, err, ErrNoAvailableCookies)
+ })
+
+ t.Run("get any cookie", func(t *testing.T) {
+ _, token := addCookie(t)
+
+ client := NewClient(db, new(kmsStub))
+
+ _, cookie, err := client.getCookie(context.Background(), nil)
+ require.NoError(t, err)
+ assert.Equal(t, cookieName, cookie.Name)
+ assert.Equal(t, token, cookie.Value)
+ assert.Equal(t, "/", cookie.Path)
+ assert.Equal(t, maxTime, cookie.Expires)
+ assert.Equal(t, "investors.com", cookie.Domain)
+ })
+
+ t.Run("get cookie by subject", func(t *testing.T) {
+ subject, token := addCookie(t)
+
+ client := NewClient(db, new(kmsStub))
+
+ _, cookie, err := client.getCookie(context.Background(), &subject)
+ require.NoError(t, err)
+ assert.Equal(t, cookieName, cookie.Name)
+ assert.Equal(t, token, cookie.Value)
+ assert.Equal(t, "/", cookie.Path)
+ assert.Equal(t, maxTime, cookie.Expires)
+ assert.Equal(t, "investors.com", cookie.Domain)
+ })
+}
+
+type kmsStub struct{}
+
+func (k *kmsStub) Close() error {
+ return nil
+}
+
+func (k *kmsStub) Encrypt(_ context.Context, _ string, plaintext []byte) ([]byte, error) {
+ return plaintext, nil
+}
+
+func (k *kmsStub) Decrypt(_ context.Context, _ string, ciphertext []byte) ([]byte, error) {
+ return ciphertext, nil
+}
diff --git a/backend/internal/ibd/html_helpers.go b/backend/internal/ibd/html_helpers.go
new file mode 100644
index 0000000..0176bc5
--- /dev/null
+++ b/backend/internal/ibd/html_helpers.go
@@ -0,0 +1,99 @@
+package ibd
+
+import (
+ "strings"
+
+ "golang.org/x/net/html"
+)
+
+func findChildren(node *html.Node, f func(node *html.Node) bool) (found []*html.Node) {
+ for c := node.FirstChild; c != nil; c = c.NextSibling {
+ if f(c) {
+ found = append(found, c)
+ }
+ }
+ return
+}
+
+func findChildrenRecursive(node *html.Node, f func(node *html.Node) bool) (found []*html.Node) {
+ if f(node) {
+ found = append(found, node)
+ }
+
+ for c := node.FirstChild; c != nil; c = c.NextSibling {
+ found = append(found, findChildrenRecursive(c, f)...)
+ }
+
+ return
+}
+
+func findClass(node *html.Node, className string) (found *html.Node) {
+ if isClass(node, className) {
+ return node
+ }
+
+ for c := node.FirstChild; c != nil; c = c.NextSibling {
+ if found = findClass(c, className); found != nil {
+ return
+ }
+ }
+
+ return
+}
+
+func isClass(node *html.Node, className string) bool {
+ if node.Type == html.ElementNode {
+ for _, attr := range node.Attr {
+ if attr.Key != "class" {
+ continue
+ }
+ classes := strings.Fields(attr.Val)
+ for _, class := range classes {
+ if class == className {
+ return true
+ }
+ }
+ }
+ }
+ return false
+}
+
+func extractText(node *html.Node) string {
+ var result strings.Builder
+ extractTextInner(node, &result)
+ return result.String()
+}
+
+func extractTextInner(node *html.Node, result *strings.Builder) {
+ if node.Type == html.TextNode {
+ result.WriteString(node.Data)
+ }
+ for c := node.FirstChild; c != nil; c = c.NextSibling {
+ extractTextInner(c, result)
+ }
+}
+
+func findId(node *html.Node, id string) (found *html.Node) {
+ if isId(node, id) {
+ return node
+ }
+
+ for c := node.FirstChild; c != nil; c = c.NextSibling {
+ if found = findId(c, id); found != nil {
+ return
+ }
+ }
+
+ return
+}
+
+func isId(node *html.Node, id string) bool {
+ if node.Type == html.ElementNode {
+ for _, attr := range node.Attr {
+ if attr.Key == "id" && attr.Val == id {
+ return true
+ }
+ }
+ }
+ return false
+}
diff --git a/backend/internal/ibd/html_helpers_test.go b/backend/internal/ibd/html_helpers_test.go
new file mode 100644
index 0000000..d251c39
--- /dev/null
+++ b/backend/internal/ibd/html_helpers_test.go
@@ -0,0 +1,79 @@
+package ibd
+
+import (
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "golang.org/x/net/html"
+)
+
+func Test_findClass(t *testing.T) {
+ t.Parallel()
+ tests := []struct {
+ name string
+ html string
+ className string
+ found bool
+ expData string
+ }{
+ {
+ name: "class exists",
+ html: `<div class="foo"></div>`,
+ className: "foo",
+ found: true,
+ expData: "div",
+ },
+ {
+ name: "class exists nested",
+ html: `<div class="foo"><a class="abc"></a></div>`,
+ className: "abc",
+ found: true,
+ expData: "a",
+ },
+ {
+ name: "class exists multiple",
+ html: `<div class="foo"><a class="foo"></a></div>`,
+ className: "foo",
+ found: true,
+ expData: "div",
+ },
+ {
+ name: "class missing",
+ html: `<div class="abc"><a class="xyz"></a></div>`,
+ className: "foo",
+ found: false,
+ expData: "",
+ },
+ {
+ name: "class missing",
+ html: `<div id="foo"><a abc="xyz"></a></div>`,
+ className: "foo",
+ found: false,
+ expData: "",
+ },
+ {
+ name: "class exists multiple save div",
+ html: `<div class="foo bar"></div>`,
+ className: "bar",
+ found: true,
+ expData: "div",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ node, err := html.Parse(strings.NewReader(tt.html))
+ require.NoError(t, err)
+
+ got := findClass(node, tt.className)
+ if !tt.found {
+ require.Nil(t, got)
+ return
+ }
+ require.NotNil(t, got)
+ assert.Equal(t, tt.expData, got.Data)
+ })
+ }
+}
diff --git a/backend/internal/ibd/ibd50.go b/backend/internal/ibd/ibd50.go
new file mode 100644
index 0000000..52e28aa
--- /dev/null
+++ b/backend/internal/ibd/ibd50.go
@@ -0,0 +1,182 @@
+package ibd
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "io"
+ "log/slog"
+ "net/http"
+ "net/url"
+ "strconv"
+
+ "github.com/ansg191/ibd-trader-backend/internal/database"
+)
+
+const ibd50Url = "https://research.investors.com/Services/SiteAjaxService.asmx/GetIBD50?sortcolumn1=%22ibd100rank%22&sortOrder1=%22asc%22&sortcolumn2=%22%22&sortOrder2=%22ASC%22"
+
+// GetIBD50 returns the IBD50 list.
+func (c *Client) GetIBD50(ctx context.Context) ([]*Stock, error) {
+ // We cannot use the scraper here because scrapfly does not support
+ // Content-Type in GET requests.
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, ibd50Url, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ cookieId, cookie, err := c.getCookie(ctx, nil)
+ if err != nil {
+ return nil, err
+ }
+ req.AddCookie(cookie)
+
+ req.Header.Add("content-type", "application/json; charset=utf-8")
+ // Add browser-emulating headers
+ req.Header.Add("accept", "*/*")
+ req.Header.Add("accept-language", "en-US,en;q=0.9")
+ req.Header.Add("newrelic", "eyJ2IjpbMCwxXSwiZCI6eyJ0eSI6IkJyb3dzZXIiLCJhYyI6IjMzOTYxMDYiLCJhcCI6IjEzODU5ODMwMDEiLCJpZCI6IjM1Zjk5NmM2MzNjYTViMWYiLCJ0ciI6IjM3ZmRhZmJlOGY2YjhmYTMwYWMzOTkzOGNlMmM0OWMxIiwidGkiOjE3MjIyNzg0NTk3MjUsInRrIjoiMTAyMjY4MSJ9fQ==")
+ req.Header.Add("priority", "u=1, i")
+ req.Header.Add("referer", "https://research.investors.com/stock-lists/ibd-50/")
+ req.Header.Add("sec-ch-ua", "\"Not/A)Brand\";v=\"8\", \"Chromium\";v=\"126\", \"Google Chrome\";v=\"126\"")
+ req.Header.Add("sec-ch-ua-mobile", "?0")
+ req.Header.Add("sec-ch-ua-platform", "\"macOS\"")
+ req.Header.Add("sec-fetch-dest", "empty")
+ req.Header.Add("sec-fetch-mode", "cors")
+ req.Header.Add("sec-fetch-site", "same-origin")
+ req.Header.Add("traceparent", "00-37fdafbe8f6b8fa30ac39938ce2c49c1-35f996c633ca5b1f-01")
+ req.Header.Add("tracestate", "1022681@nr=0-1-3396106-1385983001-35f996c633ca5b1f----1722278459725")
+ req.Header.Add("user-agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/126.0.0.0 Safari/537.36")
+ req.Header.Add("x-newrelic-id", "VwUOV1dTDhABV1FRBgQOVVUF")
+ req.Header.Add("x-requested-with", "XMLHttpRequest")
+
+ resp, err := c.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer func(Body io.ReadCloser) {
+ _ = Body.Close()
+ }(resp.Body)
+
+ var ibd50Resp getIBD50Response
+ if err = json.NewDecoder(resp.Body).Decode(&ibd50Resp); err != nil {
+ return nil, err
+ }
+
+ // If there are less than 10 stocks in the IBD50 list, it's likely that authentication failed.
+ if len(ibd50Resp.D.ETablesDataList) < 10 {
+ // Report cookie failure to DB
+ if err = database.ReportCookieFailure(ctx, c.db, cookieId); err != nil {
+ slog.Error("Failed to report cookie failure", "error", err)
+ }
+ return nil, errors.New("failed to get IBD50 list")
+ }
+
+ return ibd50Resp.ToStockList(), nil
+}
+
+type Stock struct {
+ Rank int64
+ Symbol string
+ Name string
+
+ QuoteURL *url.URL
+}
+
+type getIBD50Response struct {
+ D struct {
+ Type *string `json:"__type"`
+ ETablesDataList []struct {
+ Rank string `json:"Rank"`
+ Symbol string `json:"Symbol"`
+ CompanyName string `json:"CompanyName"`
+ CompRating *string `json:"CompRating"`
+ EPSRank *string `json:"EPSRank"`
+ RelSt *string `json:"RelSt"`
+ GrpStr *string `json:"GrpStr"`
+ Smr *string `json:"Smr"`
+ AccDis *string `json:"AccDis"`
+ SponRating *string `json:"SponRating"`
+ Price *string `json:"Price"`
+ PriceClose *string `json:"PriceClose"`
+ PriceChange *string `json:"PriceChange"`
+ PricePerChange *string `json:"PricePerChange"`
+ VolPerChange *string `json:"VolPerChange"`
+ DailyVol *string `json:"DailyVol"`
+ WeekHigh52 *string `json:"WeekHigh52"`
+ PerOffHigh *string `json:"PerOffHigh"`
+ PERatio *string `json:"PERatio"`
+ DivYield *string `json:"DivYield"`
+ LastQtrSalesPerChg *string `json:"LastQtrSalesPerChg"`
+ LastQtrEpsPerChg *string `json:"LastQtrEpsPerChg"`
+ ConsecQtrEpsGrt15 *string `json:"ConsecQtrEpsGrt15"`
+ CurQtrEpsEstPerChg *string `json:"CurQtrEpsEstPerChg"`
+ CurYrEpsEstPerChg *string `json:"CurYrEpsEstPerChg"`
+ PretaxMargin *string `json:"PretaxMargin"`
+ ROE *string `json:"ROE"`
+ MgmtOwnsPer *string `json:"MgmtOwnsPer"`
+ QuoteUrl *string `json:"QuoteUrl"`
+ StockCheckupUrl *string `json:"StockCheckupUrl"`
+ MarketsmithUrl *string `json:"MarketsmithUrl"`
+ LeaderboardUrl *string `json:"LeaderboardUrl"`
+ ChartAnalysisUrl *string `json:"ChartAnalysisUrl"`
+ Ibd100NewEntryFlag *string `json:"Ibd100NewEntryFlag"`
+ Ibd100UpInRankFlag *string `json:"Ibd100UpInRankFlag"`
+ IbdBigCap20NewEntryFlag *string `json:"IbdBigCap20NewEntryFlag"`
+ CompDesc *string `json:"CompDesc"`
+ NumberFunds *string `json:"NumberFunds"`
+ GlobalRank *string `json:"GlobalRank"`
+ EPSPriorQtr *string `json:"EPSPriorQtr"`
+ QtrsFundIncrease *string `json:"QtrsFundIncrease"`
+ } `json:"ETablesDataList"`
+ IBD50PdfUrl *string `json:"IBD50PdfUrl"`
+ CAP20PdfUrl *string `json:"CAP20PdfUrl"`
+ IBD50Date *string `json:"IBD50Date"`
+ CAP20Date *string `json:"CAP20Date"`
+ UpdatedDate *string `json:"UpdatedDate"`
+ GetAllFlags *string `json:"getAllFlags"`
+ Flag *int `json:"flag"`
+ Message *string `json:"Message"`
+ PaywallDesktopMarkup *string `json:"PaywallDesktopMarkup"`
+ PaywallMobileMarkup *string `json:"PaywallMobileMarkup"`
+ } `json:"d"`
+}
+
+func (r getIBD50Response) ToStockList() (ibd []*Stock) {
+ ibd = make([]*Stock, 0, len(r.D.ETablesDataList))
+ for _, data := range r.D.ETablesDataList {
+ rank, err := strconv.ParseInt(data.Rank, 10, 64)
+ if err != nil {
+ slog.Error(
+ "Failed to parse Rank",
+ "error", err,
+ "rank", data.Rank,
+ "symbol", data.Symbol,
+ "name", data.CompanyName,
+ )
+ continue
+ }
+
+ var quoteUrl *url.URL
+ if data.QuoteUrl != nil {
+ quoteUrl, err = url.Parse(*data.QuoteUrl)
+ if err != nil {
+ slog.Error(
+ "Failed to parse QuoteUrl",
+ "error", err,
+ "quoteUrl", *data.QuoteUrl,
+ "rank", data.Rank,
+ "symbol", data.Symbol,
+ "name", data.CompanyName,
+ )
+ }
+ }
+
+ ibd = append(ibd, &Stock{
+ Rank: rank,
+ Symbol: data.Symbol,
+ Name: data.CompanyName,
+ QuoteURL: quoteUrl,
+ })
+ }
+ return
+}
diff --git a/backend/internal/ibd/options.go b/backend/internal/ibd/options.go
new file mode 100644
index 0000000..5c378d5
--- /dev/null
+++ b/backend/internal/ibd/options.go
@@ -0,0 +1,26 @@
+package ibd
+
+import "github.com/ansg191/ibd-trader-backend/internal/ibd/transport"
+
+type optionFunc func(*options)
+
+var defaultOptions = options{
+ expectedStatuses: []int{200},
+}
+
+type options struct {
+ expectedStatuses []int
+ requiredProps transport.Properties
+}
+
+func withExpectedStatuses(statuses ...int) optionFunc {
+ return func(o *options) {
+ o.expectedStatuses = append(o.expectedStatuses, statuses...)
+ }
+}
+
+func withRequiredProps(props transport.Properties) optionFunc {
+ return func(o *options) {
+ o.requiredProps = props
+ }
+}
diff --git a/backend/internal/ibd/search.go b/backend/internal/ibd/search.go
new file mode 100644
index 0000000..341b14b
--- /dev/null
+++ b/backend/internal/ibd/search.go
@@ -0,0 +1,111 @@
+package ibd
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "time"
+
+ "github.com/ansg191/ibd-trader-backend/internal/database"
+)
+
+const (
+ searchUrl = "https://ibdservices.investors.com/im/api/search"
+)
+
+var ErrSymbolNotFound = fmt.Errorf("symbol not found")
+
+func (c *Client) Search(ctx context.Context, symbol string) (database.Stock, error) {
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, searchUrl, nil)
+ if err != nil {
+ return database.Stock{}, err
+ }
+
+ _, cookie, err := c.getCookie(ctx, nil)
+ if err != nil {
+ return database.Stock{}, err
+ }
+ req.AddCookie(cookie)
+
+ params := url.Values{}
+ params.Set("key", symbol)
+ req.URL.RawQuery = params.Encode()
+
+ resp, err := c.Do(req)
+ if err != nil {
+ return database.Stock{}, err
+ }
+ defer func(Body io.ReadCloser) {
+ _ = Body.Close()
+ }(resp.Body)
+
+ if resp.StatusCode != http.StatusOK {
+ content, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return database.Stock{}, fmt.Errorf("failed to read response body: %w", err)
+ }
+ return database.Stock{}, fmt.Errorf(
+ "unexpected status code %d: %s",
+ resp.StatusCode,
+ string(content),
+ )
+ }
+
+ var sr searchResponse
+ if err = json.NewDecoder(resp.Body).Decode(&sr); err != nil {
+ return database.Stock{}, err
+ }
+
+ for _, stock := range sr.StockData {
+ if stock.Symbol == symbol {
+ return database.Stock{
+ Symbol: stock.Symbol,
+ Name: stock.Company,
+ IBDUrl: stock.QuoteUrl,
+ }, nil
+ }
+ }
+
+ return database.Stock{}, ErrSymbolNotFound
+}
+
+type searchResponse struct {
+ Status int `json:"_status"`
+ Timestamp string `json:"_timestamp"`
+ StockData []struct {
+ Id int `json:"id"`
+ Symbol string `json:"symbol"`
+ Company string `json:"company"`
+ PriceDate string `json:"priceDate"`
+ Price float64 `json:"price"`
+ PreviousPrice float64 `json:"previousPrice"`
+ PriceChange float64 `json:"priceChange"`
+ PricePctChange float64 `json:"pricePctChange"`
+ Volume int `json:"volume"`
+ VolumeChange int `json:"volumeChange"`
+ VolumePctChange int `json:"volumePctChange"`
+ QuoteUrl string `json:"quoteUrl"`
+ } `json:"stockData"`
+ News []struct {
+ Title string `json:"title"`
+ Category string `json:"category"`
+ Body string `json:"body"`
+ ImageAlt string `json:"imageAlt"`
+ ImageUrl string `json:"imageUrl"`
+ NewsUrl string `json:"newsUrl"`
+ CategoryUrl string `json:"categoryUrl"`
+ PublishDate time.Time `json:"publishDate"`
+ PublishDateUnixts int `json:"publishDateUnixts"`
+ Stocks []struct {
+ Id int `json:"id"`
+ Index int `json:"index"`
+ Symbol string `json:"symbol"`
+ PricePctChange string `json:"pricePctChange"`
+ } `json:"stocks"`
+ VideoFormat bool `json:"videoFormat"`
+ } `json:"news"`
+ FullUrl string `json:"fullUrl"`
+}
diff --git a/backend/internal/ibd/search_test.go b/backend/internal/ibd/search_test.go
new file mode 100644
index 0000000..05e93dc
--- /dev/null
+++ b/backend/internal/ibd/search_test.go
@@ -0,0 +1,205 @@
+package ibd
+
+import (
+ "context"
+ "net/http"
+ "testing"
+
+ "github.com/ansg191/ibd-trader-backend/internal/ibd/transport"
+ "github.com/jarcoal/httpmock"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+const searchResponseJSON = `
+{
+ "_status": 200,
+ "_timestamp": "1722879439.724106",
+ "stockData": [
+ {
+ "id": 13717,
+ "symbol": "AAPL",
+ "company": "Apple",
+ "priceDate": "2024-08-05T09:18:00",
+ "price": 212.33,
+ "previousPrice": 219.86,
+ "priceChange": -7.53,
+ "pricePctChange": -3.42,
+ "volume": 643433,
+ "volumeChange": -2138,
+ "volumePctChange": 124,
+ "quoteUrl": "https://research.investors.com/stock-quotes/nasdaq-apple-aapl.htm"
+ },
+ {
+ "id": 79964,
+ "symbol": "AAPU",
+ "company": "Direxion AAPL Bull 2X",
+ "priceDate": "2024-08-05T09:18:00",
+ "price": 32.48,
+ "previousPrice": 34.9,
+ "priceChange": -2.42,
+ "pricePctChange": -6.92,
+ "volume": 15265,
+ "volumeChange": -35,
+ "volumePctChange": 212,
+ "quoteUrl": "https://research.investors.com/stock-quotes/nasdaq-direxion-aapl-bull-2x-aapu.htm"
+ },
+ {
+ "id": 80423,
+ "symbol": "APLY",
+ "company": "YieldMax AAPL Option Incm",
+ "priceDate": "2024-08-05T09:11:00",
+ "price": 17.52,
+ "previousPrice": 18.15,
+ "priceChange": -0.63,
+ "pricePctChange": -3.47,
+ "volume": 617,
+ "volumeChange": -2,
+ "volumePctChange": 97,
+ "quoteUrl": "https://research.investors.com/stock-quotes/nyse-yieldmax-aapl-option-incm-aply.htm"
+ },
+ {
+ "id": 79962,
+ "symbol": "AAPD",
+ "company": "Direxion Dly AAPL Br 1X",
+ "priceDate": "2024-08-05T09:18:00",
+ "price": 18.11,
+ "previousPrice": 17.53,
+ "priceChange": 0.58,
+ "pricePctChange": 3.31,
+ "volume": 14572,
+ "volumeChange": -7,
+ "volumePctChange": 885,
+ "quoteUrl": "https://research.investors.com/stock-quotes/nasdaq-direxion-dly-aapl-br-1x-aapd.htm"
+ },
+ {
+ "id": 79968,
+ "symbol": "AAPB",
+ "company": "GraniteSh 2x Lg AAPL",
+ "priceDate": "2024-08-05T09:16:00",
+ "price": 25.22,
+ "previousPrice": 27.25,
+ "priceChange": -2.03,
+ "pricePctChange": -7.45,
+ "volume": 2505,
+ "volumeChange": -7,
+ "volumePctChange": 151,
+ "quoteUrl": "https://research.investors.com/stock-quotes/nasdaq-granitesh-2x-lg-aapl-aapb.htm"
+ }
+ ],
+ "news": [
+ {
+ "title": "Warren Buffett Dumped Berkshire Hathaway's Favorite Stocks — Right Before They Plunged",
+ "category": "News",
+ "body": "Berkshire Hathaway earnings rose solidly in Q2. Warren Buffett sold nearly half his Apple stock stake. Berkshire stock fell...",
+ "imageAlt": "",
+ "imageUrl": "https://www.investors.com/wp-content/uploads/2024/06/Stock-WarrenBuffettwave-01-shutt-640x360.jpg",
+ "newsUrl": "https://investors.com/news/berkshire-hathaway-earnings-q2-2024-warren-buffett-apple/",
+ "categoryUrl": "https://investors.com/category/news/",
+ "publishDate": "2024-08-05T15:51:57+00:00",
+ "publishDateUnixts": 1722858717,
+ "stocks": [
+ {
+ "id": 13717,
+ "index": 0,
+ "symbol": "AAPL",
+ "pricePctChange": "-3.42"
+ }
+ ],
+ "videoFormat": false
+ },
+ {
+ "title": "Nvidia Plunges On Report Of AI Chip Flaw; Is It A Buy Now?",
+ "category": "Research",
+ "body": "Nvidia will roll out its Blackwell chip at least three months later than planned.",
+ "imageAlt": "",
+ "imageUrl": "https://www.investors.com/wp-content/uploads/2024/01/Stock-Nvidia-studio-01-company-640x360.jpg",
+ "newsUrl": "https://investors.com/research/nvda-stock-is-nvidia-a-buy-2/",
+ "categoryUrl": "https://investors.com/category/research/",
+ "publishDate": "2024-08-05T14:59:22+00:00",
+ "publishDateUnixts": 1722855562,
+ "stocks": [
+ {
+ "id": 38607,
+ "index": 0,
+ "symbol": "NVDA",
+ "pricePctChange": "-5.18"
+ }
+ ],
+ "videoFormat": false
+ },
+ {
+ "title": "Magnificent Seven Stocks Roiled: Nvidia Plunges On AI Chip Delay; Apple, Tesla Dive",
+ "category": "Research",
+ "body": "Nvidia stock dived Monday, while Apple and Tesla also fell sharply.",
+ "imageAlt": "",
+ "imageUrl": "https://www.investors.com/wp-content/uploads/2022/08/Stock-Nvidia-RTXa5500-comp-640x360.jpg",
+ "newsUrl": "https://investors.com/research/magnificent-seven-stocks-to-buy-and-and-watch/",
+ "categoryUrl": "https://investors.com/category/research/",
+ "publishDate": "2024-08-05T14:51:42+00:00",
+ "publishDateUnixts": 1722855102,
+ "stocks": [
+ {
+ "id": 13717,
+ "index": 0,
+ "symbol": "AAPL",
+ "pricePctChange": "-3.42"
+ }
+ ],
+ "videoFormat": false
+ }
+ ],
+ "fullUrl": "https://www.investors.com/search-results/?query=AAPL"
+}`
+
+const emptySearchResponseJSON = `
+{
+ "_status": 200,
+ "_timestamp": "1722879662.804395",
+ "stockData": [],
+ "news": [],
+ "fullUrl": "https://www.investors.com/search-results/?query=abcdefg"
+}`
+
+func TestClient_Search(t *testing.T) {
+ tests := []struct {
+ name string
+ response string
+ f func(t *testing.T, client *Client)
+ }{
+ {
+ name: "found",
+ response: searchResponseJSON,
+ f: func(t *testing.T, client *Client) {
+ u, err := client.Search(context.Background(), "AAPL")
+ require.NoError(t, err)
+ assert.Equal(t, "AAPL", u.Symbol)
+ assert.Equal(t, "Apple", u.Name)
+ assert.Equal(t, "https://research.investors.com/stock-quotes/nasdaq-apple-aapl.htm", u.IBDUrl)
+ },
+ },
+ {
+ name: "not found",
+ response: emptySearchResponseJSON,
+ f: func(t *testing.T, client *Client) {
+ _, err := client.Search(context.Background(), "abcdefg")
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrSymbolNotFound)
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ tp := httpmock.NewMockTransport()
+ tp.RegisterResponder("GET", searchUrl, httpmock.NewStringResponder(200, tt.response))
+
+ client := NewClient(
+ db,
+ new(kmsStub),
+ transport.NewStandardTransport(&http.Client{Transport: tp}),
+ )
+
+ tt.f(t, client)
+ })
+ }
+}
diff --git a/backend/internal/ibd/stockinfo.go b/backend/internal/ibd/stockinfo.go
new file mode 100644
index 0000000..1e3b96f
--- /dev/null
+++ b/backend/internal/ibd/stockinfo.go
@@ -0,0 +1,233 @@
+package ibd
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "strconv"
+ "strings"
+
+ "github.com/ansg191/ibd-trader-backend/internal/database"
+ "github.com/ansg191/ibd-trader-backend/internal/utils"
+
+ "github.com/Rhymond/go-money"
+ "golang.org/x/net/html"
+)
+
+func (c *Client) StockInfo(ctx context.Context, uri string) (*database.StockInfo, error) {
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, uri, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ _, cookie, err := c.getCookie(ctx, nil)
+ if err != nil {
+ return nil, err
+ }
+ req.AddCookie(cookie)
+
+ // Set required query parameters
+ params := url.Values{}
+ params.Set("list", "ibd50")
+ params.Set("type", "weekly")
+ req.URL.RawQuery = params.Encode()
+
+ resp, err := c.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer func(Body io.ReadCloser) {
+ _ = Body.Close()
+ }(resp.Body)
+
+ if resp.StatusCode != http.StatusOK {
+ content, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read response body: %w", err)
+ }
+ return nil, fmt.Errorf(
+ "unexpected status code %d: %s",
+ resp.StatusCode,
+ string(content),
+ )
+ }
+
+ node, err := html.Parse(resp.Body)
+ if err != nil {
+ return nil, err
+ }
+
+ name, symbol, err := extractNameAndSymbol(node)
+ if err != nil {
+ return nil, fmt.Errorf("failed to extract name and symbol: %w", err)
+ }
+ chartAnalysis, err := extractChartAnalysis(node)
+ if err != nil {
+ return nil, fmt.Errorf("failed to extract chart analysis: %w", err)
+ }
+ ratings, err := extractRatings(node)
+ if err != nil {
+ return nil, fmt.Errorf("failed to extract ratings: %w", err)
+ }
+ price, err := extractPrice(node)
+ if err != nil {
+ return nil, fmt.Errorf("failed to extract price: %w", err)
+ }
+
+ return &database.StockInfo{
+ Symbol: symbol,
+ Name: name,
+ ChartAnalysis: chartAnalysis,
+ Ratings: ratings,
+ Price: price,
+ }, nil
+}
+
+func extractNameAndSymbol(node *html.Node) (name string, symbol string, err error) {
+ // Find span with ID "quote-symbol"
+ quoteSymbolNode := findId(node, "quote-symbol")
+ if quoteSymbolNode == nil {
+ return "", "", fmt.Errorf("could not find `quote-symbol` span")
+ }
+
+ // Get the text of the quote-symbol span
+ name = strings.TrimSpace(extractText(quoteSymbolNode))
+
+ // Find span with ID "qteSymb"
+ qteSymbNode := findId(node, "qteSymb")
+ if qteSymbNode == nil {
+ return "", "", fmt.Errorf("could not find `qteSymb` span")
+ }
+
+ // Get the text of the qteSymb span
+ symbol = strings.TrimSpace(extractText(qteSymbNode))
+
+ // Get index of last closing parenthesis
+ lastParenIndex := strings.LastIndex(name, ")")
+ if lastParenIndex == -1 {
+ return
+ }
+
+ // Find the last opening parenthesis before the closing parenthesis
+ lastOpenParenIndex := strings.LastIndex(name[:lastParenIndex], "(")
+ if lastOpenParenIndex == -1 {
+ return
+ }
+
+ // Remove the parenthesis pair
+ name = strings.TrimSpace(name[:lastOpenParenIndex] + name[lastParenIndex+1:])
+ return
+}
+
+func extractPrice(node *html.Node) (*money.Money, error) {
+ // Find the div with the ID "lstPrice"
+ lstPriceNode := findId(node, "lstPrice")
+ if lstPriceNode == nil {
+ return nil, fmt.Errorf("could not find `lstPrice` div")
+ }
+
+ // Get the text of the lstPrice div
+ priceStr := strings.TrimSpace(extractText(lstPriceNode))
+
+ // Parse the price
+ price, err := utils.ParseMoney(priceStr)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse price: %w", err)
+ }
+
+ return price, nil
+}
+
+func extractRatings(node *html.Node) (ratings database.Ratings, err error) {
+ // Find the div with class "smartContent"
+ smartSelectNode := findClass(node, "smartContent")
+ if smartSelectNode == nil {
+ return ratings, fmt.Errorf("could not find `smartContent` div")
+ }
+
+ // Iterate over children, looking for "smartRating" divs
+ for c := smartSelectNode.FirstChild; c != nil; c = c.NextSibling {
+ if !isClass(c, "smartRating") {
+ continue
+ }
+
+ err = processSmartRating(c, &ratings)
+ if err != nil {
+ return
+ }
+ }
+ return
+}
+
+// processSmartRating extracts the rating from a "smartRating" div and updates the ratings struct.
+//
+// The node should look like this:
+//
+// <ul class="smartRating">
+// <li><a><span>Composite Rating</span></a></li>
+// <li>94</li>
+// ...
+// </ul>
+func processSmartRating(node *html.Node, ratings *database.Ratings) error {
+ // Check that the node is a ul
+ if node.Type != html.ElementNode || node.Data != "ul" {
+ return fmt.Errorf("expected ul node, got %s", node.Data)
+ }
+
+ // Get all `li` children
+ children := findChildren(node, func(node *html.Node) bool {
+ return node.Type == html.ElementNode && node.Data == "li"
+ })
+
+ // Extract the rating name
+ ratingName := strings.TrimSpace(extractText(children[0]))
+
+ // Extract the rating value
+ ratingValueStr := strings.TrimSpace(extractText(children[1]))
+
+ switch ratingName {
+ case "Composite Rating":
+ ratingValue, err := strconv.ParseUint(ratingValueStr, 10, 8)
+ if err != nil {
+ return fmt.Errorf("failed to parse Composite Rating: %w", err)
+ }
+ ratings.Composite = uint8(ratingValue)
+ case "EPS Rating":
+ ratingValue, err := strconv.ParseUint(ratingValueStr, 10, 8)
+ if err != nil {
+ return fmt.Errorf("failed to parse EPS Rating: %w", err)
+ }
+ ratings.EPS = uint8(ratingValue)
+ case "RS Rating":
+ ratingValue, err := strconv.ParseUint(ratingValueStr, 10, 8)
+ if err != nil {
+ return fmt.Errorf("failed to parse RS Rating: %w", err)
+ }
+ ratings.RelStr = uint8(ratingValue)
+ case "Group RS Rating":
+ ratings.GroupRelStr = database.LetterRatingFromString(ratingValueStr)
+ case "SMR Rating":
+ ratings.SMR = database.LetterRatingFromString(ratingValueStr)
+ case "Acc/Dis Rating":
+ ratings.AccDis = database.LetterRatingFromString(ratingValueStr)
+ default:
+ return fmt.Errorf("unknown rating name: %s", ratingName)
+ }
+
+ return nil
+}
+
+func extractChartAnalysis(node *html.Node) (string, error) {
+ // Find the div with class "chartAnalysis"
+ chartAnalysisNode := findClass(node, "chartAnalysis")
+ if chartAnalysisNode == nil {
+ return "", fmt.Errorf("could not find `chartAnalysis` div")
+ }
+
+ // Get the text of the chart analysis div
+ chartAnalysis := strings.TrimSpace(extractText(chartAnalysisNode))
+
+ return chartAnalysis, nil
+}
diff --git a/backend/internal/ibd/transport/scrapfly/options.go b/backend/internal/ibd/transport/scrapfly/options.go
new file mode 100644
index 0000000..f16a4b0
--- /dev/null
+++ b/backend/internal/ibd/transport/scrapfly/options.go
@@ -0,0 +1,84 @@
+package scrapfly
+
+const BaseURL = "https://api.scrapfly.io/scrape"
+
+var defaultScrapeOptions = ScrapeOptions{
+ baseURL: BaseURL,
+ country: nil,
+ asp: true,
+ proxyPool: ProxyPoolDatacenter,
+ renderJS: false,
+ cache: false,
+}
+
+type ScrapeOption func(*ScrapeOptions)
+
+type ScrapeOptions struct {
+ baseURL string
+ country *string
+ asp bool
+ proxyPool ProxyPool
+ renderJS bool
+ cache bool
+ debug bool
+}
+
+type ProxyPool uint8
+
+const (
+ ProxyPoolDatacenter ProxyPool = iota
+ ProxyPoolResidential
+)
+
+func (p ProxyPool) String() string {
+ switch p {
+ case ProxyPoolDatacenter:
+ return "public_datacenter_pool"
+ case ProxyPoolResidential:
+ return "public_residential_pool"
+ default:
+ panic("invalid proxy pool")
+ }
+}
+
+func WithCountry(country string) ScrapeOption {
+ return func(o *ScrapeOptions) {
+ o.country = &country
+ }
+}
+
+func WithASP(asp bool) ScrapeOption {
+ return func(o *ScrapeOptions) {
+ o.asp = asp
+ }
+}
+
+func WithProxyPool(proxyPool ProxyPool) ScrapeOption {
+ return func(o *ScrapeOptions) {
+ o.proxyPool = proxyPool
+ }
+}
+
+func WithRenderJS(jsRender bool) ScrapeOption {
+ return func(o *ScrapeOptions) {
+ o.renderJS = jsRender
+ }
+}
+
+func WithCache(cache bool) ScrapeOption {
+ return func(o *ScrapeOptions) {
+ o.cache = cache
+ }
+}
+
+func WithDebug(debug bool) ScrapeOption {
+ return func(o *ScrapeOptions) {
+ o.debug = debug
+ }
+}
+
+func WithBaseURL(baseURL string) ScrapeOption {
+ return func(o *ScrapeOptions) {
+ o.baseURL = baseURL
+ }
+}
diff --git a/backend/internal/ibd/transport/scrapfly/scraper_types.go b/backend/internal/ibd/transport/scrapfly/scraper_types.go
new file mode 100644
index 0000000..f3cf651
--- /dev/null
+++ b/backend/internal/ibd/transport/scrapfly/scraper_types.go
@@ -0,0 +1,253 @@
+package scrapfly
+
+import (
+ "fmt"
+ "io"
+ "net/http"
+ "strings"
+ "time"
+)
+
+type ScraperResponse struct {
+ Config struct {
+ Asp bool `json:"asp"`
+ AutoScroll bool `json:"auto_scroll"`
+ Body interface{} `json:"body"`
+ Cache bool `json:"cache"`
+ CacheClear bool `json:"cache_clear"`
+ CacheTtl int `json:"cache_ttl"`
+ CorrelationId interface{} `json:"correlation_id"`
+ CostBudget interface{} `json:"cost_budget"`
+ Country interface{} `json:"country"`
+ Debug bool `json:"debug"`
+ Dns bool `json:"dns"`
+ Env string `json:"env"`
+ Extract interface{} `json:"extract"`
+ ExtractionModel interface{} `json:"extraction_model"`
+ ExtractionModelCustomSchema interface{} `json:"extraction_model_custom_schema"`
+ ExtractionPrompt interface{} `json:"extraction_prompt"`
+ ExtractionTemplate interface{} `json:"extraction_template"`
+ Format string `json:"format"`
+ Geolocation interface{} `json:"geolocation"`
+ Headers struct {
+ Cookie []string `json:"Cookie"`
+ } `json:"headers"`
+ JobUuid interface{} `json:"job_uuid"`
+ Js interface{} `json:"js"`
+ JsScenario interface{} `json:"js_scenario"`
+ Lang interface{} `json:"lang"`
+ LogEvictionDate string `json:"log_eviction_date"`
+ Method string `json:"method"`
+ Origin string `json:"origin"`
+ Os interface{} `json:"os"`
+ Project string `json:"project"`
+ ProxyPool string `json:"proxy_pool"`
+ RenderJs bool `json:"render_js"`
+ RenderingStage string `json:"rendering_stage"`
+ RenderingWait int `json:"rendering_wait"`
+ Retry bool `json:"retry"`
+ ScheduleName interface{} `json:"schedule_name"`
+ ScreenshotFlags interface{} `json:"screenshot_flags"`
+ ScreenshotResolution interface{} `json:"screenshot_resolution"`
+ Screenshots interface{} `json:"screenshots"`
+ Session interface{} `json:"session"`
+ SessionStickyProxy bool `json:"session_sticky_proxy"`
+ Ssl bool `json:"ssl"`
+ Tags interface{} `json:"tags"`
+ Timeout int `json:"timeout"`
+ Url string `json:"url"`
+ UserUuid string `json:"user_uuid"`
+ Uuid string `json:"uuid"`
+ WaitForSelector interface{} `json:"wait_for_selector"`
+ WebhookName interface{} `json:"webhook_name"`
+ } `json:"config"`
+ Context struct {
+ Asp interface{} `json:"asp"`
+ BandwidthConsumed int `json:"bandwidth_consumed"`
+ BandwidthImagesConsumed int `json:"bandwidth_images_consumed"`
+ Cache struct {
+ Entry interface{} `json:"entry"`
+ State string `json:"state"`
+ } `json:"cache"`
+ Cookies []struct {
+ Comment interface{} `json:"comment"`
+ Domain string `json:"domain"`
+ Expires *string `json:"expires"`
+ HttpOnly bool `json:"http_only"`
+ MaxAge interface{} `json:"max_age"`
+ Name string `json:"name"`
+ Path string `json:"path"`
+ Secure bool `json:"secure"`
+ Size int `json:"size"`
+ Value string `json:"value"`
+ Version interface{} `json:"version"`
+ } `json:"cookies"`
+ Cost struct {
+ Details []struct {
+ Amount int `json:"amount"`
+ Code string `json:"code"`
+ Description string `json:"description"`
+ } `json:"details"`
+ Total int `json:"total"`
+ } `json:"cost"`
+ CreatedAt string `json:"created_at"`
+ Debug interface{} `json:"debug"`
+ Env string `json:"env"`
+ Fingerprint string `json:"fingerprint"`
+ Headers struct {
+ Cookie string `json:"Cookie"`
+ } `json:"headers"`
+ IsXmlHttpRequest bool `json:"is_xml_http_request"`
+ Job interface{} `json:"job"`
+ Lang []string `json:"lang"`
+ Os struct {
+ Distribution string `json:"distribution"`
+ Name string `json:"name"`
+ Type string `json:"type"`
+ Version string `json:"version"`
+ } `json:"os"`
+ Project string `json:"project"`
+ Proxy struct {
+ Country string `json:"country"`
+ Identity string `json:"identity"`
+ Network string `json:"network"`
+ Pool string `json:"pool"`
+ } `json:"proxy"`
+ Redirects []interface{} `json:"redirects"`
+ Retry int `json:"retry"`
+ Schedule interface{} `json:"schedule"`
+ Session interface{} `json:"session"`
+ Spider interface{} `json:"spider"`
+ Throttler interface{} `json:"throttler"`
+ Uri struct {
+ BaseUrl string `json:"base_url"`
+ Fragment interface{} `json:"fragment"`
+ Host string `json:"host"`
+ Params interface{} `json:"params"`
+ Port int `json:"port"`
+ Query string `json:"query"`
+ RootDomain string `json:"root_domain"`
+ Scheme string `json:"scheme"`
+ } `json:"uri"`
+ Url string `json:"url"`
+ Webhook interface{} `json:"webhook"`
+ } `json:"context"`
+ Insights interface{} `json:"insights"`
+ Result ScraperResult `json:"result"`
+ Uuid string `json:"uuid"`
+}
+
+type ScraperResult struct {
+ BrowserData struct {
+ JavascriptEvaluationResult interface{} `json:"javascript_evaluation_result"`
+ JsScenario []interface{} `json:"js_scenario"`
+ LocalStorageData struct {
+ } `json:"local_storage_data"`
+ SessionStorageData struct {
+ } `json:"session_storage_data"`
+ Websockets []interface{} `json:"websockets"`
+ XhrCall interface{} `json:"xhr_call"`
+ } `json:"browser_data"`
+ Content string `json:"content"`
+ ContentEncoding string `json:"content_encoding"`
+ ContentFormat string `json:"content_format"`
+ ContentType string `json:"content_type"`
+ Cookies []ScraperCookie `json:"cookies"`
+ Data interface{} `json:"data"`
+ Dns interface{} `json:"dns"`
+ Duration float64 `json:"duration"`
+ Error interface{} `json:"error"`
+ ExtractedData interface{} `json:"extracted_data"`
+ Format string `json:"format"`
+ Iframes []interface{} `json:"iframes"`
+ LogUrl string `json:"log_url"`
+ Reason string `json:"reason"`
+ RequestHeaders map[string]string `json:"request_headers"`
+ ResponseHeaders map[string]string `json:"response_headers"`
+ Screenshots struct {
+ } `json:"screenshots"`
+ Size int `json:"size"`
+ Ssl interface{} `json:"ssl"`
+ Status string `json:"status"`
+ StatusCode int `json:"status_code"`
+ Success bool `json:"success"`
+ Url string `json:"url"`
+}
+
+type ScraperCookie struct {
+ Name string `json:"name"`
+ Value string `json:"value"`
+ Expires string `json:"expires"`
+ Path string `json:"path"`
+ Comment string `json:"comment"`
+ Domain string `json:"domain"`
+ MaxAge int `json:"max_age"`
+ Secure bool `json:"secure"`
+ HttpOnly bool `json:"http_only"`
+ Version string `json:"version"`
+ Size int `json:"size"`
+}
+
+func (c *ScraperCookie) ToHTTPCookie() (*http.Cookie, error) {
+ var expires time.Time
+ if c.Expires != "" {
+ var err error
+ expires, err = time.Parse("2006-01-02 15:04:05", c.Expires)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse cookie expiration: %w", err)
+ }
+ }
+ return &http.Cookie{
+ Name: c.Name,
+ Value: c.Value,
+ Path: c.Path,
+ Domain: c.Domain,
+ Expires: expires,
+ Secure: c.Secure,
+ HttpOnly: c.HttpOnly,
+ }, nil
+}
+
+func (c *ScraperCookie) FromHTTPCookie(cookie *http.Cookie) {
+ var expires string
+ if !cookie.Expires.IsZero() {
+ expires = cookie.Expires.Format("2006-01-02 15:04:05")
+ }
+ *c = ScraperCookie{
+ Comment: "",
+ Domain: cookie.Domain,
+ Expires: expires,
+ HttpOnly: cookie.HttpOnly,
+ MaxAge: cookie.MaxAge,
+ Name: cookie.Name,
+ Path: cookie.Path,
+ Secure: cookie.Secure,
+ Size: len(cookie.Value),
+ Value: cookie.Value,
+ Version: "",
+ }
+}
+
+func (r *ScraperResponse) ToHTTPResponse() (*http.Response, error) {
+ resp := &http.Response{
+ StatusCode: r.Result.StatusCode,
+ Header: make(http.Header),
+ Body: io.NopCloser(strings.NewReader(r.Result.Content)),
+ ContentLength: int64(len(r.Result.Content)),
+ Close: true,
+ }
+
+ for k, v := range r.Result.ResponseHeaders {
+ resp.Header.Set(k, v)
+ }
+
+ for _, c := range r.Result.Cookies {
+ cookie, err := c.ToHTTPCookie()
+ if err != nil {
+ return nil, err
+ }
+ resp.Header.Add("Set-Cookie", cookie.String())
+ }
+
+ return resp, nil
+}
diff --git a/backend/internal/ibd/transport/scrapfly/scrapfly.go b/backend/internal/ibd/transport/scrapfly/scrapfly.go
new file mode 100644
index 0000000..3b414de
--- /dev/null
+++ b/backend/internal/ibd/transport/scrapfly/scrapfly.go
@@ -0,0 +1,103 @@
+package scrapfly
+
+import (
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "strconv"
+
+ "github.com/ansg191/ibd-trader-backend/internal/ibd/transport"
+)
+
+type ScrapflyTransport struct {
+ client *http.Client
+ apiKey string
+ options ScrapeOptions
+}
+
+func New(client *http.Client, apiKey string, opts ...ScrapeOption) *ScrapflyTransport {
+ options := defaultScrapeOptions
+ for _, opt := range opts {
+ opt(&options)
+ }
+
+ return &ScrapflyTransport{
+ client: client,
+ apiKey: apiKey,
+ options: options,
+ }
+}
+
+func (s *ScrapflyTransport) String() string {
+ return "scrapfly"
+}
+
+func (s *ScrapflyTransport) Do(req *http.Request) (*http.Response, error) {
+ // Construct scrape request URL
+ scrapeUrl, err := url.Parse(s.options.baseURL)
+ if err != nil {
+ panic(err)
+ }
+ scrapeUrl.RawQuery = s.constructRawQuery(req.URL, req.Header)
+
+ // We can't handle `Content-Type` header on GET requests
+ // Wierd quirk of the Scrapfly API
+ if req.Method == http.MethodGet && req.Header.Get("Content-Type") != "" {
+ return nil, transport.ErrUnsupportedRequest
+ }
+
+ // Construct scrape request
+ scrapeReq, err := http.NewRequestWithContext(req.Context(), req.Method, scrapeUrl.String(), req.Body)
+ if err != nil {
+ return nil, err
+ }
+
+ // Send scrape request
+ resp, err := s.client.Do(scrapeReq)
+ if err != nil {
+ return nil, err
+ }
+ defer func(Body io.ReadCloser) {
+ _ = Body.Close()
+ }(resp.Body)
+
+ // Parse scrape response
+ scraperResponse := new(ScraperResponse)
+ err = json.NewDecoder(resp.Body).Decode(scraperResponse)
+ if err != nil {
+ return nil, err
+ }
+
+ // Convert scraper response to http.Response
+ return scraperResponse.ToHTTPResponse()
+}
+
+func (s *ScrapflyTransport) Properties() transport.Properties {
+ return transport.PropertiesReliable
+}
+
+func (s *ScrapflyTransport) constructRawQuery(u *url.URL, headers http.Header) string {
+ params := url.Values{}
+ params.Set("key", s.apiKey)
+ params.Set("url", u.String())
+ if s.options.country != nil {
+ params.Set("country", *s.options.country)
+ }
+ params.Set("asp", strconv.FormatBool(s.options.asp))
+ params.Set("proxy_pool", s.options.proxyPool.String())
+ params.Set("render_js", strconv.FormatBool(s.options.renderJS))
+ params.Set("cache", strconv.FormatBool(s.options.cache))
+
+ for k, v := range headers {
+ for i, vv := range v {
+ params.Add(
+ fmt.Sprintf("headers[%s][%d]", k, i),
+ vv,
+ )
+ }
+ }
+
+ return params.Encode()
+}
diff --git a/backend/internal/ibd/transport/standard.go b/backend/internal/ibd/transport/standard.go
new file mode 100644
index 0000000..9fa9ff9
--- /dev/null
+++ b/backend/internal/ibd/transport/standard.go
@@ -0,0 +1,41 @@
+package transport
+
+import (
+ "net/http"
+
+ "github.com/EDDYCJY/fake-useragent"
+)
+
+type StandardTransport http.Client
+
+func NewStandardTransport(client *http.Client) *StandardTransport {
+ return (*StandardTransport)(client)
+}
+
+func (t *StandardTransport) Do(req *http.Request) (*http.Response, error) {
+ addFakeHeaders(req)
+ return (*http.Client)(t).Do(req)
+}
+
+func (t *StandardTransport) String() string {
+ return "standard"
+}
+
+func (t *StandardTransport) Properties() Properties {
+ return PropertiesFree
+}
+
+func addFakeHeaders(req *http.Request) {
+ req.Header.Set("User-Agent", browser.Linux())
+ req.Header.Set("Sec-CH-UA", `"Not)A;Brand";v="99", "Google Chrome";v="127", "Chromium";v="127"`)
+ req.Header.Set("Sec-CH-UA-Mobile", "?0")
+ req.Header.Set("Sec-CH-UA-Platform", "Linux")
+ req.Header.Set("Upgrade-Insecure-Requests", "1")
+ req.Header.Set("Priority", "u=0, i")
+ req.Header.Set("Sec-Fetch-Site", "none")
+ req.Header.Set("Sec-Fetch-Mode", "navigate")
+ req.Header.Set("Sec-Fetch-Dest", "document")
+ req.Header.Set("Sec-Fetch-User", "?1")
+ req.Header.Set("Accept-Language", "en-US,en;q=0.9")
+ req.Header.Set("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7")
+}
diff --git a/backend/internal/ibd/transport/transport.go b/backend/internal/ibd/transport/transport.go
new file mode 100644
index 0000000..95e9ef3
--- /dev/null
+++ b/backend/internal/ibd/transport/transport.go
@@ -0,0 +1,66 @@
+package transport
+
+import (
+ "cmp"
+ "errors"
+ "fmt"
+ "net/http"
+ "slices"
+)
+
+var ErrUnsupportedRequest = errors.New("unsupported request")
+
+type Properties uint8
+
+const (
+ // PropertiesFree indicates that the transport is free.
+ // This means that requests made with this transport don't cost any money.
+ PropertiesFree Properties = 1 << iota
+ // PropertiesReliable indicates that the transport is reliable.
+ // This means that requests made with this transport are guaranteed to be
+ // successful if the server is reachable.
+ PropertiesReliable
+)
+
+func (p Properties) IsReliable() bool {
+ return p&PropertiesReliable != 0
+}
+
+func (p Properties) IsFree() bool {
+ return p&PropertiesFree != 0
+}
+
+type Transport interface {
+ fmt.Stringer
+
+ Do(req *http.Request) (*http.Response, error)
+ Properties() Properties
+}
+
+// SortTransports sorts the transports by their properties.
+//
+// The transports are sorted in the following order:
+// 1. Free transports
+// 2. Reliable transports
+func SortTransports(transports []Transport) {
+ priorities := map[Properties]int{
+ PropertiesFree | PropertiesReliable: 0,
+ PropertiesFree: 1,
+ PropertiesReliable: 2,
+ }
+ slices.SortStableFunc(transports, func(a, b Transport) int {
+ iPriority := priorities[a.Properties()]
+ jPriority := priorities[b.Properties()]
+ return cmp.Compare(iPriority, jPriority)
+ })
+}
+
+func FilterTransports(transport []Transport, props Properties) []Transport {
+ var filtered []Transport
+ for _, tp := range transport {
+ if tp.Properties()&props == props {
+ filtered = append(filtered, tp)
+ }
+ }
+ return filtered
+}
diff --git a/backend/internal/ibd/userinfo.go b/backend/internal/ibd/userinfo.go
new file mode 100644
index 0000000..ed61497
--- /dev/null
+++ b/backend/internal/ibd/userinfo.go
@@ -0,0 +1,156 @@
+package ibd
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "log/slog"
+ "net/http"
+)
+
+const (
+ userInfoUrl = "https://myibd.investors.com/services/userprofile.aspx?format=json"
+)
+
+func (c *Client) UserInfo(ctx context.Context, cookie *http.Cookie) (*UserProfile, error) {
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, userInfoUrl, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ req.AddCookie(cookie)
+
+ resp, err := c.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer func(Body io.ReadCloser) {
+ _ = Body.Close()
+ }(resp.Body)
+
+ content, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read response body: %w", err)
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf(
+ "unexpected status code %d: %s",
+ resp.StatusCode,
+ string(content),
+ )
+ }
+
+ up := new(UserProfile)
+ if err = up.UnmarshalJSON(content); err != nil {
+ return nil, err
+ }
+
+ return up, nil
+}
+
+type UserStatus string
+
+const (
+ UserStatusUnknown UserStatus = ""
+ UserStatusVisitor UserStatus = "Visitor"
+ UserStatusSubscriber UserStatus = "Subscriber"
+)
+
+type UserProfile struct {
+ DisplayName string
+ Email string
+ FirstName string
+ LastName string
+ Status UserStatus
+}
+
+func (u *UserProfile) UnmarshalJSON(bytes []byte) error {
+ var resp userProfileResponse
+ if err := json.Unmarshal(bytes, &resp); err != nil {
+ return err
+ }
+
+ u.DisplayName = resp.UserProfile.UserDisplayName
+ u.Email = resp.UserProfile.UserEmailAddress
+ u.FirstName = resp.UserProfile.UserFirstName
+ u.LastName = resp.UserProfile.UserLastName
+
+ switch resp.UserProfile.UserTrialStatus {
+ case "Visitor":
+ u.Status = UserStatusVisitor
+ case "Subscriber":
+ u.Status = UserStatusSubscriber
+ default:
+ slog.Warn("Unknown user status", "status", resp.UserProfile.UserTrialStatus)
+ u.Status = UserStatusUnknown
+ }
+
+ return nil
+}
+
+type userProfileResponse struct {
+ UserProfile userProfile `json:"userProfile"`
+}
+
+type userProfile struct {
+ UserSubType string `json:"userSubType"`
+ UserId string `json:"userId"`
+ UserDisplayName string `json:"userDisplayName"`
+ Countrycode string `json:"countrycode"`
+ IsEUCountry string `json:"isEUCountry"`
+ Log string `json:"log"`
+ AgeGroup string `json:"ageGroup"`
+ Gender string `json:"gender"`
+ InvestingExperience string `json:"investingExperience"`
+ NumberOfTrades string `json:"numberOfTrades"`
+ Occupation string `json:"occupation"`
+ TypeOfInvestments string `json:"typeOfInvestments"`
+ UserEmailAddress string `json:"userEmailAddress"`
+ UserEmailAddressSHA1 string `json:"userEmailAddressSHA1"`
+ UserEmailAddressSHA256 string `json:"userEmailAddressSHA256"`
+ UserEmailAddressMD5 string `json:"userEmailAddressMD5"`
+ UserFirstName string `json:"userFirstName"`
+ UserLastName string `json:"userLastName"`
+ UserZip string `json:"userZip"`
+ UserTrialStatus string `json:"userTrialStatus"`
+ UserProductsOnTrial string `json:"userProductsOnTrial"`
+ UserProductsOwned string `json:"userProductsOwned"`
+ UserAdTrade string `json:"userAdTrade"`
+ UserAdTime string `json:"userAdTime"`
+ UserAdHold string `json:"userAdHold"`
+ UserAdJob string `json:"userAdJob"`
+ UserAdAge string `json:"userAdAge"`
+ UserAdOutSell string `json:"userAdOutSell"`
+ UserVisitCount string `json:"userVisitCount"`
+ RoleLeaderboard bool `json:"role_leaderboard"`
+ RoleOws bool `json:"role_ows"`
+ RoleIbdlive bool `json:"role_ibdlive"`
+ RoleFounderclub bool `json:"role_founderclub"`
+ RoleEibd bool `json:"role_eibd"`
+ RoleIcom bool `json:"role_icom"`
+ RoleEtables bool `json:"role_etables"`
+ RoleTru10 bool `json:"role_tru10"`
+ RoleMarketsurge bool `json:"role_marketsurge"`
+ RoleSwingtrader bool `json:"role_swingtrader"`
+ RoleAdfree bool `json:"role_adfree"`
+ RoleMarketdiem bool `json:"role_marketdiem"`
+ RoleWsjPlus bool `json:"role_wsj_plus"`
+ RoleWsj bool `json:"role_wsj"`
+ RoleBarrons bool `json:"role_barrons"`
+ RoleMarketwatch bool `json:"role_marketwatch"`
+ UserAdRoles string `json:"userAdRoles"`
+ TrialDailyPrintNeg bool `json:"trial_daily_print_neg"`
+ TrialDailyPrintNon bool `json:"trial_daily_print_non"`
+ TrialWeeklyPrintNeg bool `json:"trial_weekly_print_neg"`
+ TrialWeeklyPrintNon bool `json:"trial_weekly_print_non"`
+ TrialDailyComboNeg bool `json:"trial_daily_combo_neg"`
+ TrialDailyComboNon bool `json:"trial_daily_combo_non"`
+ TrialWeeklyComboNeg bool `json:"trial_weekly_combo_neg"`
+ TrialWeeklyComboNon bool `json:"trial_weekly_combo_non"`
+ TrialEibdNeg bool `json:"trial_eibd_neg"`
+ TrialEibdNon bool `json:"trial_eibd_non"`
+ UserVideoPreference string `json:"userVideoPreference"`
+ UserProfessionalStatus bool `json:"userProfessionalStatus"`
+}
diff --git a/backend/internal/keys/gcp.go b/backend/internal/keys/gcp.go
new file mode 100644
index 0000000..9d10fc5
--- /dev/null
+++ b/backend/internal/keys/gcp.go
@@ -0,0 +1,131 @@
+package keys
+
+import (
+ "context"
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/sha256"
+ "crypto/x509"
+ "encoding/pem"
+ "errors"
+ "fmt"
+ "hash/crc32"
+ "sync"
+
+ kms "cloud.google.com/go/kms/apiv1"
+ "cloud.google.com/go/kms/apiv1/kmspb"
+ "google.golang.org/protobuf/types/known/wrapperspb"
+)
+
+type GoogleKMS struct {
+ client *kms.KeyManagementClient
+
+ mx sync.RWMutex
+ keyCache map[string]*rsa.PublicKey
+}
+
+func NewGoogleKMS(ctx context.Context) (*GoogleKMS, error) {
+ client, err := kms.NewKeyManagementClient(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ return &GoogleKMS{
+ client: client,
+ keyCache: make(map[string]*rsa.PublicKey),
+ }, nil
+}
+
+func (g *GoogleKMS) checkCache(keyName string) *rsa.PublicKey {
+ g.mx.RLock()
+ defer g.mx.RUnlock()
+
+ return g.keyCache[keyName]
+}
+
+func (g *GoogleKMS) getPublicKey(ctx context.Context, keyName string) (*rsa.PublicKey, error) {
+ if key := g.checkCache(keyName); key != nil {
+ return key, nil
+ }
+
+ response, err := g.client.GetPublicKey(ctx, &kmspb.GetPublicKeyRequest{Name: keyName})
+ if err != nil {
+ return nil, err
+ }
+
+ block, _ := pem.Decode([]byte(response.Pem))
+ if block == nil || block.Type != "PUBLIC KEY" {
+ return nil, errors.New("failed to decode PEM public key")
+ }
+ publicKey, err := x509.ParsePKIXPublicKey(block.Bytes)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse public key: %w", err)
+ }
+ rsaKey, ok := publicKey.(*rsa.PublicKey)
+ if !ok {
+ return nil, errors.New("public key is not an RSA key")
+ }
+
+ g.mx.Lock()
+ defer g.mx.Unlock()
+ g.keyCache[keyName] = rsaKey
+
+ return rsaKey, nil
+}
+
+func (g *GoogleKMS) Encrypt(ctx context.Context, keyName string, plaintext []byte) ([]byte, error) {
+ publicKey, err := g.getPublicKey(ctx, keyName)
+ if err != nil {
+ return nil, err
+ }
+
+ cipherText, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, publicKey, plaintext, nil)
+ if err != nil {
+ return nil, fmt.Errorf("failed to encrypt plaintext: %w", err)
+ }
+
+ return cipherText, nil
+}
+
+func (g *GoogleKMS) Decrypt(ctx context.Context, keyName string, ciphertext []byte) ([]byte, error) {
+ req := &kmspb.AsymmetricDecryptRequest{
+ Name: keyName,
+ Ciphertext: ciphertext,
+ CiphertextCrc32C: wrapperspb.Int64(int64(calcCRC32(ciphertext))),
+ }
+
+ result, err := g.client.AsymmetricDecrypt(ctx, req)
+ if err != nil {
+ return nil, fmt.Errorf("failed to decrypt ciphertext: %w", err)
+ }
+
+ if !result.VerifiedCiphertextCrc32C {
+ return nil, errors.New("AsymmetricDecrypt: request corrupted in-transit")
+ }
+ if int64(calcCRC32(result.Plaintext)) != result.PlaintextCrc32C.Value {
+ return nil, fmt.Errorf("AsymmetricDecrypt: response corrupted in-transit")
+ }
+
+ return result.Plaintext, nil
+}
+
+func (g *GoogleKMS) Close() error {
+ return g.client.Close()
+}
+
+func calcCRC32(data []byte) uint32 {
+ t := crc32.MakeTable(crc32.Castagnoli)
+ return crc32.Checksum(data, t)
+}
+
+type GCPKeyName struct {
+ Project string
+ Location string
+ KeyRing string
+ CryptoKey string
+ CryptoKeyVersion string
+}
+
+func (k GCPKeyName) String() string {
+ return fmt.Sprintf("projects/%s/locations/%s/keyRings/%s/cryptoKeys/%s/cryptoKeyVersions/%s", k.Project, k.Location, k.KeyRing, k.CryptoKey, k.CryptoKeyVersion)
+}
diff --git a/backend/internal/keys/keys.go b/backend/internal/keys/keys.go
new file mode 100644
index 0000000..ac73173
--- /dev/null
+++ b/backend/internal/keys/keys.go
@@ -0,0 +1,150 @@
+package keys
+
+import (
+ "context"
+ "crypto/aes"
+ "crypto/cipher"
+ "crypto/rand"
+ "fmt"
+ "io"
+)
+
+var CSRNG = rand.Reader
+
+//go:generate mockgen -destination mock_keys_test.go -package keys_test -typed . KeyManagementService
+type KeyManagementService interface {
+ io.Closer
+
+ // Encrypt encrypts the given plaintext using the key with the given key name.
+ Encrypt(ctx context.Context, keyName string, plaintext []byte) ([]byte, error)
+
+ // Decrypt decrypts the given ciphertext using the key with the given key name.
+ Decrypt(ctx context.Context, keyName string, ciphertext []byte) ([]byte, error)
+}
+
+// Encrypt encrypts the given plaintext using a hybrid encryption scheme.
+//
+// It first generates a random AES 256-bit key and encrypts the plaintext with it.
+// Then, it encrypts the AES key using the KMS.
+//
+// It returns the ciphertext, the encrypted AES key, and any errors that occurred.
+func Encrypt(
+ ctx context.Context,
+ kms KeyManagementService,
+ keyName string,
+ plaintext []byte,
+) (ciphertext []byte, encryptedKey []byte, err error) {
+ // Generate a random AES key
+ aesKey := make([]byte, 32)
+ if _, err = io.ReadFull(CSRNG, aesKey); err != nil {
+ return nil, nil, fmt.Errorf("unable to generate AES key: %w", err)
+ }
+
+ // Encrypt the plaintext using the AES key
+ ciphertext, err = encrypt(aesKey, plaintext)
+ if err != nil {
+ return nil, nil, fmt.Errorf("unable to encrypt plaintext: %w", err)
+ }
+
+ // Encrypt the AES key using the KMS
+ encryptedKey, err = kms.Encrypt(ctx, keyName, aesKey)
+ if err != nil {
+ return nil, nil, fmt.Errorf("unable to encrypt AES key: %w", err)
+ }
+
+ return ciphertext, encryptedKey, nil
+}
+
+// EncryptWithKey encrypts the given plaintext using a hybrid encryption scheme.
+//
+// This works similarly to Encrypt, but instead of generating a new AES key, it uses a given already encrypted AES key.
+func EncryptWithKey(
+ ctx context.Context,
+ kms KeyManagementService,
+ keyName string,
+ encryptedKey []byte,
+ plaintext []byte,
+) ([]byte, error) {
+ // Decrypt the AES key
+ aesKey, err := kms.Decrypt(ctx, keyName, encryptedKey)
+ if err != nil {
+ return nil, fmt.Errorf("unable to decrypt AES key: %w", err)
+ }
+
+ // Encrypt the plaintext using the AES key
+ ciphertext, err := encrypt(aesKey, plaintext)
+ if err != nil {
+ return nil, fmt.Errorf("unable to encrypt plaintext: %w", err)
+ }
+
+ return ciphertext, nil
+}
+
+func encrypt(aesKey []byte, plaintext []byte) ([]byte, error) {
+ // Create an AES cipher
+ blockCipher, err := aes.NewCipher(aesKey)
+ if err != nil {
+ return nil, fmt.Errorf("unable to create AES cipher: %w", err)
+ }
+
+ gcm, err := cipher.NewGCM(blockCipher)
+ if err != nil {
+ return nil, fmt.Errorf("unable to create GCM: %w", err)
+ }
+
+ // Generate a random nonce
+ nonce := make([]byte, gcm.NonceSize())
+ if _, err = io.ReadFull(CSRNG, nonce); err != nil {
+ return nil, fmt.Errorf("unable to generate nonce: %w", err)
+ }
+
+ // Encrypt the plaintext
+ ciphertext := gcm.Seal(nonce, nonce, plaintext, nil)
+ return ciphertext, nil
+}
+
+// Decrypt decrypts the given ciphertext using a hybrid encryption scheme.
+//
+// It first decrypts the AES key using the KMS.
+// Then, it decrypts the ciphertext using the decrypted AES key.
+//
+// It returns the plaintext and any errors that occurred.
+func Decrypt(
+ ctx context.Context,
+ kms KeyManagementService,
+ keyName string,
+ ciphertext []byte,
+ encryptedKey []byte,
+) ([]byte, error) {
+ // Decrypt the AES key
+ aesKey, err := kms.Decrypt(ctx, keyName, encryptedKey)
+ if err != nil {
+ return nil, fmt.Errorf("unable to decrypt AES key: %w", err)
+ }
+
+ // Create an AES cipher
+ blockCipher, err := aes.NewCipher(aesKey)
+ if err != nil {
+ return nil, fmt.Errorf("unable to create AES cipher: %w", err)
+ }
+
+ gcm, err := cipher.NewGCM(blockCipher)
+ if err != nil {
+ return nil, fmt.Errorf("unable to create GCM: %w", err)
+ }
+
+ // Extract the nonce from the ciphertext
+ nonceSize := gcm.NonceSize()
+ if len(ciphertext) < nonceSize {
+ return nil, fmt.Errorf("ciphertext is too short")
+ }
+ nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:]
+
+ // Decrypt the ciphertext
+ plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
+ if err != nil {
+ return nil, fmt.Errorf("unable to decrypt ciphertext: %w", err)
+ }
+
+ return plaintext, nil
+}
diff --git a/backend/internal/keys/keys_test.go b/backend/internal/keys/keys_test.go
new file mode 100644
index 0000000..34aa493
--- /dev/null
+++ b/backend/internal/keys/keys_test.go
@@ -0,0 +1,64 @@
+package keys_test
+
+import (
+ "bytes"
+ "context"
+ "encoding/hex"
+ "testing"
+
+ "github.com/ansg191/ibd-trader-backend/internal/keys"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "go.uber.org/mock/gomock"
+)
+
+func TestEncrypt(t *testing.T) {
+ ctrl := gomock.NewController(t)
+
+ // Replace RNG with a deterministic RNG
+ aesKey := []byte("0123456789abcdef0123456789abcdef")
+ nonce := []byte("0123456789ab")
+ keys.CSRNG = bytes.NewReader(append(aesKey, nonce...))
+
+ // Create a mock KMS
+ kms := NewMockKeyManagementService(ctrl)
+ keyName := "keyName"
+
+ ctx := context.Background()
+ plaintext := []byte("plaintext")
+
+ kms.EXPECT().
+ Encrypt(ctx, keyName, aesKey).
+ Return([]byte("encryptedKey"), nil)
+
+ ciphertext, encryptedKey, err := keys.Encrypt(ctx, kms, keyName, plaintext)
+ require.NoError(t, err)
+
+ encrypted, err := hex.DecodeString("e9c586532dbefd63812293e1c4baf71edb7042a294c49c2020")
+ require.NoError(t, err)
+ assert.Equal(t, append(nonce, encrypted...), ciphertext)
+ assert.Equal(t, []byte("encryptedKey"), encryptedKey)
+}
+
+func TestDecrypt(t *testing.T) {
+ ctrl := gomock.NewController(t)
+
+ kms := NewMockKeyManagementService(ctrl)
+ keyName := "keyName"
+
+ ctx := context.Background()
+ encryptedKey := []byte("encryptedKey")
+ ciphertext, err := hex.DecodeString("e9c586532dbefd63812293e1c4baf71edb7042a294c49c2020")
+ require.NoError(t, err)
+ ciphertext = append([]byte("0123456789ab"), ciphertext...)
+
+ aesKey := []byte("0123456789abcdef0123456789abcdef")
+ kms.EXPECT().
+ Decrypt(ctx, keyName, encryptedKey).
+ Return(aesKey, nil)
+
+ plaintext, err := keys.Decrypt(ctx, kms, keyName, ciphertext, encryptedKey)
+ require.NoError(t, err)
+ assert.Equal(t, []byte("plaintext"), plaintext)
+}
diff --git a/backend/internal/keys/mock_keys_test.go b/backend/internal/keys/mock_keys_test.go
new file mode 100644
index 0000000..19316e0
--- /dev/null
+++ b/backend/internal/keys/mock_keys_test.go
@@ -0,0 +1,156 @@
+// Code generated by MockGen. DO NOT EDIT.
+// Source: github.com/ansg191/ibd-trader-backend/internal/keys (interfaces: KeyManagementService)
+//
+// Generated by this command:
+//
+// mockgen -destination mock_keys_test.go -package keys_test -typed . KeyManagementService
+//
+
+// Package keys_test is a generated GoMock package.
+package keys_test
+
+import (
+ context "context"
+ reflect "reflect"
+
+ gomock "go.uber.org/mock/gomock"
+)
+
+// MockKeyManagementService is a mock of KeyManagementService interface.
+type MockKeyManagementService struct {
+ ctrl *gomock.Controller
+ recorder *MockKeyManagementServiceMockRecorder
+}
+
+// MockKeyManagementServiceMockRecorder is the mock recorder for MockKeyManagementService.
+type MockKeyManagementServiceMockRecorder struct {
+ mock *MockKeyManagementService
+}
+
+// NewMockKeyManagementService creates a new mock instance.
+func NewMockKeyManagementService(ctrl *gomock.Controller) *MockKeyManagementService {
+ mock := &MockKeyManagementService{ctrl: ctrl}
+ mock.recorder = &MockKeyManagementServiceMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use.
+func (m *MockKeyManagementService) EXPECT() *MockKeyManagementServiceMockRecorder {
+ return m.recorder
+}
+
+// Close mocks base method.
+func (m *MockKeyManagementService) Close() error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Close")
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// Close indicates an expected call of Close.
+func (mr *MockKeyManagementServiceMockRecorder) Close() *MockKeyManagementServiceCloseCall {
+ mr.mock.ctrl.T.Helper()
+ call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockKeyManagementService)(nil).Close))
+ return &MockKeyManagementServiceCloseCall{Call: call}
+}
+
+// MockKeyManagementServiceCloseCall wrap *gomock.Call
+type MockKeyManagementServiceCloseCall struct {
+ *gomock.Call
+}
+
+// Return rewrite *gomock.Call.Return
+func (c *MockKeyManagementServiceCloseCall) Return(arg0 error) *MockKeyManagementServiceCloseCall {
+ c.Call = c.Call.Return(arg0)
+ return c
+}
+
+// Do rewrite *gomock.Call.Do
+func (c *MockKeyManagementServiceCloseCall) Do(f func() error) *MockKeyManagementServiceCloseCall {
+ c.Call = c.Call.Do(f)
+ return c
+}
+
+// DoAndReturn rewrite *gomock.Call.DoAndReturn
+func (c *MockKeyManagementServiceCloseCall) DoAndReturn(f func() error) *MockKeyManagementServiceCloseCall {
+ c.Call = c.Call.DoAndReturn(f)
+ return c
+}
+
+// Decrypt mocks base method.
+func (m *MockKeyManagementService) Decrypt(arg0 context.Context, arg1 string, arg2 []byte) ([]byte, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Decrypt", arg0, arg1, arg2)
+ ret0, _ := ret[0].([]byte)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// Decrypt indicates an expected call of Decrypt.
+func (mr *MockKeyManagementServiceMockRecorder) Decrypt(arg0, arg1, arg2 any) *MockKeyManagementServiceDecryptCall {
+ mr.mock.ctrl.T.Helper()
+ call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Decrypt", reflect.TypeOf((*MockKeyManagementService)(nil).Decrypt), arg0, arg1, arg2)
+ return &MockKeyManagementServiceDecryptCall{Call: call}
+}
+
+// MockKeyManagementServiceDecryptCall wrap *gomock.Call
+type MockKeyManagementServiceDecryptCall struct {
+ *gomock.Call
+}
+
+// Return rewrite *gomock.Call.Return
+func (c *MockKeyManagementServiceDecryptCall) Return(arg0 []byte, arg1 error) *MockKeyManagementServiceDecryptCall {
+ c.Call = c.Call.Return(arg0, arg1)
+ return c
+}
+
+// Do rewrite *gomock.Call.Do
+func (c *MockKeyManagementServiceDecryptCall) Do(f func(context.Context, string, []byte) ([]byte, error)) *MockKeyManagementServiceDecryptCall {
+ c.Call = c.Call.Do(f)
+ return c
+}
+
+// DoAndReturn rewrite *gomock.Call.DoAndReturn
+func (c *MockKeyManagementServiceDecryptCall) DoAndReturn(f func(context.Context, string, []byte) ([]byte, error)) *MockKeyManagementServiceDecryptCall {
+ c.Call = c.Call.DoAndReturn(f)
+ return c
+}
+
+// Encrypt mocks base method.
+func (m *MockKeyManagementService) Encrypt(arg0 context.Context, arg1 string, arg2 []byte) ([]byte, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Encrypt", arg0, arg1, arg2)
+ ret0, _ := ret[0].([]byte)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// Encrypt indicates an expected call of Encrypt.
+func (mr *MockKeyManagementServiceMockRecorder) Encrypt(arg0, arg1, arg2 any) *MockKeyManagementServiceEncryptCall {
+ mr.mock.ctrl.T.Helper()
+ call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Encrypt", reflect.TypeOf((*MockKeyManagementService)(nil).Encrypt), arg0, arg1, arg2)
+ return &MockKeyManagementServiceEncryptCall{Call: call}
+}
+
+// MockKeyManagementServiceEncryptCall wrap *gomock.Call
+type MockKeyManagementServiceEncryptCall struct {
+ *gomock.Call
+}
+
+// Return rewrite *gomock.Call.Return
+func (c *MockKeyManagementServiceEncryptCall) Return(arg0 []byte, arg1 error) *MockKeyManagementServiceEncryptCall {
+ c.Call = c.Call.Return(arg0, arg1)
+ return c
+}
+
+// Do rewrite *gomock.Call.Do
+func (c *MockKeyManagementServiceEncryptCall) Do(f func(context.Context, string, []byte) ([]byte, error)) *MockKeyManagementServiceEncryptCall {
+ c.Call = c.Call.Do(f)
+ return c
+}
+
+// DoAndReturn rewrite *gomock.Call.DoAndReturn
+func (c *MockKeyManagementServiceEncryptCall) DoAndReturn(f func(context.Context, string, []byte) ([]byte, error)) *MockKeyManagementServiceEncryptCall {
+ c.Call = c.Call.DoAndReturn(f)
+ return c
+}
diff --git a/backend/internal/leader/election/election.go b/backend/internal/leader/election/election.go
new file mode 100644
index 0000000..6f83298
--- /dev/null
+++ b/backend/internal/leader/election/election.go
@@ -0,0 +1,128 @@
+package election
+
+import (
+ "context"
+ "errors"
+ "log/slog"
+ "time"
+
+ "github.com/bsm/redislock"
+)
+
+var defaultLeaderElectionOptions = leaderElectionOptions{
+ lockKey: "ibd-leader-election",
+ lockTTL: 10 * time.Second,
+}
+
+func RunOrDie(
+ ctx context.Context,
+ client redislock.RedisClient,
+ onLeader func(context.Context),
+ opts ...LeaderElectionOption,
+) {
+ o := defaultLeaderElectionOptions
+ for _, opt := range opts {
+ opt(&o)
+ }
+
+ locker := redislock.New(client)
+
+ // Election loop
+ for {
+ lock, err := locker.Obtain(ctx, o.lockKey, o.lockTTL, nil)
+ if errors.Is(err, redislock.ErrNotObtained) {
+ // Another instance is the leader
+ } else if err != nil {
+ slog.ErrorContext(ctx, "failed to obtain lock", "error", err)
+ } else {
+ // We are the leader
+ slog.DebugContext(ctx, "elected leader")
+ runLeader(ctx, lock, onLeader, o)
+ }
+
+ // Sleep for a bit before trying again
+ timer := time.NewTimer(o.lockTTL / 5)
+ select {
+ case <-ctx.Done():
+ if !timer.Stop() {
+ <-timer.C
+ }
+ return
+ case <-timer.C:
+ }
+ }
+}
+
+func runLeader(
+ ctx context.Context,
+ lock *redislock.Lock,
+ onLeader func(context.Context),
+ o leaderElectionOptions,
+) {
+ // A context that is canceled when the leader loses the lock
+ ctx, cancel := context.WithCancel(ctx)
+ defer cancel()
+
+ // Release the lock when done
+ defer func() {
+ // Create new context without cancel if the original context is already canceled
+ relCtx := ctx
+ if ctx.Err() != nil {
+ relCtx = context.WithoutCancel(ctx)
+ }
+
+ // Add a timeout to the release context
+ relCtx, cancel := context.WithTimeout(relCtx, o.lockTTL)
+ defer cancel()
+
+ if err := lock.Release(relCtx); err != nil {
+ slog.Error("failed to release lock", "error", err)
+ }
+ }()
+
+ // Run the leader code
+ go func(ctx context.Context) {
+ onLeader(ctx)
+
+ // If the leader code returns, cancel the context to release the lock
+ cancel()
+ }(ctx)
+
+ // Refresh the lock periodically
+ ticker := time.NewTicker(o.lockTTL / 10)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ticker.C:
+ err := lock.Refresh(ctx, o.lockTTL, nil)
+ if errors.Is(err, redislock.ErrNotObtained) || errors.Is(err, redislock.ErrLockNotHeld) {
+ slog.ErrorContext(ctx, "leadership lost", "error", err)
+ return
+ } else if err != nil {
+ slog.ErrorContext(ctx, "failed to refresh lock", "error", err)
+ }
+ case <-ctx.Done():
+ return
+ }
+ }
+}
+
+type leaderElectionOptions struct {
+ lockKey string
+ lockTTL time.Duration
+}
+
+type LeaderElectionOption func(*leaderElectionOptions)
+
+func WithLockKey(key string) LeaderElectionOption {
+ return func(o *leaderElectionOptions) {
+ o.lockKey = key
+ }
+}
+
+func WithLockTTL(ttl time.Duration) LeaderElectionOption {
+ return func(o *leaderElectionOptions) {
+ o.lockTTL = ttl
+ }
+}
diff --git a/backend/internal/leader/manager/ibd/auth/auth.go b/backend/internal/leader/manager/ibd/auth/auth.go
new file mode 100644
index 0000000..9b5502d
--- /dev/null
+++ b/backend/internal/leader/manager/ibd/auth/auth.go
@@ -0,0 +1,111 @@
+package auth
+
+import (
+ "context"
+ "log/slog"
+ "time"
+
+ "github.com/ansg191/ibd-trader-backend/internal/database"
+ "github.com/ansg191/ibd-trader-backend/internal/redis/taskqueue"
+
+ "github.com/redis/go-redis/v9"
+ "github.com/robfig/cron/v3"
+)
+
+const (
+ Queue = "auth-queue"
+ QueueEncoding = taskqueue.EncodingJSON
+)
+
+// Manager is responsible for sending authentication tasks to the workers.
+type Manager struct {
+ queue taskqueue.TaskQueue[TaskInfo]
+ db database.Executor
+ schedule cron.Schedule
+}
+
+func New(
+ ctx context.Context,
+ db database.Executor,
+ rClient *redis.Client,
+ schedule cron.Schedule,
+) (*Manager, error) {
+ queue, err := taskqueue.New(
+ ctx,
+ rClient,
+ Queue,
+ "auth-manager",
+ taskqueue.WithEncoding[TaskInfo](QueueEncoding),
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ return &Manager{
+ queue: queue,
+ db: db,
+ schedule: schedule,
+ }, nil
+}
+
+func (m *Manager) Run(ctx context.Context) {
+ for {
+ now := time.Now()
+ // Find the next time
+ nextTime := m.schedule.Next(now)
+ if nextTime.IsZero() {
+ // Sleep until the next day
+ time.Sleep(time.Until(now.AddDate(0, 0, 1)))
+ continue
+ }
+
+ timer := time.NewTimer(nextTime.Sub(now))
+ slog.DebugContext(ctx, "waiting for next Auth scrape", "next_exec", nextTime)
+
+ select {
+ case <-timer.C:
+ nextExec := m.schedule.Next(nextTime)
+ m.scrapeCookies(ctx, nextExec)
+ case <-ctx.Done():
+ if !timer.Stop() {
+ <-timer.C
+ }
+ return
+ }
+ }
+}
+
+// scrapeCookies scrapes the cookies for every user from the IBD website.
+//
+// This iterates through all users with IBD credentials and checks whether their cookies are still valid.
+// If the cookies are invalid or missing, it re-authenticates the user and updates the cookies in the database.
+func (m *Manager) scrapeCookies(ctx context.Context, deadline time.Time) {
+ ctx, cancel := context.WithDeadline(ctx, deadline)
+ defer cancel()
+
+ // Get all users with IBD credentials
+ users, err := database.ListUsers(ctx, m.db, true)
+ if err != nil {
+ slog.ErrorContext(ctx, "failed to get users", "error", err)
+ return
+ }
+
+ // Create a new task for each user
+ for _, user := range users {
+ task := TaskInfo{
+ UserSubject: user.Subject,
+ }
+
+ // Enqueue the task
+ _, err := m.queue.Enqueue(ctx, task)
+ if err != nil {
+ slog.ErrorContext(ctx, "failed to enqueue task", "error", err)
+ }
+ }
+
+ slog.InfoContext(ctx, "enqueued tasks for all users")
+}
+
+type TaskInfo struct {
+ UserSubject string `json:"user_subject"`
+}
diff --git a/backend/internal/leader/manager/ibd/ibd.go b/backend/internal/leader/manager/ibd/ibd.go
new file mode 100644
index 0000000..e2d4fc0
--- /dev/null
+++ b/backend/internal/leader/manager/ibd/ibd.go
@@ -0,0 +1,8 @@
+package ibd
+
+type Schedules struct {
+ // Auth schedule
+ Auth string
+ // IBD50 schedule
+ IBD50 string
+}
diff --git a/backend/internal/leader/manager/ibd/scrape/scrape.go b/backend/internal/leader/manager/ibd/scrape/scrape.go
new file mode 100644
index 0000000..870ce5e
--- /dev/null
+++ b/backend/internal/leader/manager/ibd/scrape/scrape.go
@@ -0,0 +1,140 @@
+package scrape
+
+import (
+ "context"
+ "errors"
+ "log/slog"
+ "time"
+
+ "github.com/ansg191/ibd-trader-backend/internal/database"
+ "github.com/ansg191/ibd-trader-backend/internal/ibd"
+ "github.com/ansg191/ibd-trader-backend/internal/redis/taskqueue"
+
+ "github.com/redis/go-redis/v9"
+ "github.com/robfig/cron/v3"
+)
+
+const (
+ Queue = "scrape-queue"
+ QueueEncoding = taskqueue.EncodingJSON
+ Channel = "scrape-channel"
+)
+
+// Manager is responsible for sending scraping tasks to the workers.
+type Manager struct {
+ client *ibd.Client
+ db database.Executor
+ queue taskqueue.TaskQueue[TaskInfo]
+ schedule cron.Schedule
+ pubsub *redis.PubSub
+}
+
+func New(
+ ctx context.Context,
+ client *ibd.Client,
+ db database.Executor,
+ redis *redis.Client,
+ schedule cron.Schedule,
+) (*Manager, error) {
+ queue, err := taskqueue.New(
+ ctx,
+ redis,
+ Queue,
+ "ibd-manager",
+ taskqueue.WithEncoding[TaskInfo](QueueEncoding),
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ return &Manager{
+ client: client,
+ db: db,
+ queue: queue,
+ schedule: schedule,
+ pubsub: redis.Subscribe(ctx, Channel),
+ }, nil
+}
+
+func (m *Manager) Close() error {
+ return m.pubsub.Close()
+}
+
+func (m *Manager) Run(ctx context.Context) {
+ ch := m.pubsub.Channel()
+ for {
+ now := time.Now()
+ // Find the next time
+ nextTime := m.schedule.Next(now)
+ if nextTime.IsZero() {
+ // Sleep until the next day
+ time.Sleep(time.Until(now.AddDate(0, 0, 1)))
+ continue
+ }
+
+ timer := time.NewTimer(nextTime.Sub(now))
+ slog.DebugContext(ctx, "waiting for next IBD50 scrape", "next_exec", nextTime)
+
+ select {
+ case <-timer.C:
+ nextExec := m.schedule.Next(nextTime)
+ m.scrapeIBD50(ctx, nextExec)
+ case <-ch:
+ nextExec := m.schedule.Next(time.Now())
+ m.scrapeIBD50(ctx, nextExec)
+ case <-ctx.Done():
+ if !timer.Stop() {
+ <-timer.C
+ }
+ return
+ }
+ }
+}
+
+func (m *Manager) scrapeIBD50(ctx context.Context, deadline time.Time) {
+ ctx, cancel := context.WithDeadline(ctx, deadline)
+ defer cancel()
+
+ stocks, err := m.client.GetIBD50(ctx)
+ if err != nil {
+ if errors.Is(err, ibd.ErrNoAvailableCookies) {
+ slog.WarnContext(ctx, "no available cookies", "error", err)
+ return
+ }
+ slog.ErrorContext(ctx, "failed to get IBD50", "error", err)
+ return
+ }
+
+ for _, stock := range stocks {
+ // Add stock to DB
+ err = database.AddStock(ctx, m.db, database.Stock{
+ Symbol: stock.Symbol,
+ Name: stock.Name,
+ IBDUrl: stock.QuoteURL.String(),
+ })
+ if err != nil {
+ slog.ErrorContext(ctx, "failed to add stock", "error", err)
+ continue
+ }
+
+ // Add ranking to Db
+ err = database.AddRanking(ctx, m.db, stock.Symbol, int(stock.Rank), 0)
+ if err != nil {
+ slog.ErrorContext(ctx, "failed to add ranking", "error", err)
+ continue
+ }
+
+ // Add scrape task to queue
+ task := TaskInfo{Symbol: stock.Symbol}
+ taskID, err := m.queue.Enqueue(ctx, task)
+ if err != nil {
+ slog.ErrorContext(ctx, "failed to enqueue task", "error", err)
+ }
+
+ slog.DebugContext(ctx, "enqueued scrape task", "task_id", taskID, "symbol", stock.Symbol)
+ }
+}
+
+type TaskInfo struct {
+ Symbol string `json:"symbol"`
+}
diff --git a/backend/internal/leader/manager/manager.go b/backend/internal/leader/manager/manager.go
new file mode 100644
index 0000000..61e27e0
--- /dev/null
+++ b/backend/internal/leader/manager/manager.go
@@ -0,0 +1,90 @@
+package manager
+
+import (
+ "context"
+ "fmt"
+ "log/slog"
+ "sync"
+
+ "github.com/ansg191/ibd-trader-backend/internal/config"
+ "github.com/ansg191/ibd-trader-backend/internal/database"
+ "github.com/ansg191/ibd-trader-backend/internal/ibd"
+ "github.com/ansg191/ibd-trader-backend/internal/leader/manager/ibd/auth"
+ ibd2 "github.com/ansg191/ibd-trader-backend/internal/leader/manager/ibd/scrape"
+
+ "github.com/redis/go-redis/v9"
+ "github.com/robfig/cron/v3"
+)
+
+type Manager struct {
+ db database.Database
+ Monitor *WorkerMonitor
+ Scraper *ibd2.Manager
+ Auth *auth.Manager
+}
+
+func New(
+ ctx context.Context,
+ cfg *config.Config,
+ client *redis.Client,
+ db database.Database,
+ ibd *ibd.Client,
+) (*Manager, error) {
+ scraperSchedule, err := cron.ParseStandard(cfg.IBD.Schedules.IBD50)
+ if err != nil {
+ return nil, fmt.Errorf("unable to parse IBD50 schedule: %w", err)
+ }
+ scraper, err := ibd2.New(ctx, ibd, db, client, scraperSchedule)
+ if err != nil {
+ return nil, err
+ }
+
+ authSchedule, err := cron.ParseStandard(cfg.IBD.Schedules.Auth)
+ if err != nil {
+ return nil, fmt.Errorf("unable to parse Auth schedule: %w", err)
+ }
+ authManager, err := auth.New(ctx, db, client, authSchedule)
+ if err != nil {
+ return nil, err
+ }
+
+ return &Manager{
+ db: db,
+ Monitor: NewWorkerMonitor(client),
+ Scraper: scraper,
+ Auth: authManager,
+ }, nil
+}
+
+func (m *Manager) Run(ctx context.Context) error {
+ if err := m.db.Migrate(ctx); err != nil {
+ slog.ErrorContext(ctx, "Unable to migrate database", "error", err)
+ return err
+ }
+
+ var wg sync.WaitGroup
+ wg.Add(4)
+
+ go func() {
+ defer wg.Done()
+ m.db.Maintenance(ctx)
+ }()
+
+ go func() {
+ defer wg.Done()
+ m.Monitor.Start(ctx)
+ }()
+
+ go func() {
+ defer wg.Done()
+ m.Scraper.Run(ctx)
+ }()
+
+ go func() {
+ defer wg.Done()
+ m.Auth.Run(ctx)
+ }()
+
+ wg.Wait()
+ return ctx.Err()
+}
diff --git a/backend/internal/leader/manager/monitor.go b/backend/internal/leader/manager/monitor.go
new file mode 100644
index 0000000..3b2e3ec
--- /dev/null
+++ b/backend/internal/leader/manager/monitor.go
@@ -0,0 +1,164 @@
+package manager
+
+import (
+ "context"
+ "fmt"
+ "log/slog"
+ "time"
+
+ "github.com/buraksezer/consistent"
+ "github.com/cespare/xxhash/v2"
+ "github.com/redis/go-redis/v9"
+)
+
+const (
+ MonitorInterval = 5 * time.Second
+ ActiveWorkersSet = "active-workers"
+)
+
+// WorkerMonitor is a struct that monitors workers and their heartbeats over redis.
+type WorkerMonitor struct {
+ client *redis.Client
+
+ // ring is a consistent hash ring that distributes partitions over detected workers.
+ ring *consistent.Consistent
+ // layoutChangeCh is a channel that others can listen to for layout changes to the ring.
+ layoutChangeCh chan struct{}
+}
+
+// NewWorkerMonitor creates a new WorkerMonitor.
+func NewWorkerMonitor(client *redis.Client) *WorkerMonitor {
+ var members []consistent.Member
+ return &WorkerMonitor{
+ client: client,
+ ring: consistent.New(members, consistent.Config{
+ Hasher: new(hasher),
+ PartitionCount: consistent.DefaultPartitionCount,
+ ReplicationFactor: consistent.DefaultReplicationFactor,
+ Load: consistent.DefaultLoad,
+ }),
+ layoutChangeCh: make(chan struct{}),
+ }
+}
+
+func (m *WorkerMonitor) Close() error {
+ close(m.layoutChangeCh)
+ return nil
+}
+
+func (m *WorkerMonitor) Changes() <-chan struct{} {
+ return m.layoutChangeCh
+}
+
+func (m *WorkerMonitor) Start(ctx context.Context) {
+ m.monitorWorkers(ctx)
+ ticker := time.NewTicker(MonitorInterval)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case <-ticker.C:
+ m.monitorWorkers(ctx)
+ }
+ }
+}
+
+func (m *WorkerMonitor) monitorWorkers(ctx context.Context) {
+ ctx, cancel := context.WithTimeout(ctx, MonitorInterval)
+ defer cancel()
+
+ // Get all active workers.
+ workers, err := m.client.SMembers(ctx, ActiveWorkersSet).Result()
+ if err != nil {
+ slog.ErrorContext(ctx, "Unable to get active workers", "error", err)
+ return
+ }
+
+ // Get existing workers in the ring.
+ existingWorkers := m.ring.GetMembers()
+ ewMap := make(map[string]bool)
+ for _, worker := range existingWorkers {
+ ewMap[worker.String()] = false
+ }
+
+ // Check workers' heartbeats.
+ for _, worker := range workers {
+ exists, err := m.client.Exists(ctx, WorkerHeartbeatKey(worker)).Result()
+ if err != nil {
+ slog.ErrorContext(ctx, "Unable to check worker heartbeat", "worker", worker, "error", err)
+ continue
+ }
+
+ if exists == 0 {
+ slog.WarnContext(ctx, "Worker heartbeat not found", "worker", worker)
+
+ // Remove worker from active workers set.
+ if err = m.client.SRem(ctx, ActiveWorkersSet, worker).Err(); err != nil {
+ slog.ErrorContext(ctx, "Unable to remove worker from active workers set", "worker", worker, "error", err)
+ }
+
+ // Remove worker from the ring.
+ m.removeWorker(worker)
+ } else {
+ // Add worker to the ring if it doesn't exist.
+ if _, ok := ewMap[worker]; !ok {
+ slog.InfoContext(ctx, "New worker detected", "worker", worker)
+ m.addWorker(worker)
+ } else {
+ ewMap[worker] = true
+ }
+ }
+ }
+
+ // Check for workers that are not active anymore.
+ for worker, exists := range ewMap {
+ if !exists {
+ slog.WarnContext(ctx, "Worker is not active anymore", "worker", worker)
+ m.removeWorker(worker)
+ }
+ }
+}
+
+func (m *WorkerMonitor) addWorker(worker string) {
+ m.ring.Add(member{hostname: worker})
+
+ // Notify others about the layout change.
+ select {
+ case m.layoutChangeCh <- struct{}{}:
+ // Notify others.
+ default:
+ // No one is listening.
+ }
+}
+
+func (m *WorkerMonitor) removeWorker(worker string) {
+ m.ring.Remove(worker)
+
+ // Notify others about the layout change.
+ select {
+ case m.layoutChangeCh <- struct{}{}:
+ // Notify others.
+ default:
+ // No one is listening.
+ }
+}
+
+func WorkerHeartbeatKey(hostname string) string {
+ return fmt.Sprintf("worker:%s:heartbeat", hostname)
+}
+
+type hasher struct{}
+
+func (h *hasher) Sum64(data []byte) uint64 {
+ return xxhash.Sum64(data)
+}
+
+type member struct {
+ hostname string
+}
+
+func (m member) String() string {
+ return m.hostname
+}
diff --git a/backend/internal/redis/taskqueue/options.go b/backend/internal/redis/taskqueue/options.go
new file mode 100644
index 0000000..2d5a23f
--- /dev/null
+++ b/backend/internal/redis/taskqueue/options.go
@@ -0,0 +1,9 @@
+package taskqueue
+
+type Option[T any] func(*taskQueue[T])
+
+func WithEncoding[T any](encoding Encoding) Option[T] {
+ return func(o *taskQueue[T]) {
+ o.encoding = encoding
+ }
+}
diff --git a/backend/internal/redis/taskqueue/queue.go b/backend/internal/redis/taskqueue/queue.go
new file mode 100644
index 0000000..a4b799e
--- /dev/null
+++ b/backend/internal/redis/taskqueue/queue.go
@@ -0,0 +1,545 @@
+package taskqueue
+
+import (
+ "bytes"
+ "context"
+ "encoding/base64"
+ "encoding/gob"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "log/slog"
+ "reflect"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/redis/go-redis/v9"
+)
+
+type Encoding uint8
+
+const (
+ EncodingJSON Encoding = iota
+ EncodingGob
+
+ ResultKey = "result"
+ ErrorKey = "error"
+ NextAttemptKey = "next_attempt"
+)
+
+var MaxAttempts = 3
+var ErrTaskNotFound = errors.New("task not found")
+
+type TaskQueue[T any] interface {
+ // Enqueue adds a task to the queue.
+ // Returns the generated task ID.
+ Enqueue(ctx context.Context, data T) (TaskInfo[T], error)
+
+ // Dequeue removes a task from the queue and returns it.
+ // The task data is placed into dataOut.
+ //
+ // Dequeue blocks until a task is available, timeout, or the context is canceled.
+ // The returned task is placed in a pending state for lockTimeout duration.
+ // The task must be completed with Complete or extended with Extend before the lock expires.
+ // If the lock expires, the task is returned to the queue, where it may be picked up by another worker.
+ Dequeue(
+ ctx context.Context,
+ lockTimeout,
+ timeout time.Duration,
+ ) (*TaskInfo[T], error)
+
+ // Extend extends the lock on a task.
+ Extend(ctx context.Context, taskID TaskID) error
+
+ // Complete marks a task as complete. Optionally, an error can be provided to store additional information.
+ Complete(ctx context.Context, taskID TaskID, result string) error
+
+ // Data returns the info of a task.
+ Data(ctx context.Context, taskID TaskID) (TaskInfo[T], error)
+
+ // Return returns a task to the queue and returns the new task ID.
+ // Increments the attempt counter.
+ // Tasks with too many attempts (MaxAttempts) are considered failed and aren't returned to the queue.
+ Return(ctx context.Context, taskID TaskID, err error) (TaskID, error)
+
+ // List returns a list of task IDs in the queue.
+ // The list is ordered by the time the task was added to the queue. The most recent task is first.
+ // The count parameter limits the number of tasks returned.
+ // The start and end parameters limit the range of tasks returned.
+ // End is exclusive.
+ // Start must be before end.
+ List(ctx context.Context, start, end TaskID, count int64) ([]TaskInfo[T], error)
+}
+
+type TaskInfo[T any] struct {
+ // ID is the unique identifier of the task. Generated by redis.
+ ID TaskID
+ // Data is the task data. Stored in stream.
+ Data T
+ // Attempts is the number of times the task has been attempted. Stored in stream.
+ Attempts uint8
+ // Result is the result of the task. Stored in a hash.
+ Result isTaskResult
+}
+
+type isTaskResult interface {
+ isTaskResult()
+}
+
+type TaskResultSuccess struct {
+ Result string
+}
+
+type TaskResultError struct {
+ Error string
+ NextAttempt TaskID
+}
+
+func (*TaskResultSuccess) isTaskResult() {}
+func (*TaskResultError) isTaskResult() {}
+
+type TaskID struct {
+ timestamp time.Time
+ sequence uint64
+}
+
+func NewTaskID(timestamp time.Time, sequence uint64) TaskID {
+ return TaskID{timestamp, sequence}
+}
+
+func ParseTaskID(s string) (TaskID, error) {
+ tPart, sPart, ok := strings.Cut(s, "-")
+ if !ok {
+ return TaskID{}, errors.New("invalid task ID")
+ }
+
+ timestamp, err := strconv.ParseInt(tPart, 10, 64)
+ if err != nil {
+ return TaskID{}, err
+ }
+
+ sequence, err := strconv.ParseUint(sPart, 10, 64)
+ if err != nil {
+ return TaskID{}, err
+ }
+
+ return NewTaskID(time.UnixMilli(timestamp), sequence), nil
+}
+
+func (t TaskID) Timestamp() time.Time {
+ return t.timestamp
+}
+
+func (t TaskID) String() string {
+ tPart := strconv.FormatInt(t.timestamp.UnixMilli(), 10)
+ sPart := strconv.FormatUint(t.sequence, 10)
+ return tPart + "-" + sPart
+}
+
+type taskQueue[T any] struct {
+ rdb *redis.Client
+ encoding Encoding
+
+ streamKey string
+ groupName string
+
+ workerName string
+}
+
+func New[T any](ctx context.Context, rdb *redis.Client, name string, workerName string, opts ...Option[T]) (TaskQueue[T], error) {
+ tq := &taskQueue[T]{
+ rdb: rdb,
+ encoding: EncodingJSON,
+ streamKey: "taskqueue:" + name,
+ groupName: "default",
+ workerName: workerName,
+ }
+
+ for _, opt := range opts {
+ opt(tq)
+ }
+
+ // Create the stream if it doesn't exist
+ err := rdb.XGroupCreateMkStream(ctx, tq.streamKey, tq.groupName, "0").Err()
+ if err != nil && err.Error() != "BUSYGROUP Consumer Group name already exists" {
+ return nil, err
+ }
+
+ return tq, nil
+}
+
+func (q *taskQueue[T]) Enqueue(ctx context.Context, data T) (TaskInfo[T], error) {
+ task := TaskInfo[T]{
+ Data: data,
+ Attempts: 0,
+ }
+
+ values, err := encode[T](task, q.encoding)
+ if err != nil {
+ return TaskInfo[T]{}, err
+ }
+
+ taskID, err := q.rdb.XAdd(ctx, &redis.XAddArgs{
+ Stream: q.streamKey,
+ Values: values,
+ }).Result()
+ if err != nil {
+ return TaskInfo[T]{}, err
+ }
+
+ id, err := ParseTaskID(taskID)
+ if err != nil {
+ return TaskInfo[T]{}, err
+ }
+ task.ID = id
+ return task, nil
+}
+
+func (q *taskQueue[T]) Dequeue(ctx context.Context, lockTimeout, timeout time.Duration) (*TaskInfo[T], error) {
+ // Try to recover a task
+ task, err := q.recover(ctx, lockTimeout)
+ if err != nil {
+ return nil, err
+ }
+ if task != nil {
+ return task, nil
+ }
+
+ // Check for new tasks
+ ids, err := q.rdb.XReadGroup(ctx, &redis.XReadGroupArgs{
+ Group: q.groupName,
+ Consumer: q.workerName,
+ Streams: []string{q.streamKey, ">"},
+ Count: 1,
+ Block: timeout,
+ }).Result()
+ if err != nil && !errors.Is(err, redis.Nil) {
+ return nil, err
+ }
+
+ if len(ids) == 0 || len(ids[0].Messages) == 0 || errors.Is(err, redis.Nil) {
+ return nil, nil
+ }
+
+ msg := ids[0].Messages[0]
+ task = new(TaskInfo[T])
+ *task, err = decode[T](&msg, q.encoding)
+ if err != nil {
+ return nil, err
+ }
+ return task, nil
+}
+
+func (q *taskQueue[T]) Extend(ctx context.Context, taskID TaskID) error {
+ _, err := q.rdb.XClaim(ctx, &redis.XClaimArgs{
+ Stream: q.streamKey,
+ Group: q.groupName,
+ Consumer: q.workerName,
+ MinIdle: 0,
+ Messages: []string{taskID.String()},
+ }).Result()
+ if err != nil && !errors.Is(err, redis.Nil) {
+ return err
+ }
+ return nil
+}
+
+func (q *taskQueue[T]) Data(ctx context.Context, taskID TaskID) (TaskInfo[T], error) {
+ msg, err := q.rdb.XRange(ctx, q.streamKey, taskID.String(), taskID.String()).Result()
+ if err != nil {
+ return TaskInfo[T]{}, err
+ }
+
+ if len(msg) == 0 {
+ return TaskInfo[T]{}, ErrTaskNotFound
+ }
+
+ t, err := decode[T](&msg[0], q.encoding)
+ if err != nil {
+ return TaskInfo[T]{}, err
+ }
+
+ t.Result, err = q.getResult(ctx, taskID)
+ if err != nil {
+ return TaskInfo[T]{}, nil
+ }
+ return t, nil
+}
+
+func (q *taskQueue[T]) Complete(ctx context.Context, taskID TaskID, result string) error {
+ return q.ack(ctx, taskID, false, result)
+}
+
+var retScript = redis.NewScript(`
+local stream_key = KEYS[1]
+local hash_key = KEYS[2]
+
+-- Re-add the task to the stream
+local task_id = redis.call('XADD', stream_key, '*', unpack(ARGV))
+
+-- Update the hash key to point to the new task
+redis.call('HSET', hash_key, 'next_attempt', task_id)
+
+return task_id
+`)
+
+func (q *taskQueue[T]) Return(ctx context.Context, taskID TaskID, err1 error) (TaskID, error) {
+ msgs, err := q.rdb.XRange(ctx, q.streamKey, taskID.String(), taskID.String()).Result()
+ if err != nil {
+ return TaskID{}, err
+ }
+ if len(msgs) == 0 {
+ return TaskID{}, ErrTaskNotFound
+ }
+
+ var ackMsg string
+ if err1 != nil {
+ ackMsg = err1.Error()
+ }
+
+ // Ack the task
+ err = q.ack(ctx, taskID, true, ackMsg)
+ if err != nil {
+ return TaskID{}, err
+ }
+
+ msg := msgs[0]
+ task, err := decode[T](&msg, q.encoding)
+ if err != nil {
+ return TaskID{}, err
+ }
+
+ task.Attempts++
+ if int(task.Attempts) >= MaxAttempts {
+ // Task has failed
+ slog.ErrorContext(ctx, "task failed completely",
+ "taskID", taskID,
+ "data", task.Data,
+ "attempts", task.Attempts,
+ "maxAttempts", MaxAttempts,
+ )
+ return TaskID{}, nil
+ }
+
+ valuesMap, err := encode[T](task, q.encoding)
+ if err != nil {
+ return TaskID{}, err
+ }
+
+ values := make([]string, 0, len(valuesMap)*2)
+ for k, v := range valuesMap {
+ values = append(values, k, v)
+ }
+
+ keys := []string{
+ q.streamKey,
+ fmt.Sprintf("%s:%s", q.streamKey, taskID.String()),
+ }
+ newTaskId, err := retScript.Run(ctx, q.rdb, keys, values).Result()
+ if err != nil {
+ return TaskID{}, err
+ }
+
+ return ParseTaskID(newTaskId.(string))
+}
+
+func (q *taskQueue[T]) List(ctx context.Context, start, end TaskID, count int64) ([]TaskInfo[T], error) {
+ if !start.timestamp.IsZero() && !end.timestamp.IsZero() && start.timestamp.After(end.timestamp) {
+ return nil, errors.New("start must be before end")
+ }
+
+ var startStr, endStr string
+ if !start.timestamp.IsZero() {
+ startStr = start.String()
+ } else {
+ startStr = "-"
+ }
+ if !end.timestamp.IsZero() {
+ endStr = "(" + end.String()
+ } else {
+ endStr = "+"
+ }
+
+ msgs, err := q.rdb.XRevRangeN(ctx, q.streamKey, endStr, startStr, count).Result()
+ if err != nil {
+ return nil, err
+ }
+ if len(msgs) == 0 {
+ return []TaskInfo[T]{}, nil
+ }
+
+ tasks := make([]TaskInfo[T], len(msgs))
+ for i := range msgs {
+ tasks[i], err = decode[T](&msgs[i], q.encoding)
+ if err != nil {
+ return nil, err
+ }
+
+ tasks[i].Result, err = q.getResult(ctx, tasks[i].ID)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return tasks, nil
+}
+
+func (q *taskQueue[T]) getResult(ctx context.Context, taskID TaskID) (isTaskResult, error) {
+ key := fmt.Sprintf("%s:%s", q.streamKey, taskID.String())
+ results, err := q.rdb.HMGet(ctx, key, ResultKey, ErrorKey, NextAttemptKey).Result()
+ if err != nil {
+ return nil, err
+ }
+
+ var ret isTaskResult
+ if results[0] != nil {
+ ret = &TaskResultSuccess{Result: results[0].(string)}
+ } else if results[1] != nil {
+ ret = &TaskResultError{Error: results[1].(string)}
+ if results[2] != nil {
+ nextAttempt, err := ParseTaskID(results[2].(string))
+ if err != nil {
+ return nil, err
+ }
+ ret.(*TaskResultError).NextAttempt = nextAttempt
+ }
+ }
+ return ret, nil
+}
+
+func (q *taskQueue[T]) recover(ctx context.Context, idleTimeout time.Duration) (*TaskInfo[T], error) {
+ msgs, _, err := q.rdb.XAutoClaim(ctx, &redis.XAutoClaimArgs{
+ Stream: q.streamKey,
+ Group: q.groupName,
+ MinIdle: idleTimeout,
+ Start: "0",
+ Count: 1,
+ Consumer: q.workerName,
+ }).Result()
+ if err != nil {
+ return nil, err
+ }
+
+ if len(msgs) == 0 {
+ return nil, nil
+ }
+
+ msg := msgs[0]
+ task, err := decode[T](&msg, q.encoding)
+ if err != nil {
+ return nil, err
+ }
+ return &task, nil
+}
+
+func (q *taskQueue[T]) ack(ctx context.Context, taskID TaskID, errored bool, msg string) error {
+ key := fmt.Sprintf("%s:%s", q.streamKey, taskID.String())
+ _, err := q.rdb.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
+ pipe.XAck(ctx, q.streamKey, q.groupName, taskID.String())
+ if errored {
+ pipe.HSet(ctx, key, ErrorKey, msg)
+ } else {
+ pipe.HSet(ctx, key, ResultKey, msg)
+ }
+ return nil
+ })
+ return err
+}
+
+func decode[T any](msg *redis.XMessage, encoding Encoding) (task TaskInfo[T], err error) {
+ task.ID, err = ParseTaskID(msg.ID)
+ if err != nil {
+ return
+ }
+
+ err = getField(msg, "attempts", &task.Attempts)
+ if err != nil {
+ return
+ }
+
+ var data string
+ err = getField(msg, "data", &data)
+ if err != nil {
+ return
+ }
+
+ switch encoding {
+ case EncodingJSON:
+ err = json.Unmarshal([]byte(data), &task.Data)
+ case EncodingGob:
+ var decoded []byte
+ decoded, err = base64.StdEncoding.DecodeString(data)
+ if err != nil {
+ return
+ }
+ err = gob.NewDecoder(bytes.NewReader(decoded)).Decode(&task.Data)
+ default:
+ err = errors.New("unsupported encoding")
+ }
+ return
+}
+
+func getField(msg *redis.XMessage, field string, v any) error {
+ vVal, ok := msg.Values[field]
+ if !ok {
+ return errors.New("missing field")
+ }
+
+ vStr, ok := vVal.(string)
+ if !ok {
+ return errors.New("invalid field type")
+ }
+
+ value := reflect.ValueOf(v).Elem()
+ switch value.Kind() {
+ case reflect.String:
+ value.SetString(vStr)
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ i, err := strconv.ParseInt(vStr, 10, 64)
+ if err != nil {
+ return err
+ }
+ value.SetInt(i)
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ i, err := strconv.ParseUint(vStr, 10, 64)
+ if err != nil {
+ return err
+ }
+ value.SetUint(i)
+ case reflect.Bool:
+ b, err := strconv.ParseBool(vStr)
+ if err != nil {
+ return err
+ }
+ value.SetBool(b)
+ default:
+ return errors.New("unsupported field type")
+ }
+ return nil
+}
+
+func encode[T any](task TaskInfo[T], encoding Encoding) (ret map[string]string, err error) {
+ ret = make(map[string]string)
+ ret["attempts"] = strconv.FormatUint(uint64(task.Attempts), 10)
+
+ switch encoding {
+ case EncodingJSON:
+ var data []byte
+ data, err = json.Marshal(task.Data)
+ if err != nil {
+ return
+ }
+ ret["data"] = string(data)
+ case EncodingGob:
+ var data bytes.Buffer
+ err = gob.NewEncoder(&data).Encode(task.Data)
+ if err != nil {
+ return
+ }
+ ret["data"] = base64.StdEncoding.EncodeToString(data.Bytes())
+ default:
+ err = errors.New("unsupported encoding")
+ }
+ return
+}
diff --git a/backend/internal/redis/taskqueue/queue_test.go b/backend/internal/redis/taskqueue/queue_test.go
new file mode 100644
index 0000000..ee95d39
--- /dev/null
+++ b/backend/internal/redis/taskqueue/queue_test.go
@@ -0,0 +1,467 @@
+package taskqueue
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "log"
+ "testing"
+ "time"
+
+ "github.com/ory/dockertest/v3"
+ "github.com/ory/dockertest/v3/docker"
+ "github.com/redis/go-redis/v9"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+var client *redis.Client
+
+func TestMain(m *testing.M) {
+ pool, err := dockertest.NewPool("")
+ if err != nil {
+ log.Fatalf("Could not create pool: %s", err)
+ }
+
+ err = pool.Client.Ping()
+ if err != nil {
+ log.Fatalf("Could not connect to Docker: %s", err)
+ }
+
+ //resource, err := pool.Run("redis", "7", nil)
+ resource, err := pool.RunWithOptions(&dockertest.RunOptions{
+ Repository: "redis",
+ Tag: "7",
+ }, func(config *docker.HostConfig) {
+ config.AutoRemove = true
+ config.RestartPolicy = docker.RestartPolicy{Name: "no"}
+ })
+ if err != nil {
+ log.Fatalf("Could not start resource: %s", err)
+ }
+
+ //_ = resource.Expire(60)
+
+ if err = pool.Retry(func() error {
+ client = redis.NewClient(&redis.Options{
+ Addr: fmt.Sprintf("localhost:%s", resource.GetPort("6379/tcp")),
+ })
+ return client.Ping(context.Background()).Err()
+ }); err != nil {
+ log.Fatalf("Could not connect to redis: %s", err)
+ }
+
+ defer func() {
+ if err = client.Close(); err != nil {
+ log.Printf("Could not close client: %s", err)
+ }
+ if err = pool.Purge(resource); err != nil {
+ log.Fatalf("Could not purge resource: %s", err)
+ }
+ }()
+
+ m.Run()
+}
+
+func TestTaskQueue(t *testing.T) {
+ if testing.Short() {
+ t.Skip()
+ }
+
+ lockTimeout := 100 * time.Millisecond
+
+ tests := []struct {
+ name string
+ f func(t *testing.T)
+ }{
+ {
+ name: "Create queue",
+ f: func(t *testing.T) {
+ q, err := New[string](context.Background(), client, "test", "worker")
+ require.NoError(t, err)
+ require.NotNil(t, q)
+ },
+ },
+ {
+ name: "enqueue & dequeue",
+ f: func(t *testing.T) {
+ q, err := New[string](context.Background(), client, "test", "worker")
+ require.NoError(t, err)
+ require.NotNil(t, q)
+
+ taskId, err := q.Enqueue(context.Background(), "hello")
+ require.NoError(t, err)
+ require.NotEmpty(t, taskId)
+
+ task, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond)
+ require.NoError(t, err)
+ require.NotNil(t, task)
+ require.Equal(t, "hello", task.Data)
+ },
+ },
+ {
+ name: "complex data",
+ f: func(t *testing.T) {
+ type foo struct {
+ A int
+ B string
+ }
+
+ q, err := New[foo](context.Background(), client, "test", "worker")
+ require.NoError(t, err)
+ require.NotNil(t, q)
+
+ taskId, err := q.Enqueue(context.Background(), foo{A: 42, B: "hello"})
+ require.NoError(t, err)
+ require.NotEmpty(t, taskId)
+
+ task, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond)
+ require.NoError(t, err)
+ require.NotNil(t, task)
+ require.Equal(t, foo{A: 42, B: "hello"}, task.Data)
+ },
+ },
+ {
+ name: "different workers",
+ f: func(t *testing.T) {
+ q1, err := New[string](context.Background(), client, "test", "worker1")
+ require.NoError(t, err)
+ require.NotNil(t, q1)
+
+ q2, err := New[string](context.Background(), client, "test", "worker2")
+ require.NoError(t, err)
+ require.NotNil(t, q2)
+
+ taskId, err := q1.Enqueue(context.Background(), "hello")
+ require.NoError(t, err)
+ require.NotEmpty(t, taskId)
+
+ task, err := q2.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond)
+ require.NoError(t, err)
+ require.NotNil(t, task)
+ assert.Equal(t, "hello", task.Data)
+ },
+ },
+ {
+ name: "complete",
+ f: func(t *testing.T) {
+ q, err := New[string](context.Background(), client, "test", "worker")
+ require.NoError(t, err)
+ require.NotNil(t, q)
+
+ // Enqueue a task
+ taskId, err := q.Enqueue(context.Background(), "hello")
+ require.NoError(t, err)
+ require.NotEmpty(t, taskId)
+
+ // Dequeue the task
+ task, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond)
+ require.NoError(t, err)
+ require.NotNil(t, task)
+ assert.Equal(t, "hello", task.Data)
+
+ // Complete the task
+ err = q.Complete(context.Background(), task.ID, "done")
+ require.NoError(t, err)
+
+ // Try to dequeue the task again
+ task, err = q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond)
+ require.NoError(t, err)
+ require.Nil(t, task)
+ },
+ },
+ {
+ name: "timeout",
+ f: func(t *testing.T) {
+ q, err := New[string](context.Background(), client, "test", "worker")
+ require.NoError(t, err)
+ require.NotNil(t, q)
+
+ // Enqueue a task
+ taskId, err := q.Enqueue(context.Background(), "hello")
+ require.NoError(t, err)
+ require.NotEmpty(t, taskId)
+
+ // Dequeue the task
+ task, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond)
+ require.NoError(t, err)
+ require.NotNil(t, task)
+ assert.Equal(t, "hello", task.Data)
+
+ // Wait for the lock to expire
+ time.Sleep(lockTimeout + 10*time.Millisecond)
+
+ // Try to dequeue the task again
+ task, err = q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond)
+ require.NoError(t, err)
+ require.NotNil(t, task)
+ assert.Equal(t, "hello", task.Data)
+ },
+ },
+ {
+ name: "extend",
+ f: func(t *testing.T) {
+ q, err := New[string](context.Background(), client, "test", "worker")
+ require.NoError(t, err)
+ require.NotNil(t, q)
+
+ // Enqueue a task
+ taskId, err := q.Enqueue(context.Background(), "hello")
+ require.NoError(t, err)
+ require.NotEmpty(t, taskId)
+
+ // Dequeue the task
+ task, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond)
+ require.NoError(t, err)
+ require.NotNil(t, task)
+ assert.Equal(t, "hello", task.Data)
+
+ // Wait for the lock to expire
+ time.Sleep(lockTimeout + 10*time.Millisecond)
+
+ // Extend the lock
+ err = q.Extend(context.Background(), task.ID)
+ require.NoError(t, err)
+
+ // Try to dequeue the task again
+ task, err = q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond)
+ require.NoError(t, err)
+ require.Nil(t, task)
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if err := client.FlushDB(context.Background()).Err(); err != nil {
+ t.Fatal(err)
+ }
+
+ tt.f(t)
+ })
+ }
+
+ _ = client.FlushDB(context.Background())
+}
+
+func TestTaskQueue_List(t *testing.T) {
+ if testing.Short() {
+ t.Skip()
+ }
+
+ tests := []struct {
+ name string
+ f func(t *testing.T)
+ }{
+ {
+ name: "empty",
+ f: func(t *testing.T) {
+ q, err := New[any](context.Background(), client, "test", "worker")
+ require.NoError(t, err)
+ require.NotNil(t, q)
+
+ tasks, err := q.List(context.Background(), TaskID{}, TaskID{}, 1)
+ require.NoError(t, err)
+ assert.Empty(t, tasks)
+ },
+ },
+ {
+ name: "single",
+ f: func(t *testing.T) {
+ q, err := New[string](context.Background(), client, "test", "worker")
+ require.NoError(t, err)
+ require.NotNil(t, q)
+
+ taskID, err := q.Enqueue(context.Background(), "hello")
+ require.NoError(t, err)
+
+ tasks, err := q.List(context.Background(), TaskID{}, TaskID{}, 1)
+ require.NoError(t, err)
+ assert.Len(t, tasks, 1)
+ assert.Equal(t, taskID, tasks[0])
+ },
+ },
+ {
+ name: "multiple",
+ f: func(t *testing.T) {
+ q, err := New[string](context.Background(), client, "test", "worker")
+ require.NoError(t, err)
+ require.NotNil(t, q)
+
+ taskID, err := q.Enqueue(context.Background(), "hello")
+ require.NoError(t, err)
+ taskID2, err := q.Enqueue(context.Background(), "world")
+ require.NoError(t, err)
+
+ tasks, err := q.List(context.Background(), TaskID{}, TaskID{}, 2)
+ require.NoError(t, err)
+ assert.Len(t, tasks, 2)
+ assert.Equal(t, taskID, tasks[1])
+ assert.Equal(t, taskID2, tasks[0])
+ },
+ },
+ {
+ name: "multiple limited",
+ f: func(t *testing.T) {
+ q, err := New[string](context.Background(), client, "test", "worker")
+ require.NoError(t, err)
+ require.NotNil(t, q)
+
+ _, err = q.Enqueue(context.Background(), "hello")
+ require.NoError(t, err)
+ taskID2, err := q.Enqueue(context.Background(), "world")
+ require.NoError(t, err)
+
+ tasks, err := q.List(context.Background(), TaskID{}, TaskID{}, 1)
+ require.NoError(t, err)
+ assert.Len(t, tasks, 1)
+ assert.Equal(t, taskID2, tasks[0])
+ },
+ },
+ {
+ name: "multiple time range",
+ f: func(t *testing.T) {
+ q, err := New[string](context.Background(), client, "test", "worker")
+ require.NoError(t, err)
+ require.NotNil(t, q)
+
+ taskID, err := q.Enqueue(context.Background(), "hello")
+ require.NoError(t, err)
+ time.Sleep(10 * time.Millisecond)
+ taskID2, err := q.Enqueue(context.Background(), "world")
+ require.NoError(t, err)
+
+ tasks, err := q.List(context.Background(), TaskID{}, taskID2.ID, 100)
+ require.NoError(t, err)
+ assert.Len(t, tasks, 1)
+ assert.Equal(t, taskID, tasks[0])
+
+ tasks, err = q.List(context.Background(), taskID2.ID, TaskID{}, 100)
+ require.NoError(t, err)
+ assert.Len(t, tasks, 1)
+ assert.Equal(t, taskID2, tasks[0])
+ },
+ },
+ {
+ name: "completed tasks",
+ f: func(t *testing.T) {
+ q, err := New[string](context.Background(), client, "test", "worker")
+ require.NoError(t, err)
+ require.NotNil(t, q)
+
+ task1, err := q.Enqueue(context.Background(), "hello")
+ require.NoError(t, err)
+ task2, err := q.Enqueue(context.Background(), "world")
+ require.NoError(t, err)
+
+ err = q.Complete(context.Background(), task1.ID, "done")
+ require.NoError(t, err)
+
+ tasks, err := q.List(context.Background(), TaskID{}, TaskID{}, 100)
+ require.NoError(t, err)
+ assert.Len(t, tasks, 2)
+ assert.Equal(t, task2, tasks[0])
+
+ assert.Equal(t, "hello", tasks[1].Data)
+ require.IsType(t, &TaskResultSuccess{}, tasks[1].Result)
+ assert.Equal(t, "done", tasks[1].Result.(*TaskResultSuccess).Result)
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if err := client.FlushDB(context.Background()).Err(); err != nil {
+ t.Fatal(err)
+ }
+
+ tt.f(t)
+ })
+ }
+
+ _ = client.FlushDB(context.Background())
+}
+
+func TestTaskQueue_Return(t *testing.T) {
+ if testing.Short() {
+ t.Skip()
+ }
+
+ lockTimeout := 100 * time.Millisecond
+
+ tests := []struct {
+ name string
+ f func(t *testing.T)
+ }{
+ {
+ name: "simple",
+ f: func(t *testing.T) {
+ q, err := New[string](context.Background(), client, "test", "worker")
+ require.NoError(t, err)
+ require.NotNil(t, q)
+
+ task1, err := q.Enqueue(context.Background(), "hello")
+ require.NoError(t, err)
+
+ id := claimAndFail(t, q, lockTimeout)
+
+ task2, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond)
+ require.NoError(t, err)
+ require.NotNil(t, task2)
+ assert.Equal(t, task2.ID, id)
+ assert.Equal(t, task1.Data, task2.Data)
+ assert.Equal(t, uint8(1), task2.Attempts)
+
+ task1Data, err := q.Data(context.Background(), task1.ID)
+ require.NoError(t, err)
+ assert.Equal(t, task1Data.ID, task1.ID)
+ assert.Equal(t, task1Data.Data, task1.Data)
+ assert.IsType(t, &TaskResultError{}, task1Data.Result)
+ assert.Equal(t, "failed", task1Data.Result.(*TaskResultError).Error)
+ assert.Equal(t, task2.ID, task1Data.Result.(*TaskResultError).NextAttempt)
+ },
+ },
+ {
+ name: "failure",
+ f: func(t *testing.T) {
+ q, err := New[string](context.Background(), client, "test", "worker")
+ require.NoError(t, err)
+ require.NotNil(t, q)
+
+ _, err = q.Enqueue(context.Background(), "hello")
+ require.NoError(t, err)
+
+ claimAndFail(t, q, lockTimeout)
+ claimAndFail(t, q, lockTimeout)
+ claimAndFail(t, q, lockTimeout)
+
+ task3, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond)
+ require.NoError(t, err)
+ assert.Nil(t, task3)
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if err := client.FlushDB(context.Background()).Err(); err != nil {
+ t.Fatal(err)
+ }
+
+ tt.f(t)
+ })
+ }
+
+ _ = client.FlushDB(context.Background())
+}
+
+func claimAndFail[T any](t *testing.T, q TaskQueue[T], lockTimeout time.Duration) TaskID {
+ task, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond)
+ require.NoError(t, err)
+ require.NotNil(t, task)
+
+ id, err := q.Return(context.Background(), task.ID, errors.New("failed"))
+ require.NoError(t, err)
+ assert.NotEqual(t, task.ID, id)
+ return id
+}
diff --git a/backend/internal/server/idb/stock/v1/stock.go b/backend/internal/server/idb/stock/v1/stock.go
new file mode 100644
index 0000000..8afc2b1
--- /dev/null
+++ b/backend/internal/server/idb/stock/v1/stock.go
@@ -0,0 +1,64 @@
+package stock
+
+import (
+ "context"
+ "fmt"
+ "log/slog"
+
+ pb "github.com/ansg191/ibd-trader-backend/api/gen/idb/stock/v1"
+ "github.com/ansg191/ibd-trader-backend/internal/database"
+ "github.com/ansg191/ibd-trader-backend/internal/leader/manager/ibd/scrape"
+ "github.com/ansg191/ibd-trader-backend/internal/redis/taskqueue"
+
+ "cloud.google.com/go/longrunning/autogen/longrunningpb"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
+ "google.golang.org/protobuf/types/known/anypb"
+ "google.golang.org/protobuf/types/known/timestamppb"
+)
+
+const ScrapeOperationPrefix = "scrape"
+
+type Server struct {
+ pb.UnimplementedStockServiceServer
+
+ db database.Executor
+ queue taskqueue.TaskQueue[scrape.TaskInfo]
+}
+
+func New(db database.Executor, queue taskqueue.TaskQueue[scrape.TaskInfo]) *Server {
+ return &Server{db: db, queue: queue}
+}
+
+func (s *Server) CreateStock(ctx context.Context, request *pb.CreateStockRequest) (*pb.CreateStockResponse, error) {
+ task, err := s.queue.Enqueue(ctx, scrape.TaskInfo{Symbol: request.Symbol})
+ if err != nil {
+ slog.ErrorContext(ctx, "failed to enqueue task", "err", err)
+ return nil, status.New(codes.Internal, "failed to enqueue task").Err()
+ }
+ op := &longrunningpb.Operation{
+ Name: fmt.Sprintf("%s/%s", ScrapeOperationPrefix, task.ID.String()),
+ Metadata: new(anypb.Any),
+ Done: false,
+ Result: nil,
+ }
+ err = op.Metadata.MarshalFrom(&pb.StockScrapeOperationMetadata{
+ Symbol: request.Symbol,
+ StartTime: timestamppb.New(task.ID.Timestamp()),
+ })
+ if err != nil {
+ slog.ErrorContext(ctx, "failed to marshal metadata", "err", err)
+ return nil, status.New(codes.Internal, "failed to marshal metadata").Err()
+ }
+ return &pb.CreateStockResponse{Operation: op}, nil
+}
+
+func (s *Server) GetStock(ctx context.Context, request *pb.GetStockRequest) (*pb.GetStockResponse, error) {
+ //TODO implement me
+ panic("implement me")
+}
+
+func (s *Server) ListStocks(ctx context.Context, request *pb.ListStocksRequest) (*pb.ListStocksResponse, error) {
+ //TODO implement me
+ panic("implement me")
+}
diff --git a/backend/internal/server/idb/user/v1/user.go b/backend/internal/server/idb/user/v1/user.go
new file mode 100644
index 0000000..2f32e03
--- /dev/null
+++ b/backend/internal/server/idb/user/v1/user.go
@@ -0,0 +1,159 @@
+package user
+
+import (
+ "context"
+ "errors"
+
+ pb "github.com/ansg191/ibd-trader-backend/api/gen/idb/user/v1"
+ "github.com/ansg191/ibd-trader-backend/internal/database"
+ "github.com/ansg191/ibd-trader-backend/internal/ibd"
+ "github.com/ansg191/ibd-trader-backend/internal/keys"
+
+ "github.com/mennanov/fmutils"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
+ "google.golang.org/protobuf/proto"
+)
+
+type Server struct {
+ pb.UnimplementedUserServiceServer
+
+ db database.TransactionExecutor
+ kms keys.KeyManagementService
+ keyName string
+ client *ibd.Client
+}
+
+func New(db database.TransactionExecutor, kms keys.KeyManagementService, keyName string, client *ibd.Client) *Server {
+ return &Server{
+ db: db,
+ kms: kms,
+ keyName: keyName,
+ client: client,
+ }
+}
+
+func (u *Server) CreateUser(ctx context.Context, request *pb.CreateUserRequest) (*pb.CreateUserResponse, error) {
+ err := database.AddUser(ctx, u.db, request.Subject)
+ if err != nil {
+ return nil, status.Errorf(codes.Internal, "unable to create user: %v", err)
+ }
+
+ user, err := database.GetUser(ctx, u.db, request.Subject)
+ if err != nil {
+ return nil, status.Errorf(codes.Internal, "unable to get user: %v", err)
+ }
+
+ return &pb.CreateUserResponse{
+ User: &pb.User{
+ Subject: user.Subject,
+ IbdUsername: user.IBDUsername,
+ IbdPassword: nil,
+ },
+ }, nil
+}
+
+func (u *Server) GetUser(ctx context.Context, request *pb.GetUserRequest) (*pb.GetUserResponse, error) {
+ user, err := database.GetUser(ctx, u.db, request.Subject)
+ if errors.Is(err, database.ErrUserNotFound) {
+ return nil, status.New(codes.NotFound, "user not found").Err()
+ }
+ if err != nil {
+ return nil, status.Errorf(codes.Internal, "unable to get user: %v", err)
+ }
+
+ return &pb.GetUserResponse{
+ User: &pb.User{
+ Subject: user.Subject,
+ IbdUsername: user.IBDUsername,
+ IbdPassword: nil,
+ },
+ }, nil
+}
+
+func (u *Server) UpdateUser(ctx context.Context, request *pb.UpdateUserRequest) (*pb.UpdateUserResponse, error) {
+ request.UpdateMask.Normalize()
+ if !request.UpdateMask.IsValid(request.User) {
+ return nil, status.Errorf(codes.InvalidArgument, "invalid update mask")
+ }
+
+ existingUserRes, err := u.GetUser(ctx, &pb.GetUserRequest{Subject: request.User.Subject})
+ if err != nil {
+ return nil, err
+ }
+ existingUser := existingUserRes.User
+
+ newUser := proto.Clone(existingUser).(*pb.User)
+ fmutils.Overwrite(request.User, newUser, request.UpdateMask.Paths)
+
+ // if IDB creds are both set and are different, update them
+ if (newUser.IbdPassword != nil && newUser.IbdUsername != nil) &&
+ (newUser.IbdPassword != existingUser.IbdPassword ||
+ newUser.IbdUsername != existingUser.IbdUsername) {
+ // Update IBD creds
+ err = database.AddIBDCreds(ctx, u.db, u.kms, u.keyName, newUser.Subject, *newUser.IbdUsername, *newUser.IbdPassword)
+ if err != nil {
+ return nil, status.Errorf(codes.Internal, "unable to update user: %v", err)
+ }
+ }
+
+ newUser.IbdPassword = nil
+ return &pb.UpdateUserResponse{
+ User: newUser,
+ }, nil
+}
+
+func (u *Server) CheckIBDUsername(ctx context.Context, req *pb.CheckIBDUsernameRequest) (*pb.CheckIBDUsernameResponse, error) {
+ username := req.IbdUsername
+ if username == "" {
+ return nil, status.Errorf(codes.InvalidArgument, "username cannot be empty")
+ }
+
+ // Check if the username exists
+ exists, err := u.client.CheckIBDUsername(ctx, username)
+ if err != nil {
+ return nil, status.Errorf(codes.Internal, "unable to check username: %v", err)
+ }
+
+ return &pb.CheckIBDUsernameResponse{
+ Exists: exists,
+ }, nil
+}
+
+func (u *Server) AuthenticateUser(ctx context.Context, req *pb.AuthenticateUserRequest) (*pb.AuthenticateUserResponse, error) {
+ // Check if user has cookies
+ cookies, err := database.GetCookies(ctx, u.db, u.kms, req.Subject, false)
+ if err != nil {
+ return nil, status.Errorf(codes.Internal, "unable to get cookies: %v", err)
+ }
+ if len(cookies) > 0 {
+ return &pb.AuthenticateUserResponse{
+ Authenticated: true,
+ }, nil
+ }
+
+ // Authenticate user
+ // Get IBD creds
+ username, password, err := database.GetIBDCreds(ctx, u.db, u.kms, req.Subject)
+ if errors.Is(err, database.ErrIBDCredsNotFound) {
+ return nil, status.New(codes.NotFound, "User has no IDB creds").Err()
+ }
+ if err != nil {
+ return nil, status.Errorf(codes.Internal, "unable to get IBD creds: %v", err)
+ }
+
+ // Authenticate user
+ cookie, err := u.client.Authenticate(ctx, username, password)
+ if errors.Is(err, ibd.ErrBadCredentials) {
+ return &pb.AuthenticateUserResponse{
+ Authenticated: false,
+ }, nil
+ }
+ if err != nil {
+ return nil, status.Errorf(codes.Internal, "unable to authenticate user: %v", err)
+ }
+
+ return &pb.AuthenticateUserResponse{
+ Authenticated: cookie != nil,
+ }, nil
+}
diff --git a/backend/internal/server/operations.go b/backend/internal/server/operations.go
new file mode 100644
index 0000000..2487427
--- /dev/null
+++ b/backend/internal/server/operations.go
@@ -0,0 +1,142 @@
+package server
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "log/slog"
+ "strings"
+
+ "cloud.google.com/go/longrunning/autogen/longrunningpb"
+ spb "github.com/ansg191/ibd-trader-backend/api/gen/idb/stock/v1"
+ "github.com/ansg191/ibd-trader-backend/internal/leader/manager/ibd/scrape"
+ "github.com/ansg191/ibd-trader-backend/internal/redis/taskqueue"
+ "github.com/ansg191/ibd-trader-backend/internal/server/idb/stock/v1"
+ epb "google.golang.org/genproto/googleapis/rpc/errdetails"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
+ "google.golang.org/protobuf/types/known/anypb"
+ "google.golang.org/protobuf/types/known/timestamppb"
+)
+
+type operationServer struct {
+ longrunningpb.UnimplementedOperationsServer
+
+ scrape taskqueue.TaskQueue[scrape.TaskInfo]
+}
+
+func newOperationServer(scrapeQueue taskqueue.TaskQueue[scrape.TaskInfo]) *operationServer {
+ return &operationServer{scrape: scrapeQueue}
+}
+
+func (o *operationServer) ListOperations(
+ ctx context.Context,
+ req *longrunningpb.ListOperationsRequest,
+) (*longrunningpb.ListOperationsResponse, error) {
+ var end taskqueue.TaskID
+ if req.PageToken != "" {
+ var err error
+ end, err = taskqueue.ParseTaskID(req.PageToken)
+ if err != nil {
+ return nil, status.New(codes.InvalidArgument, err.Error()).Err()
+ }
+ } else {
+ end = taskqueue.TaskID{}
+ }
+
+ switch req.Name {
+ case stock.ScrapeOperationPrefix:
+ tasks, err := o.scrape.List(ctx, taskqueue.TaskID{}, end, int64(req.PageSize))
+ if err != nil {
+ return nil, status.New(codes.Internal, "unable to list IDs").Err()
+ }
+
+ ops := make([]*longrunningpb.Operation, len(tasks))
+ for i, task := range tasks {
+ ops[i] = &longrunningpb.Operation{
+ Name: fmt.Sprintf("%s/%s", stock.ScrapeOperationPrefix, task.ID.String()),
+ Metadata: new(anypb.Any),
+ Done: task.Result != nil,
+ Result: nil,
+ }
+ err = ops[i].Metadata.MarshalFrom(&spb.StockScrapeOperationMetadata{
+ Symbol: task.Data.Symbol,
+ StartTime: timestamppb.New(task.ID.Timestamp()),
+ })
+ if err != nil {
+ return nil, status.New(codes.Internal, "unable to marshal metadata").Err()
+ }
+
+ switch res := task.Result.(type) {
+ case *taskqueue.TaskResultSuccess:
+ return nil, status.New(codes.Unimplemented, "not implemented").Err()
+ case *taskqueue.TaskResultError:
+ s := status.New(codes.Unknown, res.Error)
+ s, err = s.WithDetails(
+ &epb.ErrorInfo{
+ Reason: "",
+ Domain: "",
+ Metadata: nil,
+ })
+ if err != nil {
+ return nil, status.New(codes.Internal, "unable to marshal error details").Err()
+ }
+ ops[i].Result = &longrunningpb.Operation_Error{Error: s.Proto()}
+ }
+ }
+
+ var nextPageToken string
+ if len(tasks) == int(req.PageSize) {
+ nextPageToken = tasks[len(tasks)-1].ID.String()
+ } else {
+ nextPageToken = ""
+ }
+
+ return &longrunningpb.ListOperationsResponse{
+ Operations: ops,
+ NextPageToken: nextPageToken,
+ }, nil
+ default:
+ return nil, status.New(codes.NotFound, "unknown operation type").Err()
+ }
+}
+
+func (o *operationServer) GetOperation(ctx context.Context, req *longrunningpb.GetOperationRequest) (*longrunningpb.Operation, error) {
+ prefix, id, ok := strings.Cut(req.Name, "/")
+ if !ok || prefix == "" || id == "" {
+ return nil, status.New(codes.InvalidArgument, "invalid operation name").Err()
+ }
+
+ taskID, err := taskqueue.ParseTaskID(id)
+ if err != nil {
+ return nil, status.New(codes.InvalidArgument, err.Error()).Err()
+ }
+
+ switch prefix {
+ case stock.ScrapeOperationPrefix:
+ task, err := o.scrape.Data(ctx, taskID)
+ if errors.Is(err, taskqueue.ErrTaskNotFound) {
+ return nil, status.New(codes.NotFound, "operation not found").Err()
+ }
+ if err != nil {
+ slog.ErrorContext(ctx, "unable to get operation", "error", err)
+ return nil, status.New(codes.Internal, "unable to get operation").Err()
+ }
+ op := &longrunningpb.Operation{
+ Name: req.Name,
+ Metadata: new(anypb.Any),
+ Done: task.Result != nil,
+ Result: nil,
+ }
+ err = op.Metadata.MarshalFrom(&spb.StockScrapeOperationMetadata{
+ Symbol: task.Data.Symbol,
+ StartTime: timestamppb.New(task.ID.Timestamp()),
+ })
+ if err != nil {
+ return nil, status.New(codes.Internal, "unable to marshal metadata").Err()
+ }
+ return op, nil
+ default:
+ return nil, status.New(codes.NotFound, "unknown operation type").Err()
+ }
+}
diff --git a/backend/internal/server/server.go b/backend/internal/server/server.go
new file mode 100644
index 0000000..c525cfd
--- /dev/null
+++ b/backend/internal/server/server.go
@@ -0,0 +1,77 @@
+package server
+
+import (
+ "context"
+ "fmt"
+ "log/slog"
+ "net"
+
+ "cloud.google.com/go/longrunning/autogen/longrunningpb"
+ spb "github.com/ansg191/ibd-trader-backend/api/gen/idb/stock/v1"
+ upb "github.com/ansg191/ibd-trader-backend/api/gen/idb/user/v1"
+ "github.com/ansg191/ibd-trader-backend/internal/database"
+ "github.com/ansg191/ibd-trader-backend/internal/ibd"
+ "github.com/ansg191/ibd-trader-backend/internal/keys"
+ "github.com/ansg191/ibd-trader-backend/internal/leader/manager/ibd/scrape"
+ "github.com/ansg191/ibd-trader-backend/internal/redis/taskqueue"
+ "github.com/ansg191/ibd-trader-backend/internal/server/idb/stock/v1"
+ "github.com/ansg191/ibd-trader-backend/internal/server/idb/user/v1"
+ "github.com/redis/go-redis/v9"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/reflection"
+)
+
+//go:generate make -C ../../api/ generate
+
+type Server struct {
+ s *grpc.Server
+ port uint16
+}
+
+func New(
+ ctx context.Context,
+ port uint16,
+ db database.TransactionExecutor,
+ rClient *redis.Client,
+ client *ibd.Client,
+ kms keys.KeyManagementService,
+ keyName string,
+) (*Server, error) {
+ scrapeQueue, err := taskqueue.New(
+ ctx,
+ rClient,
+ scrape.Queue,
+ "grpc-server",
+ taskqueue.WithEncoding[scrape.TaskInfo](scrape.QueueEncoding))
+ if err != nil {
+ return nil, err
+ }
+
+ s := grpc.NewServer()
+ upb.RegisterUserServiceServer(s, user.New(db, kms, keyName, client))
+ spb.RegisterStockServiceServer(s, stock.New(db, scrapeQueue))
+ longrunningpb.RegisterOperationsServer(s, newOperationServer(scrapeQueue))
+ reflection.Register(s)
+ return &Server{s, port}, nil
+}
+
+func (s *Server) Serve(ctx context.Context) error {
+ lis, err := net.Listen("tcp", fmt.Sprintf(":%d", s.port))
+ if err != nil {
+ return err
+ }
+
+ // Graceful shutdown
+ go func() {
+ <-ctx.Done()
+ slog.ErrorContext(ctx,
+ "Shutting down server",
+ "err", ctx.Err(),
+ "cause", context.Cause(ctx),
+ )
+ s.s.GracefulStop()
+ }()
+
+ slog.InfoContext(ctx, "Starting gRPC server", "port", s.port)
+ return s.s.Serve(lis)
+}
diff --git a/backend/internal/utils/money.go b/backend/internal/utils/money.go
new file mode 100644
index 0000000..2dc2286
--- /dev/null
+++ b/backend/internal/utils/money.go
@@ -0,0 +1,99 @@
+package utils
+
+import (
+ "fmt"
+ "strconv"
+ "strings"
+
+ "github.com/Rhymond/go-money"
+)
+
+// supported currencies
+var currencies = money.Currencies{
+ "USD": money.GetCurrency(money.USD),
+ "EUR": money.GetCurrency(money.EUR),
+ "GBP": money.GetCurrency(money.GBP),
+ "JPY": money.GetCurrency(money.JPY),
+ "CNY": money.GetCurrency(money.CNY),
+}
+
+func ParseMoney(s string) (*money.Money, error) {
+ for _, c := range currencies {
+ numPart, ok := isCurrency(s, c)
+ if !ok {
+ continue
+ }
+
+ // Parse the number part
+ num, err := strconv.ParseUint(numPart, 10, 64)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse number: %w", err)
+ }
+
+ return money.New(int64(num), c.Code), nil
+ }
+ return nil, fmt.Errorf("matching currency not found")
+}
+
+func isCurrency(s string, c *money.Currency) (string, bool) {
+ var numPart string
+ for _, tp := range c.Template {
+ switch tp {
+ case '$':
+ // There should be a matching grapheme in the s at this position
+ remaining, ok := strings.CutPrefix(s, c.Grapheme)
+ if !ok {
+ return "", false
+ }
+ s = remaining
+ case '1':
+ // There should be a number, thousands, or decimal separator in the s at this position
+ // Number of expected decimal places
+ decimalFound := -1
+ // Read from string until a non-number, non-thousands, non-decimal, or EOF is found
+ for len(s) > 0 && (string(s[0]) == c.Thousand || string(s[0]) == c.Decimal || '0' <= s[0] && s[0] <= '9') {
+ // If the character is a number
+ if '0' <= s[0] && s[0] <= '9' {
+ // If we've hit decimal limit, break
+ if decimalFound == 0 {
+ break
+ }
+ // add the number to the numPart
+ numPart += string(s[0])
+ // Decrement the decimal count
+ // If the decimal has been found, `decimalFound` is positive
+ // If the decimal hasn't been found, `decimalFound` is negative, and decrementing it does nothing
+ decimalFound--
+ }
+ // If decimal has been found (>= 0) and the character is a thousand separator or decimal separator,
+ // then the number is invalid
+ if decimalFound >= 0 && (string(s[0]) == c.Thousand || string(s[0]) == c.Decimal) {
+ return "", false
+ }
+ // If the character is a decimal separator, set `decimalFound` to the number of
+ // expected decimal places for the currency
+ if string(s[0]) == c.Decimal {
+ decimalFound = c.Fraction
+ }
+ // Move to the next character
+ s = s[1:]
+ }
+ if decimalFound > 0 {
+ // If there should be more decimal places, add them
+ numPart += strings.Repeat("0", decimalFound)
+ } else if decimalFound < 0 {
+ // If no decimal was found, add the expected number of decimal places
+ numPart += strings.Repeat("0", c.Fraction)
+ }
+ case ' ':
+ // There should be a space in the s at this position
+ if len(s) == 0 || s[0] != ' ' {
+ return "", false
+ }
+ s = s[1:]
+ default:
+ panic(fmt.Sprintf("unsupported template character: %c", tp))
+ }
+ }
+ return numPart, true
+}
diff --git a/backend/internal/utils/money_test.go b/backend/internal/utils/money_test.go
new file mode 100644
index 0000000..27ace06
--- /dev/null
+++ b/backend/internal/utils/money_test.go
@@ -0,0 +1,106 @@
+package utils
+
+import (
+ "testing"
+
+ "github.com/Rhymond/go-money"
+ "github.com/stretchr/testify/assert"
+)
+
+func TestParseMoney(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ want *money.Money
+ }{
+ {
+ name: "en-US int no comma",
+ input: "$123",
+ want: money.New(12300, money.USD),
+ },
+ {
+ name: "en-US int comma",
+ input: "$1,123",
+ want: money.New(112300, money.USD),
+ },
+ {
+ name: "en-US decimal comma",
+ input: "$1,123.45",
+ want: money.New(112345, money.USD),
+ },
+ {
+ name: "zh-CN decimal comma",
+ input: "1,234.56 \u5143",
+ want: money.New(123456, money.CNY),
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ m, err := ParseMoney(tt.input)
+ assert.NoError(t, err)
+ assert.Equal(t, tt.want, m)
+ })
+ }
+}
+
+func Test_isCurrency(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ currency string
+ numPart string
+ ok bool
+ }{
+ {
+ name: "en-US int no comma",
+ input: "$123",
+ currency: money.USD,
+ numPart: "12300",
+ ok: true,
+ },
+ {
+ name: "en-US int comma",
+ input: "$1,123",
+ currency: money.USD,
+ numPart: "112300",
+ ok: true,
+ },
+ {
+ name: "en-US decimal comma",
+ input: "$1,123.45",
+ currency: money.USD,
+ numPart: "112345",
+ ok: true,
+ },
+ {
+ name: "en-US 1 decimal comma",
+ input: "$1,123.5",
+ currency: money.USD,
+ numPart: "112350",
+ ok: true,
+ },
+ {
+ name: "en-US no grapheme",
+ input: "1,234.56",
+ currency: money.USD,
+ numPart: "",
+ ok: false,
+ },
+ {
+ name: "zh-CN decimal comma",
+ input: "1,234.56 \u5143",
+ currency: money.CNY,
+ numPart: "123456",
+ ok: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ c := money.GetCurrency(tt.currency)
+ numPart, ok := isCurrency(tt.input, c)
+ assert.Equal(t, tt.ok, ok)
+ assert.Equal(t, tt.numPart, numPart)
+ })
+ }
+}
diff --git a/backend/internal/worker/analyzer/analyzer.go b/backend/internal/worker/analyzer/analyzer.go
new file mode 100644
index 0000000..20621dd
--- /dev/null
+++ b/backend/internal/worker/analyzer/analyzer.go
@@ -0,0 +1,142 @@
+package analyzer
+
+import (
+ "context"
+ "log/slog"
+ "time"
+
+ "github.com/ansg191/ibd-trader-backend/internal/analyzer"
+ "github.com/ansg191/ibd-trader-backend/internal/database"
+ "github.com/ansg191/ibd-trader-backend/internal/redis/taskqueue"
+
+ "github.com/redis/go-redis/v9"
+)
+
+const (
+ Queue = "analyzer"
+ QueueEncoding = taskqueue.EncodingJSON
+
+ lockTimeout = 1 * time.Minute
+ dequeueTimeout = 5 * time.Second
+)
+
+func RunAnalyzer(
+ ctx context.Context,
+ redis *redis.Client,
+ analyzer analyzer.Analyzer,
+ db database.Executor,
+ name string,
+) error {
+ queue, err := taskqueue.New(
+ ctx,
+ redis,
+ Queue,
+ name,
+ taskqueue.WithEncoding[TaskInfo](QueueEncoding),
+ )
+ if err != nil {
+ return err
+ }
+
+ for {
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ default:
+ waitForTask(ctx, queue, analyzer, db)
+ }
+ }
+}
+
+func waitForTask(
+ ctx context.Context,
+ queue taskqueue.TaskQueue[TaskInfo],
+ analyzer analyzer.Analyzer,
+ db database.Executor,
+) {
+ task, err := queue.Dequeue(ctx, lockTimeout, dequeueTimeout)
+ if err != nil {
+ slog.ErrorContext(ctx, "Failed to dequeue task", "error", err)
+ return
+ }
+ if task == nil {
+ // No task available.
+ return
+ }
+
+ errCh := make(chan error)
+ resCh := make(chan string)
+ defer close(errCh)
+ defer close(resCh)
+ go func() {
+ res, err := analyzeStock(ctx, analyzer, db, task.Data.ID)
+ if err != nil {
+ errCh <- err
+ return
+ }
+ resCh <- res
+ }()
+
+ ticker := time.NewTicker(lockTimeout / 5)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ctx.Done():
+ // Context was canceled. Return early.
+ return
+ case <-ticker.C:
+ // Extend the lock periodically.
+ func() {
+ ctx, cancel := context.WithTimeout(ctx, lockTimeout/5)
+ defer cancel()
+
+ err := queue.Extend(ctx, task.ID)
+ if err != nil {
+ slog.ErrorContext(ctx, "Failed to extend lock", "error", err)
+ }
+ }()
+ case err = <-errCh:
+ // analyzeStock has errored.
+ slog.ErrorContext(ctx, "Failed to analyze", "error", err)
+ _, err = queue.Return(ctx, task.ID, err)
+ if err != nil {
+ slog.ErrorContext(ctx, "Failed to return task", "error", err)
+ return
+ }
+ return
+ case res := <-resCh:
+ // analyzeStock has completed successfully.
+ slog.DebugContext(ctx, "Analyzed ID", "id", task.Data.ID, "result", res)
+ err = queue.Complete(ctx, task.ID, res)
+ if err != nil {
+ slog.ErrorContext(ctx, "Failed to complete task", "error", err)
+ return
+ }
+ return
+ }
+ }
+}
+
+func analyzeStock(ctx context.Context, a analyzer.Analyzer, db database.Executor, id string) (string, error) {
+ info, err := database.GetStockInfo(ctx, db, id)
+ if err != nil {
+ return "", err
+ }
+
+ analysis, err := a.Analyze(
+ ctx,
+ info.Symbol,
+ info.Price,
+ info.ChartAnalysis,
+ )
+ if err != nil {
+ return "", err
+ }
+
+ return database.AddAnalysis(ctx, db, id, analysis)
+}
+
+type TaskInfo struct {
+ ID string `json:"id"`
+}
diff --git a/backend/internal/worker/auth/auth.go b/backend/internal/worker/auth/auth.go
new file mode 100644
index 0000000..0daa112
--- /dev/null
+++ b/backend/internal/worker/auth/auth.go
@@ -0,0 +1,239 @@
+package auth
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "fmt"
+ "log/slog"
+ "time"
+
+ "github.com/ansg191/ibd-trader-backend/internal/database"
+ "github.com/ansg191/ibd-trader-backend/internal/ibd"
+ "github.com/ansg191/ibd-trader-backend/internal/keys"
+ "github.com/ansg191/ibd-trader-backend/internal/leader/manager/ibd/auth"
+ "github.com/ansg191/ibd-trader-backend/internal/redis/taskqueue"
+
+ "github.com/redis/go-redis/v9"
+)
+
+const (
+ lockTimeout = 1 * time.Minute
+ dequeueTimeout = 5 * time.Second
+)
+
+func RunAuthScraper(
+ ctx context.Context,
+ client *ibd.Client,
+ redis *redis.Client,
+ db database.TransactionExecutor,
+ kms keys.KeyManagementService,
+ name string,
+) error {
+ queue, err := taskqueue.New(
+ ctx,
+ redis,
+ auth.Queue,
+ name,
+ taskqueue.WithEncoding[auth.TaskInfo](auth.QueueEncoding),
+ )
+ if err != nil {
+ return err
+ }
+
+ for {
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ default:
+ waitForTask(ctx, queue, client, db, kms)
+ }
+ }
+}
+
+func waitForTask(
+ ctx context.Context,
+ queue taskqueue.TaskQueue[auth.TaskInfo],
+ client *ibd.Client,
+ db database.TransactionExecutor,
+ kms keys.KeyManagementService,
+) {
+ task, err := queue.Dequeue(ctx, lockTimeout, dequeueTimeout)
+ if err != nil {
+ slog.ErrorContext(ctx, "Failed to dequeue task", "error", err)
+ return
+ }
+ if task == nil {
+ // No task available.
+ return
+ }
+ slog.DebugContext(ctx, "Picked up auth task", "task", task.ID, "user", task.Data.UserSubject)
+
+ ch := make(chan error)
+ defer close(ch)
+ go func() {
+ ch <- scrapeCookies(ctx, client, db, kms, task.Data.UserSubject)
+ }()
+
+ ticker := time.NewTicker(lockTimeout / 5)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ctx.Done():
+ // The context was canceled. Return early.
+ return
+ case <-ticker.C:
+ // Extend the lock periodically.
+ func() {
+ ctx, cancel := context.WithTimeout(ctx, lockTimeout/5)
+ defer cancel()
+
+ err := queue.Extend(ctx, task.ID)
+ if err != nil {
+ slog.ErrorContext(ctx, "Failed to extend lock", "error", err)
+ }
+ }()
+ case err = <-ch:
+ // scrapeCookies has completed.
+ if err != nil {
+ slog.ErrorContext(ctx, "Failed to scrape cookies", "error", err)
+ _, err = queue.Return(ctx, task.ID, err)
+ if err != nil {
+ slog.ErrorContext(ctx, "Failed to return task", "error", err)
+ return
+ }
+ } else {
+ err = queue.Complete(ctx, task.ID, "")
+ if err != nil {
+ slog.ErrorContext(ctx, "Failed to complete task", "error", err)
+ return
+ }
+ slog.DebugContext(ctx, "Authenticated user", "user", task.Data.UserSubject)
+ }
+ return
+ }
+ }
+}
+
+func scrapeCookies(
+ ctx context.Context,
+ client *ibd.Client,
+ db database.TransactionExecutor,
+ kms keys.KeyManagementService,
+ user string,
+) error {
+ ctx, cancel := context.WithTimeout(ctx, lockTimeout)
+ defer cancel()
+
+ // Check if the user has valid cookies
+ done, err := hasValidCookies(ctx, db, user)
+ if err != nil {
+ return fmt.Errorf("failed to check cookies: %w", err)
+ }
+ if done {
+ return nil
+ }
+
+ // Health check degraded cookies
+ done, err = healthCheckDegradedCookies(ctx, client, db, kms, user)
+ if err != nil {
+ return fmt.Errorf("failed to health check cookies: %w", err)
+ }
+ if done {
+ return nil
+ }
+
+ // No cookies are valid, so scrape new cookies
+ return scrapeNewCookies(ctx, client, db, kms, user)
+}
+
+func hasValidCookies(ctx context.Context, db database.Executor, user string) (bool, error) {
+ // Check if the user has non-degraded cookies
+ row := db.QueryRowContext(ctx, `
+SELECT 1
+FROM ibd_tokens
+WHERE user_subject = $1
+ AND expires_at > NOW()
+ AND degraded = FALSE;`, user)
+
+ var exists bool
+ err := row.Scan(&exists)
+ if errors.Is(err, sql.ErrNoRows) {
+ return false, nil
+ }
+ if err != nil {
+ return false, fmt.Errorf("failed to get non-degraded cookies: %w", err)
+ }
+
+ return true, nil
+}
+
+func healthCheckDegradedCookies(
+ ctx context.Context,
+ client *ibd.Client,
+ db database.Executor,
+ kms keys.KeyManagementService,
+ user string,
+) (bool, error) {
+ // Check if the user has degraded cookies
+ cookies, err := database.GetCookies(ctx, db, kms, user, true)
+ if err != nil {
+ return false, fmt.Errorf("failed to get degraded cookies: %w", err)
+ }
+
+ valid := false
+ for _, cookie := range cookies {
+ slog.DebugContext(ctx, "Health checking cookie", "cookie", cookie.ID)
+
+ // Health check the cookie
+ up, err := client.UserInfo(ctx, cookie.ToHTTPCookie())
+ if err != nil {
+ slog.ErrorContext(ctx, "Failed to health check cookie", "error", err)
+ continue
+ }
+
+ if up.Status != ibd.UserStatusSubscriber {
+ continue
+ }
+
+ // Cookie is valid
+ valid = true
+
+ // Update the cookie
+ err = database.RepairCookie(ctx, db, cookie.ID)
+ if err != nil {
+ slog.ErrorContext(ctx, "Failed to repair cookie", "error", err)
+ }
+ }
+
+ return valid, nil
+}
+
+func scrapeNewCookies(
+ ctx context.Context,
+ client *ibd.Client,
+ db database.TransactionExecutor,
+ kms keys.KeyManagementService,
+ user string,
+) error {
+ // Get the user's credentials
+ username, password, err := database.GetIBDCreds(ctx, db, kms, user)
+ if err != nil {
+ return fmt.Errorf("failed to get IBD credentials: %w", err)
+ }
+
+ // Scrape the user's cookies
+ cookie, err := client.Authenticate(ctx, username, password)
+ if err != nil {
+ return fmt.Errorf("failed to authenticate user: %w", err)
+ }
+
+ // Store the cookie
+ err = database.AddCookie(ctx, db, kms, user, cookie)
+ if err != nil {
+ return fmt.Errorf("failed to store cookie: %w", err)
+ }
+
+ return nil
+}
diff --git a/backend/internal/worker/scraper/scraper.go b/backend/internal/worker/scraper/scraper.go
new file mode 100644
index 0000000..c5c1b6c
--- /dev/null
+++ b/backend/internal/worker/scraper/scraper.go
@@ -0,0 +1,198 @@
+package scraper
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "log/slog"
+ "time"
+
+ "github.com/ansg191/ibd-trader-backend/internal/database"
+ "github.com/ansg191/ibd-trader-backend/internal/ibd"
+ "github.com/ansg191/ibd-trader-backend/internal/leader/manager/ibd/scrape"
+ "github.com/ansg191/ibd-trader-backend/internal/redis/taskqueue"
+ "github.com/ansg191/ibd-trader-backend/internal/worker/analyzer"
+
+ "github.com/redis/go-redis/v9"
+)
+
+const (
+ lockTimeout = 1 * time.Minute
+ dequeueTimeout = 5 * time.Second
+)
+
+func RunScraper(
+ ctx context.Context,
+ redis *redis.Client,
+ client *ibd.Client,
+ db database.TransactionExecutor,
+ name string,
+) error {
+ queue, err := taskqueue.New(
+ ctx,
+ redis,
+ scrape.Queue,
+ name,
+ taskqueue.WithEncoding[scrape.TaskInfo](scrape.QueueEncoding),
+ )
+ if err != nil {
+ return err
+ }
+
+ aQueue, err := taskqueue.New(
+ ctx,
+ redis,
+ analyzer.Queue,
+ name,
+ taskqueue.WithEncoding[analyzer.TaskInfo](analyzer.QueueEncoding),
+ )
+ if err != nil {
+ return err
+ }
+
+ for {
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ default:
+ waitForTask(ctx, queue, aQueue, client, db)
+ }
+ }
+}
+
+func waitForTask(
+ ctx context.Context,
+ queue taskqueue.TaskQueue[scrape.TaskInfo],
+ aQueue taskqueue.TaskQueue[analyzer.TaskInfo],
+ client *ibd.Client,
+ db database.TransactionExecutor,
+) {
+ task, err := queue.Dequeue(ctx, lockTimeout, dequeueTimeout)
+ if err != nil {
+ slog.ErrorContext(ctx, "Failed to dequeue task", "error", err)
+ return
+ }
+ if task == nil {
+ // No task available.
+ return
+ }
+
+ errCh := make(chan error)
+ resCh := make(chan string)
+ defer close(errCh)
+ defer close(resCh)
+ go func() {
+ res, err := scrapeUrl(ctx, client, db, aQueue, task.Data.Symbol)
+ if err != nil {
+ errCh <- err
+ return
+ }
+ resCh <- res
+ }()
+
+ ticker := time.NewTicker(lockTimeout / 5)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ctx.Done():
+ // Context was canceled. Return early.
+ return
+ case <-ticker.C:
+ // Extend the lock periodically.
+ func() {
+ ctx, cancel := context.WithTimeout(ctx, lockTimeout/5)
+ defer cancel()
+
+ err := queue.Extend(ctx, task.ID)
+ if err != nil {
+ slog.ErrorContext(ctx, "Failed to extend lock", "error", err)
+ }
+ }()
+ case err = <-errCh:
+ // scrapeUrl has errored.
+ slog.ErrorContext(ctx, "Failed to scrape URL", "error", err)
+ _, err = queue.Return(ctx, task.ID, err)
+ if err != nil {
+ slog.ErrorContext(ctx, "Failed to return task", "error", err)
+ return
+ }
+ return
+ case res := <-resCh:
+ // scrapeUrl has completed successfully.
+ slog.DebugContext(ctx, "Scraped URL", "symbol", task.Data.Symbol)
+ err = queue.Complete(ctx, task.ID, res)
+ if err != nil {
+ slog.ErrorContext(ctx, "Failed to complete task", "error", err)
+ return
+ }
+ return
+ }
+ }
+}
+
+func scrapeUrl(
+ ctx context.Context,
+ client *ibd.Client,
+ db database.TransactionExecutor,
+ aQueue taskqueue.TaskQueue[analyzer.TaskInfo],
+ symbol string,
+) (string, error) {
+ ctx, cancel := context.WithTimeout(ctx, lockTimeout)
+ defer cancel()
+
+ stockUrl, err := getStockUrl(ctx, db, client, symbol)
+ if err != nil {
+ return "", fmt.Errorf("failed to get stock url: %w", err)
+ }
+
+ // Scrape the stock info.
+ info, err := client.StockInfo(ctx, stockUrl)
+ if err != nil {
+ return "", fmt.Errorf("failed to get stock info: %w", err)
+ }
+
+ // Add stock info to the database.
+ id, err := database.AddStockInfo(ctx, db, info)
+ if err != nil {
+ return "", fmt.Errorf("failed to add stock info: %w", err)
+ }
+
+ // Add the stock to the analyzer queue.
+ _, err = aQueue.Enqueue(ctx, analyzer.TaskInfo{ID: id})
+ if err != nil {
+ return "", fmt.Errorf("failed to enqueue analysis task: %w", err)
+ }
+
+ slog.DebugContext(ctx, "Added stock info", "id", id)
+
+ return id, nil
+}
+
+func getStockUrl(ctx context.Context, db database.TransactionExecutor, client *ibd.Client, symbol string) (string, error) {
+ // Get the stock from the database.
+ stock, err := database.GetStock(ctx, db, symbol)
+ if err == nil {
+ return stock.IBDUrl, nil
+ }
+ if !errors.Is(err, database.ErrStockNotFound) {
+ return "", fmt.Errorf("failed to get stock: %w", err)
+ }
+
+ // If stock isn't found in the database, get the stock from IBD.
+ stock, err = client.Search(ctx, symbol)
+ if errors.Is(err, ibd.ErrSymbolNotFound) {
+ return "", fmt.Errorf("symbol not found: %w", err)
+ }
+ if err != nil {
+ return "", fmt.Errorf("failed to search for symbol: %w", err)
+ }
+
+ // Add the stock to the database.
+ err = database.AddStock(ctx, db, stock)
+ if err != nil {
+ return "", fmt.Errorf("failed to add stock: %w", err)
+ }
+
+ return stock.IBDUrl, nil
+}
diff --git a/backend/internal/worker/worker.go b/backend/internal/worker/worker.go
new file mode 100644
index 0000000..6017fb7
--- /dev/null
+++ b/backend/internal/worker/worker.go
@@ -0,0 +1,151 @@
+package worker
+
+import (
+ "context"
+ "crypto/rand"
+ "encoding/base64"
+ "io"
+ "log/slog"
+ "os"
+ "time"
+
+ "github.com/ansg191/ibd-trader-backend/internal/analyzer"
+ "github.com/ansg191/ibd-trader-backend/internal/database"
+ "github.com/ansg191/ibd-trader-backend/internal/ibd"
+ "github.com/ansg191/ibd-trader-backend/internal/keys"
+ "github.com/ansg191/ibd-trader-backend/internal/leader/manager"
+ analyzer2 "github.com/ansg191/ibd-trader-backend/internal/worker/analyzer"
+ "github.com/ansg191/ibd-trader-backend/internal/worker/auth"
+ "github.com/ansg191/ibd-trader-backend/internal/worker/scraper"
+
+ "github.com/redis/go-redis/v9"
+ "golang.org/x/sync/errgroup"
+)
+
+const (
+ HeartbeatInterval = 5 * time.Second
+ HeartbeatTTL = 30 * time.Second
+)
+
+func StartWorker(
+ ctx context.Context,
+ ibdClient *ibd.Client,
+ client *redis.Client,
+ db database.TransactionExecutor,
+ kms keys.KeyManagementService,
+ a analyzer.Analyzer,
+) error {
+ // Get the worker name.
+ name, err := workerName()
+ if err != nil {
+ return err
+ }
+ slog.InfoContext(ctx, "Starting worker", "worker", name)
+
+ g, ctx := errgroup.WithContext(ctx)
+
+ g.Go(func() error {
+ return workerRegistrationLoop(ctx, client, name)
+ })
+ g.Go(func() error {
+ return scraper.RunScraper(ctx, client, ibdClient, db, name)
+ })
+ g.Go(func() error {
+ return auth.RunAuthScraper(ctx, ibdClient, client, db, kms, name)
+ })
+ g.Go(func() error {
+ return analyzer2.RunAnalyzer(ctx, client, a, db, name)
+ })
+
+ return g.Wait()
+}
+
+func workerRegistrationLoop(ctx context.Context, client *redis.Client, name string) error {
+ sendHeartbeat(ctx, client, name)
+
+ ticker := time.NewTicker(HeartbeatInterval)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ticker.C:
+ sendHeartbeat(ctx, client, name)
+ case <-ctx.Done():
+ removeWorker(ctx, client, name)
+ return ctx.Err()
+ }
+ }
+}
+
+// sendHeartbeat sends a heartbeat for the worker.
+// It ensures that the worker is in the active workers set and its heartbeat exists.
+func sendHeartbeat(ctx context.Context, client *redis.Client, name string) {
+ ctx, cancel := context.WithTimeout(ctx, HeartbeatInterval)
+ defer cancel()
+
+ // Add the worker to the active workers set.
+ if err := client.SAdd(ctx, manager.ActiveWorkersSet, name).Err(); err != nil {
+ slog.ErrorContext(ctx,
+ "Unable to add worker to active workers set",
+ "worker", name,
+ "error", err,
+ )
+ return
+ }
+
+ // Set the worker's heartbeat.
+ heartbeatKey := manager.WorkerHeartbeatKey(name)
+ if err := client.Set(ctx, heartbeatKey, time.Now().Unix(), HeartbeatTTL).Err(); err != nil {
+ slog.ErrorContext(ctx,
+ "Unable to set worker heartbeat",
+ "worker", name,
+ "error", err,
+ )
+ return
+ }
+}
+
+// removeWorker removes the worker from the active workers set.
+func removeWorker(ctx context.Context, client *redis.Client, name string) {
+ if ctx.Err() != nil {
+ // If the context is canceled, create a new uncanceled context.
+ ctx = context.WithoutCancel(ctx)
+ }
+ ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
+ defer cancel()
+
+ // Remove the worker from the active workers set.
+ if err := client.SRem(ctx, manager.ActiveWorkersSet, name).Err(); err != nil {
+ slog.ErrorContext(ctx,
+ "Unable to remove worker from active workers set",
+ "worker", name,
+ "error", err,
+ )
+ return
+ }
+
+ // Remove the worker's heartbeat.
+ heartbeatKey := manager.WorkerHeartbeatKey(name)
+ if err := client.Del(ctx, heartbeatKey).Err(); err != nil {
+ slog.ErrorContext(ctx,
+ "Unable to remove worker heartbeat",
+ "worker", name,
+ "error", err,
+ )
+ return
+ }
+}
+
+func workerName() (string, error) {
+ hostname, err := os.Hostname()
+ if err != nil {
+ return "", err
+ }
+
+ bytes := make([]byte, 12)
+ if _, err = io.ReadFull(rand.Reader, bytes); err != nil {
+ return "", err
+ }
+
+ return hostname + "-" + base64.URLEncoding.EncodeToString(bytes), nil
+}