package ibd import ( "context" "database/sql" "fmt" "log" "math/rand/v2" "testing" "time" "github.com/ansg191/ibd-trader-backend/internal/database" "github.com/ansg191/ibd-trader-backend/internal/keys" _ "github.com/lib/pq" "github.com/ory/dockertest/v3" "github.com/ory/dockertest/v3/docker" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) var ( db *sql.DB maxTime = time.Date(2100, 1, 1, 0, 0, 0, 0, time.UTC) letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") ) func TestMain(m *testing.M) { pool, err := dockertest.NewPool("") if err != nil { log.Fatalf("Could not create pool: %s", err) } err = pool.Client.Ping() if err != nil { log.Fatalf("Could not connect to Docker: %s", err) } resource, err := pool.RunWithOptions(&dockertest.RunOptions{ Repository: "postgres", Tag: "16", Env: []string{ "POSTGRES_PASSWORD=secret", "POSTGRES_USER=ibd-client-test", "POSTGRES_DB=ibd-client-test", "listen_addresses='*'", }, Cmd: []string{ "postgres", "-c", "log_statement=all", }, }, func(config *docker.HostConfig) { config.AutoRemove = true config.RestartPolicy = docker.RestartPolicy{Name: "no"} }) if err != nil { log.Fatalf("Could not start resource: %s", err) } hostAndPort := resource.GetHostPort("5432/tcp") databaseUrl := fmt.Sprintf("postgres://ibd-client-test:secret@%s/ibd-client-test?sslmode=disable", hostAndPort) // Kill container after 120 seconds _ = resource.Expire(120) pool.MaxWait = 120 * time.Second if err = pool.Retry(func() error { db, err = sql.Open("postgres", databaseUrl) if err != nil { return err } return db.Ping() }); err != nil { log.Fatalf("Could not connect to database: %s", err) } err = database.Migrate(context.Background(), databaseUrl) if err != nil { log.Fatalf("Could not migrate database: %s", err) } defer func() { if err := pool.Purge(resource); err != nil { log.Fatalf("Could not purge resource: %s", err) } }() m.Run() } func randStringRunes(n int) string { b := make([]rune, n) for i := range b { b[i] = letterRunes[rand.IntN(len(letterRunes))] } return string(b) } func addCookie(t *testing.T) (user, token string) { t.Helper() // Randomly generate a user and token user = randStringRunes(8) token = randStringRunes(16) ciphertext, key, err := keys.Encrypt(context.Background(), new(kmsStub), "", []byte(token)) require.NoError(t, err) tx, err := db.Begin() require.NoError(t, err) var keyID uint err = tx.QueryRow(` INSERT INTO keys (kms_key_name, encrypted_key) VALUES ('', $1) RETURNING id; `, key).Scan(&keyID) require.NoError(t, err) _, err = tx.Exec(` INSERT INTO users (subject, encryption_key) VALUES ($1, $2); `, user, keyID) require.NoError(t, err) _, err = tx.Exec(` INSERT INTO ibd_tokens (user_subject, token, encryption_key, expires_at) VALUES ($1, $2, $3, $4);`, user, ciphertext, keyID, maxTime, ) require.NoError(t, err) err = tx.Commit() require.NoError(t, err) return user, token } func TestClient_getCookie(t *testing.T) { t.Run("no cookies", func(t *testing.T) { client := NewClient(db, new(kmsStub)) _, _, err := client.getCookie(context.Background(), nil) assert.ErrorIs(t, err, ErrNoAvailableCookies) }) t.Run("no cookies by subject", func(t *testing.T) { client := NewClient(db, new(kmsStub)) subject := "test" _, _, err := client.getCookie(context.Background(), &subject) assert.ErrorIs(t, err, ErrNoAvailableCookies) }) t.Run("get any cookie", func(t *testing.T) { _, token := addCookie(t) client := NewClient(db, new(kmsStub)) _, cookie, err := client.getCookie(context.Background(), nil) require.NoError(t, err) assert.Equal(t, cookieName, cookie.Name) assert.Equal(t, token, cookie.Value) assert.Equal(t, "/", cookie.Path) assert.Equal(t, maxTime, cookie.Expires) assert.Equal(t, "investors.com", cookie.Domain) }) t.Run("get cookie by subject", func(t *testing.T) { subject, token := addCookie(t) client := NewClient(db, new(kmsStub)) _, cookie, err := client.getCookie(context.Background(), &subject) require.NoError(t, err) assert.Equal(t, cookieName, cookie.Name) assert.Equal(t, token, cookie.Value) assert.Equal(t, "/", cookie.Path) assert.Equal(t, maxTime, cookie.Expires) assert.Equal(t, "investors.com", cookie.Domain) }) } type kmsStub struct{} func (k *kmsStub) Close() error { return nil } func (k *kmsStub) Encrypt(_ context.Context, _ string, plaintext []byte) ([]byte, error) { return plaintext, nil } func (k *kmsStub) Decrypt(_ context.Context, _ string, ciphertext []byte) ([]byte, error) { return ciphertext, nil }