diff options
author | 2024-08-11 13:15:50 -0700 | |
---|---|---|
committer | 2024-08-11 13:15:50 -0700 | |
commit | 6a3c21fb0b1c126849f2bbff494403bbe901448e (patch) | |
tree | 5d7805524357c2c8a9819c39d2051a4e3633a1d5 /backend/internal/redis/taskqueue | |
parent | 29c6040a51616e9e4cf6c70ee16391b2a3b238c9 (diff) | |
parent | f34b92ded11b07f78575ac62c260a380c468e5ea (diff) | |
download | ibd-trader-6a3c21fb0b1c126849f2bbff494403bbe901448e.tar.gz ibd-trader-6a3c21fb0b1c126849f2bbff494403bbe901448e.tar.zst ibd-trader-6a3c21fb0b1c126849f2bbff494403bbe901448e.zip |
Merge remote-tracking branch 'backend/main'
Diffstat (limited to 'backend/internal/redis/taskqueue')
-rw-r--r-- | backend/internal/redis/taskqueue/options.go | 9 | ||||
-rw-r--r-- | backend/internal/redis/taskqueue/queue.go | 545 | ||||
-rw-r--r-- | backend/internal/redis/taskqueue/queue_test.go | 467 |
3 files changed, 1021 insertions, 0 deletions
diff --git a/backend/internal/redis/taskqueue/options.go b/backend/internal/redis/taskqueue/options.go new file mode 100644 index 0000000..2d5a23f --- /dev/null +++ b/backend/internal/redis/taskqueue/options.go @@ -0,0 +1,9 @@ +package taskqueue + +type Option[T any] func(*taskQueue[T]) + +func WithEncoding[T any](encoding Encoding) Option[T] { + return func(o *taskQueue[T]) { + o.encoding = encoding + } +} diff --git a/backend/internal/redis/taskqueue/queue.go b/backend/internal/redis/taskqueue/queue.go new file mode 100644 index 0000000..a4b799e --- /dev/null +++ b/backend/internal/redis/taskqueue/queue.go @@ -0,0 +1,545 @@ +package taskqueue + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/gob" + "encoding/json" + "errors" + "fmt" + "log/slog" + "reflect" + "strconv" + "strings" + "time" + + "github.com/redis/go-redis/v9" +) + +type Encoding uint8 + +const ( + EncodingJSON Encoding = iota + EncodingGob + + ResultKey = "result" + ErrorKey = "error" + NextAttemptKey = "next_attempt" +) + +var MaxAttempts = 3 +var ErrTaskNotFound = errors.New("task not found") + +type TaskQueue[T any] interface { + // Enqueue adds a task to the queue. + // Returns the generated task ID. + Enqueue(ctx context.Context, data T) (TaskInfo[T], error) + + // Dequeue removes a task from the queue and returns it. + // The task data is placed into dataOut. + // + // Dequeue blocks until a task is available, timeout, or the context is canceled. + // The returned task is placed in a pending state for lockTimeout duration. + // The task must be completed with Complete or extended with Extend before the lock expires. + // If the lock expires, the task is returned to the queue, where it may be picked up by another worker. + Dequeue( + ctx context.Context, + lockTimeout, + timeout time.Duration, + ) (*TaskInfo[T], error) + + // Extend extends the lock on a task. + Extend(ctx context.Context, taskID TaskID) error + + // Complete marks a task as complete. Optionally, an error can be provided to store additional information. + Complete(ctx context.Context, taskID TaskID, result string) error + + // Data returns the info of a task. + Data(ctx context.Context, taskID TaskID) (TaskInfo[T], error) + + // Return returns a task to the queue and returns the new task ID. + // Increments the attempt counter. + // Tasks with too many attempts (MaxAttempts) are considered failed and aren't returned to the queue. + Return(ctx context.Context, taskID TaskID, err error) (TaskID, error) + + // List returns a list of task IDs in the queue. + // The list is ordered by the time the task was added to the queue. The most recent task is first. + // The count parameter limits the number of tasks returned. + // The start and end parameters limit the range of tasks returned. + // End is exclusive. + // Start must be before end. + List(ctx context.Context, start, end TaskID, count int64) ([]TaskInfo[T], error) +} + +type TaskInfo[T any] struct { + // ID is the unique identifier of the task. Generated by redis. + ID TaskID + // Data is the task data. Stored in stream. + Data T + // Attempts is the number of times the task has been attempted. Stored in stream. + Attempts uint8 + // Result is the result of the task. Stored in a hash. + Result isTaskResult +} + +type isTaskResult interface { + isTaskResult() +} + +type TaskResultSuccess struct { + Result string +} + +type TaskResultError struct { + Error string + NextAttempt TaskID +} + +func (*TaskResultSuccess) isTaskResult() {} +func (*TaskResultError) isTaskResult() {} + +type TaskID struct { + timestamp time.Time + sequence uint64 +} + +func NewTaskID(timestamp time.Time, sequence uint64) TaskID { + return TaskID{timestamp, sequence} +} + +func ParseTaskID(s string) (TaskID, error) { + tPart, sPart, ok := strings.Cut(s, "-") + if !ok { + return TaskID{}, errors.New("invalid task ID") + } + + timestamp, err := strconv.ParseInt(tPart, 10, 64) + if err != nil { + return TaskID{}, err + } + + sequence, err := strconv.ParseUint(sPart, 10, 64) + if err != nil { + return TaskID{}, err + } + + return NewTaskID(time.UnixMilli(timestamp), sequence), nil +} + +func (t TaskID) Timestamp() time.Time { + return t.timestamp +} + +func (t TaskID) String() string { + tPart := strconv.FormatInt(t.timestamp.UnixMilli(), 10) + sPart := strconv.FormatUint(t.sequence, 10) + return tPart + "-" + sPart +} + +type taskQueue[T any] struct { + rdb *redis.Client + encoding Encoding + + streamKey string + groupName string + + workerName string +} + +func New[T any](ctx context.Context, rdb *redis.Client, name string, workerName string, opts ...Option[T]) (TaskQueue[T], error) { + tq := &taskQueue[T]{ + rdb: rdb, + encoding: EncodingJSON, + streamKey: "taskqueue:" + name, + groupName: "default", + workerName: workerName, + } + + for _, opt := range opts { + opt(tq) + } + + // Create the stream if it doesn't exist + err := rdb.XGroupCreateMkStream(ctx, tq.streamKey, tq.groupName, "0").Err() + if err != nil && err.Error() != "BUSYGROUP Consumer Group name already exists" { + return nil, err + } + + return tq, nil +} + +func (q *taskQueue[T]) Enqueue(ctx context.Context, data T) (TaskInfo[T], error) { + task := TaskInfo[T]{ + Data: data, + Attempts: 0, + } + + values, err := encode[T](task, q.encoding) + if err != nil { + return TaskInfo[T]{}, err + } + + taskID, err := q.rdb.XAdd(ctx, &redis.XAddArgs{ + Stream: q.streamKey, + Values: values, + }).Result() + if err != nil { + return TaskInfo[T]{}, err + } + + id, err := ParseTaskID(taskID) + if err != nil { + return TaskInfo[T]{}, err + } + task.ID = id + return task, nil +} + +func (q *taskQueue[T]) Dequeue(ctx context.Context, lockTimeout, timeout time.Duration) (*TaskInfo[T], error) { + // Try to recover a task + task, err := q.recover(ctx, lockTimeout) + if err != nil { + return nil, err + } + if task != nil { + return task, nil + } + + // Check for new tasks + ids, err := q.rdb.XReadGroup(ctx, &redis.XReadGroupArgs{ + Group: q.groupName, + Consumer: q.workerName, + Streams: []string{q.streamKey, ">"}, + Count: 1, + Block: timeout, + }).Result() + if err != nil && !errors.Is(err, redis.Nil) { + return nil, err + } + + if len(ids) == 0 || len(ids[0].Messages) == 0 || errors.Is(err, redis.Nil) { + return nil, nil + } + + msg := ids[0].Messages[0] + task = new(TaskInfo[T]) + *task, err = decode[T](&msg, q.encoding) + if err != nil { + return nil, err + } + return task, nil +} + +func (q *taskQueue[T]) Extend(ctx context.Context, taskID TaskID) error { + _, err := q.rdb.XClaim(ctx, &redis.XClaimArgs{ + Stream: q.streamKey, + Group: q.groupName, + Consumer: q.workerName, + MinIdle: 0, + Messages: []string{taskID.String()}, + }).Result() + if err != nil && !errors.Is(err, redis.Nil) { + return err + } + return nil +} + +func (q *taskQueue[T]) Data(ctx context.Context, taskID TaskID) (TaskInfo[T], error) { + msg, err := q.rdb.XRange(ctx, q.streamKey, taskID.String(), taskID.String()).Result() + if err != nil { + return TaskInfo[T]{}, err + } + + if len(msg) == 0 { + return TaskInfo[T]{}, ErrTaskNotFound + } + + t, err := decode[T](&msg[0], q.encoding) + if err != nil { + return TaskInfo[T]{}, err + } + + t.Result, err = q.getResult(ctx, taskID) + if err != nil { + return TaskInfo[T]{}, nil + } + return t, nil +} + +func (q *taskQueue[T]) Complete(ctx context.Context, taskID TaskID, result string) error { + return q.ack(ctx, taskID, false, result) +} + +var retScript = redis.NewScript(` +local stream_key = KEYS[1] +local hash_key = KEYS[2] + +-- Re-add the task to the stream +local task_id = redis.call('XADD', stream_key, '*', unpack(ARGV)) + +-- Update the hash key to point to the new task +redis.call('HSET', hash_key, 'next_attempt', task_id) + +return task_id +`) + +func (q *taskQueue[T]) Return(ctx context.Context, taskID TaskID, err1 error) (TaskID, error) { + msgs, err := q.rdb.XRange(ctx, q.streamKey, taskID.String(), taskID.String()).Result() + if err != nil { + return TaskID{}, err + } + if len(msgs) == 0 { + return TaskID{}, ErrTaskNotFound + } + + var ackMsg string + if err1 != nil { + ackMsg = err1.Error() + } + + // Ack the task + err = q.ack(ctx, taskID, true, ackMsg) + if err != nil { + return TaskID{}, err + } + + msg := msgs[0] + task, err := decode[T](&msg, q.encoding) + if err != nil { + return TaskID{}, err + } + + task.Attempts++ + if int(task.Attempts) >= MaxAttempts { + // Task has failed + slog.ErrorContext(ctx, "task failed completely", + "taskID", taskID, + "data", task.Data, + "attempts", task.Attempts, + "maxAttempts", MaxAttempts, + ) + return TaskID{}, nil + } + + valuesMap, err := encode[T](task, q.encoding) + if err != nil { + return TaskID{}, err + } + + values := make([]string, 0, len(valuesMap)*2) + for k, v := range valuesMap { + values = append(values, k, v) + } + + keys := []string{ + q.streamKey, + fmt.Sprintf("%s:%s", q.streamKey, taskID.String()), + } + newTaskId, err := retScript.Run(ctx, q.rdb, keys, values).Result() + if err != nil { + return TaskID{}, err + } + + return ParseTaskID(newTaskId.(string)) +} + +func (q *taskQueue[T]) List(ctx context.Context, start, end TaskID, count int64) ([]TaskInfo[T], error) { + if !start.timestamp.IsZero() && !end.timestamp.IsZero() && start.timestamp.After(end.timestamp) { + return nil, errors.New("start must be before end") + } + + var startStr, endStr string + if !start.timestamp.IsZero() { + startStr = start.String() + } else { + startStr = "-" + } + if !end.timestamp.IsZero() { + endStr = "(" + end.String() + } else { + endStr = "+" + } + + msgs, err := q.rdb.XRevRangeN(ctx, q.streamKey, endStr, startStr, count).Result() + if err != nil { + return nil, err + } + if len(msgs) == 0 { + return []TaskInfo[T]{}, nil + } + + tasks := make([]TaskInfo[T], len(msgs)) + for i := range msgs { + tasks[i], err = decode[T](&msgs[i], q.encoding) + if err != nil { + return nil, err + } + + tasks[i].Result, err = q.getResult(ctx, tasks[i].ID) + if err != nil { + return nil, err + } + } + + return tasks, nil +} + +func (q *taskQueue[T]) getResult(ctx context.Context, taskID TaskID) (isTaskResult, error) { + key := fmt.Sprintf("%s:%s", q.streamKey, taskID.String()) + results, err := q.rdb.HMGet(ctx, key, ResultKey, ErrorKey, NextAttemptKey).Result() + if err != nil { + return nil, err + } + + var ret isTaskResult + if results[0] != nil { + ret = &TaskResultSuccess{Result: results[0].(string)} + } else if results[1] != nil { + ret = &TaskResultError{Error: results[1].(string)} + if results[2] != nil { + nextAttempt, err := ParseTaskID(results[2].(string)) + if err != nil { + return nil, err + } + ret.(*TaskResultError).NextAttempt = nextAttempt + } + } + return ret, nil +} + +func (q *taskQueue[T]) recover(ctx context.Context, idleTimeout time.Duration) (*TaskInfo[T], error) { + msgs, _, err := q.rdb.XAutoClaim(ctx, &redis.XAutoClaimArgs{ + Stream: q.streamKey, + Group: q.groupName, + MinIdle: idleTimeout, + Start: "0", + Count: 1, + Consumer: q.workerName, + }).Result() + if err != nil { + return nil, err + } + + if len(msgs) == 0 { + return nil, nil + } + + msg := msgs[0] + task, err := decode[T](&msg, q.encoding) + if err != nil { + return nil, err + } + return &task, nil +} + +func (q *taskQueue[T]) ack(ctx context.Context, taskID TaskID, errored bool, msg string) error { + key := fmt.Sprintf("%s:%s", q.streamKey, taskID.String()) + _, err := q.rdb.TxPipelined(ctx, func(pipe redis.Pipeliner) error { + pipe.XAck(ctx, q.streamKey, q.groupName, taskID.String()) + if errored { + pipe.HSet(ctx, key, ErrorKey, msg) + } else { + pipe.HSet(ctx, key, ResultKey, msg) + } + return nil + }) + return err +} + +func decode[T any](msg *redis.XMessage, encoding Encoding) (task TaskInfo[T], err error) { + task.ID, err = ParseTaskID(msg.ID) + if err != nil { + return + } + + err = getField(msg, "attempts", &task.Attempts) + if err != nil { + return + } + + var data string + err = getField(msg, "data", &data) + if err != nil { + return + } + + switch encoding { + case EncodingJSON: + err = json.Unmarshal([]byte(data), &task.Data) + case EncodingGob: + var decoded []byte + decoded, err = base64.StdEncoding.DecodeString(data) + if err != nil { + return + } + err = gob.NewDecoder(bytes.NewReader(decoded)).Decode(&task.Data) + default: + err = errors.New("unsupported encoding") + } + return +} + +func getField(msg *redis.XMessage, field string, v any) error { + vVal, ok := msg.Values[field] + if !ok { + return errors.New("missing field") + } + + vStr, ok := vVal.(string) + if !ok { + return errors.New("invalid field type") + } + + value := reflect.ValueOf(v).Elem() + switch value.Kind() { + case reflect.String: + value.SetString(vStr) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + i, err := strconv.ParseInt(vStr, 10, 64) + if err != nil { + return err + } + value.SetInt(i) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + i, err := strconv.ParseUint(vStr, 10, 64) + if err != nil { + return err + } + value.SetUint(i) + case reflect.Bool: + b, err := strconv.ParseBool(vStr) + if err != nil { + return err + } + value.SetBool(b) + default: + return errors.New("unsupported field type") + } + return nil +} + +func encode[T any](task TaskInfo[T], encoding Encoding) (ret map[string]string, err error) { + ret = make(map[string]string) + ret["attempts"] = strconv.FormatUint(uint64(task.Attempts), 10) + + switch encoding { + case EncodingJSON: + var data []byte + data, err = json.Marshal(task.Data) + if err != nil { + return + } + ret["data"] = string(data) + case EncodingGob: + var data bytes.Buffer + err = gob.NewEncoder(&data).Encode(task.Data) + if err != nil { + return + } + ret["data"] = base64.StdEncoding.EncodeToString(data.Bytes()) + default: + err = errors.New("unsupported encoding") + } + return +} diff --git a/backend/internal/redis/taskqueue/queue_test.go b/backend/internal/redis/taskqueue/queue_test.go new file mode 100644 index 0000000..ee95d39 --- /dev/null +++ b/backend/internal/redis/taskqueue/queue_test.go @@ -0,0 +1,467 @@ +package taskqueue + +import ( + "context" + "errors" + "fmt" + "log" + "testing" + "time" + + "github.com/ory/dockertest/v3" + "github.com/ory/dockertest/v3/docker" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var client *redis.Client + +func TestMain(m *testing.M) { + pool, err := dockertest.NewPool("") + if err != nil { + log.Fatalf("Could not create pool: %s", err) + } + + err = pool.Client.Ping() + if err != nil { + log.Fatalf("Could not connect to Docker: %s", err) + } + + //resource, err := pool.Run("redis", "7", nil) + resource, err := pool.RunWithOptions(&dockertest.RunOptions{ + Repository: "redis", + Tag: "7", + }, func(config *docker.HostConfig) { + config.AutoRemove = true + config.RestartPolicy = docker.RestartPolicy{Name: "no"} + }) + if err != nil { + log.Fatalf("Could not start resource: %s", err) + } + + //_ = resource.Expire(60) + + if err = pool.Retry(func() error { + client = redis.NewClient(&redis.Options{ + Addr: fmt.Sprintf("localhost:%s", resource.GetPort("6379/tcp")), + }) + return client.Ping(context.Background()).Err() + }); err != nil { + log.Fatalf("Could not connect to redis: %s", err) + } + + defer func() { + if err = client.Close(); err != nil { + log.Printf("Could not close client: %s", err) + } + if err = pool.Purge(resource); err != nil { + log.Fatalf("Could not purge resource: %s", err) + } + }() + + m.Run() +} + +func TestTaskQueue(t *testing.T) { + if testing.Short() { + t.Skip() + } + + lockTimeout := 100 * time.Millisecond + + tests := []struct { + name string + f func(t *testing.T) + }{ + { + name: "Create queue", + f: func(t *testing.T) { + q, err := New[string](context.Background(), client, "test", "worker") + require.NoError(t, err) + require.NotNil(t, q) + }, + }, + { + name: "enqueue & dequeue", + f: func(t *testing.T) { + q, err := New[string](context.Background(), client, "test", "worker") + require.NoError(t, err) + require.NotNil(t, q) + + taskId, err := q.Enqueue(context.Background(), "hello") + require.NoError(t, err) + require.NotEmpty(t, taskId) + + task, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) + require.NoError(t, err) + require.NotNil(t, task) + require.Equal(t, "hello", task.Data) + }, + }, + { + name: "complex data", + f: func(t *testing.T) { + type foo struct { + A int + B string + } + + q, err := New[foo](context.Background(), client, "test", "worker") + require.NoError(t, err) + require.NotNil(t, q) + + taskId, err := q.Enqueue(context.Background(), foo{A: 42, B: "hello"}) + require.NoError(t, err) + require.NotEmpty(t, taskId) + + task, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) + require.NoError(t, err) + require.NotNil(t, task) + require.Equal(t, foo{A: 42, B: "hello"}, task.Data) + }, + }, + { + name: "different workers", + f: func(t *testing.T) { + q1, err := New[string](context.Background(), client, "test", "worker1") + require.NoError(t, err) + require.NotNil(t, q1) + + q2, err := New[string](context.Background(), client, "test", "worker2") + require.NoError(t, err) + require.NotNil(t, q2) + + taskId, err := q1.Enqueue(context.Background(), "hello") + require.NoError(t, err) + require.NotEmpty(t, taskId) + + task, err := q2.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) + require.NoError(t, err) + require.NotNil(t, task) + assert.Equal(t, "hello", task.Data) + }, + }, + { + name: "complete", + f: func(t *testing.T) { + q, err := New[string](context.Background(), client, "test", "worker") + require.NoError(t, err) + require.NotNil(t, q) + + // Enqueue a task + taskId, err := q.Enqueue(context.Background(), "hello") + require.NoError(t, err) + require.NotEmpty(t, taskId) + + // Dequeue the task + task, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) + require.NoError(t, err) + require.NotNil(t, task) + assert.Equal(t, "hello", task.Data) + + // Complete the task + err = q.Complete(context.Background(), task.ID, "done") + require.NoError(t, err) + + // Try to dequeue the task again + task, err = q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) + require.NoError(t, err) + require.Nil(t, task) + }, + }, + { + name: "timeout", + f: func(t *testing.T) { + q, err := New[string](context.Background(), client, "test", "worker") + require.NoError(t, err) + require.NotNil(t, q) + + // Enqueue a task + taskId, err := q.Enqueue(context.Background(), "hello") + require.NoError(t, err) + require.NotEmpty(t, taskId) + + // Dequeue the task + task, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) + require.NoError(t, err) + require.NotNil(t, task) + assert.Equal(t, "hello", task.Data) + + // Wait for the lock to expire + time.Sleep(lockTimeout + 10*time.Millisecond) + + // Try to dequeue the task again + task, err = q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) + require.NoError(t, err) + require.NotNil(t, task) + assert.Equal(t, "hello", task.Data) + }, + }, + { + name: "extend", + f: func(t *testing.T) { + q, err := New[string](context.Background(), client, "test", "worker") + require.NoError(t, err) + require.NotNil(t, q) + + // Enqueue a task + taskId, err := q.Enqueue(context.Background(), "hello") + require.NoError(t, err) + require.NotEmpty(t, taskId) + + // Dequeue the task + task, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) + require.NoError(t, err) + require.NotNil(t, task) + assert.Equal(t, "hello", task.Data) + + // Wait for the lock to expire + time.Sleep(lockTimeout + 10*time.Millisecond) + + // Extend the lock + err = q.Extend(context.Background(), task.ID) + require.NoError(t, err) + + // Try to dequeue the task again + task, err = q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) + require.NoError(t, err) + require.Nil(t, task) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := client.FlushDB(context.Background()).Err(); err != nil { + t.Fatal(err) + } + + tt.f(t) + }) + } + + _ = client.FlushDB(context.Background()) +} + +func TestTaskQueue_List(t *testing.T) { + if testing.Short() { + t.Skip() + } + + tests := []struct { + name string + f func(t *testing.T) + }{ + { + name: "empty", + f: func(t *testing.T) { + q, err := New[any](context.Background(), client, "test", "worker") + require.NoError(t, err) + require.NotNil(t, q) + + tasks, err := q.List(context.Background(), TaskID{}, TaskID{}, 1) + require.NoError(t, err) + assert.Empty(t, tasks) + }, + }, + { + name: "single", + f: func(t *testing.T) { + q, err := New[string](context.Background(), client, "test", "worker") + require.NoError(t, err) + require.NotNil(t, q) + + taskID, err := q.Enqueue(context.Background(), "hello") + require.NoError(t, err) + + tasks, err := q.List(context.Background(), TaskID{}, TaskID{}, 1) + require.NoError(t, err) + assert.Len(t, tasks, 1) + assert.Equal(t, taskID, tasks[0]) + }, + }, + { + name: "multiple", + f: func(t *testing.T) { + q, err := New[string](context.Background(), client, "test", "worker") + require.NoError(t, err) + require.NotNil(t, q) + + taskID, err := q.Enqueue(context.Background(), "hello") + require.NoError(t, err) + taskID2, err := q.Enqueue(context.Background(), "world") + require.NoError(t, err) + + tasks, err := q.List(context.Background(), TaskID{}, TaskID{}, 2) + require.NoError(t, err) + assert.Len(t, tasks, 2) + assert.Equal(t, taskID, tasks[1]) + assert.Equal(t, taskID2, tasks[0]) + }, + }, + { + name: "multiple limited", + f: func(t *testing.T) { + q, err := New[string](context.Background(), client, "test", "worker") + require.NoError(t, err) + require.NotNil(t, q) + + _, err = q.Enqueue(context.Background(), "hello") + require.NoError(t, err) + taskID2, err := q.Enqueue(context.Background(), "world") + require.NoError(t, err) + + tasks, err := q.List(context.Background(), TaskID{}, TaskID{}, 1) + require.NoError(t, err) + assert.Len(t, tasks, 1) + assert.Equal(t, taskID2, tasks[0]) + }, + }, + { + name: "multiple time range", + f: func(t *testing.T) { + q, err := New[string](context.Background(), client, "test", "worker") + require.NoError(t, err) + require.NotNil(t, q) + + taskID, err := q.Enqueue(context.Background(), "hello") + require.NoError(t, err) + time.Sleep(10 * time.Millisecond) + taskID2, err := q.Enqueue(context.Background(), "world") + require.NoError(t, err) + + tasks, err := q.List(context.Background(), TaskID{}, taskID2.ID, 100) + require.NoError(t, err) + assert.Len(t, tasks, 1) + assert.Equal(t, taskID, tasks[0]) + + tasks, err = q.List(context.Background(), taskID2.ID, TaskID{}, 100) + require.NoError(t, err) + assert.Len(t, tasks, 1) + assert.Equal(t, taskID2, tasks[0]) + }, + }, + { + name: "completed tasks", + f: func(t *testing.T) { + q, err := New[string](context.Background(), client, "test", "worker") + require.NoError(t, err) + require.NotNil(t, q) + + task1, err := q.Enqueue(context.Background(), "hello") + require.NoError(t, err) + task2, err := q.Enqueue(context.Background(), "world") + require.NoError(t, err) + + err = q.Complete(context.Background(), task1.ID, "done") + require.NoError(t, err) + + tasks, err := q.List(context.Background(), TaskID{}, TaskID{}, 100) + require.NoError(t, err) + assert.Len(t, tasks, 2) + assert.Equal(t, task2, tasks[0]) + + assert.Equal(t, "hello", tasks[1].Data) + require.IsType(t, &TaskResultSuccess{}, tasks[1].Result) + assert.Equal(t, "done", tasks[1].Result.(*TaskResultSuccess).Result) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := client.FlushDB(context.Background()).Err(); err != nil { + t.Fatal(err) + } + + tt.f(t) + }) + } + + _ = client.FlushDB(context.Background()) +} + +func TestTaskQueue_Return(t *testing.T) { + if testing.Short() { + t.Skip() + } + + lockTimeout := 100 * time.Millisecond + + tests := []struct { + name string + f func(t *testing.T) + }{ + { + name: "simple", + f: func(t *testing.T) { + q, err := New[string](context.Background(), client, "test", "worker") + require.NoError(t, err) + require.NotNil(t, q) + + task1, err := q.Enqueue(context.Background(), "hello") + require.NoError(t, err) + + id := claimAndFail(t, q, lockTimeout) + + task2, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) + require.NoError(t, err) + require.NotNil(t, task2) + assert.Equal(t, task2.ID, id) + assert.Equal(t, task1.Data, task2.Data) + assert.Equal(t, uint8(1), task2.Attempts) + + task1Data, err := q.Data(context.Background(), task1.ID) + require.NoError(t, err) + assert.Equal(t, task1Data.ID, task1.ID) + assert.Equal(t, task1Data.Data, task1.Data) + assert.IsType(t, &TaskResultError{}, task1Data.Result) + assert.Equal(t, "failed", task1Data.Result.(*TaskResultError).Error) + assert.Equal(t, task2.ID, task1Data.Result.(*TaskResultError).NextAttempt) + }, + }, + { + name: "failure", + f: func(t *testing.T) { + q, err := New[string](context.Background(), client, "test", "worker") + require.NoError(t, err) + require.NotNil(t, q) + + _, err = q.Enqueue(context.Background(), "hello") + require.NoError(t, err) + + claimAndFail(t, q, lockTimeout) + claimAndFail(t, q, lockTimeout) + claimAndFail(t, q, lockTimeout) + + task3, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) + require.NoError(t, err) + assert.Nil(t, task3) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := client.FlushDB(context.Background()).Err(); err != nil { + t.Fatal(err) + } + + tt.f(t) + }) + } + + _ = client.FlushDB(context.Background()) +} + +func claimAndFail[T any](t *testing.T, q TaskQueue[T], lockTimeout time.Duration) TaskID { + task, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) + require.NoError(t, err) + require.NotNil(t, task) + + id, err := q.Return(context.Background(), task.ID, errors.New("failed")) + require.NoError(t, err) + assert.NotEqual(t, task.ID, id) + return id +} |