aboutsummaryrefslogtreecommitdiff
path: root/backend/internal/ibd
diff options
context:
space:
mode:
authorGravatar Anshul Gupta <ansg191@anshulg.com> 2024-08-07 18:56:01 -0700
committerGravatar GitHub <noreply@github.com> 2024-08-07 18:56:01 -0700
commit08993e2f8497341079010d3d06361c99492c4c07 (patch)
treec65d6d571c928410faace1fa51c2ea3f49fce003 /backend/internal/ibd
parent3de4ebb7560851ccbefe296c197456fe80c22901 (diff)
parentb8aef1a7fb24815c7d93bc30c7b289b4f5896779 (diff)
downloadibd-trader-08993e2f8497341079010d3d06361c99492c4c07.tar.gz
ibd-trader-08993e2f8497341079010d3d06361c99492c4c07.tar.zst
ibd-trader-08993e2f8497341079010d3d06361c99492c4c07.zip
Merge pull request #1 from ansg191/refactor-database
Diffstat (limited to 'backend/internal/ibd')
-rw-r--r--backend/internal/ibd/auth_test.go4
-rw-r--r--backend/internal/ibd/client.go13
-rw-r--r--backend/internal/ibd/client_test.go196
-rw-r--r--backend/internal/ibd/ibd50.go4
-rw-r--r--backend/internal/ibd/search_test.go8
5 files changed, 167 insertions, 58 deletions
diff --git a/backend/internal/ibd/auth_test.go b/backend/internal/ibd/auth_test.go
index 54ea98a..157b507 100644
--- a/backend/internal/ibd/auth_test.go
+++ b/backend/internal/ibd/auth_test.go
@@ -163,7 +163,7 @@ func TestClient_Authenticate(t *testing.T) {
return resp, nil
})
- client := NewClient(nil, newTransport(tp))
+ client := NewClient(nil, nil, newTransport(tp))
cookie, err := client.Authenticate(context.Background(), "abc", "xyz")
require.NoError(t, err)
@@ -189,7 +189,7 @@ func TestClient_Authenticate_401(t *testing.T) {
return httpmock.NewStringResponse(http.StatusUnauthorized, `{"name":"ValidationError","code":"ERR016","message":"Wrong username or password","description":"Wrong username or password"}`), nil
})
- client := NewClient(nil, newTransport(tp))
+ client := NewClient(nil, nil, newTransport(tp))
cookie, err := client.Authenticate(context.Background(), "abc", "xyz")
assert.Nil(t, cookie)
diff --git a/backend/internal/ibd/client.go b/backend/internal/ibd/client.go
index 25c5173..c8575e3 100644
--- a/backend/internal/ibd/client.go
+++ b/backend/internal/ibd/client.go
@@ -9,6 +9,7 @@ import (
"github.com/ansg191/ibd-trader-backend/internal/database"
"github.com/ansg191/ibd-trader-backend/internal/ibd/transport"
+ "github.com/ansg191/ibd-trader-backend/internal/keys"
)
var ErrNoAvailableCookies = errors.New("no available cookies")
@@ -16,20 +17,22 @@ var ErrNoAvailableTransports = errors.New("no available transports")
type Client struct {
transports []transport.Transport
- cookies database.CookieSource
+ db database.Executor
+ kms keys.KeyManagementService
}
func NewClient(
- cookies database.CookieSource,
+ db database.Executor,
+ kms keys.KeyManagementService,
transports ...transport.Transport,
) *Client {
- return &Client{transports, cookies}
+ return &Client{transports, db, kms}
}
func (c *Client) getCookie(ctx context.Context, subject *string) (uint, *http.Cookie, error) {
if subject == nil {
// No subject requirement, get any cookie
- cookie, err := c.cookies.GetAnyCookie(ctx)
+ cookie, err := database.GetAnyCookie(ctx, c.db, c.kms)
if err != nil {
return 0, nil, err
}
@@ -41,7 +44,7 @@ func (c *Client) getCookie(ctx context.Context, subject *string) (uint, *http.Co
}
// Get cookie by subject
- cookies, err := c.cookies.GetCookies(ctx, *subject, false)
+ cookies, err := database.GetCookies(ctx, c.db, c.kms, *subject, false)
if err != nil {
return 0, nil, err
}
diff --git a/backend/internal/ibd/client_test.go b/backend/internal/ibd/client_test.go
index d2dc1b2..2368a31 100644
--- a/backend/internal/ibd/client_test.go
+++ b/backend/internal/ibd/client_test.go
@@ -2,30 +2,155 @@ package ibd
import (
"context"
+ "database/sql"
+ "fmt"
+ "log"
+ "math/rand/v2"
"testing"
"time"
"github.com/ansg191/ibd-trader-backend/internal/database"
+ "github.com/ansg191/ibd-trader-backend/internal/keys"
+ _ "github.com/lib/pq"
+ "github.com/ory/dockertest/v3"
+ "github.com/ory/dockertest/v3/docker"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
-func TestClient_getCookie(t *testing.T) {
- t.Parallel()
+var (
+ db *sql.DB
+ maxTime = time.Date(2100, 1, 1, 0, 0, 0, 0, time.UTC)
+ letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
+)
- t.Run("no cookies", func(t *testing.T) {
- t.Parallel()
+func TestMain(m *testing.M) {
+ pool, err := dockertest.NewPool("")
+ if err != nil {
+ log.Fatalf("Could not create pool: %s", err)
+ }
+
+ err = pool.Client.Ping()
+ if err != nil {
+ log.Fatalf("Could not connect to Docker: %s", err)
+ }
+
+ resource, err := pool.RunWithOptions(&dockertest.RunOptions{
+ Repository: "postgres",
+ Tag: "16",
+ Env: []string{
+ "POSTGRES_PASSWORD=secret",
+ "POSTGRES_USER=ibd-client-test",
+ "POSTGRES_DB=ibd-client-test",
+ "listen_addresses='*'",
+ },
+ Cmd: []string{
+ "postgres",
+ "-c",
+ "log_statement=all",
+ },
+ }, func(config *docker.HostConfig) {
+ config.AutoRemove = true
+ config.RestartPolicy = docker.RestartPolicy{Name: "no"}
+ })
+ if err != nil {
+ log.Fatalf("Could not start resource: %s", err)
+ }
+
+ hostAndPort := resource.GetHostPort("5432/tcp")
+ databaseUrl := fmt.Sprintf("postgres://ibd-client-test:secret@%s/ibd-client-test?sslmode=disable", hostAndPort)
+
+ // Kill container after 120 seconds
+ _ = resource.Expire(120)
+
+ pool.MaxWait = 120 * time.Second
+ if err = pool.Retry(func() error {
+ db, err = sql.Open("postgres", databaseUrl)
+ if err != nil {
+ return err
+ }
+ return db.Ping()
+ }); err != nil {
+ log.Fatalf("Could not connect to database: %s", err)
+ }
+
+ err = database.Migrate(context.Background(), databaseUrl)
+ if err != nil {
+ log.Fatalf("Could not migrate database: %s", err)
+ }
+
+ defer func() {
+ if err := pool.Purge(resource); err != nil {
+ log.Fatalf("Could not purge resource: %s", err)
+ }
+ }()
+
+ m.Run()
+}
+
+func randStringRunes(n int) string {
+ b := make([]rune, n)
+ for i := range b {
+ b[i] = letterRunes[rand.IntN(len(letterRunes))]
+ }
+ return string(b)
+}
+
+func addCookie(t *testing.T) (user, token string) {
+ t.Helper()
+
+ // Randomly generate a user and token
+ user = randStringRunes(8)
+ token = randStringRunes(16)
+
+ ciphertext, key, err := keys.Encrypt(context.Background(), new(kmsStub), "", []byte(token))
+ require.NoError(t, err)
+
+ tx, err := db.Begin()
+ require.NoError(t, err)
+
+ var keyID uint
+ err = tx.QueryRow(`
+INSERT INTO keys (kms_key_name, encrypted_key)
+ VALUES ('', $1)
+ RETURNING id;
+`, key).Scan(&keyID)
+ require.NoError(t, err)
+
+ _, err = tx.Exec(`
+INSERT
+INTO users (subject, encryption_key)
+VALUES ($1, $2);
+`, user, keyID)
+ require.NoError(t, err)
+
+ _, err = tx.Exec(`
+INSERT
+INTO ibd_tokens (user_subject, token, encryption_key, expires_at)
+VALUES ($1, $2, $3, $4);`,
+ user,
+ ciphertext,
+ keyID,
+ maxTime,
+ )
+ require.NoError(t, err)
+
+ err = tx.Commit()
+ require.NoError(t, err)
+
+ return user, token
+}
- client := NewClient(new(emptyCookieSourceStub))
+func TestClient_getCookie(t *testing.T) {
+ t.Run("no cookies", func(t *testing.T) {
+ client := NewClient(db, new(kmsStub))
_, _, err := client.getCookie(context.Background(), nil)
assert.ErrorIs(t, err, ErrNoAvailableCookies)
})
t.Run("no cookies by subject", func(t *testing.T) {
- t.Parallel()
-
- client := NewClient(new(emptyCookieSourceStub))
+ client := NewClient(db, new(kmsStub))
subject := "test"
_, _, err := client.getCookie(context.Background(), &subject)
@@ -33,67 +158,44 @@ func TestClient_getCookie(t *testing.T) {
})
t.Run("get any cookie", func(t *testing.T) {
- t.Parallel()
+ _, token := addCookie(t)
- client := NewClient(new(cookieSourceStub))
+ client := NewClient(db, new(kmsStub))
- id, cookie, err := client.getCookie(context.Background(), nil)
+ _, cookie, err := client.getCookie(context.Background(), nil)
require.NoError(t, err)
- assert.Equal(t, uint(42), id)
assert.Equal(t, cookieName, cookie.Name)
- assert.Equal(t, "test-token", cookie.Value)
+ assert.Equal(t, token, cookie.Value)
assert.Equal(t, "/", cookie.Path)
- assert.Equal(t, time.Unix(0, 0), cookie.Expires)
+ assert.Equal(t, maxTime, cookie.Expires)
assert.Equal(t, "investors.com", cookie.Domain)
})
t.Run("get cookie by subject", func(t *testing.T) {
- t.Parallel()
+ subject, token := addCookie(t)
- client := NewClient(new(cookieSourceStub))
+ client := NewClient(db, new(kmsStub))
- subject := "test"
- id, cookie, err := client.getCookie(context.Background(), &subject)
+ _, cookie, err := client.getCookie(context.Background(), &subject)
require.NoError(t, err)
- assert.Equal(t, uint(42), id)
assert.Equal(t, cookieName, cookie.Name)
- assert.Equal(t, "test-token", cookie.Value)
+ assert.Equal(t, token, cookie.Value)
assert.Equal(t, "/", cookie.Path)
- assert.Equal(t, time.Unix(0, 0), cookie.Expires)
+ assert.Equal(t, maxTime, cookie.Expires)
assert.Equal(t, "investors.com", cookie.Domain)
})
}
-type emptyCookieSourceStub struct{}
-
-func (c *emptyCookieSourceStub) GetAnyCookie(_ context.Context) (*database.IBDCookie, error) {
- return nil, nil
-}
-
-func (c *emptyCookieSourceStub) GetCookies(_ context.Context, _ string, _ bool) ([]database.IBDCookie, error) {
- return nil, nil
-}
+type kmsStub struct{}
-func (c *emptyCookieSourceStub) ReportCookieFailure(_ context.Context, _ uint) error {
+func (k *kmsStub) Close() error {
return nil
}
-var testCookie = database.IBDCookie{
- ID: 42,
- Token: "test-token",
- Expiry: time.Unix(0, 0),
+func (k *kmsStub) Encrypt(_ context.Context, _ string, plaintext []byte) ([]byte, error) {
+ return plaintext, nil
}
-type cookieSourceStub struct{}
-
-func (c *cookieSourceStub) GetAnyCookie(_ context.Context) (*database.IBDCookie, error) {
- return &testCookie, nil
-}
-
-func (c *cookieSourceStub) GetCookies(_ context.Context, _ string, _ bool) ([]database.IBDCookie, error) {
- return []database.IBDCookie{testCookie}, nil
-}
-
-func (c *cookieSourceStub) ReportCookieFailure(_ context.Context, _ uint) error {
- return nil
+func (k *kmsStub) Decrypt(_ context.Context, _ string, ciphertext []byte) ([]byte, error) {
+ return ciphertext, nil
}
diff --git a/backend/internal/ibd/ibd50.go b/backend/internal/ibd/ibd50.go
index ea02f82..52e28aa 100644
--- a/backend/internal/ibd/ibd50.go
+++ b/backend/internal/ibd/ibd50.go
@@ -9,6 +9,8 @@ import (
"net/http"
"net/url"
"strconv"
+
+ "github.com/ansg191/ibd-trader-backend/internal/database"
)
const ibd50Url = "https://research.investors.com/Services/SiteAjaxService.asmx/GetIBD50?sortcolumn1=%22ibd100rank%22&sortOrder1=%22asc%22&sortcolumn2=%22%22&sortOrder2=%22ASC%22"
@@ -63,7 +65,7 @@ func (c *Client) GetIBD50(ctx context.Context) ([]*Stock, error) {
// If there are less than 10 stocks in the IBD50 list, it's likely that authentication failed.
if len(ibd50Resp.D.ETablesDataList) < 10 {
// Report cookie failure to DB
- if err = c.cookies.ReportCookieFailure(ctx, cookieId); err != nil {
+ if err = database.ReportCookieFailure(ctx, c.db, cookieId); err != nil {
slog.Error("Failed to report cookie failure", "error", err)
}
return nil, errors.New("failed to get IBD50 list")
diff --git a/backend/internal/ibd/search_test.go b/backend/internal/ibd/search_test.go
index 99157cf..05e93dc 100644
--- a/backend/internal/ibd/search_test.go
+++ b/backend/internal/ibd/search_test.go
@@ -162,8 +162,6 @@ const emptySearchResponseJSON = `
}`
func TestClient_Search(t *testing.T) {
- t.Parallel()
-
tests := []struct {
name string
response string
@@ -195,7 +193,11 @@ func TestClient_Search(t *testing.T) {
tp := httpmock.NewMockTransport()
tp.RegisterResponder("GET", searchUrl, httpmock.NewStringResponder(200, tt.response))
- client := NewClient(new(cookieSourceStub), transport.NewStandardTransport(&http.Client{Transport: tp}))
+ client := NewClient(
+ db,
+ new(kmsStub),
+ transport.NewStandardTransport(&http.Client{Transport: tp}),
+ )
tt.f(t, client)
})