diff options
author | 2024-08-08 16:53:59 -0700 | |
---|---|---|
committer | 2024-08-08 16:53:59 -0700 | |
commit | f34b92ded11b07f78575ac62c260a380c468e5ea (patch) | |
tree | 8ffdc68ed0f2e253e7f9feff3aa90a1182e5946c /backend/internal/redis | |
parent | a439618cdc8168bad617d04875697b572f3ed41d (diff) | |
download | ibd-trader-f34b92ded11b07f78575ac62c260a380c468e5ea.tar.gz ibd-trader-f34b92ded11b07f78575ac62c260a380c468e5ea.tar.zst ibd-trader-f34b92ded11b07f78575ac62c260a380c468e5ea.zip |
Rework redis taskqueue to store task results
Diffstat (limited to 'backend/internal/redis')
-rw-r--r-- | backend/internal/redis/taskqueue/queue.go | 169 | ||||
-rw-r--r-- | backend/internal/redis/taskqueue/queue_test.go | 61 |
2 files changed, 132 insertions, 98 deletions
diff --git a/backend/internal/redis/taskqueue/queue.go b/backend/internal/redis/taskqueue/queue.go index 1298a76..a4b799e 100644 --- a/backend/internal/redis/taskqueue/queue.go +++ b/backend/internal/redis/taskqueue/queue.go @@ -7,6 +7,7 @@ import ( "encoding/gob" "encoding/json" "errors" + "fmt" "log/slog" "reflect" "strconv" @@ -21,6 +22,10 @@ type Encoding uint8 const ( EncodingJSON Encoding = iota EncodingGob + + ResultKey = "result" + ErrorKey = "error" + NextAttemptKey = "next_attempt" ) var MaxAttempts = 3 @@ -48,7 +53,7 @@ type TaskQueue[T any] interface { Extend(ctx context.Context, taskID TaskID) error // Complete marks a task as complete. Optionally, an error can be provided to store additional information. - Complete(ctx context.Context, taskID TaskID, err error) error + Complete(ctx context.Context, taskID TaskID, result string) error // Data returns the info of a task. Data(ctx context.Context, taskID TaskID) (TaskInfo[T], error) @@ -74,12 +79,26 @@ type TaskInfo[T any] struct { Data T // Attempts is the number of times the task has been attempted. Stored in stream. Attempts uint8 - // Done is true if the task has been completed. True if ID in completed hash - Done bool - // Error is the error message if the task has failed. Stored in completed hash. - Error string + // Result is the result of the task. Stored in a hash. + Result isTaskResult +} + +type isTaskResult interface { + isTaskResult() +} + +type TaskResultSuccess struct { + Result string +} + +type TaskResultError struct { + Error string + NextAttempt TaskID } +func (*TaskResultSuccess) isTaskResult() {} +func (*TaskResultError) isTaskResult() {} + type TaskID struct { timestamp time.Time sequence uint64 @@ -125,19 +144,16 @@ type taskQueue[T any] struct { streamKey string groupName string - completedSetKey string - workerName string } func New[T any](ctx context.Context, rdb *redis.Client, name string, workerName string, opts ...Option[T]) (TaskQueue[T], error) { tq := &taskQueue[T]{ - rdb: rdb, - encoding: EncodingJSON, - streamKey: "taskqueue:" + name, - groupName: "default", - completedSetKey: "taskqueue:" + name + ":completed", - workerName: workerName, + rdb: rdb, + encoding: EncodingJSON, + streamKey: "taskqueue:" + name, + groupName: "default", + workerName: workerName, } for _, opt := range opts { @@ -244,35 +260,30 @@ func (q *taskQueue[T]) Data(ctx context.Context, taskID TaskID) (TaskInfo[T], er return TaskInfo[T]{}, err } - tErr, err := q.rdb.HGet(ctx, q.completedSetKey, taskID.String()).Result() - if err != nil && !errors.Is(err, redis.Nil) { - return TaskInfo[T]{}, err - } - - if errors.Is(err, redis.Nil) { - return t, nil + t.Result, err = q.getResult(ctx, taskID) + if err != nil { + return TaskInfo[T]{}, nil } - - t.Done = true - t.Error = tErr return t, nil } -func (q *taskQueue[T]) Complete(ctx context.Context, taskID TaskID, err error) error { - _, err = q.rdb.TxPipelined(ctx, func(pipe redis.Pipeliner) error { - pipe.XAck(ctx, q.streamKey, q.groupName, taskID.String()) - //xdel = pipe.XDel(ctx, q.streamKey, taskID.String()) - //pipe.SAdd(ctx, q.completedSetKey, taskID.String()) - if err != nil { - pipe.HSet(ctx, q.completedSetKey, taskID.String(), err.Error()) - } else { - pipe.HSet(ctx, q.completedSetKey, taskID.String(), "") - } - return nil - }) - return err +func (q *taskQueue[T]) Complete(ctx context.Context, taskID TaskID, result string) error { + return q.ack(ctx, taskID, false, result) } +var retScript = redis.NewScript(` +local stream_key = KEYS[1] +local hash_key = KEYS[2] + +-- Re-add the task to the stream +local task_id = redis.call('XADD', stream_key, '*', unpack(ARGV)) + +-- Update the hash key to point to the new task +redis.call('HSET', hash_key, 'next_attempt', task_id) + +return task_id +`) + func (q *taskQueue[T]) Return(ctx context.Context, taskID TaskID, err1 error) (TaskID, error) { msgs, err := q.rdb.XRange(ctx, q.streamKey, taskID.String(), taskID.String()).Result() if err != nil { @@ -282,8 +293,13 @@ func (q *taskQueue[T]) Return(ctx context.Context, taskID TaskID, err1 error) (T return TaskID{}, ErrTaskNotFound } - // Complete the task - err = q.Complete(ctx, taskID, err1) + var ackMsg string + if err1 != nil { + ackMsg = err1.Error() + } + + // Ack the task + err = q.ack(ctx, taskID, true, ackMsg) if err != nil { return TaskID{}, err } @@ -306,18 +322,26 @@ func (q *taskQueue[T]) Return(ctx context.Context, taskID TaskID, err1 error) (T return TaskID{}, nil } - values, err := encode[T](task, q.encoding) + valuesMap, err := encode[T](task, q.encoding) if err != nil { return TaskID{}, err } - newTaskId, err := q.rdb.XAdd(ctx, &redis.XAddArgs{ - Stream: q.streamKey, - Values: values, - }).Result() + + values := make([]string, 0, len(valuesMap)*2) + for k, v := range valuesMap { + values = append(values, k, v) + } + + keys := []string{ + q.streamKey, + fmt.Sprintf("%s:%s", q.streamKey, taskID.String()), + } + newTaskId, err := retScript.Run(ctx, q.rdb, keys, values).Result() if err != nil { return TaskID{}, err } - return ParseTaskID(newTaskId) + + return ParseTaskID(newTaskId.(string)) } func (q *taskQueue[T]) List(ctx context.Context, start, end TaskID, count int64) ([]TaskInfo[T], error) { @@ -345,32 +369,45 @@ func (q *taskQueue[T]) List(ctx context.Context, start, end TaskID, count int64) return []TaskInfo[T]{}, nil } - ids := make([]string, len(msgs)) - for i, msg := range msgs { - ids[i] = msg.ID - } - errs, err := q.rdb.HMGet(ctx, q.completedSetKey, ids...).Result() - if err != nil { - return nil, err - } - if len(errs) != len(msgs) { - return nil, errors.New("SMIsMember returned wrong number of results") - } - tasks := make([]TaskInfo[T], len(msgs)) for i := range msgs { tasks[i], err = decode[T](&msgs[i], q.encoding) if err != nil { return nil, err } - tasks[i].Done = errs[i] != nil - if tasks[i].Done { - tasks[i].Error = errs[i].(string) + + tasks[i].Result, err = q.getResult(ctx, tasks[i].ID) + if err != nil { + return nil, err } } + return tasks, nil } +func (q *taskQueue[T]) getResult(ctx context.Context, taskID TaskID) (isTaskResult, error) { + key := fmt.Sprintf("%s:%s", q.streamKey, taskID.String()) + results, err := q.rdb.HMGet(ctx, key, ResultKey, ErrorKey, NextAttemptKey).Result() + if err != nil { + return nil, err + } + + var ret isTaskResult + if results[0] != nil { + ret = &TaskResultSuccess{Result: results[0].(string)} + } else if results[1] != nil { + ret = &TaskResultError{Error: results[1].(string)} + if results[2] != nil { + nextAttempt, err := ParseTaskID(results[2].(string)) + if err != nil { + return nil, err + } + ret.(*TaskResultError).NextAttempt = nextAttempt + } + } + return ret, nil +} + func (q *taskQueue[T]) recover(ctx context.Context, idleTimeout time.Duration) (*TaskInfo[T], error) { msgs, _, err := q.rdb.XAutoClaim(ctx, &redis.XAutoClaimArgs{ Stream: q.streamKey, @@ -396,6 +433,20 @@ func (q *taskQueue[T]) recover(ctx context.Context, idleTimeout time.Duration) ( return &task, nil } +func (q *taskQueue[T]) ack(ctx context.Context, taskID TaskID, errored bool, msg string) error { + key := fmt.Sprintf("%s:%s", q.streamKey, taskID.String()) + _, err := q.rdb.TxPipelined(ctx, func(pipe redis.Pipeliner) error { + pipe.XAck(ctx, q.streamKey, q.groupName, taskID.String()) + if errored { + pipe.HSet(ctx, key, ErrorKey, msg) + } else { + pipe.HSet(ctx, key, ResultKey, msg) + } + return nil + }) + return err +} + func decode[T any](msg *redis.XMessage, encoding Encoding) (task TaskInfo[T], err error) { task.ID, err = ParseTaskID(msg.ID) if err != nil { diff --git a/backend/internal/redis/taskqueue/queue_test.go b/backend/internal/redis/taskqueue/queue_test.go index 774caa8..ee95d39 100644 --- a/backend/internal/redis/taskqueue/queue_test.go +++ b/backend/internal/redis/taskqueue/queue_test.go @@ -40,7 +40,7 @@ func TestMain(m *testing.M) { log.Fatalf("Could not start resource: %s", err) } - _ = resource.Expire(60) + //_ = resource.Expire(60) if err = pool.Retry(func() error { client = redis.NewClient(&redis.Options{ @@ -161,7 +161,7 @@ func TestTaskQueue(t *testing.T) { assert.Equal(t, "hello", task.Data) // Complete the task - err = q.Complete(context.Background(), task.ID, nil) + err = q.Complete(context.Background(), task.ID, "done") require.NoError(t, err) // Try to dequeue the task again @@ -354,7 +354,7 @@ func TestTaskQueue_List(t *testing.T) { task2, err := q.Enqueue(context.Background(), "world") require.NoError(t, err) - err = q.Complete(context.Background(), task1.ID, nil) + err = q.Complete(context.Background(), task1.ID, "done") require.NoError(t, err) tasks, err := q.List(context.Background(), TaskID{}, TaskID{}, 100) @@ -363,33 +363,8 @@ func TestTaskQueue_List(t *testing.T) { assert.Equal(t, task2, tasks[0]) assert.Equal(t, "hello", tasks[1].Data) - assert.Equal(t, true, tasks[1].Done) - assert.Equal(t, "", tasks[1].Error) - }, - }, - { - name: "failed tasks", - f: func(t *testing.T) { - q, err := New[string](context.Background(), client, "test", "worker") - require.NoError(t, err) - require.NotNil(t, q) - - task1, err := q.Enqueue(context.Background(), "hello") - require.NoError(t, err) - task2, err := q.Enqueue(context.Background(), "world") - require.NoError(t, err) - - err = q.Complete(context.Background(), task1.ID, errors.New("failed")) - require.NoError(t, err) - - tasks, err := q.List(context.Background(), TaskID{}, TaskID{}, 100) - require.NoError(t, err) - assert.Len(t, tasks, 2) - assert.Equal(t, task2, tasks[0]) - - assert.Equal(t, "hello", tasks[1].Data) - assert.Equal(t, true, tasks[1].Done) - assert.Equal(t, "failed", tasks[1].Error) + require.IsType(t, &TaskResultSuccess{}, tasks[1].Result) + assert.Equal(t, "done", tasks[1].Result.(*TaskResultSuccess).Result) }, }, } @@ -430,12 +405,20 @@ func TestTaskQueue_Return(t *testing.T) { id := claimAndFail(t, q, lockTimeout) - task3, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) + task2, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) + require.NoError(t, err) + require.NotNil(t, task2) + assert.Equal(t, task2.ID, id) + assert.Equal(t, task1.Data, task2.Data) + assert.Equal(t, uint8(1), task2.Attempts) + + task1Data, err := q.Data(context.Background(), task1.ID) require.NoError(t, err) - require.NotNil(t, task3) - assert.Equal(t, task3.ID, id) - assert.Equal(t, task1.Data, task3.Data) - assert.Equal(t, uint8(1), task3.Attempts) + assert.Equal(t, task1Data.ID, task1.ID) + assert.Equal(t, task1Data.Data, task1.Data) + assert.IsType(t, &TaskResultError{}, task1Data.Result) + assert.Equal(t, "failed", task1Data.Result.(*TaskResultError).Error) + assert.Equal(t, task2.ID, task1Data.Result.(*TaskResultError).NextAttempt) }, }, { @@ -473,12 +456,12 @@ func TestTaskQueue_Return(t *testing.T) { } func claimAndFail[T any](t *testing.T, q TaskQueue[T], lockTimeout time.Duration) TaskID { - task2, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) + task, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) require.NoError(t, err) - require.NotNil(t, task2) + require.NotNil(t, task) - id, err := q.Return(context.Background(), task2.ID, errors.New("failed")) + id, err := q.Return(context.Background(), task.ID, errors.New("failed")) require.NoError(t, err) - assert.NotEqual(t, task2.ID, id) + assert.NotEqual(t, task.ID, id) return id } |