package taskqueue import ( "context" "errors" "fmt" "log" "testing" "time" "github.com/ory/dockertest/v3" "github.com/ory/dockertest/v3/docker" "github.com/redis/go-redis/v9" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) var client *redis.Client func TestMain(m *testing.M) { pool, err := dockertest.NewPool("") if err != nil { log.Fatalf("Could not create pool: %s", err) } err = pool.Client.Ping() if err != nil { log.Fatalf("Could not connect to Docker: %s", err) } //resource, err := pool.Run("redis", "7", nil) resource, err := pool.RunWithOptions(&dockertest.RunOptions{ Repository: "redis", Tag: "7", }, func(config *docker.HostConfig) { config.AutoRemove = true config.RestartPolicy = docker.RestartPolicy{Name: "no"} }) if err != nil { log.Fatalf("Could not start resource: %s", err) } //_ = resource.Expire(60) if err = pool.Retry(func() error { client = redis.NewClient(&redis.Options{ Addr: fmt.Sprintf("localhost:%s", resource.GetPort("6379/tcp")), }) return client.Ping(context.Background()).Err() }); err != nil { log.Fatalf("Could not connect to redis: %s", err) } defer func() { if err = client.Close(); err != nil { log.Printf("Could not close client: %s", err) } if err = pool.Purge(resource); err != nil { log.Fatalf("Could not purge resource: %s", err) } }() m.Run() } func TestTaskQueue(t *testing.T) { if testing.Short() { t.Skip() } lockTimeout := 100 * time.Millisecond tests := []struct { name string f func(t *testing.T) }{ { name: "Create queue", f: func(t *testing.T) { q, err := New[string](context.Background(), client, "test", "worker") require.NoError(t, err) require.NotNil(t, q) }, }, { name: "enqueue & dequeue", f: func(t *testing.T) { q, err := New[string](context.Background(), client, "test", "worker") require.NoError(t, err) require.NotNil(t, q) taskId, err := q.Enqueue(context.Background(), "hello") require.NoError(t, err) require.NotEmpty(t, taskId) task, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) require.NoError(t, err) require.NotNil(t, task) require.Equal(t, "hello", task.Data) }, }, { name: "complex data", f: func(t *testing.T) { type foo struct { A int B string } q, err := New[foo](context.Background(), client, "test", "worker") require.NoError(t, err) require.NotNil(t, q) taskId, err := q.Enqueue(context.Background(), foo{A: 42, B: "hello"}) require.NoError(t, err) require.NotEmpty(t, taskId) task, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) require.NoError(t, err) require.NotNil(t, task) require.Equal(t, foo{A: 42, B: "hello"}, task.Data) }, }, { name: "different workers", f: func(t *testing.T) { q1, err := New[string](context.Background(), client, "test", "worker1") require.NoError(t, err) require.NotNil(t, q1) q2, err := New[string](context.Background(), client, "test", "worker2") require.NoError(t, err) require.NotNil(t, q2) taskId, err := q1.Enqueue(context.Background(), "hello") require.NoError(t, err) require.NotEmpty(t, taskId) task, err := q2.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) require.NoError(t, err) require.NotNil(t, task) assert.Equal(t, "hello", task.Data) }, }, { name: "complete", f: func(t *testing.T) { q, err := New[string](context.Background(), client, "test", "worker") require.NoError(t, err) require.NotNil(t, q) // Enqueue a task taskId, err := q.Enqueue(context.Background(), "hello") require.NoError(t, err) require.NotEmpty(t, taskId) // Dequeue the task task, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) require.NoError(t, err) require.NotNil(t, task) assert.Equal(t, "hello", task.Data) // Complete the task err = q.Complete(context.Background(), task.ID, "done") require.NoError(t, err) // Try to dequeue the task again task, err = q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) require.NoError(t, err) require.Nil(t, task) }, }, { name: "timeout", f: func(t *testing.T) { q, err := New[string](context.Background(), client, "test", "worker") require.NoError(t, err) require.NotNil(t, q) // Enqueue a task taskId, err := q.Enqueue(context.Background(), "hello") require.NoError(t, err) require.NotEmpty(t, taskId) // Dequeue the task task, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) require.NoError(t, err) require.NotNil(t, task) assert.Equal(t, "hello", task.Data) // Wait for the lock to expire time.Sleep(lockTimeout + 10*time.Millisecond) // Try to dequeue the task again task, err = q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) require.NoError(t, err) require.NotNil(t, task) assert.Equal(t, "hello", task.Data) }, }, { name: "extend", f: func(t *testing.T) { q, err := New[string](context.Background(), client, "test", "worker") require.NoError(t, err) require.NotNil(t, q) // Enqueue a task taskId, err := q.Enqueue(context.Background(), "hello") require.NoError(t, err) require.NotEmpty(t, taskId) // Dequeue the task task, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) require.NoError(t, err) require.NotNil(t, task) assert.Equal(t, "hello", task.Data) // Wait for the lock to expire time.Sleep(lockTimeout + 10*time.Millisecond) // Extend the lock err = q.Extend(context.Background(), task.ID) require.NoError(t, err) // Try to dequeue the task again task, err = q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) require.NoError(t, err) require.Nil(t, task) }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := client.FlushDB(context.Background()).Err(); err != nil { t.Fatal(err) } tt.f(t) }) } _ = client.FlushDB(context.Background()) } func TestTaskQueue_List(t *testing.T) { if testing.Short() { t.Skip() } tests := []struct { name string f func(t *testing.T) }{ { name: "empty", f: func(t *testing.T) { q, err := New[any](context.Background(), client, "test", "worker") require.NoError(t, err) require.NotNil(t, q) tasks, err := q.List(context.Background(), TaskID{}, TaskID{}, 1) require.NoError(t, err) assert.Empty(t, tasks) }, }, { name: "single", f: func(t *testing.T) { q, err := New[string](context.Background(), client, "test", "worker") require.NoError(t, err) require.NotNil(t, q) taskID, err := q.Enqueue(context.Background(), "hello") require.NoError(t, err) tasks, err := q.List(context.Background(), TaskID{}, TaskID{}, 1) require.NoError(t, err) assert.Len(t, tasks, 1) assert.Equal(t, taskID, tasks[0]) }, }, { name: "multiple", f: func(t *testing.T) { q, err := New[string](context.Background(), client, "test", "worker") require.NoError(t, err) require.NotNil(t, q) taskID, err := q.Enqueue(context.Background(), "hello") require.NoError(t, err) taskID2, err := q.Enqueue(context.Background(), "world") require.NoError(t, err) tasks, err := q.List(context.Background(), TaskID{}, TaskID{}, 2) require.NoError(t, err) assert.Len(t, tasks, 2) assert.Equal(t, taskID, tasks[1]) assert.Equal(t, taskID2, tasks[0]) }, }, { name: "multiple limited", f: func(t *testing.T) { q, err := New[string](context.Background(), client, "test", "worker") require.NoError(t, err) require.NotNil(t, q) _, err = q.Enqueue(context.Background(), "hello") require.NoError(t, err) taskID2, err := q.Enqueue(context.Background(), "world") require.NoError(t, err) tasks, err := q.List(context.Background(), TaskID{}, TaskID{}, 1) require.NoError(t, err) assert.Len(t, tasks, 1) assert.Equal(t, taskID2, tasks[0]) }, }, { name: "multiple time range", f: func(t *testing.T) { q, err := New[string](context.Background(), client, "test", "worker") require.NoError(t, err) require.NotNil(t, q) taskID, err := q.Enqueue(context.Background(), "hello") require.NoError(t, err) time.Sleep(10 * time.Millisecond) taskID2, err := q.Enqueue(context.Background(), "world") require.NoError(t, err) tasks, err := q.List(context.Background(), TaskID{}, taskID2.ID, 100) require.NoError(t, err) assert.Len(t, tasks, 1) assert.Equal(t, taskID, tasks[0]) tasks, err = q.List(context.Background(), taskID2.ID, TaskID{}, 100) require.NoError(t, err) assert.Len(t, tasks, 1) assert.Equal(t, taskID2, tasks[0]) }, }, { name: "completed tasks", f: func(t *testing.T) { q, err := New[string](context.Background(), client, "test", "worker") require.NoError(t, err) require.NotNil(t, q) task1, err := q.Enqueue(context.Background(), "hello") require.NoError(t, err) task2, err := q.Enqueue(context.Background(), "world") require.NoError(t, err) err = q.Complete(context.Background(), task1.ID, "done") require.NoError(t, err) tasks, err := q.List(context.Background(), TaskID{}, TaskID{}, 100) require.NoError(t, err) assert.Len(t, tasks, 2) assert.Equal(t, task2, tasks[0]) assert.Equal(t, "hello", tasks[1].Data) require.IsType(t, &TaskResultSuccess{}, tasks[1].Result) assert.Equal(t, "done", tasks[1].Result.(*TaskResultSuccess).Result) }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := client.FlushDB(context.Background()).Err(); err != nil { t.Fatal(err) } tt.f(t) }) } _ = client.FlushDB(context.Background()) } func TestTaskQueue_Return(t *testing.T) { if testing.Short() { t.Skip() } lockTimeout := 100 * time.Millisecond tests := []struct { name string f func(t *testing.T) }{ { name: "simple", f: func(t *testing.T) { q, err := New[string](context.Background(), client, "test", "worker") require.NoError(t, err) require.NotNil(t, q) task1, err := q.Enqueue(context.Background(), "hello") require.NoError(t, err) id := claimAndFail(t, q, lockTimeout) task2, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) require.NoError(t, err) require.NotNil(t, task2) assert.Equal(t, task2.ID, id) assert.Equal(t, task1.Data, task2.Data) assert.Equal(t, uint8(1), task2.Attempts) task1Data, err := q.Data(context.Background(), task1.ID) require.NoError(t, err) assert.Equal(t, task1Data.ID, task1.ID) assert.Equal(t, task1Data.Data, task1.Data) assert.IsType(t, &TaskResultError{}, task1Data.Result) assert.Equal(t, "failed", task1Data.Result.(*TaskResultError).Error) assert.Equal(t, task2.ID, task1Data.Result.(*TaskResultError).NextAttempt) }, }, { name: "failure", f: func(t *testing.T) { q, err := New[string](context.Background(), client, "test", "worker") require.NoError(t, err) require.NotNil(t, q) _, err = q.Enqueue(context.Background(), "hello") require.NoError(t, err) claimAndFail(t, q, lockTimeout) claimAndFail(t, q, lockTimeout) claimAndFail(t, q, lockTimeout) task3, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) require.NoError(t, err) assert.Nil(t, task3) }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := client.FlushDB(context.Background()).Err(); err != nil { t.Fatal(err) } tt.f(t) }) } _ = client.FlushDB(context.Background()) } func claimAndFail[T any](t *testing.T, q TaskQueue[T], lockTimeout time.Duration) TaskID { task, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) require.NoError(t, err) require.NotNil(t, task) id, err := q.Return(context.Background(), task.ID, errors.New("failed")) require.NoError(t, err) assert.NotEqual(t, task.ID, id) return id }