diff options
author | 2024-08-07 18:56:01 -0700 | |
---|---|---|
committer | 2024-08-07 18:56:01 -0700 | |
commit | 08993e2f8497341079010d3d06361c99492c4c07 (patch) | |
tree | c65d6d571c928410faace1fa51c2ea3f49fce003 /backend/internal/ibd | |
parent | 3de4ebb7560851ccbefe296c197456fe80c22901 (diff) | |
parent | b8aef1a7fb24815c7d93bc30c7b289b4f5896779 (diff) | |
download | ibd-trader-08993e2f8497341079010d3d06361c99492c4c07.tar.gz ibd-trader-08993e2f8497341079010d3d06361c99492c4c07.tar.zst ibd-trader-08993e2f8497341079010d3d06361c99492c4c07.zip |
Merge pull request #1 from ansg191/refactor-database
Diffstat (limited to 'backend/internal/ibd')
-rw-r--r-- | backend/internal/ibd/auth_test.go | 4 | ||||
-rw-r--r-- | backend/internal/ibd/client.go | 13 | ||||
-rw-r--r-- | backend/internal/ibd/client_test.go | 196 | ||||
-rw-r--r-- | backend/internal/ibd/ibd50.go | 4 | ||||
-rw-r--r-- | backend/internal/ibd/search_test.go | 8 |
5 files changed, 167 insertions, 58 deletions
diff --git a/backend/internal/ibd/auth_test.go b/backend/internal/ibd/auth_test.go index 54ea98a..157b507 100644 --- a/backend/internal/ibd/auth_test.go +++ b/backend/internal/ibd/auth_test.go @@ -163,7 +163,7 @@ func TestClient_Authenticate(t *testing.T) { return resp, nil }) - client := NewClient(nil, newTransport(tp)) + client := NewClient(nil, nil, newTransport(tp)) cookie, err := client.Authenticate(context.Background(), "abc", "xyz") require.NoError(t, err) @@ -189,7 +189,7 @@ func TestClient_Authenticate_401(t *testing.T) { return httpmock.NewStringResponse(http.StatusUnauthorized, `{"name":"ValidationError","code":"ERR016","message":"Wrong username or password","description":"Wrong username or password"}`), nil }) - client := NewClient(nil, newTransport(tp)) + client := NewClient(nil, nil, newTransport(tp)) cookie, err := client.Authenticate(context.Background(), "abc", "xyz") assert.Nil(t, cookie) diff --git a/backend/internal/ibd/client.go b/backend/internal/ibd/client.go index 25c5173..c8575e3 100644 --- a/backend/internal/ibd/client.go +++ b/backend/internal/ibd/client.go @@ -9,6 +9,7 @@ import ( "github.com/ansg191/ibd-trader-backend/internal/database" "github.com/ansg191/ibd-trader-backend/internal/ibd/transport" + "github.com/ansg191/ibd-trader-backend/internal/keys" ) var ErrNoAvailableCookies = errors.New("no available cookies") @@ -16,20 +17,22 @@ var ErrNoAvailableTransports = errors.New("no available transports") type Client struct { transports []transport.Transport - cookies database.CookieSource + db database.Executor + kms keys.KeyManagementService } func NewClient( - cookies database.CookieSource, + db database.Executor, + kms keys.KeyManagementService, transports ...transport.Transport, ) *Client { - return &Client{transports, cookies} + return &Client{transports, db, kms} } func (c *Client) getCookie(ctx context.Context, subject *string) (uint, *http.Cookie, error) { if subject == nil { // No subject requirement, get any cookie - cookie, err := c.cookies.GetAnyCookie(ctx) + cookie, err := database.GetAnyCookie(ctx, c.db, c.kms) if err != nil { return 0, nil, err } @@ -41,7 +44,7 @@ func (c *Client) getCookie(ctx context.Context, subject *string) (uint, *http.Co } // Get cookie by subject - cookies, err := c.cookies.GetCookies(ctx, *subject, false) + cookies, err := database.GetCookies(ctx, c.db, c.kms, *subject, false) if err != nil { return 0, nil, err } diff --git a/backend/internal/ibd/client_test.go b/backend/internal/ibd/client_test.go index d2dc1b2..2368a31 100644 --- a/backend/internal/ibd/client_test.go +++ b/backend/internal/ibd/client_test.go @@ -2,30 +2,155 @@ package ibd import ( "context" + "database/sql" + "fmt" + "log" + "math/rand/v2" "testing" "time" "github.com/ansg191/ibd-trader-backend/internal/database" + "github.com/ansg191/ibd-trader-backend/internal/keys" + _ "github.com/lib/pq" + "github.com/ory/dockertest/v3" + "github.com/ory/dockertest/v3/docker" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestClient_getCookie(t *testing.T) { - t.Parallel() +var ( + db *sql.DB + maxTime = time.Date(2100, 1, 1, 0, 0, 0, 0, time.UTC) + letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") +) - t.Run("no cookies", func(t *testing.T) { - t.Parallel() +func TestMain(m *testing.M) { + pool, err := dockertest.NewPool("") + if err != nil { + log.Fatalf("Could not create pool: %s", err) + } + + err = pool.Client.Ping() + if err != nil { + log.Fatalf("Could not connect to Docker: %s", err) + } + + resource, err := pool.RunWithOptions(&dockertest.RunOptions{ + Repository: "postgres", + Tag: "16", + Env: []string{ + "POSTGRES_PASSWORD=secret", + "POSTGRES_USER=ibd-client-test", + "POSTGRES_DB=ibd-client-test", + "listen_addresses='*'", + }, + Cmd: []string{ + "postgres", + "-c", + "log_statement=all", + }, + }, func(config *docker.HostConfig) { + config.AutoRemove = true + config.RestartPolicy = docker.RestartPolicy{Name: "no"} + }) + if err != nil { + log.Fatalf("Could not start resource: %s", err) + } + + hostAndPort := resource.GetHostPort("5432/tcp") + databaseUrl := fmt.Sprintf("postgres://ibd-client-test:secret@%s/ibd-client-test?sslmode=disable", hostAndPort) + + // Kill container after 120 seconds + _ = resource.Expire(120) + + pool.MaxWait = 120 * time.Second + if err = pool.Retry(func() error { + db, err = sql.Open("postgres", databaseUrl) + if err != nil { + return err + } + return db.Ping() + }); err != nil { + log.Fatalf("Could not connect to database: %s", err) + } + + err = database.Migrate(context.Background(), databaseUrl) + if err != nil { + log.Fatalf("Could not migrate database: %s", err) + } + + defer func() { + if err := pool.Purge(resource); err != nil { + log.Fatalf("Could not purge resource: %s", err) + } + }() + + m.Run() +} + +func randStringRunes(n int) string { + b := make([]rune, n) + for i := range b { + b[i] = letterRunes[rand.IntN(len(letterRunes))] + } + return string(b) +} + +func addCookie(t *testing.T) (user, token string) { + t.Helper() + + // Randomly generate a user and token + user = randStringRunes(8) + token = randStringRunes(16) + + ciphertext, key, err := keys.Encrypt(context.Background(), new(kmsStub), "", []byte(token)) + require.NoError(t, err) + + tx, err := db.Begin() + require.NoError(t, err) + + var keyID uint + err = tx.QueryRow(` +INSERT INTO keys (kms_key_name, encrypted_key) + VALUES ('', $1) + RETURNING id; +`, key).Scan(&keyID) + require.NoError(t, err) + + _, err = tx.Exec(` +INSERT +INTO users (subject, encryption_key) +VALUES ($1, $2); +`, user, keyID) + require.NoError(t, err) + + _, err = tx.Exec(` +INSERT +INTO ibd_tokens (user_subject, token, encryption_key, expires_at) +VALUES ($1, $2, $3, $4);`, + user, + ciphertext, + keyID, + maxTime, + ) + require.NoError(t, err) + + err = tx.Commit() + require.NoError(t, err) + + return user, token +} - client := NewClient(new(emptyCookieSourceStub)) +func TestClient_getCookie(t *testing.T) { + t.Run("no cookies", func(t *testing.T) { + client := NewClient(db, new(kmsStub)) _, _, err := client.getCookie(context.Background(), nil) assert.ErrorIs(t, err, ErrNoAvailableCookies) }) t.Run("no cookies by subject", func(t *testing.T) { - t.Parallel() - - client := NewClient(new(emptyCookieSourceStub)) + client := NewClient(db, new(kmsStub)) subject := "test" _, _, err := client.getCookie(context.Background(), &subject) @@ -33,67 +158,44 @@ func TestClient_getCookie(t *testing.T) { }) t.Run("get any cookie", func(t *testing.T) { - t.Parallel() + _, token := addCookie(t) - client := NewClient(new(cookieSourceStub)) + client := NewClient(db, new(kmsStub)) - id, cookie, err := client.getCookie(context.Background(), nil) + _, cookie, err := client.getCookie(context.Background(), nil) require.NoError(t, err) - assert.Equal(t, uint(42), id) assert.Equal(t, cookieName, cookie.Name) - assert.Equal(t, "test-token", cookie.Value) + assert.Equal(t, token, cookie.Value) assert.Equal(t, "/", cookie.Path) - assert.Equal(t, time.Unix(0, 0), cookie.Expires) + assert.Equal(t, maxTime, cookie.Expires) assert.Equal(t, "investors.com", cookie.Domain) }) t.Run("get cookie by subject", func(t *testing.T) { - t.Parallel() + subject, token := addCookie(t) - client := NewClient(new(cookieSourceStub)) + client := NewClient(db, new(kmsStub)) - subject := "test" - id, cookie, err := client.getCookie(context.Background(), &subject) + _, cookie, err := client.getCookie(context.Background(), &subject) require.NoError(t, err) - assert.Equal(t, uint(42), id) assert.Equal(t, cookieName, cookie.Name) - assert.Equal(t, "test-token", cookie.Value) + assert.Equal(t, token, cookie.Value) assert.Equal(t, "/", cookie.Path) - assert.Equal(t, time.Unix(0, 0), cookie.Expires) + assert.Equal(t, maxTime, cookie.Expires) assert.Equal(t, "investors.com", cookie.Domain) }) } -type emptyCookieSourceStub struct{} - -func (c *emptyCookieSourceStub) GetAnyCookie(_ context.Context) (*database.IBDCookie, error) { - return nil, nil -} - -func (c *emptyCookieSourceStub) GetCookies(_ context.Context, _ string, _ bool) ([]database.IBDCookie, error) { - return nil, nil -} +type kmsStub struct{} -func (c *emptyCookieSourceStub) ReportCookieFailure(_ context.Context, _ uint) error { +func (k *kmsStub) Close() error { return nil } -var testCookie = database.IBDCookie{ - ID: 42, - Token: "test-token", - Expiry: time.Unix(0, 0), +func (k *kmsStub) Encrypt(_ context.Context, _ string, plaintext []byte) ([]byte, error) { + return plaintext, nil } -type cookieSourceStub struct{} - -func (c *cookieSourceStub) GetAnyCookie(_ context.Context) (*database.IBDCookie, error) { - return &testCookie, nil -} - -func (c *cookieSourceStub) GetCookies(_ context.Context, _ string, _ bool) ([]database.IBDCookie, error) { - return []database.IBDCookie{testCookie}, nil -} - -func (c *cookieSourceStub) ReportCookieFailure(_ context.Context, _ uint) error { - return nil +func (k *kmsStub) Decrypt(_ context.Context, _ string, ciphertext []byte) ([]byte, error) { + return ciphertext, nil } diff --git a/backend/internal/ibd/ibd50.go b/backend/internal/ibd/ibd50.go index ea02f82..52e28aa 100644 --- a/backend/internal/ibd/ibd50.go +++ b/backend/internal/ibd/ibd50.go @@ -9,6 +9,8 @@ import ( "net/http" "net/url" "strconv" + + "github.com/ansg191/ibd-trader-backend/internal/database" ) const ibd50Url = "https://research.investors.com/Services/SiteAjaxService.asmx/GetIBD50?sortcolumn1=%22ibd100rank%22&sortOrder1=%22asc%22&sortcolumn2=%22%22&sortOrder2=%22ASC%22" @@ -63,7 +65,7 @@ func (c *Client) GetIBD50(ctx context.Context) ([]*Stock, error) { // If there are less than 10 stocks in the IBD50 list, it's likely that authentication failed. if len(ibd50Resp.D.ETablesDataList) < 10 { // Report cookie failure to DB - if err = c.cookies.ReportCookieFailure(ctx, cookieId); err != nil { + if err = database.ReportCookieFailure(ctx, c.db, cookieId); err != nil { slog.Error("Failed to report cookie failure", "error", err) } return nil, errors.New("failed to get IBD50 list") diff --git a/backend/internal/ibd/search_test.go b/backend/internal/ibd/search_test.go index 99157cf..05e93dc 100644 --- a/backend/internal/ibd/search_test.go +++ b/backend/internal/ibd/search_test.go @@ -162,8 +162,6 @@ const emptySearchResponseJSON = ` }` func TestClient_Search(t *testing.T) { - t.Parallel() - tests := []struct { name string response string @@ -195,7 +193,11 @@ func TestClient_Search(t *testing.T) { tp := httpmock.NewMockTransport() tp.RegisterResponder("GET", searchUrl, httpmock.NewStringResponder(200, tt.response)) - client := NewClient(new(cookieSourceStub), transport.NewStandardTransport(&http.Client{Transport: tp})) + client := NewClient( + db, + new(kmsStub), + transport.NewStandardTransport(&http.Client{Transport: tp}), + ) tt.f(t, client) }) |