aboutsummaryrefslogtreecommitdiff
path: root/backend/internal/redis
diff options
context:
space:
mode:
authorGravatar Anshul Gupta <ansg191@anshulg.com> 2024-08-08 16:53:59 -0700
committerGravatar Anshul Gupta <ansg191@anshulg.com> 2024-08-08 16:53:59 -0700
commitf34b92ded11b07f78575ac62c260a380c468e5ea (patch)
tree8ffdc68ed0f2e253e7f9feff3aa90a1182e5946c /backend/internal/redis
parenta439618cdc8168bad617d04875697b572f3ed41d (diff)
downloadibd-trader-f34b92ded11b07f78575ac62c260a380c468e5ea.tar.gz
ibd-trader-f34b92ded11b07f78575ac62c260a380c468e5ea.tar.zst
ibd-trader-f34b92ded11b07f78575ac62c260a380c468e5ea.zip
Rework redis taskqueue to store task results
Diffstat (limited to 'backend/internal/redis')
-rw-r--r--backend/internal/redis/taskqueue/queue.go169
-rw-r--r--backend/internal/redis/taskqueue/queue_test.go61
2 files changed, 132 insertions, 98 deletions
diff --git a/backend/internal/redis/taskqueue/queue.go b/backend/internal/redis/taskqueue/queue.go
index 1298a76..a4b799e 100644
--- a/backend/internal/redis/taskqueue/queue.go
+++ b/backend/internal/redis/taskqueue/queue.go
@@ -7,6 +7,7 @@ import (
"encoding/gob"
"encoding/json"
"errors"
+ "fmt"
"log/slog"
"reflect"
"strconv"
@@ -21,6 +22,10 @@ type Encoding uint8
const (
EncodingJSON Encoding = iota
EncodingGob
+
+ ResultKey = "result"
+ ErrorKey = "error"
+ NextAttemptKey = "next_attempt"
)
var MaxAttempts = 3
@@ -48,7 +53,7 @@ type TaskQueue[T any] interface {
Extend(ctx context.Context, taskID TaskID) error
// Complete marks a task as complete. Optionally, an error can be provided to store additional information.
- Complete(ctx context.Context, taskID TaskID, err error) error
+ Complete(ctx context.Context, taskID TaskID, result string) error
// Data returns the info of a task.
Data(ctx context.Context, taskID TaskID) (TaskInfo[T], error)
@@ -74,12 +79,26 @@ type TaskInfo[T any] struct {
Data T
// Attempts is the number of times the task has been attempted. Stored in stream.
Attempts uint8
- // Done is true if the task has been completed. True if ID in completed hash
- Done bool
- // Error is the error message if the task has failed. Stored in completed hash.
- Error string
+ // Result is the result of the task. Stored in a hash.
+ Result isTaskResult
+}
+
+type isTaskResult interface {
+ isTaskResult()
+}
+
+type TaskResultSuccess struct {
+ Result string
+}
+
+type TaskResultError struct {
+ Error string
+ NextAttempt TaskID
}
+func (*TaskResultSuccess) isTaskResult() {}
+func (*TaskResultError) isTaskResult() {}
+
type TaskID struct {
timestamp time.Time
sequence uint64
@@ -125,19 +144,16 @@ type taskQueue[T any] struct {
streamKey string
groupName string
- completedSetKey string
-
workerName string
}
func New[T any](ctx context.Context, rdb *redis.Client, name string, workerName string, opts ...Option[T]) (TaskQueue[T], error) {
tq := &taskQueue[T]{
- rdb: rdb,
- encoding: EncodingJSON,
- streamKey: "taskqueue:" + name,
- groupName: "default",
- completedSetKey: "taskqueue:" + name + ":completed",
- workerName: workerName,
+ rdb: rdb,
+ encoding: EncodingJSON,
+ streamKey: "taskqueue:" + name,
+ groupName: "default",
+ workerName: workerName,
}
for _, opt := range opts {
@@ -244,35 +260,30 @@ func (q *taskQueue[T]) Data(ctx context.Context, taskID TaskID) (TaskInfo[T], er
return TaskInfo[T]{}, err
}
- tErr, err := q.rdb.HGet(ctx, q.completedSetKey, taskID.String()).Result()
- if err != nil && !errors.Is(err, redis.Nil) {
- return TaskInfo[T]{}, err
- }
-
- if errors.Is(err, redis.Nil) {
- return t, nil
+ t.Result, err = q.getResult(ctx, taskID)
+ if err != nil {
+ return TaskInfo[T]{}, nil
}
-
- t.Done = true
- t.Error = tErr
return t, nil
}
-func (q *taskQueue[T]) Complete(ctx context.Context, taskID TaskID, err error) error {
- _, err = q.rdb.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
- pipe.XAck(ctx, q.streamKey, q.groupName, taskID.String())
- //xdel = pipe.XDel(ctx, q.streamKey, taskID.String())
- //pipe.SAdd(ctx, q.completedSetKey, taskID.String())
- if err != nil {
- pipe.HSet(ctx, q.completedSetKey, taskID.String(), err.Error())
- } else {
- pipe.HSet(ctx, q.completedSetKey, taskID.String(), "")
- }
- return nil
- })
- return err
+func (q *taskQueue[T]) Complete(ctx context.Context, taskID TaskID, result string) error {
+ return q.ack(ctx, taskID, false, result)
}
+var retScript = redis.NewScript(`
+local stream_key = KEYS[1]
+local hash_key = KEYS[2]
+
+-- Re-add the task to the stream
+local task_id = redis.call('XADD', stream_key, '*', unpack(ARGV))
+
+-- Update the hash key to point to the new task
+redis.call('HSET', hash_key, 'next_attempt', task_id)
+
+return task_id
+`)
+
func (q *taskQueue[T]) Return(ctx context.Context, taskID TaskID, err1 error) (TaskID, error) {
msgs, err := q.rdb.XRange(ctx, q.streamKey, taskID.String(), taskID.String()).Result()
if err != nil {
@@ -282,8 +293,13 @@ func (q *taskQueue[T]) Return(ctx context.Context, taskID TaskID, err1 error) (T
return TaskID{}, ErrTaskNotFound
}
- // Complete the task
- err = q.Complete(ctx, taskID, err1)
+ var ackMsg string
+ if err1 != nil {
+ ackMsg = err1.Error()
+ }
+
+ // Ack the task
+ err = q.ack(ctx, taskID, true, ackMsg)
if err != nil {
return TaskID{}, err
}
@@ -306,18 +322,26 @@ func (q *taskQueue[T]) Return(ctx context.Context, taskID TaskID, err1 error) (T
return TaskID{}, nil
}
- values, err := encode[T](task, q.encoding)
+ valuesMap, err := encode[T](task, q.encoding)
if err != nil {
return TaskID{}, err
}
- newTaskId, err := q.rdb.XAdd(ctx, &redis.XAddArgs{
- Stream: q.streamKey,
- Values: values,
- }).Result()
+
+ values := make([]string, 0, len(valuesMap)*2)
+ for k, v := range valuesMap {
+ values = append(values, k, v)
+ }
+
+ keys := []string{
+ q.streamKey,
+ fmt.Sprintf("%s:%s", q.streamKey, taskID.String()),
+ }
+ newTaskId, err := retScript.Run(ctx, q.rdb, keys, values).Result()
if err != nil {
return TaskID{}, err
}
- return ParseTaskID(newTaskId)
+
+ return ParseTaskID(newTaskId.(string))
}
func (q *taskQueue[T]) List(ctx context.Context, start, end TaskID, count int64) ([]TaskInfo[T], error) {
@@ -345,32 +369,45 @@ func (q *taskQueue[T]) List(ctx context.Context, start, end TaskID, count int64)
return []TaskInfo[T]{}, nil
}
- ids := make([]string, len(msgs))
- for i, msg := range msgs {
- ids[i] = msg.ID
- }
- errs, err := q.rdb.HMGet(ctx, q.completedSetKey, ids...).Result()
- if err != nil {
- return nil, err
- }
- if len(errs) != len(msgs) {
- return nil, errors.New("SMIsMember returned wrong number of results")
- }
-
tasks := make([]TaskInfo[T], len(msgs))
for i := range msgs {
tasks[i], err = decode[T](&msgs[i], q.encoding)
if err != nil {
return nil, err
}
- tasks[i].Done = errs[i] != nil
- if tasks[i].Done {
- tasks[i].Error = errs[i].(string)
+
+ tasks[i].Result, err = q.getResult(ctx, tasks[i].ID)
+ if err != nil {
+ return nil, err
}
}
+
return tasks, nil
}
+func (q *taskQueue[T]) getResult(ctx context.Context, taskID TaskID) (isTaskResult, error) {
+ key := fmt.Sprintf("%s:%s", q.streamKey, taskID.String())
+ results, err := q.rdb.HMGet(ctx, key, ResultKey, ErrorKey, NextAttemptKey).Result()
+ if err != nil {
+ return nil, err
+ }
+
+ var ret isTaskResult
+ if results[0] != nil {
+ ret = &TaskResultSuccess{Result: results[0].(string)}
+ } else if results[1] != nil {
+ ret = &TaskResultError{Error: results[1].(string)}
+ if results[2] != nil {
+ nextAttempt, err := ParseTaskID(results[2].(string))
+ if err != nil {
+ return nil, err
+ }
+ ret.(*TaskResultError).NextAttempt = nextAttempt
+ }
+ }
+ return ret, nil
+}
+
func (q *taskQueue[T]) recover(ctx context.Context, idleTimeout time.Duration) (*TaskInfo[T], error) {
msgs, _, err := q.rdb.XAutoClaim(ctx, &redis.XAutoClaimArgs{
Stream: q.streamKey,
@@ -396,6 +433,20 @@ func (q *taskQueue[T]) recover(ctx context.Context, idleTimeout time.Duration) (
return &task, nil
}
+func (q *taskQueue[T]) ack(ctx context.Context, taskID TaskID, errored bool, msg string) error {
+ key := fmt.Sprintf("%s:%s", q.streamKey, taskID.String())
+ _, err := q.rdb.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
+ pipe.XAck(ctx, q.streamKey, q.groupName, taskID.String())
+ if errored {
+ pipe.HSet(ctx, key, ErrorKey, msg)
+ } else {
+ pipe.HSet(ctx, key, ResultKey, msg)
+ }
+ return nil
+ })
+ return err
+}
+
func decode[T any](msg *redis.XMessage, encoding Encoding) (task TaskInfo[T], err error) {
task.ID, err = ParseTaskID(msg.ID)
if err != nil {
diff --git a/backend/internal/redis/taskqueue/queue_test.go b/backend/internal/redis/taskqueue/queue_test.go
index 774caa8..ee95d39 100644
--- a/backend/internal/redis/taskqueue/queue_test.go
+++ b/backend/internal/redis/taskqueue/queue_test.go
@@ -40,7 +40,7 @@ func TestMain(m *testing.M) {
log.Fatalf("Could not start resource: %s", err)
}
- _ = resource.Expire(60)
+ //_ = resource.Expire(60)
if err = pool.Retry(func() error {
client = redis.NewClient(&redis.Options{
@@ -161,7 +161,7 @@ func TestTaskQueue(t *testing.T) {
assert.Equal(t, "hello", task.Data)
// Complete the task
- err = q.Complete(context.Background(), task.ID, nil)
+ err = q.Complete(context.Background(), task.ID, "done")
require.NoError(t, err)
// Try to dequeue the task again
@@ -354,7 +354,7 @@ func TestTaskQueue_List(t *testing.T) {
task2, err := q.Enqueue(context.Background(), "world")
require.NoError(t, err)
- err = q.Complete(context.Background(), task1.ID, nil)
+ err = q.Complete(context.Background(), task1.ID, "done")
require.NoError(t, err)
tasks, err := q.List(context.Background(), TaskID{}, TaskID{}, 100)
@@ -363,33 +363,8 @@ func TestTaskQueue_List(t *testing.T) {
assert.Equal(t, task2, tasks[0])
assert.Equal(t, "hello", tasks[1].Data)
- assert.Equal(t, true, tasks[1].Done)
- assert.Equal(t, "", tasks[1].Error)
- },
- },
- {
- name: "failed tasks",
- f: func(t *testing.T) {
- q, err := New[string](context.Background(), client, "test", "worker")
- require.NoError(t, err)
- require.NotNil(t, q)
-
- task1, err := q.Enqueue(context.Background(), "hello")
- require.NoError(t, err)
- task2, err := q.Enqueue(context.Background(), "world")
- require.NoError(t, err)
-
- err = q.Complete(context.Background(), task1.ID, errors.New("failed"))
- require.NoError(t, err)
-
- tasks, err := q.List(context.Background(), TaskID{}, TaskID{}, 100)
- require.NoError(t, err)
- assert.Len(t, tasks, 2)
- assert.Equal(t, task2, tasks[0])
-
- assert.Equal(t, "hello", tasks[1].Data)
- assert.Equal(t, true, tasks[1].Done)
- assert.Equal(t, "failed", tasks[1].Error)
+ require.IsType(t, &TaskResultSuccess{}, tasks[1].Result)
+ assert.Equal(t, "done", tasks[1].Result.(*TaskResultSuccess).Result)
},
},
}
@@ -430,12 +405,20 @@ func TestTaskQueue_Return(t *testing.T) {
id := claimAndFail(t, q, lockTimeout)
- task3, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond)
+ task2, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond)
+ require.NoError(t, err)
+ require.NotNil(t, task2)
+ assert.Equal(t, task2.ID, id)
+ assert.Equal(t, task1.Data, task2.Data)
+ assert.Equal(t, uint8(1), task2.Attempts)
+
+ task1Data, err := q.Data(context.Background(), task1.ID)
require.NoError(t, err)
- require.NotNil(t, task3)
- assert.Equal(t, task3.ID, id)
- assert.Equal(t, task1.Data, task3.Data)
- assert.Equal(t, uint8(1), task3.Attempts)
+ assert.Equal(t, task1Data.ID, task1.ID)
+ assert.Equal(t, task1Data.Data, task1.Data)
+ assert.IsType(t, &TaskResultError{}, task1Data.Result)
+ assert.Equal(t, "failed", task1Data.Result.(*TaskResultError).Error)
+ assert.Equal(t, task2.ID, task1Data.Result.(*TaskResultError).NextAttempt)
},
},
{
@@ -473,12 +456,12 @@ func TestTaskQueue_Return(t *testing.T) {
}
func claimAndFail[T any](t *testing.T, q TaskQueue[T], lockTimeout time.Duration) TaskID {
- task2, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond)
+ task, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond)
require.NoError(t, err)
- require.NotNil(t, task2)
+ require.NotNil(t, task)
- id, err := q.Return(context.Background(), task2.ID, errors.New("failed"))
+ id, err := q.Return(context.Background(), task.ID, errors.New("failed"))
require.NoError(t, err)
- assert.NotEqual(t, task2.ID, id)
+ assert.NotEqual(t, task.ID, id)
return id
}