aboutsummaryrefslogtreecommitdiff
path: root/backend/internal/redis/taskqueue/queue.go
diff options
context:
space:
mode:
Diffstat (limited to 'backend/internal/redis/taskqueue/queue.go')
-rw-r--r--backend/internal/redis/taskqueue/queue.go545
1 files changed, 545 insertions, 0 deletions
diff --git a/backend/internal/redis/taskqueue/queue.go b/backend/internal/redis/taskqueue/queue.go
new file mode 100644
index 0000000..a4b799e
--- /dev/null
+++ b/backend/internal/redis/taskqueue/queue.go
@@ -0,0 +1,545 @@
+package taskqueue
+
+import (
+ "bytes"
+ "context"
+ "encoding/base64"
+ "encoding/gob"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "log/slog"
+ "reflect"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/redis/go-redis/v9"
+)
+
+type Encoding uint8
+
+const (
+ EncodingJSON Encoding = iota
+ EncodingGob
+
+ ResultKey = "result"
+ ErrorKey = "error"
+ NextAttemptKey = "next_attempt"
+)
+
+var MaxAttempts = 3
+var ErrTaskNotFound = errors.New("task not found")
+
+type TaskQueue[T any] interface {
+ // Enqueue adds a task to the queue.
+ // Returns the generated task ID.
+ Enqueue(ctx context.Context, data T) (TaskInfo[T], error)
+
+ // Dequeue removes a task from the queue and returns it.
+ // The task data is placed into dataOut.
+ //
+ // Dequeue blocks until a task is available, timeout, or the context is canceled.
+ // The returned task is placed in a pending state for lockTimeout duration.
+ // The task must be completed with Complete or extended with Extend before the lock expires.
+ // If the lock expires, the task is returned to the queue, where it may be picked up by another worker.
+ Dequeue(
+ ctx context.Context,
+ lockTimeout,
+ timeout time.Duration,
+ ) (*TaskInfo[T], error)
+
+ // Extend extends the lock on a task.
+ Extend(ctx context.Context, taskID TaskID) error
+
+ // Complete marks a task as complete. Optionally, an error can be provided to store additional information.
+ Complete(ctx context.Context, taskID TaskID, result string) error
+
+ // Data returns the info of a task.
+ Data(ctx context.Context, taskID TaskID) (TaskInfo[T], error)
+
+ // Return returns a task to the queue and returns the new task ID.
+ // Increments the attempt counter.
+ // Tasks with too many attempts (MaxAttempts) are considered failed and aren't returned to the queue.
+ Return(ctx context.Context, taskID TaskID, err error) (TaskID, error)
+
+ // List returns a list of task IDs in the queue.
+ // The list is ordered by the time the task was added to the queue. The most recent task is first.
+ // The count parameter limits the number of tasks returned.
+ // The start and end parameters limit the range of tasks returned.
+ // End is exclusive.
+ // Start must be before end.
+ List(ctx context.Context, start, end TaskID, count int64) ([]TaskInfo[T], error)
+}
+
+type TaskInfo[T any] struct {
+ // ID is the unique identifier of the task. Generated by redis.
+ ID TaskID
+ // Data is the task data. Stored in stream.
+ Data T
+ // Attempts is the number of times the task has been attempted. Stored in stream.
+ Attempts uint8
+ // Result is the result of the task. Stored in a hash.
+ Result isTaskResult
+}
+
+type isTaskResult interface {
+ isTaskResult()
+}
+
+type TaskResultSuccess struct {
+ Result string
+}
+
+type TaskResultError struct {
+ Error string
+ NextAttempt TaskID
+}
+
+func (*TaskResultSuccess) isTaskResult() {}
+func (*TaskResultError) isTaskResult() {}
+
+type TaskID struct {
+ timestamp time.Time
+ sequence uint64
+}
+
+func NewTaskID(timestamp time.Time, sequence uint64) TaskID {
+ return TaskID{timestamp, sequence}
+}
+
+func ParseTaskID(s string) (TaskID, error) {
+ tPart, sPart, ok := strings.Cut(s, "-")
+ if !ok {
+ return TaskID{}, errors.New("invalid task ID")
+ }
+
+ timestamp, err := strconv.ParseInt(tPart, 10, 64)
+ if err != nil {
+ return TaskID{}, err
+ }
+
+ sequence, err := strconv.ParseUint(sPart, 10, 64)
+ if err != nil {
+ return TaskID{}, err
+ }
+
+ return NewTaskID(time.UnixMilli(timestamp), sequence), nil
+}
+
+func (t TaskID) Timestamp() time.Time {
+ return t.timestamp
+}
+
+func (t TaskID) String() string {
+ tPart := strconv.FormatInt(t.timestamp.UnixMilli(), 10)
+ sPart := strconv.FormatUint(t.sequence, 10)
+ return tPart + "-" + sPart
+}
+
+type taskQueue[T any] struct {
+ rdb *redis.Client
+ encoding Encoding
+
+ streamKey string
+ groupName string
+
+ workerName string
+}
+
+func New[T any](ctx context.Context, rdb *redis.Client, name string, workerName string, opts ...Option[T]) (TaskQueue[T], error) {
+ tq := &taskQueue[T]{
+ rdb: rdb,
+ encoding: EncodingJSON,
+ streamKey: "taskqueue:" + name,
+ groupName: "default",
+ workerName: workerName,
+ }
+
+ for _, opt := range opts {
+ opt(tq)
+ }
+
+ // Create the stream if it doesn't exist
+ err := rdb.XGroupCreateMkStream(ctx, tq.streamKey, tq.groupName, "0").Err()
+ if err != nil && err.Error() != "BUSYGROUP Consumer Group name already exists" {
+ return nil, err
+ }
+
+ return tq, nil
+}
+
+func (q *taskQueue[T]) Enqueue(ctx context.Context, data T) (TaskInfo[T], error) {
+ task := TaskInfo[T]{
+ Data: data,
+ Attempts: 0,
+ }
+
+ values, err := encode[T](task, q.encoding)
+ if err != nil {
+ return TaskInfo[T]{}, err
+ }
+
+ taskID, err := q.rdb.XAdd(ctx, &redis.XAddArgs{
+ Stream: q.streamKey,
+ Values: values,
+ }).Result()
+ if err != nil {
+ return TaskInfo[T]{}, err
+ }
+
+ id, err := ParseTaskID(taskID)
+ if err != nil {
+ return TaskInfo[T]{}, err
+ }
+ task.ID = id
+ return task, nil
+}
+
+func (q *taskQueue[T]) Dequeue(ctx context.Context, lockTimeout, timeout time.Duration) (*TaskInfo[T], error) {
+ // Try to recover a task
+ task, err := q.recover(ctx, lockTimeout)
+ if err != nil {
+ return nil, err
+ }
+ if task != nil {
+ return task, nil
+ }
+
+ // Check for new tasks
+ ids, err := q.rdb.XReadGroup(ctx, &redis.XReadGroupArgs{
+ Group: q.groupName,
+ Consumer: q.workerName,
+ Streams: []string{q.streamKey, ">"},
+ Count: 1,
+ Block: timeout,
+ }).Result()
+ if err != nil && !errors.Is(err, redis.Nil) {
+ return nil, err
+ }
+
+ if len(ids) == 0 || len(ids[0].Messages) == 0 || errors.Is(err, redis.Nil) {
+ return nil, nil
+ }
+
+ msg := ids[0].Messages[0]
+ task = new(TaskInfo[T])
+ *task, err = decode[T](&msg, q.encoding)
+ if err != nil {
+ return nil, err
+ }
+ return task, nil
+}
+
+func (q *taskQueue[T]) Extend(ctx context.Context, taskID TaskID) error {
+ _, err := q.rdb.XClaim(ctx, &redis.XClaimArgs{
+ Stream: q.streamKey,
+ Group: q.groupName,
+ Consumer: q.workerName,
+ MinIdle: 0,
+ Messages: []string{taskID.String()},
+ }).Result()
+ if err != nil && !errors.Is(err, redis.Nil) {
+ return err
+ }
+ return nil
+}
+
+func (q *taskQueue[T]) Data(ctx context.Context, taskID TaskID) (TaskInfo[T], error) {
+ msg, err := q.rdb.XRange(ctx, q.streamKey, taskID.String(), taskID.String()).Result()
+ if err != nil {
+ return TaskInfo[T]{}, err
+ }
+
+ if len(msg) == 0 {
+ return TaskInfo[T]{}, ErrTaskNotFound
+ }
+
+ t, err := decode[T](&msg[0], q.encoding)
+ if err != nil {
+ return TaskInfo[T]{}, err
+ }
+
+ t.Result, err = q.getResult(ctx, taskID)
+ if err != nil {
+ return TaskInfo[T]{}, nil
+ }
+ return t, nil
+}
+
+func (q *taskQueue[T]) Complete(ctx context.Context, taskID TaskID, result string) error {
+ return q.ack(ctx, taskID, false, result)
+}
+
+var retScript = redis.NewScript(`
+local stream_key = KEYS[1]
+local hash_key = KEYS[2]
+
+-- Re-add the task to the stream
+local task_id = redis.call('XADD', stream_key, '*', unpack(ARGV))
+
+-- Update the hash key to point to the new task
+redis.call('HSET', hash_key, 'next_attempt', task_id)
+
+return task_id
+`)
+
+func (q *taskQueue[T]) Return(ctx context.Context, taskID TaskID, err1 error) (TaskID, error) {
+ msgs, err := q.rdb.XRange(ctx, q.streamKey, taskID.String(), taskID.String()).Result()
+ if err != nil {
+ return TaskID{}, err
+ }
+ if len(msgs) == 0 {
+ return TaskID{}, ErrTaskNotFound
+ }
+
+ var ackMsg string
+ if err1 != nil {
+ ackMsg = err1.Error()
+ }
+
+ // Ack the task
+ err = q.ack(ctx, taskID, true, ackMsg)
+ if err != nil {
+ return TaskID{}, err
+ }
+
+ msg := msgs[0]
+ task, err := decode[T](&msg, q.encoding)
+ if err != nil {
+ return TaskID{}, err
+ }
+
+ task.Attempts++
+ if int(task.Attempts) >= MaxAttempts {
+ // Task has failed
+ slog.ErrorContext(ctx, "task failed completely",
+ "taskID", taskID,
+ "data", task.Data,
+ "attempts", task.Attempts,
+ "maxAttempts", MaxAttempts,
+ )
+ return TaskID{}, nil
+ }
+
+ valuesMap, err := encode[T](task, q.encoding)
+ if err != nil {
+ return TaskID{}, err
+ }
+
+ values := make([]string, 0, len(valuesMap)*2)
+ for k, v := range valuesMap {
+ values = append(values, k, v)
+ }
+
+ keys := []string{
+ q.streamKey,
+ fmt.Sprintf("%s:%s", q.streamKey, taskID.String()),
+ }
+ newTaskId, err := retScript.Run(ctx, q.rdb, keys, values).Result()
+ if err != nil {
+ return TaskID{}, err
+ }
+
+ return ParseTaskID(newTaskId.(string))
+}
+
+func (q *taskQueue[T]) List(ctx context.Context, start, end TaskID, count int64) ([]TaskInfo[T], error) {
+ if !start.timestamp.IsZero() && !end.timestamp.IsZero() && start.timestamp.After(end.timestamp) {
+ return nil, errors.New("start must be before end")
+ }
+
+ var startStr, endStr string
+ if !start.timestamp.IsZero() {
+ startStr = start.String()
+ } else {
+ startStr = "-"
+ }
+ if !end.timestamp.IsZero() {
+ endStr = "(" + end.String()
+ } else {
+ endStr = "+"
+ }
+
+ msgs, err := q.rdb.XRevRangeN(ctx, q.streamKey, endStr, startStr, count).Result()
+ if err != nil {
+ return nil, err
+ }
+ if len(msgs) == 0 {
+ return []TaskInfo[T]{}, nil
+ }
+
+ tasks := make([]TaskInfo[T], len(msgs))
+ for i := range msgs {
+ tasks[i], err = decode[T](&msgs[i], q.encoding)
+ if err != nil {
+ return nil, err
+ }
+
+ tasks[i].Result, err = q.getResult(ctx, tasks[i].ID)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return tasks, nil
+}
+
+func (q *taskQueue[T]) getResult(ctx context.Context, taskID TaskID) (isTaskResult, error) {
+ key := fmt.Sprintf("%s:%s", q.streamKey, taskID.String())
+ results, err := q.rdb.HMGet(ctx, key, ResultKey, ErrorKey, NextAttemptKey).Result()
+ if err != nil {
+ return nil, err
+ }
+
+ var ret isTaskResult
+ if results[0] != nil {
+ ret = &TaskResultSuccess{Result: results[0].(string)}
+ } else if results[1] != nil {
+ ret = &TaskResultError{Error: results[1].(string)}
+ if results[2] != nil {
+ nextAttempt, err := ParseTaskID(results[2].(string))
+ if err != nil {
+ return nil, err
+ }
+ ret.(*TaskResultError).NextAttempt = nextAttempt
+ }
+ }
+ return ret, nil
+}
+
+func (q *taskQueue[T]) recover(ctx context.Context, idleTimeout time.Duration) (*TaskInfo[T], error) {
+ msgs, _, err := q.rdb.XAutoClaim(ctx, &redis.XAutoClaimArgs{
+ Stream: q.streamKey,
+ Group: q.groupName,
+ MinIdle: idleTimeout,
+ Start: "0",
+ Count: 1,
+ Consumer: q.workerName,
+ }).Result()
+ if err != nil {
+ return nil, err
+ }
+
+ if len(msgs) == 0 {
+ return nil, nil
+ }
+
+ msg := msgs[0]
+ task, err := decode[T](&msg, q.encoding)
+ if err != nil {
+ return nil, err
+ }
+ return &task, nil
+}
+
+func (q *taskQueue[T]) ack(ctx context.Context, taskID TaskID, errored bool, msg string) error {
+ key := fmt.Sprintf("%s:%s", q.streamKey, taskID.String())
+ _, err := q.rdb.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
+ pipe.XAck(ctx, q.streamKey, q.groupName, taskID.String())
+ if errored {
+ pipe.HSet(ctx, key, ErrorKey, msg)
+ } else {
+ pipe.HSet(ctx, key, ResultKey, msg)
+ }
+ return nil
+ })
+ return err
+}
+
+func decode[T any](msg *redis.XMessage, encoding Encoding) (task TaskInfo[T], err error) {
+ task.ID, err = ParseTaskID(msg.ID)
+ if err != nil {
+ return
+ }
+
+ err = getField(msg, "attempts", &task.Attempts)
+ if err != nil {
+ return
+ }
+
+ var data string
+ err = getField(msg, "data", &data)
+ if err != nil {
+ return
+ }
+
+ switch encoding {
+ case EncodingJSON:
+ err = json.Unmarshal([]byte(data), &task.Data)
+ case EncodingGob:
+ var decoded []byte
+ decoded, err = base64.StdEncoding.DecodeString(data)
+ if err != nil {
+ return
+ }
+ err = gob.NewDecoder(bytes.NewReader(decoded)).Decode(&task.Data)
+ default:
+ err = errors.New("unsupported encoding")
+ }
+ return
+}
+
+func getField(msg *redis.XMessage, field string, v any) error {
+ vVal, ok := msg.Values[field]
+ if !ok {
+ return errors.New("missing field")
+ }
+
+ vStr, ok := vVal.(string)
+ if !ok {
+ return errors.New("invalid field type")
+ }
+
+ value := reflect.ValueOf(v).Elem()
+ switch value.Kind() {
+ case reflect.String:
+ value.SetString(vStr)
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ i, err := strconv.ParseInt(vStr, 10, 64)
+ if err != nil {
+ return err
+ }
+ value.SetInt(i)
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ i, err := strconv.ParseUint(vStr, 10, 64)
+ if err != nil {
+ return err
+ }
+ value.SetUint(i)
+ case reflect.Bool:
+ b, err := strconv.ParseBool(vStr)
+ if err != nil {
+ return err
+ }
+ value.SetBool(b)
+ default:
+ return errors.New("unsupported field type")
+ }
+ return nil
+}
+
+func encode[T any](task TaskInfo[T], encoding Encoding) (ret map[string]string, err error) {
+ ret = make(map[string]string)
+ ret["attempts"] = strconv.FormatUint(uint64(task.Attempts), 10)
+
+ switch encoding {
+ case EncodingJSON:
+ var data []byte
+ data, err = json.Marshal(task.Data)
+ if err != nil {
+ return
+ }
+ ret["data"] = string(data)
+ case EncodingGob:
+ var data bytes.Buffer
+ err = gob.NewEncoder(&data).Encode(task.Data)
+ if err != nil {
+ return
+ }
+ ret["data"] = base64.StdEncoding.EncodeToString(data.Bytes())
+ default:
+ err = errors.New("unsupported encoding")
+ }
+ return
+}