aboutsummaryrefslogtreecommitdiff
path: root/backend/internal/ibd/client_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'backend/internal/ibd/client_test.go')
-rw-r--r--backend/internal/ibd/client_test.go201
1 files changed, 201 insertions, 0 deletions
diff --git a/backend/internal/ibd/client_test.go b/backend/internal/ibd/client_test.go
new file mode 100644
index 0000000..2368a31
--- /dev/null
+++ b/backend/internal/ibd/client_test.go
@@ -0,0 +1,201 @@
+package ibd
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "log"
+ "math/rand/v2"
+ "testing"
+ "time"
+
+ "github.com/ansg191/ibd-trader-backend/internal/database"
+ "github.com/ansg191/ibd-trader-backend/internal/keys"
+ _ "github.com/lib/pq"
+ "github.com/ory/dockertest/v3"
+ "github.com/ory/dockertest/v3/docker"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+var (
+ db *sql.DB
+ maxTime = time.Date(2100, 1, 1, 0, 0, 0, 0, time.UTC)
+ letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
+)
+
+func TestMain(m *testing.M) {
+ pool, err := dockertest.NewPool("")
+ if err != nil {
+ log.Fatalf("Could not create pool: %s", err)
+ }
+
+ err = pool.Client.Ping()
+ if err != nil {
+ log.Fatalf("Could not connect to Docker: %s", err)
+ }
+
+ resource, err := pool.RunWithOptions(&dockertest.RunOptions{
+ Repository: "postgres",
+ Tag: "16",
+ Env: []string{
+ "POSTGRES_PASSWORD=secret",
+ "POSTGRES_USER=ibd-client-test",
+ "POSTGRES_DB=ibd-client-test",
+ "listen_addresses='*'",
+ },
+ Cmd: []string{
+ "postgres",
+ "-c",
+ "log_statement=all",
+ },
+ }, func(config *docker.HostConfig) {
+ config.AutoRemove = true
+ config.RestartPolicy = docker.RestartPolicy{Name: "no"}
+ })
+ if err != nil {
+ log.Fatalf("Could not start resource: %s", err)
+ }
+
+ hostAndPort := resource.GetHostPort("5432/tcp")
+ databaseUrl := fmt.Sprintf("postgres://ibd-client-test:secret@%s/ibd-client-test?sslmode=disable", hostAndPort)
+
+ // Kill container after 120 seconds
+ _ = resource.Expire(120)
+
+ pool.MaxWait = 120 * time.Second
+ if err = pool.Retry(func() error {
+ db, err = sql.Open("postgres", databaseUrl)
+ if err != nil {
+ return err
+ }
+ return db.Ping()
+ }); err != nil {
+ log.Fatalf("Could not connect to database: %s", err)
+ }
+
+ err = database.Migrate(context.Background(), databaseUrl)
+ if err != nil {
+ log.Fatalf("Could not migrate database: %s", err)
+ }
+
+ defer func() {
+ if err := pool.Purge(resource); err != nil {
+ log.Fatalf("Could not purge resource: %s", err)
+ }
+ }()
+
+ m.Run()
+}
+
+func randStringRunes(n int) string {
+ b := make([]rune, n)
+ for i := range b {
+ b[i] = letterRunes[rand.IntN(len(letterRunes))]
+ }
+ return string(b)
+}
+
+func addCookie(t *testing.T) (user, token string) {
+ t.Helper()
+
+ // Randomly generate a user and token
+ user = randStringRunes(8)
+ token = randStringRunes(16)
+
+ ciphertext, key, err := keys.Encrypt(context.Background(), new(kmsStub), "", []byte(token))
+ require.NoError(t, err)
+
+ tx, err := db.Begin()
+ require.NoError(t, err)
+
+ var keyID uint
+ err = tx.QueryRow(`
+INSERT INTO keys (kms_key_name, encrypted_key)
+ VALUES ('', $1)
+ RETURNING id;
+`, key).Scan(&keyID)
+ require.NoError(t, err)
+
+ _, err = tx.Exec(`
+INSERT
+INTO users (subject, encryption_key)
+VALUES ($1, $2);
+`, user, keyID)
+ require.NoError(t, err)
+
+ _, err = tx.Exec(`
+INSERT
+INTO ibd_tokens (user_subject, token, encryption_key, expires_at)
+VALUES ($1, $2, $3, $4);`,
+ user,
+ ciphertext,
+ keyID,
+ maxTime,
+ )
+ require.NoError(t, err)
+
+ err = tx.Commit()
+ require.NoError(t, err)
+
+ return user, token
+}
+
+func TestClient_getCookie(t *testing.T) {
+ t.Run("no cookies", func(t *testing.T) {
+ client := NewClient(db, new(kmsStub))
+
+ _, _, err := client.getCookie(context.Background(), nil)
+ assert.ErrorIs(t, err, ErrNoAvailableCookies)
+ })
+
+ t.Run("no cookies by subject", func(t *testing.T) {
+ client := NewClient(db, new(kmsStub))
+
+ subject := "test"
+ _, _, err := client.getCookie(context.Background(), &subject)
+ assert.ErrorIs(t, err, ErrNoAvailableCookies)
+ })
+
+ t.Run("get any cookie", func(t *testing.T) {
+ _, token := addCookie(t)
+
+ client := NewClient(db, new(kmsStub))
+
+ _, cookie, err := client.getCookie(context.Background(), nil)
+ require.NoError(t, err)
+ assert.Equal(t, cookieName, cookie.Name)
+ assert.Equal(t, token, cookie.Value)
+ assert.Equal(t, "/", cookie.Path)
+ assert.Equal(t, maxTime, cookie.Expires)
+ assert.Equal(t, "investors.com", cookie.Domain)
+ })
+
+ t.Run("get cookie by subject", func(t *testing.T) {
+ subject, token := addCookie(t)
+
+ client := NewClient(db, new(kmsStub))
+
+ _, cookie, err := client.getCookie(context.Background(), &subject)
+ require.NoError(t, err)
+ assert.Equal(t, cookieName, cookie.Name)
+ assert.Equal(t, token, cookie.Value)
+ assert.Equal(t, "/", cookie.Path)
+ assert.Equal(t, maxTime, cookie.Expires)
+ assert.Equal(t, "investors.com", cookie.Domain)
+ })
+}
+
+type kmsStub struct{}
+
+func (k *kmsStub) Close() error {
+ return nil
+}
+
+func (k *kmsStub) Encrypt(_ context.Context, _ string, plaintext []byte) ([]byte, error) {
+ return plaintext, nil
+}
+
+func (k *kmsStub) Decrypt(_ context.Context, _ string, ciphertext []byte) ([]byte, error) {
+ return ciphertext, nil
+}