aboutsummaryrefslogtreecommitdiff
path: root/backend/internal/redis/taskqueue
diff options
context:
space:
mode:
authorGravatar Anshul Gupta <ansg191@anshulg.com> 2024-08-11 13:15:50 -0700
committerGravatar Anshul Gupta <ansg191@anshulg.com> 2024-08-11 13:15:50 -0700
commit6a3c21fb0b1c126849f2bbff494403bbe901448e (patch)
tree5d7805524357c2c8a9819c39d2051a4e3633a1d5 /backend/internal/redis/taskqueue
parent29c6040a51616e9e4cf6c70ee16391b2a3b238c9 (diff)
parentf34b92ded11b07f78575ac62c260a380c468e5ea (diff)
downloadibd-trader-6a3c21fb0b1c126849f2bbff494403bbe901448e.tar.gz
ibd-trader-6a3c21fb0b1c126849f2bbff494403bbe901448e.tar.zst
ibd-trader-6a3c21fb0b1c126849f2bbff494403bbe901448e.zip
Merge remote-tracking branch 'backend/main'
Diffstat (limited to 'backend/internal/redis/taskqueue')
-rw-r--r--backend/internal/redis/taskqueue/options.go9
-rw-r--r--backend/internal/redis/taskqueue/queue.go545
-rw-r--r--backend/internal/redis/taskqueue/queue_test.go467
3 files changed, 1021 insertions, 0 deletions
diff --git a/backend/internal/redis/taskqueue/options.go b/backend/internal/redis/taskqueue/options.go
new file mode 100644
index 0000000..2d5a23f
--- /dev/null
+++ b/backend/internal/redis/taskqueue/options.go
@@ -0,0 +1,9 @@
+package taskqueue
+
+type Option[T any] func(*taskQueue[T])
+
+func WithEncoding[T any](encoding Encoding) Option[T] {
+ return func(o *taskQueue[T]) {
+ o.encoding = encoding
+ }
+}
diff --git a/backend/internal/redis/taskqueue/queue.go b/backend/internal/redis/taskqueue/queue.go
new file mode 100644
index 0000000..a4b799e
--- /dev/null
+++ b/backend/internal/redis/taskqueue/queue.go
@@ -0,0 +1,545 @@
+package taskqueue
+
+import (
+ "bytes"
+ "context"
+ "encoding/base64"
+ "encoding/gob"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "log/slog"
+ "reflect"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/redis/go-redis/v9"
+)
+
+type Encoding uint8
+
+const (
+ EncodingJSON Encoding = iota
+ EncodingGob
+
+ ResultKey = "result"
+ ErrorKey = "error"
+ NextAttemptKey = "next_attempt"
+)
+
+var MaxAttempts = 3
+var ErrTaskNotFound = errors.New("task not found")
+
+type TaskQueue[T any] interface {
+ // Enqueue adds a task to the queue.
+ // Returns the generated task ID.
+ Enqueue(ctx context.Context, data T) (TaskInfo[T], error)
+
+ // Dequeue removes a task from the queue and returns it.
+ // The task data is placed into dataOut.
+ //
+ // Dequeue blocks until a task is available, timeout, or the context is canceled.
+ // The returned task is placed in a pending state for lockTimeout duration.
+ // The task must be completed with Complete or extended with Extend before the lock expires.
+ // If the lock expires, the task is returned to the queue, where it may be picked up by another worker.
+ Dequeue(
+ ctx context.Context,
+ lockTimeout,
+ timeout time.Duration,
+ ) (*TaskInfo[T], error)
+
+ // Extend extends the lock on a task.
+ Extend(ctx context.Context, taskID TaskID) error
+
+ // Complete marks a task as complete. Optionally, an error can be provided to store additional information.
+ Complete(ctx context.Context, taskID TaskID, result string) error
+
+ // Data returns the info of a task.
+ Data(ctx context.Context, taskID TaskID) (TaskInfo[T], error)
+
+ // Return returns a task to the queue and returns the new task ID.
+ // Increments the attempt counter.
+ // Tasks with too many attempts (MaxAttempts) are considered failed and aren't returned to the queue.
+ Return(ctx context.Context, taskID TaskID, err error) (TaskID, error)
+
+ // List returns a list of task IDs in the queue.
+ // The list is ordered by the time the task was added to the queue. The most recent task is first.
+ // The count parameter limits the number of tasks returned.
+ // The start and end parameters limit the range of tasks returned.
+ // End is exclusive.
+ // Start must be before end.
+ List(ctx context.Context, start, end TaskID, count int64) ([]TaskInfo[T], error)
+}
+
+type TaskInfo[T any] struct {
+ // ID is the unique identifier of the task. Generated by redis.
+ ID TaskID
+ // Data is the task data. Stored in stream.
+ Data T
+ // Attempts is the number of times the task has been attempted. Stored in stream.
+ Attempts uint8
+ // Result is the result of the task. Stored in a hash.
+ Result isTaskResult
+}
+
+type isTaskResult interface {
+ isTaskResult()
+}
+
+type TaskResultSuccess struct {
+ Result string
+}
+
+type TaskResultError struct {
+ Error string
+ NextAttempt TaskID
+}
+
+func (*TaskResultSuccess) isTaskResult() {}
+func (*TaskResultError) isTaskResult() {}
+
+type TaskID struct {
+ timestamp time.Time
+ sequence uint64
+}
+
+func NewTaskID(timestamp time.Time, sequence uint64) TaskID {
+ return TaskID{timestamp, sequence}
+}
+
+func ParseTaskID(s string) (TaskID, error) {
+ tPart, sPart, ok := strings.Cut(s, "-")
+ if !ok {
+ return TaskID{}, errors.New("invalid task ID")
+ }
+
+ timestamp, err := strconv.ParseInt(tPart, 10, 64)
+ if err != nil {
+ return TaskID{}, err
+ }
+
+ sequence, err := strconv.ParseUint(sPart, 10, 64)
+ if err != nil {
+ return TaskID{}, err
+ }
+
+ return NewTaskID(time.UnixMilli(timestamp), sequence), nil
+}
+
+func (t TaskID) Timestamp() time.Time {
+ return t.timestamp
+}
+
+func (t TaskID) String() string {
+ tPart := strconv.FormatInt(t.timestamp.UnixMilli(), 10)
+ sPart := strconv.FormatUint(t.sequence, 10)
+ return tPart + "-" + sPart
+}
+
+type taskQueue[T any] struct {
+ rdb *redis.Client
+ encoding Encoding
+
+ streamKey string
+ groupName string
+
+ workerName string
+}
+
+func New[T any](ctx context.Context, rdb *redis.Client, name string, workerName string, opts ...Option[T]) (TaskQueue[T], error) {
+ tq := &taskQueue[T]{
+ rdb: rdb,
+ encoding: EncodingJSON,
+ streamKey: "taskqueue:" + name,
+ groupName: "default",
+ workerName: workerName,
+ }
+
+ for _, opt := range opts {
+ opt(tq)
+ }
+
+ // Create the stream if it doesn't exist
+ err := rdb.XGroupCreateMkStream(ctx, tq.streamKey, tq.groupName, "0").Err()
+ if err != nil && err.Error() != "BUSYGROUP Consumer Group name already exists" {
+ return nil, err
+ }
+
+ return tq, nil
+}
+
+func (q *taskQueue[T]) Enqueue(ctx context.Context, data T) (TaskInfo[T], error) {
+ task := TaskInfo[T]{
+ Data: data,
+ Attempts: 0,
+ }
+
+ values, err := encode[T](task, q.encoding)
+ if err != nil {
+ return TaskInfo[T]{}, err
+ }
+
+ taskID, err := q.rdb.XAdd(ctx, &redis.XAddArgs{
+ Stream: q.streamKey,
+ Values: values,
+ }).Result()
+ if err != nil {
+ return TaskInfo[T]{}, err
+ }
+
+ id, err := ParseTaskID(taskID)
+ if err != nil {
+ return TaskInfo[T]{}, err
+ }
+ task.ID = id
+ return task, nil
+}
+
+func (q *taskQueue[T]) Dequeue(ctx context.Context, lockTimeout, timeout time.Duration) (*TaskInfo[T], error) {
+ // Try to recover a task
+ task, err := q.recover(ctx, lockTimeout)
+ if err != nil {
+ return nil, err
+ }
+ if task != nil {
+ return task, nil
+ }
+
+ // Check for new tasks
+ ids, err := q.rdb.XReadGroup(ctx, &redis.XReadGroupArgs{
+ Group: q.groupName,
+ Consumer: q.workerName,
+ Streams: []string{q.streamKey, ">"},
+ Count: 1,
+ Block: timeout,
+ }).Result()
+ if err != nil && !errors.Is(err, redis.Nil) {
+ return nil, err
+ }
+
+ if len(ids) == 0 || len(ids[0].Messages) == 0 || errors.Is(err, redis.Nil) {
+ return nil, nil
+ }
+
+ msg := ids[0].Messages[0]
+ task = new(TaskInfo[T])
+ *task, err = decode[T](&msg, q.encoding)
+ if err != nil {
+ return nil, err
+ }
+ return task, nil
+}
+
+func (q *taskQueue[T]) Extend(ctx context.Context, taskID TaskID) error {
+ _, err := q.rdb.XClaim(ctx, &redis.XClaimArgs{
+ Stream: q.streamKey,
+ Group: q.groupName,
+ Consumer: q.workerName,
+ MinIdle: 0,
+ Messages: []string{taskID.String()},
+ }).Result()
+ if err != nil && !errors.Is(err, redis.Nil) {
+ return err
+ }
+ return nil
+}
+
+func (q *taskQueue[T]) Data(ctx context.Context, taskID TaskID) (TaskInfo[T], error) {
+ msg, err := q.rdb.XRange(ctx, q.streamKey, taskID.String(), taskID.String()).Result()
+ if err != nil {
+ return TaskInfo[T]{}, err
+ }
+
+ if len(msg) == 0 {
+ return TaskInfo[T]{}, ErrTaskNotFound
+ }
+
+ t, err := decode[T](&msg[0], q.encoding)
+ if err != nil {
+ return TaskInfo[T]{}, err
+ }
+
+ t.Result, err = q.getResult(ctx, taskID)
+ if err != nil {
+ return TaskInfo[T]{}, nil
+ }
+ return t, nil
+}
+
+func (q *taskQueue[T]) Complete(ctx context.Context, taskID TaskID, result string) error {
+ return q.ack(ctx, taskID, false, result)
+}
+
+var retScript = redis.NewScript(`
+local stream_key = KEYS[1]
+local hash_key = KEYS[2]
+
+-- Re-add the task to the stream
+local task_id = redis.call('XADD', stream_key, '*', unpack(ARGV))
+
+-- Update the hash key to point to the new task
+redis.call('HSET', hash_key, 'next_attempt', task_id)
+
+return task_id
+`)
+
+func (q *taskQueue[T]) Return(ctx context.Context, taskID TaskID, err1 error) (TaskID, error) {
+ msgs, err := q.rdb.XRange(ctx, q.streamKey, taskID.String(), taskID.String()).Result()
+ if err != nil {
+ return TaskID{}, err
+ }
+ if len(msgs) == 0 {
+ return TaskID{}, ErrTaskNotFound
+ }
+
+ var ackMsg string
+ if err1 != nil {
+ ackMsg = err1.Error()
+ }
+
+ // Ack the task
+ err = q.ack(ctx, taskID, true, ackMsg)
+ if err != nil {
+ return TaskID{}, err
+ }
+
+ msg := msgs[0]
+ task, err := decode[T](&msg, q.encoding)
+ if err != nil {
+ return TaskID{}, err
+ }
+
+ task.Attempts++
+ if int(task.Attempts) >= MaxAttempts {
+ // Task has failed
+ slog.ErrorContext(ctx, "task failed completely",
+ "taskID", taskID,
+ "data", task.Data,
+ "attempts", task.Attempts,
+ "maxAttempts", MaxAttempts,
+ )
+ return TaskID{}, nil
+ }
+
+ valuesMap, err := encode[T](task, q.encoding)
+ if err != nil {
+ return TaskID{}, err
+ }
+
+ values := make([]string, 0, len(valuesMap)*2)
+ for k, v := range valuesMap {
+ values = append(values, k, v)
+ }
+
+ keys := []string{
+ q.streamKey,
+ fmt.Sprintf("%s:%s", q.streamKey, taskID.String()),
+ }
+ newTaskId, err := retScript.Run(ctx, q.rdb, keys, values).Result()
+ if err != nil {
+ return TaskID{}, err
+ }
+
+ return ParseTaskID(newTaskId.(string))
+}
+
+func (q *taskQueue[T]) List(ctx context.Context, start, end TaskID, count int64) ([]TaskInfo[T], error) {
+ if !start.timestamp.IsZero() && !end.timestamp.IsZero() && start.timestamp.After(end.timestamp) {
+ return nil, errors.New("start must be before end")
+ }
+
+ var startStr, endStr string
+ if !start.timestamp.IsZero() {
+ startStr = start.String()
+ } else {
+ startStr = "-"
+ }
+ if !end.timestamp.IsZero() {
+ endStr = "(" + end.String()
+ } else {
+ endStr = "+"
+ }
+
+ msgs, err := q.rdb.XRevRangeN(ctx, q.streamKey, endStr, startStr, count).Result()
+ if err != nil {
+ return nil, err
+ }
+ if len(msgs) == 0 {
+ return []TaskInfo[T]{}, nil
+ }
+
+ tasks := make([]TaskInfo[T], len(msgs))
+ for i := range msgs {
+ tasks[i], err = decode[T](&msgs[i], q.encoding)
+ if err != nil {
+ return nil, err
+ }
+
+ tasks[i].Result, err = q.getResult(ctx, tasks[i].ID)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return tasks, nil
+}
+
+func (q *taskQueue[T]) getResult(ctx context.Context, taskID TaskID) (isTaskResult, error) {
+ key := fmt.Sprintf("%s:%s", q.streamKey, taskID.String())
+ results, err := q.rdb.HMGet(ctx, key, ResultKey, ErrorKey, NextAttemptKey).Result()
+ if err != nil {
+ return nil, err
+ }
+
+ var ret isTaskResult
+ if results[0] != nil {
+ ret = &TaskResultSuccess{Result: results[0].(string)}
+ } else if results[1] != nil {
+ ret = &TaskResultError{Error: results[1].(string)}
+ if results[2] != nil {
+ nextAttempt, err := ParseTaskID(results[2].(string))
+ if err != nil {
+ return nil, err
+ }
+ ret.(*TaskResultError).NextAttempt = nextAttempt
+ }
+ }
+ return ret, nil
+}
+
+func (q *taskQueue[T]) recover(ctx context.Context, idleTimeout time.Duration) (*TaskInfo[T], error) {
+ msgs, _, err := q.rdb.XAutoClaim(ctx, &redis.XAutoClaimArgs{
+ Stream: q.streamKey,
+ Group: q.groupName,
+ MinIdle: idleTimeout,
+ Start: "0",
+ Count: 1,
+ Consumer: q.workerName,
+ }).Result()
+ if err != nil {
+ return nil, err
+ }
+
+ if len(msgs) == 0 {
+ return nil, nil
+ }
+
+ msg := msgs[0]
+ task, err := decode[T](&msg, q.encoding)
+ if err != nil {
+ return nil, err
+ }
+ return &task, nil
+}
+
+func (q *taskQueue[T]) ack(ctx context.Context, taskID TaskID, errored bool, msg string) error {
+ key := fmt.Sprintf("%s:%s", q.streamKey, taskID.String())
+ _, err := q.rdb.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
+ pipe.XAck(ctx, q.streamKey, q.groupName, taskID.String())
+ if errored {
+ pipe.HSet(ctx, key, ErrorKey, msg)
+ } else {
+ pipe.HSet(ctx, key, ResultKey, msg)
+ }
+ return nil
+ })
+ return err
+}
+
+func decode[T any](msg *redis.XMessage, encoding Encoding) (task TaskInfo[T], err error) {
+ task.ID, err = ParseTaskID(msg.ID)
+ if err != nil {
+ return
+ }
+
+ err = getField(msg, "attempts", &task.Attempts)
+ if err != nil {
+ return
+ }
+
+ var data string
+ err = getField(msg, "data", &data)
+ if err != nil {
+ return
+ }
+
+ switch encoding {
+ case EncodingJSON:
+ err = json.Unmarshal([]byte(data), &task.Data)
+ case EncodingGob:
+ var decoded []byte
+ decoded, err = base64.StdEncoding.DecodeString(data)
+ if err != nil {
+ return
+ }
+ err = gob.NewDecoder(bytes.NewReader(decoded)).Decode(&task.Data)
+ default:
+ err = errors.New("unsupported encoding")
+ }
+ return
+}
+
+func getField(msg *redis.XMessage, field string, v any) error {
+ vVal, ok := msg.Values[field]
+ if !ok {
+ return errors.New("missing field")
+ }
+
+ vStr, ok := vVal.(string)
+ if !ok {
+ return errors.New("invalid field type")
+ }
+
+ value := reflect.ValueOf(v).Elem()
+ switch value.Kind() {
+ case reflect.String:
+ value.SetString(vStr)
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ i, err := strconv.ParseInt(vStr, 10, 64)
+ if err != nil {
+ return err
+ }
+ value.SetInt(i)
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ i, err := strconv.ParseUint(vStr, 10, 64)
+ if err != nil {
+ return err
+ }
+ value.SetUint(i)
+ case reflect.Bool:
+ b, err := strconv.ParseBool(vStr)
+ if err != nil {
+ return err
+ }
+ value.SetBool(b)
+ default:
+ return errors.New("unsupported field type")
+ }
+ return nil
+}
+
+func encode[T any](task TaskInfo[T], encoding Encoding) (ret map[string]string, err error) {
+ ret = make(map[string]string)
+ ret["attempts"] = strconv.FormatUint(uint64(task.Attempts), 10)
+
+ switch encoding {
+ case EncodingJSON:
+ var data []byte
+ data, err = json.Marshal(task.Data)
+ if err != nil {
+ return
+ }
+ ret["data"] = string(data)
+ case EncodingGob:
+ var data bytes.Buffer
+ err = gob.NewEncoder(&data).Encode(task.Data)
+ if err != nil {
+ return
+ }
+ ret["data"] = base64.StdEncoding.EncodeToString(data.Bytes())
+ default:
+ err = errors.New("unsupported encoding")
+ }
+ return
+}
diff --git a/backend/internal/redis/taskqueue/queue_test.go b/backend/internal/redis/taskqueue/queue_test.go
new file mode 100644
index 0000000..ee95d39
--- /dev/null
+++ b/backend/internal/redis/taskqueue/queue_test.go
@@ -0,0 +1,467 @@
+package taskqueue
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "log"
+ "testing"
+ "time"
+
+ "github.com/ory/dockertest/v3"
+ "github.com/ory/dockertest/v3/docker"
+ "github.com/redis/go-redis/v9"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+var client *redis.Client
+
+func TestMain(m *testing.M) {
+ pool, err := dockertest.NewPool("")
+ if err != nil {
+ log.Fatalf("Could not create pool: %s", err)
+ }
+
+ err = pool.Client.Ping()
+ if err != nil {
+ log.Fatalf("Could not connect to Docker: %s", err)
+ }
+
+ //resource, err := pool.Run("redis", "7", nil)
+ resource, err := pool.RunWithOptions(&dockertest.RunOptions{
+ Repository: "redis",
+ Tag: "7",
+ }, func(config *docker.HostConfig) {
+ config.AutoRemove = true
+ config.RestartPolicy = docker.RestartPolicy{Name: "no"}
+ })
+ if err != nil {
+ log.Fatalf("Could not start resource: %s", err)
+ }
+
+ //_ = resource.Expire(60)
+
+ if err = pool.Retry(func() error {
+ client = redis.NewClient(&redis.Options{
+ Addr: fmt.Sprintf("localhost:%s", resource.GetPort("6379/tcp")),
+ })
+ return client.Ping(context.Background()).Err()
+ }); err != nil {
+ log.Fatalf("Could not connect to redis: %s", err)
+ }
+
+ defer func() {
+ if err = client.Close(); err != nil {
+ log.Printf("Could not close client: %s", err)
+ }
+ if err = pool.Purge(resource); err != nil {
+ log.Fatalf("Could not purge resource: %s", err)
+ }
+ }()
+
+ m.Run()
+}
+
+func TestTaskQueue(t *testing.T) {
+ if testing.Short() {
+ t.Skip()
+ }
+
+ lockTimeout := 100 * time.Millisecond
+
+ tests := []struct {
+ name string
+ f func(t *testing.T)
+ }{
+ {
+ name: "Create queue",
+ f: func(t *testing.T) {
+ q, err := New[string](context.Background(), client, "test", "worker")
+ require.NoError(t, err)
+ require.NotNil(t, q)
+ },
+ },
+ {
+ name: "enqueue & dequeue",
+ f: func(t *testing.T) {
+ q, err := New[string](context.Background(), client, "test", "worker")
+ require.NoError(t, err)
+ require.NotNil(t, q)
+
+ taskId, err := q.Enqueue(context.Background(), "hello")
+ require.NoError(t, err)
+ require.NotEmpty(t, taskId)
+
+ task, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond)
+ require.NoError(t, err)
+ require.NotNil(t, task)
+ require.Equal(t, "hello", task.Data)
+ },
+ },
+ {
+ name: "complex data",
+ f: func(t *testing.T) {
+ type foo struct {
+ A int
+ B string
+ }
+
+ q, err := New[foo](context.Background(), client, "test", "worker")
+ require.NoError(t, err)
+ require.NotNil(t, q)
+
+ taskId, err := q.Enqueue(context.Background(), foo{A: 42, B: "hello"})
+ require.NoError(t, err)
+ require.NotEmpty(t, taskId)
+
+ task, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond)
+ require.NoError(t, err)
+ require.NotNil(t, task)
+ require.Equal(t, foo{A: 42, B: "hello"}, task.Data)
+ },
+ },
+ {
+ name: "different workers",
+ f: func(t *testing.T) {
+ q1, err := New[string](context.Background(), client, "test", "worker1")
+ require.NoError(t, err)
+ require.NotNil(t, q1)
+
+ q2, err := New[string](context.Background(), client, "test", "worker2")
+ require.NoError(t, err)
+ require.NotNil(t, q2)
+
+ taskId, err := q1.Enqueue(context.Background(), "hello")
+ require.NoError(t, err)
+ require.NotEmpty(t, taskId)
+
+ task, err := q2.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond)
+ require.NoError(t, err)
+ require.NotNil(t, task)
+ assert.Equal(t, "hello", task.Data)
+ },
+ },
+ {
+ name: "complete",
+ f: func(t *testing.T) {
+ q, err := New[string](context.Background(), client, "test", "worker")
+ require.NoError(t, err)
+ require.NotNil(t, q)
+
+ // Enqueue a task
+ taskId, err := q.Enqueue(context.Background(), "hello")
+ require.NoError(t, err)
+ require.NotEmpty(t, taskId)
+
+ // Dequeue the task
+ task, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond)
+ require.NoError(t, err)
+ require.NotNil(t, task)
+ assert.Equal(t, "hello", task.Data)
+
+ // Complete the task
+ err = q.Complete(context.Background(), task.ID, "done")
+ require.NoError(t, err)
+
+ // Try to dequeue the task again
+ task, err = q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond)
+ require.NoError(t, err)
+ require.Nil(t, task)
+ },
+ },
+ {
+ name: "timeout",
+ f: func(t *testing.T) {
+ q, err := New[string](context.Background(), client, "test", "worker")
+ require.NoError(t, err)
+ require.NotNil(t, q)
+
+ // Enqueue a task
+ taskId, err := q.Enqueue(context.Background(), "hello")
+ require.NoError(t, err)
+ require.NotEmpty(t, taskId)
+
+ // Dequeue the task
+ task, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond)
+ require.NoError(t, err)
+ require.NotNil(t, task)
+ assert.Equal(t, "hello", task.Data)
+
+ // Wait for the lock to expire
+ time.Sleep(lockTimeout + 10*time.Millisecond)
+
+ // Try to dequeue the task again
+ task, err = q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond)
+ require.NoError(t, err)
+ require.NotNil(t, task)
+ assert.Equal(t, "hello", task.Data)
+ },
+ },
+ {
+ name: "extend",
+ f: func(t *testing.T) {
+ q, err := New[string](context.Background(), client, "test", "worker")
+ require.NoError(t, err)
+ require.NotNil(t, q)
+
+ // Enqueue a task
+ taskId, err := q.Enqueue(context.Background(), "hello")
+ require.NoError(t, err)
+ require.NotEmpty(t, taskId)
+
+ // Dequeue the task
+ task, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond)
+ require.NoError(t, err)
+ require.NotNil(t, task)
+ assert.Equal(t, "hello", task.Data)
+
+ // Wait for the lock to expire
+ time.Sleep(lockTimeout + 10*time.Millisecond)
+
+ // Extend the lock
+ err = q.Extend(context.Background(), task.ID)
+ require.NoError(t, err)
+
+ // Try to dequeue the task again
+ task, err = q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond)
+ require.NoError(t, err)
+ require.Nil(t, task)
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if err := client.FlushDB(context.Background()).Err(); err != nil {
+ t.Fatal(err)
+ }
+
+ tt.f(t)
+ })
+ }
+
+ _ = client.FlushDB(context.Background())
+}
+
+func TestTaskQueue_List(t *testing.T) {
+ if testing.Short() {
+ t.Skip()
+ }
+
+ tests := []struct {
+ name string
+ f func(t *testing.T)
+ }{
+ {
+ name: "empty",
+ f: func(t *testing.T) {
+ q, err := New[any](context.Background(), client, "test", "worker")
+ require.NoError(t, err)
+ require.NotNil(t, q)
+
+ tasks, err := q.List(context.Background(), TaskID{}, TaskID{}, 1)
+ require.NoError(t, err)
+ assert.Empty(t, tasks)
+ },
+ },
+ {
+ name: "single",
+ f: func(t *testing.T) {
+ q, err := New[string](context.Background(), client, "test", "worker")
+ require.NoError(t, err)
+ require.NotNil(t, q)
+
+ taskID, err := q.Enqueue(context.Background(), "hello")
+ require.NoError(t, err)
+
+ tasks, err := q.List(context.Background(), TaskID{}, TaskID{}, 1)
+ require.NoError(t, err)
+ assert.Len(t, tasks, 1)
+ assert.Equal(t, taskID, tasks[0])
+ },
+ },
+ {
+ name: "multiple",
+ f: func(t *testing.T) {
+ q, err := New[string](context.Background(), client, "test", "worker")
+ require.NoError(t, err)
+ require.NotNil(t, q)
+
+ taskID, err := q.Enqueue(context.Background(), "hello")
+ require.NoError(t, err)
+ taskID2, err := q.Enqueue(context.Background(), "world")
+ require.NoError(t, err)
+
+ tasks, err := q.List(context.Background(), TaskID{}, TaskID{}, 2)
+ require.NoError(t, err)
+ assert.Len(t, tasks, 2)
+ assert.Equal(t, taskID, tasks[1])
+ assert.Equal(t, taskID2, tasks[0])
+ },
+ },
+ {
+ name: "multiple limited",
+ f: func(t *testing.T) {
+ q, err := New[string](context.Background(), client, "test", "worker")
+ require.NoError(t, err)
+ require.NotNil(t, q)
+
+ _, err = q.Enqueue(context.Background(), "hello")
+ require.NoError(t, err)
+ taskID2, err := q.Enqueue(context.Background(), "world")
+ require.NoError(t, err)
+
+ tasks, err := q.List(context.Background(), TaskID{}, TaskID{}, 1)
+ require.NoError(t, err)
+ assert.Len(t, tasks, 1)
+ assert.Equal(t, taskID2, tasks[0])
+ },
+ },
+ {
+ name: "multiple time range",
+ f: func(t *testing.T) {
+ q, err := New[string](context.Background(), client, "test", "worker")
+ require.NoError(t, err)
+ require.NotNil(t, q)
+
+ taskID, err := q.Enqueue(context.Background(), "hello")
+ require.NoError(t, err)
+ time.Sleep(10 * time.Millisecond)
+ taskID2, err := q.Enqueue(context.Background(), "world")
+ require.NoError(t, err)
+
+ tasks, err := q.List(context.Background(), TaskID{}, taskID2.ID, 100)
+ require.NoError(t, err)
+ assert.Len(t, tasks, 1)
+ assert.Equal(t, taskID, tasks[0])
+
+ tasks, err = q.List(context.Background(), taskID2.ID, TaskID{}, 100)
+ require.NoError(t, err)
+ assert.Len(t, tasks, 1)
+ assert.Equal(t, taskID2, tasks[0])
+ },
+ },
+ {
+ name: "completed tasks",
+ f: func(t *testing.T) {
+ q, err := New[string](context.Background(), client, "test", "worker")
+ require.NoError(t, err)
+ require.NotNil(t, q)
+
+ task1, err := q.Enqueue(context.Background(), "hello")
+ require.NoError(t, err)
+ task2, err := q.Enqueue(context.Background(), "world")
+ require.NoError(t, err)
+
+ err = q.Complete(context.Background(), task1.ID, "done")
+ require.NoError(t, err)
+
+ tasks, err := q.List(context.Background(), TaskID{}, TaskID{}, 100)
+ require.NoError(t, err)
+ assert.Len(t, tasks, 2)
+ assert.Equal(t, task2, tasks[0])
+
+ assert.Equal(t, "hello", tasks[1].Data)
+ require.IsType(t, &TaskResultSuccess{}, tasks[1].Result)
+ assert.Equal(t, "done", tasks[1].Result.(*TaskResultSuccess).Result)
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if err := client.FlushDB(context.Background()).Err(); err != nil {
+ t.Fatal(err)
+ }
+
+ tt.f(t)
+ })
+ }
+
+ _ = client.FlushDB(context.Background())
+}
+
+func TestTaskQueue_Return(t *testing.T) {
+ if testing.Short() {
+ t.Skip()
+ }
+
+ lockTimeout := 100 * time.Millisecond
+
+ tests := []struct {
+ name string
+ f func(t *testing.T)
+ }{
+ {
+ name: "simple",
+ f: func(t *testing.T) {
+ q, err := New[string](context.Background(), client, "test", "worker")
+ require.NoError(t, err)
+ require.NotNil(t, q)
+
+ task1, err := q.Enqueue(context.Background(), "hello")
+ require.NoError(t, err)
+
+ id := claimAndFail(t, q, lockTimeout)
+
+ task2, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond)
+ require.NoError(t, err)
+ require.NotNil(t, task2)
+ assert.Equal(t, task2.ID, id)
+ assert.Equal(t, task1.Data, task2.Data)
+ assert.Equal(t, uint8(1), task2.Attempts)
+
+ task1Data, err := q.Data(context.Background(), task1.ID)
+ require.NoError(t, err)
+ assert.Equal(t, task1Data.ID, task1.ID)
+ assert.Equal(t, task1Data.Data, task1.Data)
+ assert.IsType(t, &TaskResultError{}, task1Data.Result)
+ assert.Equal(t, "failed", task1Data.Result.(*TaskResultError).Error)
+ assert.Equal(t, task2.ID, task1Data.Result.(*TaskResultError).NextAttempt)
+ },
+ },
+ {
+ name: "failure",
+ f: func(t *testing.T) {
+ q, err := New[string](context.Background(), client, "test", "worker")
+ require.NoError(t, err)
+ require.NotNil(t, q)
+
+ _, err = q.Enqueue(context.Background(), "hello")
+ require.NoError(t, err)
+
+ claimAndFail(t, q, lockTimeout)
+ claimAndFail(t, q, lockTimeout)
+ claimAndFail(t, q, lockTimeout)
+
+ task3, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond)
+ require.NoError(t, err)
+ assert.Nil(t, task3)
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if err := client.FlushDB(context.Background()).Err(); err != nil {
+ t.Fatal(err)
+ }
+
+ tt.f(t)
+ })
+ }
+
+ _ = client.FlushDB(context.Background())
+}
+
+func claimAndFail[T any](t *testing.T, q TaskQueue[T], lockTimeout time.Duration) TaskID {
+ task, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond)
+ require.NoError(t, err)
+ require.NotNil(t, task)
+
+ id, err := q.Return(context.Background(), task.ID, errors.New("failed"))
+ require.NoError(t, err)
+ assert.NotEqual(t, task.ID, id)
+ return id
+}