package taskqueue import ( "bytes" "context" "encoding/base64" "encoding/gob" "encoding/json" "errors" "fmt" "log/slog" "reflect" "strconv" "strings" "time" "github.com/redis/go-redis/v9" ) type Encoding uint8 const ( EncodingJSON Encoding = iota EncodingGob ResultKey = "result" ErrorKey = "error" NextAttemptKey = "next_attempt" ) var MaxAttempts = 3 var ErrTaskNotFound = errors.New("task not found") type TaskQueue[T any] interface { // Enqueue adds a task to the queue. // Returns the generated task ID. Enqueue(ctx context.Context, data T) (TaskInfo[T], error) // Dequeue removes a task from the queue and returns it. // The task data is placed into dataOut. // // Dequeue blocks until a task is available, timeout, or the context is canceled. // The returned task is placed in a pending state for lockTimeout duration. // The task must be completed with Complete or extended with Extend before the lock expires. // If the lock expires, the task is returned to the queue, where it may be picked up by another worker. Dequeue( ctx context.Context, lockTimeout, timeout time.Duration, ) (*TaskInfo[T], error) // Extend extends the lock on a task. Extend(ctx context.Context, taskID TaskID) error // Complete marks a task as complete. Optionally, an error can be provided to store additional information. Complete(ctx context.Context, taskID TaskID, result string) error // Data returns the info of a task. Data(ctx context.Context, taskID TaskID) (TaskInfo[T], error) // Return returns a task to the queue and returns the new task ID. // Increments the attempt counter. // Tasks with too many attempts (MaxAttempts) are considered failed and aren't returned to the queue. Return(ctx context.Context, taskID TaskID, err error) (TaskID, error) // List returns a list of task IDs in the queue. // The list is ordered by the time the task was added to the queue. The most recent task is first. // The count parameter limits the number of tasks returned. // The start and end parameters limit the range of tasks returned. // End is exclusive. // Start must be before end. List(ctx context.Context, start, end TaskID, count int64) ([]TaskInfo[T], error) } type TaskInfo[T any] struct { // ID is the unique identifier of the task. Generated by redis. ID TaskID // Data is the task data. Stored in stream. Data T // Attempts is the number of times the task has been attempted. Stored in stream. Attempts uint8 // Result is the result of the task. Stored in a hash. Result isTaskResult } type isTaskResult interface { isTaskResult() } type TaskResultSuccess struct { Result string } type TaskResultError struct { Error string NextAttempt TaskID } func (*TaskResultSuccess) isTaskResult() {} func (*TaskResultError) isTaskResult() {} type TaskID struct { timestamp time.Time sequence uint64 } func NewTaskID(timestamp time.Time, sequence uint64) TaskID { return TaskID{timestamp, sequence} } func ParseTaskID(s string) (TaskID, error) { tPart, sPart, ok := strings.Cut(s, "-") if !ok { return TaskID{}, errors.New("invalid task ID") } timestamp, err := strconv.ParseInt(tPart, 10, 64) if err != nil { return TaskID{}, err } sequence, err := strconv.ParseUint(sPart, 10, 64) if err != nil { return TaskID{}, err } return NewTaskID(time.UnixMilli(timestamp), sequence), nil } func (t TaskID) Timestamp() time.Time { return t.timestamp } func (t TaskID) String() string { tPart := strconv.FormatInt(t.timestamp.UnixMilli(), 10) sPart := strconv.FormatUint(t.sequence, 10) return tPart + "-" + sPart } type taskQueue[T any] struct { rdb *redis.Client encoding Encoding streamKey string groupName string workerName string } func New[T any](ctx context.Context, rdb *redis.Client, name string, workerName string, opts ...Option[T]) (TaskQueue[T], error) { tq := &taskQueue[T]{ rdb: rdb, encoding: EncodingJSON, streamKey: "taskqueue:" + name, groupName: "default", workerName: workerName, } for _, opt := range opts { opt(tq) } // Create the stream if it doesn't exist err := rdb.XGroupCreateMkStream(ctx, tq.streamKey, tq.groupName, "0").Err() if err != nil && err.Error() != "BUSYGROUP Consumer Group name already exists" { return nil, err } return tq, nil } func (q *taskQueue[T]) Enqueue(ctx context.Context, data T) (TaskInfo[T], error) { task := TaskInfo[T]{ Data: data, Attempts: 0, } values, err := encode[T](task, q.encoding) if err != nil { return TaskInfo[T]{}, err } taskID, err := q.rdb.XAdd(ctx, &redis.XAddArgs{ Stream: q.streamKey, Values: values, }).Result() if err != nil { return TaskInfo[T]{}, err } id, err := ParseTaskID(taskID) if err != nil { return TaskInfo[T]{}, err } task.ID = id return task, nil } func (q *taskQueue[T]) Dequeue(ctx context.Context, lockTimeout, timeout time.Duration) (*TaskInfo[T], error) { // Try to recover a task task, err := q.recover(ctx, lockTimeout) if err != nil { return nil, err } if task != nil { return task, nil } // Check for new tasks ids, err := q.rdb.XReadGroup(ctx, &redis.XReadGroupArgs{ Group: q.groupName, Consumer: q.workerName, Streams: []string{q.streamKey, ">"}, Count: 1, Block: timeout, }).Result() if err != nil && !errors.Is(err, redis.Nil) { return nil, err } if len(ids) == 0 || len(ids[0].Messages) == 0 || errors.Is(err, redis.Nil) { return nil, nil } msg := ids[0].Messages[0] task = new(TaskInfo[T]) *task, err = decode[T](&msg, q.encoding) if err != nil { return nil, err } return task, nil } func (q *taskQueue[T]) Extend(ctx context.Context, taskID TaskID) error { _, err := q.rdb.XClaim(ctx, &redis.XClaimArgs{ Stream: q.streamKey, Group: q.groupName, Consumer: q.workerName, MinIdle: 0, Messages: []string{taskID.String()}, }).Result() if err != nil && !errors.Is(err, redis.Nil) { return err } return nil } func (q *taskQueue[T]) Data(ctx context.Context, taskID TaskID) (TaskInfo[T], error) { msg, err := q.rdb.XRange(ctx, q.streamKey, taskID.String(), taskID.String()).Result() if err != nil { return TaskInfo[T]{}, err } if len(msg) == 0 { return TaskInfo[T]{}, ErrTaskNotFound } t, err := decode[T](&msg[0], q.encoding) if err != nil { return TaskInfo[T]{}, err } t.Result, err = q.getResult(ctx, taskID) if err != nil { return TaskInfo[T]{}, nil } return t, nil } func (q *taskQueue[T]) Complete(ctx context.Context, taskID TaskID, result string) error { return q.ack(ctx, taskID, false, result) } var retScript = redis.NewScript(` local stream_key = KEYS[1] local hash_key = KEYS[2] -- Re-add the task to the stream local task_id = redis.call('XADD', stream_key, '*', unpack(ARGV)) -- Update the hash key to point to the new task redis.call('HSET', hash_key, 'next_attempt', task_id) return task_id `) func (q *taskQueue[T]) Return(ctx context.Context, taskID TaskID, err1 error) (TaskID, error) { msgs, err := q.rdb.XRange(ctx, q.streamKey, taskID.String(), taskID.String()).Result() if err != nil { return TaskID{}, err } if len(msgs) == 0 { return TaskID{}, ErrTaskNotFound } var ackMsg string if err1 != nil { ackMsg = err1.Error() } // Ack the task err = q.ack(ctx, taskID, true, ackMsg) if err != nil { return TaskID{}, err } msg := msgs[0] task, err := decode[T](&msg, q.encoding) if err != nil { return TaskID{}, err } task.Attempts++ if int(task.Attempts) >= MaxAttempts { // Task has failed slog.ErrorContext(ctx, "task failed completely", "taskID", taskID, "data", task.Data, "attempts", task.Attempts, "maxAttempts", MaxAttempts, ) return TaskID{}, nil } valuesMap, err := encode[T](task, q.encoding) if err != nil { return TaskID{}, err } values := make([]string, 0, len(valuesMap)*2) for k, v := range valuesMap { values = append(values, k, v) } keys := []string{ q.streamKey, fmt.Sprintf("%s:%s", q.streamKey, taskID.String()), } newTaskId, err := retScript.Run(ctx, q.rdb, keys, values).Result() if err != nil { return TaskID{}, err } return ParseTaskID(newTaskId.(string)) } func (q *taskQueue[T]) List(ctx context.Context, start, end TaskID, count int64) ([]TaskInfo[T], error) { if !start.timestamp.IsZero() && !end.timestamp.IsZero() && start.timestamp.After(end.timestamp) { return nil, errors.New("start must be before end") } var startStr, endStr string if !start.timestamp.IsZero() { startStr = start.String() } else { startStr = "-" } if !end.timestamp.IsZero() { endStr = "(" + end.String() } else { endStr = "+" } msgs, err := q.rdb.XRevRangeN(ctx, q.streamKey, endStr, startStr, count).Result() if err != nil { return nil, err } if len(msgs) == 0 { return []TaskInfo[T]{}, nil } tasks := make([]TaskInfo[T], len(msgs)) for i := range msgs { tasks[i], err = decode[T](&msgs[i], q.encoding) if err != nil { return nil, err } tasks[i].Result, err = q.getResult(ctx, tasks[i].ID) if err != nil { return nil, err } } return tasks, nil } func (q *taskQueue[T]) getResult(ctx context.Context, taskID TaskID) (isTaskResult, error) { key := fmt.Sprintf("%s:%s", q.streamKey, taskID.String()) results, err := q.rdb.HMGet(ctx, key, ResultKey, ErrorKey, NextAttemptKey).Result() if err != nil { return nil, err } var ret isTaskResult if results[0] != nil { ret = &TaskResultSuccess{Result: results[0].(string)} } else if results[1] != nil { ret = &TaskResultError{Error: results[1].(string)} if results[2] != nil { nextAttempt, err := ParseTaskID(results[2].(string)) if err != nil { return nil, err } ret.(*TaskResultError).NextAttempt = nextAttempt } } return ret, nil } func (q *taskQueue[T]) recover(ctx context.Context, idleTimeout time.Duration) (*TaskInfo[T], error) { msgs, _, err := q.rdb.XAutoClaim(ctx, &redis.XAutoClaimArgs{ Stream: q.streamKey, Group: q.groupName, MinIdle: idleTimeout, Start: "0", Count: 1, Consumer: q.workerName, }).Result() if err != nil { return nil, err } if len(msgs) == 0 { return nil, nil } msg := msgs[0] task, err := decode[T](&msg, q.encoding) if err != nil { return nil, err } return &task, nil } func (q *taskQueue[T]) ack(ctx context.Context, taskID TaskID, errored bool, msg string) error { key := fmt.Sprintf("%s:%s", q.streamKey, taskID.String()) _, err := q.rdb.TxPipelined(ctx, func(pipe redis.Pipeliner) error { pipe.XAck(ctx, q.streamKey, q.groupName, taskID.String()) if errored { pipe.HSet(ctx, key, ErrorKey, msg) } else { pipe.HSet(ctx, key, ResultKey, msg) } return nil }) return err } func decode[T any](msg *redis.XMessage, encoding Encoding) (task TaskInfo[T], err error) { task.ID, err = ParseTaskID(msg.ID) if err != nil { return } err = getField(msg, "attempts", &task.Attempts) if err != nil { return } var data string err = getField(msg, "data", &data) if err != nil { return } switch encoding { case EncodingJSON: err = json.Unmarshal([]byte(data), &task.Data) case EncodingGob: var decoded []byte decoded, err = base64.StdEncoding.DecodeString(data) if err != nil { return } err = gob.NewDecoder(bytes.NewReader(decoded)).Decode(&task.Data) default: err = errors.New("unsupported encoding") } return } func getField(msg *redis.XMessage, field string, v any) error { vVal, ok := msg.Values[field] if !ok { return errors.New("missing field") } vStr, ok := vVal.(string) if !ok { return errors.New("invalid field type") } value := reflect.ValueOf(v).Elem() switch value.Kind() { case reflect.String: value.SetString(vStr) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: i, err := strconv.ParseInt(vStr, 10, 64) if err != nil { return err } value.SetInt(i) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: i, err := strconv.ParseUint(vStr, 10, 64) if err != nil { return err } value.SetUint(i) case reflect.Bool: b, err := strconv.ParseBool(vStr) if err != nil { return err } value.SetBool(b) default: return errors.New("unsupported field type") } return nil } func encode[T any](task TaskInfo[T], encoding Encoding) (ret map[string]string, err error) { ret = make(map[string]string) ret["attempts"] = strconv.FormatUint(uint64(task.Attempts), 10) switch encoding { case EncodingJSON: var data []byte data, err = json.Marshal(task.Data) if err != nil { return } ret["data"] = string(data) case EncodingGob: var data bytes.Buffer err = gob.NewEncoder(&data).Encode(task.Data) if err != nil { return } ret["data"] = base64.StdEncoding.EncodeToString(data.Bytes()) default: err = errors.New("unsupported encoding") } return }