diff options
author | 2024-08-08 16:53:59 -0700 | |
---|---|---|
committer | 2024-08-08 16:53:59 -0700 | |
commit | f34b92ded11b07f78575ac62c260a380c468e5ea (patch) | |
tree | 8ffdc68ed0f2e253e7f9feff3aa90a1182e5946c /backend | |
parent | a439618cdc8168bad617d04875697b572f3ed41d (diff) | |
download | ibd-trader-f34b92ded11b07f78575ac62c260a380c468e5ea.tar.gz ibd-trader-f34b92ded11b07f78575ac62c260a380c468e5ea.tar.zst ibd-trader-f34b92ded11b07f78575ac62c260a380c468e5ea.zip |
Rework redis taskqueue to store task results
Diffstat (limited to 'backend')
-rw-r--r-- | backend/internal/database/stocks.go | 16 | ||||
-rw-r--r-- | backend/internal/redis/taskqueue/queue.go | 169 | ||||
-rw-r--r-- | backend/internal/redis/taskqueue/queue_test.go | 61 | ||||
-rw-r--r-- | backend/internal/server/operations.go | 21 | ||||
-rw-r--r-- | backend/internal/worker/analyzer/analyzer.go | 49 | ||||
-rw-r--r-- | backend/internal/worker/auth/auth.go | 2 | ||||
-rw-r--r-- | backend/internal/worker/scraper/scraper.go | 55 |
7 files changed, 220 insertions, 153 deletions
diff --git a/backend/internal/database/stocks.go b/backend/internal/database/stocks.go index 0627a72..24f5fe7 100644 --- a/backend/internal/database/stocks.go +++ b/backend/internal/database/stocks.go @@ -151,8 +151,13 @@ WHERE r.id = $1;`, id) return &info, nil } -func AddAnalysis(ctx context.Context, exec Executor, ratingId string, analysis *analyzer.Analysis) error { - _, err := exec.ExecContext(ctx, ` +func AddAnalysis( + ctx context.Context, + exec Executor, + ratingId string, + analysis *analyzer.Analysis, +) (id string, err error) { + err = exec.QueryRowContext(ctx, ` UPDATE chart_analysis ca SET processed = true, action = $2, @@ -161,14 +166,15 @@ SET processed = true, confidence = $5 FROM ratings r WHERE r.id = $1 - AND r.chart_analysis = ca.id;`, + AND r.chart_analysis = ca.id +RETURNING ca.id;`, ratingId, analysis.Action, analysis.Price.Display(), analysis.Reason, analysis.Confidence, - ) - return err + ).Scan(&id) + return id, err } type Stock struct { diff --git a/backend/internal/redis/taskqueue/queue.go b/backend/internal/redis/taskqueue/queue.go index 1298a76..a4b799e 100644 --- a/backend/internal/redis/taskqueue/queue.go +++ b/backend/internal/redis/taskqueue/queue.go @@ -7,6 +7,7 @@ import ( "encoding/gob" "encoding/json" "errors" + "fmt" "log/slog" "reflect" "strconv" @@ -21,6 +22,10 @@ type Encoding uint8 const ( EncodingJSON Encoding = iota EncodingGob + + ResultKey = "result" + ErrorKey = "error" + NextAttemptKey = "next_attempt" ) var MaxAttempts = 3 @@ -48,7 +53,7 @@ type TaskQueue[T any] interface { Extend(ctx context.Context, taskID TaskID) error // Complete marks a task as complete. Optionally, an error can be provided to store additional information. - Complete(ctx context.Context, taskID TaskID, err error) error + Complete(ctx context.Context, taskID TaskID, result string) error // Data returns the info of a task. Data(ctx context.Context, taskID TaskID) (TaskInfo[T], error) @@ -74,12 +79,26 @@ type TaskInfo[T any] struct { Data T // Attempts is the number of times the task has been attempted. Stored in stream. Attempts uint8 - // Done is true if the task has been completed. True if ID in completed hash - Done bool - // Error is the error message if the task has failed. Stored in completed hash. - Error string + // Result is the result of the task. Stored in a hash. + Result isTaskResult +} + +type isTaskResult interface { + isTaskResult() +} + +type TaskResultSuccess struct { + Result string +} + +type TaskResultError struct { + Error string + NextAttempt TaskID } +func (*TaskResultSuccess) isTaskResult() {} +func (*TaskResultError) isTaskResult() {} + type TaskID struct { timestamp time.Time sequence uint64 @@ -125,19 +144,16 @@ type taskQueue[T any] struct { streamKey string groupName string - completedSetKey string - workerName string } func New[T any](ctx context.Context, rdb *redis.Client, name string, workerName string, opts ...Option[T]) (TaskQueue[T], error) { tq := &taskQueue[T]{ - rdb: rdb, - encoding: EncodingJSON, - streamKey: "taskqueue:" + name, - groupName: "default", - completedSetKey: "taskqueue:" + name + ":completed", - workerName: workerName, + rdb: rdb, + encoding: EncodingJSON, + streamKey: "taskqueue:" + name, + groupName: "default", + workerName: workerName, } for _, opt := range opts { @@ -244,35 +260,30 @@ func (q *taskQueue[T]) Data(ctx context.Context, taskID TaskID) (TaskInfo[T], er return TaskInfo[T]{}, err } - tErr, err := q.rdb.HGet(ctx, q.completedSetKey, taskID.String()).Result() - if err != nil && !errors.Is(err, redis.Nil) { - return TaskInfo[T]{}, err - } - - if errors.Is(err, redis.Nil) { - return t, nil + t.Result, err = q.getResult(ctx, taskID) + if err != nil { + return TaskInfo[T]{}, nil } - - t.Done = true - t.Error = tErr return t, nil } -func (q *taskQueue[T]) Complete(ctx context.Context, taskID TaskID, err error) error { - _, err = q.rdb.TxPipelined(ctx, func(pipe redis.Pipeliner) error { - pipe.XAck(ctx, q.streamKey, q.groupName, taskID.String()) - //xdel = pipe.XDel(ctx, q.streamKey, taskID.String()) - //pipe.SAdd(ctx, q.completedSetKey, taskID.String()) - if err != nil { - pipe.HSet(ctx, q.completedSetKey, taskID.String(), err.Error()) - } else { - pipe.HSet(ctx, q.completedSetKey, taskID.String(), "") - } - return nil - }) - return err +func (q *taskQueue[T]) Complete(ctx context.Context, taskID TaskID, result string) error { + return q.ack(ctx, taskID, false, result) } +var retScript = redis.NewScript(` +local stream_key = KEYS[1] +local hash_key = KEYS[2] + +-- Re-add the task to the stream +local task_id = redis.call('XADD', stream_key, '*', unpack(ARGV)) + +-- Update the hash key to point to the new task +redis.call('HSET', hash_key, 'next_attempt', task_id) + +return task_id +`) + func (q *taskQueue[T]) Return(ctx context.Context, taskID TaskID, err1 error) (TaskID, error) { msgs, err := q.rdb.XRange(ctx, q.streamKey, taskID.String(), taskID.String()).Result() if err != nil { @@ -282,8 +293,13 @@ func (q *taskQueue[T]) Return(ctx context.Context, taskID TaskID, err1 error) (T return TaskID{}, ErrTaskNotFound } - // Complete the task - err = q.Complete(ctx, taskID, err1) + var ackMsg string + if err1 != nil { + ackMsg = err1.Error() + } + + // Ack the task + err = q.ack(ctx, taskID, true, ackMsg) if err != nil { return TaskID{}, err } @@ -306,18 +322,26 @@ func (q *taskQueue[T]) Return(ctx context.Context, taskID TaskID, err1 error) (T return TaskID{}, nil } - values, err := encode[T](task, q.encoding) + valuesMap, err := encode[T](task, q.encoding) if err != nil { return TaskID{}, err } - newTaskId, err := q.rdb.XAdd(ctx, &redis.XAddArgs{ - Stream: q.streamKey, - Values: values, - }).Result() + + values := make([]string, 0, len(valuesMap)*2) + for k, v := range valuesMap { + values = append(values, k, v) + } + + keys := []string{ + q.streamKey, + fmt.Sprintf("%s:%s", q.streamKey, taskID.String()), + } + newTaskId, err := retScript.Run(ctx, q.rdb, keys, values).Result() if err != nil { return TaskID{}, err } - return ParseTaskID(newTaskId) + + return ParseTaskID(newTaskId.(string)) } func (q *taskQueue[T]) List(ctx context.Context, start, end TaskID, count int64) ([]TaskInfo[T], error) { @@ -345,32 +369,45 @@ func (q *taskQueue[T]) List(ctx context.Context, start, end TaskID, count int64) return []TaskInfo[T]{}, nil } - ids := make([]string, len(msgs)) - for i, msg := range msgs { - ids[i] = msg.ID - } - errs, err := q.rdb.HMGet(ctx, q.completedSetKey, ids...).Result() - if err != nil { - return nil, err - } - if len(errs) != len(msgs) { - return nil, errors.New("SMIsMember returned wrong number of results") - } - tasks := make([]TaskInfo[T], len(msgs)) for i := range msgs { tasks[i], err = decode[T](&msgs[i], q.encoding) if err != nil { return nil, err } - tasks[i].Done = errs[i] != nil - if tasks[i].Done { - tasks[i].Error = errs[i].(string) + + tasks[i].Result, err = q.getResult(ctx, tasks[i].ID) + if err != nil { + return nil, err } } + return tasks, nil } +func (q *taskQueue[T]) getResult(ctx context.Context, taskID TaskID) (isTaskResult, error) { + key := fmt.Sprintf("%s:%s", q.streamKey, taskID.String()) + results, err := q.rdb.HMGet(ctx, key, ResultKey, ErrorKey, NextAttemptKey).Result() + if err != nil { + return nil, err + } + + var ret isTaskResult + if results[0] != nil { + ret = &TaskResultSuccess{Result: results[0].(string)} + } else if results[1] != nil { + ret = &TaskResultError{Error: results[1].(string)} + if results[2] != nil { + nextAttempt, err := ParseTaskID(results[2].(string)) + if err != nil { + return nil, err + } + ret.(*TaskResultError).NextAttempt = nextAttempt + } + } + return ret, nil +} + func (q *taskQueue[T]) recover(ctx context.Context, idleTimeout time.Duration) (*TaskInfo[T], error) { msgs, _, err := q.rdb.XAutoClaim(ctx, &redis.XAutoClaimArgs{ Stream: q.streamKey, @@ -396,6 +433,20 @@ func (q *taskQueue[T]) recover(ctx context.Context, idleTimeout time.Duration) ( return &task, nil } +func (q *taskQueue[T]) ack(ctx context.Context, taskID TaskID, errored bool, msg string) error { + key := fmt.Sprintf("%s:%s", q.streamKey, taskID.String()) + _, err := q.rdb.TxPipelined(ctx, func(pipe redis.Pipeliner) error { + pipe.XAck(ctx, q.streamKey, q.groupName, taskID.String()) + if errored { + pipe.HSet(ctx, key, ErrorKey, msg) + } else { + pipe.HSet(ctx, key, ResultKey, msg) + } + return nil + }) + return err +} + func decode[T any](msg *redis.XMessage, encoding Encoding) (task TaskInfo[T], err error) { task.ID, err = ParseTaskID(msg.ID) if err != nil { diff --git a/backend/internal/redis/taskqueue/queue_test.go b/backend/internal/redis/taskqueue/queue_test.go index 774caa8..ee95d39 100644 --- a/backend/internal/redis/taskqueue/queue_test.go +++ b/backend/internal/redis/taskqueue/queue_test.go @@ -40,7 +40,7 @@ func TestMain(m *testing.M) { log.Fatalf("Could not start resource: %s", err) } - _ = resource.Expire(60) + //_ = resource.Expire(60) if err = pool.Retry(func() error { client = redis.NewClient(&redis.Options{ @@ -161,7 +161,7 @@ func TestTaskQueue(t *testing.T) { assert.Equal(t, "hello", task.Data) // Complete the task - err = q.Complete(context.Background(), task.ID, nil) + err = q.Complete(context.Background(), task.ID, "done") require.NoError(t, err) // Try to dequeue the task again @@ -354,7 +354,7 @@ func TestTaskQueue_List(t *testing.T) { task2, err := q.Enqueue(context.Background(), "world") require.NoError(t, err) - err = q.Complete(context.Background(), task1.ID, nil) + err = q.Complete(context.Background(), task1.ID, "done") require.NoError(t, err) tasks, err := q.List(context.Background(), TaskID{}, TaskID{}, 100) @@ -363,33 +363,8 @@ func TestTaskQueue_List(t *testing.T) { assert.Equal(t, task2, tasks[0]) assert.Equal(t, "hello", tasks[1].Data) - assert.Equal(t, true, tasks[1].Done) - assert.Equal(t, "", tasks[1].Error) - }, - }, - { - name: "failed tasks", - f: func(t *testing.T) { - q, err := New[string](context.Background(), client, "test", "worker") - require.NoError(t, err) - require.NotNil(t, q) - - task1, err := q.Enqueue(context.Background(), "hello") - require.NoError(t, err) - task2, err := q.Enqueue(context.Background(), "world") - require.NoError(t, err) - - err = q.Complete(context.Background(), task1.ID, errors.New("failed")) - require.NoError(t, err) - - tasks, err := q.List(context.Background(), TaskID{}, TaskID{}, 100) - require.NoError(t, err) - assert.Len(t, tasks, 2) - assert.Equal(t, task2, tasks[0]) - - assert.Equal(t, "hello", tasks[1].Data) - assert.Equal(t, true, tasks[1].Done) - assert.Equal(t, "failed", tasks[1].Error) + require.IsType(t, &TaskResultSuccess{}, tasks[1].Result) + assert.Equal(t, "done", tasks[1].Result.(*TaskResultSuccess).Result) }, }, } @@ -430,12 +405,20 @@ func TestTaskQueue_Return(t *testing.T) { id := claimAndFail(t, q, lockTimeout) - task3, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) + task2, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) + require.NoError(t, err) + require.NotNil(t, task2) + assert.Equal(t, task2.ID, id) + assert.Equal(t, task1.Data, task2.Data) + assert.Equal(t, uint8(1), task2.Attempts) + + task1Data, err := q.Data(context.Background(), task1.ID) require.NoError(t, err) - require.NotNil(t, task3) - assert.Equal(t, task3.ID, id) - assert.Equal(t, task1.Data, task3.Data) - assert.Equal(t, uint8(1), task3.Attempts) + assert.Equal(t, task1Data.ID, task1.ID) + assert.Equal(t, task1Data.Data, task1.Data) + assert.IsType(t, &TaskResultError{}, task1Data.Result) + assert.Equal(t, "failed", task1Data.Result.(*TaskResultError).Error) + assert.Equal(t, task2.ID, task1Data.Result.(*TaskResultError).NextAttempt) }, }, { @@ -473,12 +456,12 @@ func TestTaskQueue_Return(t *testing.T) { } func claimAndFail[T any](t *testing.T, q TaskQueue[T], lockTimeout time.Duration) TaskID { - task2, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) + task, err := q.Dequeue(context.Background(), lockTimeout, 10*time.Millisecond) require.NoError(t, err) - require.NotNil(t, task2) + require.NotNil(t, task) - id, err := q.Return(context.Background(), task2.ID, errors.New("failed")) + id, err := q.Return(context.Background(), task.ID, errors.New("failed")) require.NoError(t, err) - assert.NotEqual(t, task2.ID, id) + assert.NotEqual(t, task.ID, id) return id } diff --git a/backend/internal/server/operations.go b/backend/internal/server/operations.go index dab67f4..2487427 100644 --- a/backend/internal/server/operations.go +++ b/backend/internal/server/operations.go @@ -12,6 +12,7 @@ import ( "github.com/ansg191/ibd-trader-backend/internal/leader/manager/ibd/scrape" "github.com/ansg191/ibd-trader-backend/internal/redis/taskqueue" "github.com/ansg191/ibd-trader-backend/internal/server/idb/stock/v1" + epb "google.golang.org/genproto/googleapis/rpc/errdetails" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/anypb" @@ -55,7 +56,7 @@ func (o *operationServer) ListOperations( ops[i] = &longrunningpb.Operation{ Name: fmt.Sprintf("%s/%s", stock.ScrapeOperationPrefix, task.ID.String()), Metadata: new(anypb.Any), - Done: task.Done, + Done: task.Result != nil, Result: nil, } err = ops[i].Metadata.MarshalFrom(&spb.StockScrapeOperationMetadata{ @@ -66,8 +67,20 @@ func (o *operationServer) ListOperations( return nil, status.New(codes.Internal, "unable to marshal metadata").Err() } - if task.Done && task.Error != "" { - s := status.New(codes.Unknown, task.Error) + switch res := task.Result.(type) { + case *taskqueue.TaskResultSuccess: + return nil, status.New(codes.Unimplemented, "not implemented").Err() + case *taskqueue.TaskResultError: + s := status.New(codes.Unknown, res.Error) + s, err = s.WithDetails( + &epb.ErrorInfo{ + Reason: "", + Domain: "", + Metadata: nil, + }) + if err != nil { + return nil, status.New(codes.Internal, "unable to marshal error details").Err() + } ops[i].Result = &longrunningpb.Operation_Error{Error: s.Proto()} } } @@ -112,7 +125,7 @@ func (o *operationServer) GetOperation(ctx context.Context, req *longrunningpb.G op := &longrunningpb.Operation{ Name: req.Name, Metadata: new(anypb.Any), - Done: task.Done, + Done: task.Result != nil, Result: nil, } err = op.Metadata.MarshalFrom(&spb.StockScrapeOperationMetadata{ diff --git a/backend/internal/worker/analyzer/analyzer.go b/backend/internal/worker/analyzer/analyzer.go index ea8069e..20621dd 100644 --- a/backend/internal/worker/analyzer/analyzer.go +++ b/backend/internal/worker/analyzer/analyzer.go @@ -64,10 +64,17 @@ func waitForTask( return } - ch := make(chan error) - defer close(ch) + errCh := make(chan error) + resCh := make(chan string) + defer close(errCh) + defer close(resCh) go func() { - ch <- analyzeStock(ctx, analyzer, db, task.Data.ID) + res, err := analyzeStock(ctx, analyzer, db, task.Data.ID) + if err != nil { + errCh <- err + return + } + resCh <- res }() ticker := time.NewTicker(lockTimeout / 5) @@ -89,32 +96,32 @@ func waitForTask( slog.ErrorContext(ctx, "Failed to extend lock", "error", err) } }() - case err = <-ch: - // scrapeUrl has completed. + case err = <-errCh: + // analyzeStock has errored. + slog.ErrorContext(ctx, "Failed to analyze", "error", err) + _, err = queue.Return(ctx, task.ID, err) if err != nil { - slog.ErrorContext(ctx, "Failed to analyze", "error", err) - _, err = queue.Return(ctx, task.ID, err) - if err != nil { - slog.ErrorContext(ctx, "Failed to return task", "error", err) - return - } - } else { - slog.DebugContext(ctx, "Analyzed ID", "id", task.Data.ID) - err = queue.Complete(ctx, task.ID, nil) - if err != nil { - slog.ErrorContext(ctx, "Failed to complete task", "error", err) - return - } + slog.ErrorContext(ctx, "Failed to return task", "error", err) + return + } + return + case res := <-resCh: + // analyzeStock has completed successfully. + slog.DebugContext(ctx, "Analyzed ID", "id", task.Data.ID, "result", res) + err = queue.Complete(ctx, task.ID, res) + if err != nil { + slog.ErrorContext(ctx, "Failed to complete task", "error", err) + return } return } } } -func analyzeStock(ctx context.Context, a analyzer.Analyzer, db database.Executor, id string) error { +func analyzeStock(ctx context.Context, a analyzer.Analyzer, db database.Executor, id string) (string, error) { info, err := database.GetStockInfo(ctx, db, id) if err != nil { - return err + return "", err } analysis, err := a.Analyze( @@ -124,7 +131,7 @@ func analyzeStock(ctx context.Context, a analyzer.Analyzer, db database.Executor info.ChartAnalysis, ) if err != nil { - return err + return "", err } return database.AddAnalysis(ctx, db, id, analysis) diff --git a/backend/internal/worker/auth/auth.go b/backend/internal/worker/auth/auth.go index 2043b5e..0daa112 100644 --- a/backend/internal/worker/auth/auth.go +++ b/backend/internal/worker/auth/auth.go @@ -104,7 +104,7 @@ func waitForTask( return } } else { - err = queue.Complete(ctx, task.ID, nil) + err = queue.Complete(ctx, task.ID, "") if err != nil { slog.ErrorContext(ctx, "Failed to complete task", "error", err) return diff --git a/backend/internal/worker/scraper/scraper.go b/backend/internal/worker/scraper/scraper.go index 4788834..c5c1b6c 100644 --- a/backend/internal/worker/scraper/scraper.go +++ b/backend/internal/worker/scraper/scraper.go @@ -77,10 +77,17 @@ func waitForTask( return } - ch := make(chan error) + errCh := make(chan error) + resCh := make(chan string) + defer close(errCh) + defer close(resCh) go func() { - defer close(ch) - ch <- scrapeUrl(ctx, client, db, aQueue, task.Data.Symbol) + res, err := scrapeUrl(ctx, client, db, aQueue, task.Data.Symbol) + if err != nil { + errCh <- err + return + } + resCh <- res }() ticker := time.NewTicker(lockTimeout / 5) @@ -102,22 +109,22 @@ func waitForTask( slog.ErrorContext(ctx, "Failed to extend lock", "error", err) } }() - case err = <-ch: - // scrapeUrl has completed. + case err = <-errCh: + // scrapeUrl has errored. + slog.ErrorContext(ctx, "Failed to scrape URL", "error", err) + _, err = queue.Return(ctx, task.ID, err) if err != nil { - slog.ErrorContext(ctx, "Failed to scrape URL", "error", err) - _, err = queue.Return(ctx, task.ID, err) - if err != nil { - slog.ErrorContext(ctx, "Failed to return task", "error", err) - return - } - } else { - slog.DebugContext(ctx, "Scraped URL", "symbol", task.Data.Symbol) - err = queue.Complete(ctx, task.ID, nil) - if err != nil { - slog.ErrorContext(ctx, "Failed to complete task", "error", err) - return - } + slog.ErrorContext(ctx, "Failed to return task", "error", err) + return + } + return + case res := <-resCh: + // scrapeUrl has completed successfully. + slog.DebugContext(ctx, "Scraped URL", "symbol", task.Data.Symbol) + err = queue.Complete(ctx, task.ID, res) + if err != nil { + slog.ErrorContext(ctx, "Failed to complete task", "error", err) + return } return } @@ -130,36 +137,36 @@ func scrapeUrl( db database.TransactionExecutor, aQueue taskqueue.TaskQueue[analyzer.TaskInfo], symbol string, -) error { +) (string, error) { ctx, cancel := context.WithTimeout(ctx, lockTimeout) defer cancel() stockUrl, err := getStockUrl(ctx, db, client, symbol) if err != nil { - return fmt.Errorf("failed to get stock url: %w", err) + return "", fmt.Errorf("failed to get stock url: %w", err) } // Scrape the stock info. info, err := client.StockInfo(ctx, stockUrl) if err != nil { - return fmt.Errorf("failed to get stock info: %w", err) + return "", fmt.Errorf("failed to get stock info: %w", err) } // Add stock info to the database. id, err := database.AddStockInfo(ctx, db, info) if err != nil { - return fmt.Errorf("failed to add stock info: %w", err) + return "", fmt.Errorf("failed to add stock info: %w", err) } // Add the stock to the analyzer queue. _, err = aQueue.Enqueue(ctx, analyzer.TaskInfo{ID: id}) if err != nil { - return fmt.Errorf("failed to enqueue analysis task: %w", err) + return "", fmt.Errorf("failed to enqueue analysis task: %w", err) } slog.DebugContext(ctx, "Added stock info", "id", id) - return nil + return id, nil } func getStockUrl(ctx context.Context, db database.TransactionExecutor, client *ibd.Client, symbol string) (string, error) { |