diff options
Diffstat (limited to 'backend/internal/ibd/client_test.go')
-rw-r--r-- | backend/internal/ibd/client_test.go | 201 |
1 files changed, 201 insertions, 0 deletions
diff --git a/backend/internal/ibd/client_test.go b/backend/internal/ibd/client_test.go new file mode 100644 index 0000000..2368a31 --- /dev/null +++ b/backend/internal/ibd/client_test.go @@ -0,0 +1,201 @@ +package ibd + +import ( + "context" + "database/sql" + "fmt" + "log" + "math/rand/v2" + "testing" + "time" + + "github.com/ansg191/ibd-trader-backend/internal/database" + "github.com/ansg191/ibd-trader-backend/internal/keys" + _ "github.com/lib/pq" + "github.com/ory/dockertest/v3" + "github.com/ory/dockertest/v3/docker" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var ( + db *sql.DB + maxTime = time.Date(2100, 1, 1, 0, 0, 0, 0, time.UTC) + letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") +) + +func TestMain(m *testing.M) { + pool, err := dockertest.NewPool("") + if err != nil { + log.Fatalf("Could not create pool: %s", err) + } + + err = pool.Client.Ping() + if err != nil { + log.Fatalf("Could not connect to Docker: %s", err) + } + + resource, err := pool.RunWithOptions(&dockertest.RunOptions{ + Repository: "postgres", + Tag: "16", + Env: []string{ + "POSTGRES_PASSWORD=secret", + "POSTGRES_USER=ibd-client-test", + "POSTGRES_DB=ibd-client-test", + "listen_addresses='*'", + }, + Cmd: []string{ + "postgres", + "-c", + "log_statement=all", + }, + }, func(config *docker.HostConfig) { + config.AutoRemove = true + config.RestartPolicy = docker.RestartPolicy{Name: "no"} + }) + if err != nil { + log.Fatalf("Could not start resource: %s", err) + } + + hostAndPort := resource.GetHostPort("5432/tcp") + databaseUrl := fmt.Sprintf("postgres://ibd-client-test:secret@%s/ibd-client-test?sslmode=disable", hostAndPort) + + // Kill container after 120 seconds + _ = resource.Expire(120) + + pool.MaxWait = 120 * time.Second + if err = pool.Retry(func() error { + db, err = sql.Open("postgres", databaseUrl) + if err != nil { + return err + } + return db.Ping() + }); err != nil { + log.Fatalf("Could not connect to database: %s", err) + } + + err = database.Migrate(context.Background(), databaseUrl) + if err != nil { + log.Fatalf("Could not migrate database: %s", err) + } + + defer func() { + if err := pool.Purge(resource); err != nil { + log.Fatalf("Could not purge resource: %s", err) + } + }() + + m.Run() +} + +func randStringRunes(n int) string { + b := make([]rune, n) + for i := range b { + b[i] = letterRunes[rand.IntN(len(letterRunes))] + } + return string(b) +} + +func addCookie(t *testing.T) (user, token string) { + t.Helper() + + // Randomly generate a user and token + user = randStringRunes(8) + token = randStringRunes(16) + + ciphertext, key, err := keys.Encrypt(context.Background(), new(kmsStub), "", []byte(token)) + require.NoError(t, err) + + tx, err := db.Begin() + require.NoError(t, err) + + var keyID uint + err = tx.QueryRow(` +INSERT INTO keys (kms_key_name, encrypted_key) + VALUES ('', $1) + RETURNING id; +`, key).Scan(&keyID) + require.NoError(t, err) + + _, err = tx.Exec(` +INSERT +INTO users (subject, encryption_key) +VALUES ($1, $2); +`, user, keyID) + require.NoError(t, err) + + _, err = tx.Exec(` +INSERT +INTO ibd_tokens (user_subject, token, encryption_key, expires_at) +VALUES ($1, $2, $3, $4);`, + user, + ciphertext, + keyID, + maxTime, + ) + require.NoError(t, err) + + err = tx.Commit() + require.NoError(t, err) + + return user, token +} + +func TestClient_getCookie(t *testing.T) { + t.Run("no cookies", func(t *testing.T) { + client := NewClient(db, new(kmsStub)) + + _, _, err := client.getCookie(context.Background(), nil) + assert.ErrorIs(t, err, ErrNoAvailableCookies) + }) + + t.Run("no cookies by subject", func(t *testing.T) { + client := NewClient(db, new(kmsStub)) + + subject := "test" + _, _, err := client.getCookie(context.Background(), &subject) + assert.ErrorIs(t, err, ErrNoAvailableCookies) + }) + + t.Run("get any cookie", func(t *testing.T) { + _, token := addCookie(t) + + client := NewClient(db, new(kmsStub)) + + _, cookie, err := client.getCookie(context.Background(), nil) + require.NoError(t, err) + assert.Equal(t, cookieName, cookie.Name) + assert.Equal(t, token, cookie.Value) + assert.Equal(t, "/", cookie.Path) + assert.Equal(t, maxTime, cookie.Expires) + assert.Equal(t, "investors.com", cookie.Domain) + }) + + t.Run("get cookie by subject", func(t *testing.T) { + subject, token := addCookie(t) + + client := NewClient(db, new(kmsStub)) + + _, cookie, err := client.getCookie(context.Background(), &subject) + require.NoError(t, err) + assert.Equal(t, cookieName, cookie.Name) + assert.Equal(t, token, cookie.Value) + assert.Equal(t, "/", cookie.Path) + assert.Equal(t, maxTime, cookie.Expires) + assert.Equal(t, "investors.com", cookie.Domain) + }) +} + +type kmsStub struct{} + +func (k *kmsStub) Close() error { + return nil +} + +func (k *kmsStub) Encrypt(_ context.Context, _ string, plaintext []byte) ([]byte, error) { + return plaintext, nil +} + +func (k *kmsStub) Decrypt(_ context.Context, _ string, ciphertext []byte) ([]byte, error) { + return ciphertext, nil +} |