diff options
Diffstat (limited to 'backend/internal/ibd/client_test.go')
-rw-r--r-- | backend/internal/ibd/client_test.go | 196 |
1 files changed, 149 insertions, 47 deletions
diff --git a/backend/internal/ibd/client_test.go b/backend/internal/ibd/client_test.go index d2dc1b2..2368a31 100644 --- a/backend/internal/ibd/client_test.go +++ b/backend/internal/ibd/client_test.go @@ -2,30 +2,155 @@ package ibd import ( "context" + "database/sql" + "fmt" + "log" + "math/rand/v2" "testing" "time" "github.com/ansg191/ibd-trader-backend/internal/database" + "github.com/ansg191/ibd-trader-backend/internal/keys" + _ "github.com/lib/pq" + "github.com/ory/dockertest/v3" + "github.com/ory/dockertest/v3/docker" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestClient_getCookie(t *testing.T) { - t.Parallel() +var ( + db *sql.DB + maxTime = time.Date(2100, 1, 1, 0, 0, 0, 0, time.UTC) + letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") +) - t.Run("no cookies", func(t *testing.T) { - t.Parallel() +func TestMain(m *testing.M) { + pool, err := dockertest.NewPool("") + if err != nil { + log.Fatalf("Could not create pool: %s", err) + } + + err = pool.Client.Ping() + if err != nil { + log.Fatalf("Could not connect to Docker: %s", err) + } + + resource, err := pool.RunWithOptions(&dockertest.RunOptions{ + Repository: "postgres", + Tag: "16", + Env: []string{ + "POSTGRES_PASSWORD=secret", + "POSTGRES_USER=ibd-client-test", + "POSTGRES_DB=ibd-client-test", + "listen_addresses='*'", + }, + Cmd: []string{ + "postgres", + "-c", + "log_statement=all", + }, + }, func(config *docker.HostConfig) { + config.AutoRemove = true + config.RestartPolicy = docker.RestartPolicy{Name: "no"} + }) + if err != nil { + log.Fatalf("Could not start resource: %s", err) + } + + hostAndPort := resource.GetHostPort("5432/tcp") + databaseUrl := fmt.Sprintf("postgres://ibd-client-test:secret@%s/ibd-client-test?sslmode=disable", hostAndPort) + + // Kill container after 120 seconds + _ = resource.Expire(120) + + pool.MaxWait = 120 * time.Second + if err = pool.Retry(func() error { + db, err = sql.Open("postgres", databaseUrl) + if err != nil { + return err + } + return db.Ping() + }); err != nil { + log.Fatalf("Could not connect to database: %s", err) + } + + err = database.Migrate(context.Background(), databaseUrl) + if err != nil { + log.Fatalf("Could not migrate database: %s", err) + } + + defer func() { + if err := pool.Purge(resource); err != nil { + log.Fatalf("Could not purge resource: %s", err) + } + }() + + m.Run() +} + +func randStringRunes(n int) string { + b := make([]rune, n) + for i := range b { + b[i] = letterRunes[rand.IntN(len(letterRunes))] + } + return string(b) +} + +func addCookie(t *testing.T) (user, token string) { + t.Helper() + + // Randomly generate a user and token + user = randStringRunes(8) + token = randStringRunes(16) + + ciphertext, key, err := keys.Encrypt(context.Background(), new(kmsStub), "", []byte(token)) + require.NoError(t, err) + + tx, err := db.Begin() + require.NoError(t, err) + + var keyID uint + err = tx.QueryRow(` +INSERT INTO keys (kms_key_name, encrypted_key) + VALUES ('', $1) + RETURNING id; +`, key).Scan(&keyID) + require.NoError(t, err) + + _, err = tx.Exec(` +INSERT +INTO users (subject, encryption_key) +VALUES ($1, $2); +`, user, keyID) + require.NoError(t, err) + + _, err = tx.Exec(` +INSERT +INTO ibd_tokens (user_subject, token, encryption_key, expires_at) +VALUES ($1, $2, $3, $4);`, + user, + ciphertext, + keyID, + maxTime, + ) + require.NoError(t, err) + + err = tx.Commit() + require.NoError(t, err) + + return user, token +} - client := NewClient(new(emptyCookieSourceStub)) +func TestClient_getCookie(t *testing.T) { + t.Run("no cookies", func(t *testing.T) { + client := NewClient(db, new(kmsStub)) _, _, err := client.getCookie(context.Background(), nil) assert.ErrorIs(t, err, ErrNoAvailableCookies) }) t.Run("no cookies by subject", func(t *testing.T) { - t.Parallel() - - client := NewClient(new(emptyCookieSourceStub)) + client := NewClient(db, new(kmsStub)) subject := "test" _, _, err := client.getCookie(context.Background(), &subject) @@ -33,67 +158,44 @@ func TestClient_getCookie(t *testing.T) { }) t.Run("get any cookie", func(t *testing.T) { - t.Parallel() + _, token := addCookie(t) - client := NewClient(new(cookieSourceStub)) + client := NewClient(db, new(kmsStub)) - id, cookie, err := client.getCookie(context.Background(), nil) + _, cookie, err := client.getCookie(context.Background(), nil) require.NoError(t, err) - assert.Equal(t, uint(42), id) assert.Equal(t, cookieName, cookie.Name) - assert.Equal(t, "test-token", cookie.Value) + assert.Equal(t, token, cookie.Value) assert.Equal(t, "/", cookie.Path) - assert.Equal(t, time.Unix(0, 0), cookie.Expires) + assert.Equal(t, maxTime, cookie.Expires) assert.Equal(t, "investors.com", cookie.Domain) }) t.Run("get cookie by subject", func(t *testing.T) { - t.Parallel() + subject, token := addCookie(t) - client := NewClient(new(cookieSourceStub)) + client := NewClient(db, new(kmsStub)) - subject := "test" - id, cookie, err := client.getCookie(context.Background(), &subject) + _, cookie, err := client.getCookie(context.Background(), &subject) require.NoError(t, err) - assert.Equal(t, uint(42), id) assert.Equal(t, cookieName, cookie.Name) - assert.Equal(t, "test-token", cookie.Value) + assert.Equal(t, token, cookie.Value) assert.Equal(t, "/", cookie.Path) - assert.Equal(t, time.Unix(0, 0), cookie.Expires) + assert.Equal(t, maxTime, cookie.Expires) assert.Equal(t, "investors.com", cookie.Domain) }) } -type emptyCookieSourceStub struct{} - -func (c *emptyCookieSourceStub) GetAnyCookie(_ context.Context) (*database.IBDCookie, error) { - return nil, nil -} - -func (c *emptyCookieSourceStub) GetCookies(_ context.Context, _ string, _ bool) ([]database.IBDCookie, error) { - return nil, nil -} +type kmsStub struct{} -func (c *emptyCookieSourceStub) ReportCookieFailure(_ context.Context, _ uint) error { +func (k *kmsStub) Close() error { return nil } -var testCookie = database.IBDCookie{ - ID: 42, - Token: "test-token", - Expiry: time.Unix(0, 0), +func (k *kmsStub) Encrypt(_ context.Context, _ string, plaintext []byte) ([]byte, error) { + return plaintext, nil } -type cookieSourceStub struct{} - -func (c *cookieSourceStub) GetAnyCookie(_ context.Context) (*database.IBDCookie, error) { - return &testCookie, nil -} - -func (c *cookieSourceStub) GetCookies(_ context.Context, _ string, _ bool) ([]database.IBDCookie, error) { - return []database.IBDCookie{testCookie}, nil -} - -func (c *cookieSourceStub) ReportCookieFailure(_ context.Context, _ uint) error { - return nil +func (k *kmsStub) Decrypt(_ context.Context, _ string, ciphertext []byte) ([]byte, error) { + return ciphertext, nil } |