aboutsummaryrefslogtreecommitdiff
path: root/backend/internal/ibd/client_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'backend/internal/ibd/client_test.go')
-rw-r--r--backend/internal/ibd/client_test.go196
1 files changed, 149 insertions, 47 deletions
diff --git a/backend/internal/ibd/client_test.go b/backend/internal/ibd/client_test.go
index d2dc1b2..2368a31 100644
--- a/backend/internal/ibd/client_test.go
+++ b/backend/internal/ibd/client_test.go
@@ -2,30 +2,155 @@ package ibd
import (
"context"
+ "database/sql"
+ "fmt"
+ "log"
+ "math/rand/v2"
"testing"
"time"
"github.com/ansg191/ibd-trader-backend/internal/database"
+ "github.com/ansg191/ibd-trader-backend/internal/keys"
+ _ "github.com/lib/pq"
+ "github.com/ory/dockertest/v3"
+ "github.com/ory/dockertest/v3/docker"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
-func TestClient_getCookie(t *testing.T) {
- t.Parallel()
+var (
+ db *sql.DB
+ maxTime = time.Date(2100, 1, 1, 0, 0, 0, 0, time.UTC)
+ letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
+)
- t.Run("no cookies", func(t *testing.T) {
- t.Parallel()
+func TestMain(m *testing.M) {
+ pool, err := dockertest.NewPool("")
+ if err != nil {
+ log.Fatalf("Could not create pool: %s", err)
+ }
+
+ err = pool.Client.Ping()
+ if err != nil {
+ log.Fatalf("Could not connect to Docker: %s", err)
+ }
+
+ resource, err := pool.RunWithOptions(&dockertest.RunOptions{
+ Repository: "postgres",
+ Tag: "16",
+ Env: []string{
+ "POSTGRES_PASSWORD=secret",
+ "POSTGRES_USER=ibd-client-test",
+ "POSTGRES_DB=ibd-client-test",
+ "listen_addresses='*'",
+ },
+ Cmd: []string{
+ "postgres",
+ "-c",
+ "log_statement=all",
+ },
+ }, func(config *docker.HostConfig) {
+ config.AutoRemove = true
+ config.RestartPolicy = docker.RestartPolicy{Name: "no"}
+ })
+ if err != nil {
+ log.Fatalf("Could not start resource: %s", err)
+ }
+
+ hostAndPort := resource.GetHostPort("5432/tcp")
+ databaseUrl := fmt.Sprintf("postgres://ibd-client-test:secret@%s/ibd-client-test?sslmode=disable", hostAndPort)
+
+ // Kill container after 120 seconds
+ _ = resource.Expire(120)
+
+ pool.MaxWait = 120 * time.Second
+ if err = pool.Retry(func() error {
+ db, err = sql.Open("postgres", databaseUrl)
+ if err != nil {
+ return err
+ }
+ return db.Ping()
+ }); err != nil {
+ log.Fatalf("Could not connect to database: %s", err)
+ }
+
+ err = database.Migrate(context.Background(), databaseUrl)
+ if err != nil {
+ log.Fatalf("Could not migrate database: %s", err)
+ }
+
+ defer func() {
+ if err := pool.Purge(resource); err != nil {
+ log.Fatalf("Could not purge resource: %s", err)
+ }
+ }()
+
+ m.Run()
+}
+
+func randStringRunes(n int) string {
+ b := make([]rune, n)
+ for i := range b {
+ b[i] = letterRunes[rand.IntN(len(letterRunes))]
+ }
+ return string(b)
+}
+
+func addCookie(t *testing.T) (user, token string) {
+ t.Helper()
+
+ // Randomly generate a user and token
+ user = randStringRunes(8)
+ token = randStringRunes(16)
+
+ ciphertext, key, err := keys.Encrypt(context.Background(), new(kmsStub), "", []byte(token))
+ require.NoError(t, err)
+
+ tx, err := db.Begin()
+ require.NoError(t, err)
+
+ var keyID uint
+ err = tx.QueryRow(`
+INSERT INTO keys (kms_key_name, encrypted_key)
+ VALUES ('', $1)
+ RETURNING id;
+`, key).Scan(&keyID)
+ require.NoError(t, err)
+
+ _, err = tx.Exec(`
+INSERT
+INTO users (subject, encryption_key)
+VALUES ($1, $2);
+`, user, keyID)
+ require.NoError(t, err)
+
+ _, err = tx.Exec(`
+INSERT
+INTO ibd_tokens (user_subject, token, encryption_key, expires_at)
+VALUES ($1, $2, $3, $4);`,
+ user,
+ ciphertext,
+ keyID,
+ maxTime,
+ )
+ require.NoError(t, err)
+
+ err = tx.Commit()
+ require.NoError(t, err)
+
+ return user, token
+}
- client := NewClient(new(emptyCookieSourceStub))
+func TestClient_getCookie(t *testing.T) {
+ t.Run("no cookies", func(t *testing.T) {
+ client := NewClient(db, new(kmsStub))
_, _, err := client.getCookie(context.Background(), nil)
assert.ErrorIs(t, err, ErrNoAvailableCookies)
})
t.Run("no cookies by subject", func(t *testing.T) {
- t.Parallel()
-
- client := NewClient(new(emptyCookieSourceStub))
+ client := NewClient(db, new(kmsStub))
subject := "test"
_, _, err := client.getCookie(context.Background(), &subject)
@@ -33,67 +158,44 @@ func TestClient_getCookie(t *testing.T) {
})
t.Run("get any cookie", func(t *testing.T) {
- t.Parallel()
+ _, token := addCookie(t)
- client := NewClient(new(cookieSourceStub))
+ client := NewClient(db, new(kmsStub))
- id, cookie, err := client.getCookie(context.Background(), nil)
+ _, cookie, err := client.getCookie(context.Background(), nil)
require.NoError(t, err)
- assert.Equal(t, uint(42), id)
assert.Equal(t, cookieName, cookie.Name)
- assert.Equal(t, "test-token", cookie.Value)
+ assert.Equal(t, token, cookie.Value)
assert.Equal(t, "/", cookie.Path)
- assert.Equal(t, time.Unix(0, 0), cookie.Expires)
+ assert.Equal(t, maxTime, cookie.Expires)
assert.Equal(t, "investors.com", cookie.Domain)
})
t.Run("get cookie by subject", func(t *testing.T) {
- t.Parallel()
+ subject, token := addCookie(t)
- client := NewClient(new(cookieSourceStub))
+ client := NewClient(db, new(kmsStub))
- subject := "test"
- id, cookie, err := client.getCookie(context.Background(), &subject)
+ _, cookie, err := client.getCookie(context.Background(), &subject)
require.NoError(t, err)
- assert.Equal(t, uint(42), id)
assert.Equal(t, cookieName, cookie.Name)
- assert.Equal(t, "test-token", cookie.Value)
+ assert.Equal(t, token, cookie.Value)
assert.Equal(t, "/", cookie.Path)
- assert.Equal(t, time.Unix(0, 0), cookie.Expires)
+ assert.Equal(t, maxTime, cookie.Expires)
assert.Equal(t, "investors.com", cookie.Domain)
})
}
-type emptyCookieSourceStub struct{}
-
-func (c *emptyCookieSourceStub) GetAnyCookie(_ context.Context) (*database.IBDCookie, error) {
- return nil, nil
-}
-
-func (c *emptyCookieSourceStub) GetCookies(_ context.Context, _ string, _ bool) ([]database.IBDCookie, error) {
- return nil, nil
-}
+type kmsStub struct{}
-func (c *emptyCookieSourceStub) ReportCookieFailure(_ context.Context, _ uint) error {
+func (k *kmsStub) Close() error {
return nil
}
-var testCookie = database.IBDCookie{
- ID: 42,
- Token: "test-token",
- Expiry: time.Unix(0, 0),
+func (k *kmsStub) Encrypt(_ context.Context, _ string, plaintext []byte) ([]byte, error) {
+ return plaintext, nil
}
-type cookieSourceStub struct{}
-
-func (c *cookieSourceStub) GetAnyCookie(_ context.Context) (*database.IBDCookie, error) {
- return &testCookie, nil
-}
-
-func (c *cookieSourceStub) GetCookies(_ context.Context, _ string, _ bool) ([]database.IBDCookie, error) {
- return []database.IBDCookie{testCookie}, nil
-}
-
-func (c *cookieSourceStub) ReportCookieFailure(_ context.Context, _ uint) error {
- return nil
+func (k *kmsStub) Decrypt(_ context.Context, _ string, ciphertext []byte) ([]byte, error) {
+ return ciphertext, nil
}