diff options
Diffstat (limited to 'backend/internal/server')
-rw-r--r-- | backend/internal/server/idb/stock/v1/stock.go | 64 | ||||
-rw-r--r-- | backend/internal/server/idb/user/v1/user.go | 159 | ||||
-rw-r--r-- | backend/internal/server/operations.go | 142 | ||||
-rw-r--r-- | backend/internal/server/server.go | 77 |
4 files changed, 442 insertions, 0 deletions
diff --git a/backend/internal/server/idb/stock/v1/stock.go b/backend/internal/server/idb/stock/v1/stock.go new file mode 100644 index 0000000..8afc2b1 --- /dev/null +++ b/backend/internal/server/idb/stock/v1/stock.go @@ -0,0 +1,64 @@ +package stock + +import ( + "context" + "fmt" + "log/slog" + + pb "github.com/ansg191/ibd-trader-backend/api/gen/idb/stock/v1" + "github.com/ansg191/ibd-trader-backend/internal/database" + "github.com/ansg191/ibd-trader-backend/internal/leader/manager/ibd/scrape" + "github.com/ansg191/ibd-trader-backend/internal/redis/taskqueue" + + "cloud.google.com/go/longrunning/autogen/longrunningpb" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/timestamppb" +) + +const ScrapeOperationPrefix = "scrape" + +type Server struct { + pb.UnimplementedStockServiceServer + + db database.Executor + queue taskqueue.TaskQueue[scrape.TaskInfo] +} + +func New(db database.Executor, queue taskqueue.TaskQueue[scrape.TaskInfo]) *Server { + return &Server{db: db, queue: queue} +} + +func (s *Server) CreateStock(ctx context.Context, request *pb.CreateStockRequest) (*pb.CreateStockResponse, error) { + task, err := s.queue.Enqueue(ctx, scrape.TaskInfo{Symbol: request.Symbol}) + if err != nil { + slog.ErrorContext(ctx, "failed to enqueue task", "err", err) + return nil, status.New(codes.Internal, "failed to enqueue task").Err() + } + op := &longrunningpb.Operation{ + Name: fmt.Sprintf("%s/%s", ScrapeOperationPrefix, task.ID.String()), + Metadata: new(anypb.Any), + Done: false, + Result: nil, + } + err = op.Metadata.MarshalFrom(&pb.StockScrapeOperationMetadata{ + Symbol: request.Symbol, + StartTime: timestamppb.New(task.ID.Timestamp()), + }) + if err != nil { + slog.ErrorContext(ctx, "failed to marshal metadata", "err", err) + return nil, status.New(codes.Internal, "failed to marshal metadata").Err() + } + return &pb.CreateStockResponse{Operation: op}, nil +} + +func (s *Server) GetStock(ctx context.Context, request *pb.GetStockRequest) (*pb.GetStockResponse, error) { + //TODO implement me + panic("implement me") +} + +func (s *Server) ListStocks(ctx context.Context, request *pb.ListStocksRequest) (*pb.ListStocksResponse, error) { + //TODO implement me + panic("implement me") +} diff --git a/backend/internal/server/idb/user/v1/user.go b/backend/internal/server/idb/user/v1/user.go new file mode 100644 index 0000000..2f32e03 --- /dev/null +++ b/backend/internal/server/idb/user/v1/user.go @@ -0,0 +1,159 @@ +package user + +import ( + "context" + "errors" + + pb "github.com/ansg191/ibd-trader-backend/api/gen/idb/user/v1" + "github.com/ansg191/ibd-trader-backend/internal/database" + "github.com/ansg191/ibd-trader-backend/internal/ibd" + "github.com/ansg191/ibd-trader-backend/internal/keys" + + "github.com/mennanov/fmutils" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" +) + +type Server struct { + pb.UnimplementedUserServiceServer + + db database.TransactionExecutor + kms keys.KeyManagementService + keyName string + client *ibd.Client +} + +func New(db database.TransactionExecutor, kms keys.KeyManagementService, keyName string, client *ibd.Client) *Server { + return &Server{ + db: db, + kms: kms, + keyName: keyName, + client: client, + } +} + +func (u *Server) CreateUser(ctx context.Context, request *pb.CreateUserRequest) (*pb.CreateUserResponse, error) { + err := database.AddUser(ctx, u.db, request.Subject) + if err != nil { + return nil, status.Errorf(codes.Internal, "unable to create user: %v", err) + } + + user, err := database.GetUser(ctx, u.db, request.Subject) + if err != nil { + return nil, status.Errorf(codes.Internal, "unable to get user: %v", err) + } + + return &pb.CreateUserResponse{ + User: &pb.User{ + Subject: user.Subject, + IbdUsername: user.IBDUsername, + IbdPassword: nil, + }, + }, nil +} + +func (u *Server) GetUser(ctx context.Context, request *pb.GetUserRequest) (*pb.GetUserResponse, error) { + user, err := database.GetUser(ctx, u.db, request.Subject) + if errors.Is(err, database.ErrUserNotFound) { + return nil, status.New(codes.NotFound, "user not found").Err() + } + if err != nil { + return nil, status.Errorf(codes.Internal, "unable to get user: %v", err) + } + + return &pb.GetUserResponse{ + User: &pb.User{ + Subject: user.Subject, + IbdUsername: user.IBDUsername, + IbdPassword: nil, + }, + }, nil +} + +func (u *Server) UpdateUser(ctx context.Context, request *pb.UpdateUserRequest) (*pb.UpdateUserResponse, error) { + request.UpdateMask.Normalize() + if !request.UpdateMask.IsValid(request.User) { + return nil, status.Errorf(codes.InvalidArgument, "invalid update mask") + } + + existingUserRes, err := u.GetUser(ctx, &pb.GetUserRequest{Subject: request.User.Subject}) + if err != nil { + return nil, err + } + existingUser := existingUserRes.User + + newUser := proto.Clone(existingUser).(*pb.User) + fmutils.Overwrite(request.User, newUser, request.UpdateMask.Paths) + + // if IDB creds are both set and are different, update them + if (newUser.IbdPassword != nil && newUser.IbdUsername != nil) && + (newUser.IbdPassword != existingUser.IbdPassword || + newUser.IbdUsername != existingUser.IbdUsername) { + // Update IBD creds + err = database.AddIBDCreds(ctx, u.db, u.kms, u.keyName, newUser.Subject, *newUser.IbdUsername, *newUser.IbdPassword) + if err != nil { + return nil, status.Errorf(codes.Internal, "unable to update user: %v", err) + } + } + + newUser.IbdPassword = nil + return &pb.UpdateUserResponse{ + User: newUser, + }, nil +} + +func (u *Server) CheckIBDUsername(ctx context.Context, req *pb.CheckIBDUsernameRequest) (*pb.CheckIBDUsernameResponse, error) { + username := req.IbdUsername + if username == "" { + return nil, status.Errorf(codes.InvalidArgument, "username cannot be empty") + } + + // Check if the username exists + exists, err := u.client.CheckIBDUsername(ctx, username) + if err != nil { + return nil, status.Errorf(codes.Internal, "unable to check username: %v", err) + } + + return &pb.CheckIBDUsernameResponse{ + Exists: exists, + }, nil +} + +func (u *Server) AuthenticateUser(ctx context.Context, req *pb.AuthenticateUserRequest) (*pb.AuthenticateUserResponse, error) { + // Check if user has cookies + cookies, err := database.GetCookies(ctx, u.db, u.kms, req.Subject, false) + if err != nil { + return nil, status.Errorf(codes.Internal, "unable to get cookies: %v", err) + } + if len(cookies) > 0 { + return &pb.AuthenticateUserResponse{ + Authenticated: true, + }, nil + } + + // Authenticate user + // Get IBD creds + username, password, err := database.GetIBDCreds(ctx, u.db, u.kms, req.Subject) + if errors.Is(err, database.ErrIBDCredsNotFound) { + return nil, status.New(codes.NotFound, "User has no IDB creds").Err() + } + if err != nil { + return nil, status.Errorf(codes.Internal, "unable to get IBD creds: %v", err) + } + + // Authenticate user + cookie, err := u.client.Authenticate(ctx, username, password) + if errors.Is(err, ibd.ErrBadCredentials) { + return &pb.AuthenticateUserResponse{ + Authenticated: false, + }, nil + } + if err != nil { + return nil, status.Errorf(codes.Internal, "unable to authenticate user: %v", err) + } + + return &pb.AuthenticateUserResponse{ + Authenticated: cookie != nil, + }, nil +} diff --git a/backend/internal/server/operations.go b/backend/internal/server/operations.go new file mode 100644 index 0000000..2487427 --- /dev/null +++ b/backend/internal/server/operations.go @@ -0,0 +1,142 @@ +package server + +import ( + "context" + "errors" + "fmt" + "log/slog" + "strings" + + "cloud.google.com/go/longrunning/autogen/longrunningpb" + spb "github.com/ansg191/ibd-trader-backend/api/gen/idb/stock/v1" + "github.com/ansg191/ibd-trader-backend/internal/leader/manager/ibd/scrape" + "github.com/ansg191/ibd-trader-backend/internal/redis/taskqueue" + "github.com/ansg191/ibd-trader-backend/internal/server/idb/stock/v1" + epb "google.golang.org/genproto/googleapis/rpc/errdetails" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/timestamppb" +) + +type operationServer struct { + longrunningpb.UnimplementedOperationsServer + + scrape taskqueue.TaskQueue[scrape.TaskInfo] +} + +func newOperationServer(scrapeQueue taskqueue.TaskQueue[scrape.TaskInfo]) *operationServer { + return &operationServer{scrape: scrapeQueue} +} + +func (o *operationServer) ListOperations( + ctx context.Context, + req *longrunningpb.ListOperationsRequest, +) (*longrunningpb.ListOperationsResponse, error) { + var end taskqueue.TaskID + if req.PageToken != "" { + var err error + end, err = taskqueue.ParseTaskID(req.PageToken) + if err != nil { + return nil, status.New(codes.InvalidArgument, err.Error()).Err() + } + } else { + end = taskqueue.TaskID{} + } + + switch req.Name { + case stock.ScrapeOperationPrefix: + tasks, err := o.scrape.List(ctx, taskqueue.TaskID{}, end, int64(req.PageSize)) + if err != nil { + return nil, status.New(codes.Internal, "unable to list IDs").Err() + } + + ops := make([]*longrunningpb.Operation, len(tasks)) + for i, task := range tasks { + ops[i] = &longrunningpb.Operation{ + Name: fmt.Sprintf("%s/%s", stock.ScrapeOperationPrefix, task.ID.String()), + Metadata: new(anypb.Any), + Done: task.Result != nil, + Result: nil, + } + err = ops[i].Metadata.MarshalFrom(&spb.StockScrapeOperationMetadata{ + Symbol: task.Data.Symbol, + StartTime: timestamppb.New(task.ID.Timestamp()), + }) + if err != nil { + return nil, status.New(codes.Internal, "unable to marshal metadata").Err() + } + + switch res := task.Result.(type) { + case *taskqueue.TaskResultSuccess: + return nil, status.New(codes.Unimplemented, "not implemented").Err() + case *taskqueue.TaskResultError: + s := status.New(codes.Unknown, res.Error) + s, err = s.WithDetails( + &epb.ErrorInfo{ + Reason: "", + Domain: "", + Metadata: nil, + }) + if err != nil { + return nil, status.New(codes.Internal, "unable to marshal error details").Err() + } + ops[i].Result = &longrunningpb.Operation_Error{Error: s.Proto()} + } + } + + var nextPageToken string + if len(tasks) == int(req.PageSize) { + nextPageToken = tasks[len(tasks)-1].ID.String() + } else { + nextPageToken = "" + } + + return &longrunningpb.ListOperationsResponse{ + Operations: ops, + NextPageToken: nextPageToken, + }, nil + default: + return nil, status.New(codes.NotFound, "unknown operation type").Err() + } +} + +func (o *operationServer) GetOperation(ctx context.Context, req *longrunningpb.GetOperationRequest) (*longrunningpb.Operation, error) { + prefix, id, ok := strings.Cut(req.Name, "/") + if !ok || prefix == "" || id == "" { + return nil, status.New(codes.InvalidArgument, "invalid operation name").Err() + } + + taskID, err := taskqueue.ParseTaskID(id) + if err != nil { + return nil, status.New(codes.InvalidArgument, err.Error()).Err() + } + + switch prefix { + case stock.ScrapeOperationPrefix: + task, err := o.scrape.Data(ctx, taskID) + if errors.Is(err, taskqueue.ErrTaskNotFound) { + return nil, status.New(codes.NotFound, "operation not found").Err() + } + if err != nil { + slog.ErrorContext(ctx, "unable to get operation", "error", err) + return nil, status.New(codes.Internal, "unable to get operation").Err() + } + op := &longrunningpb.Operation{ + Name: req.Name, + Metadata: new(anypb.Any), + Done: task.Result != nil, + Result: nil, + } + err = op.Metadata.MarshalFrom(&spb.StockScrapeOperationMetadata{ + Symbol: task.Data.Symbol, + StartTime: timestamppb.New(task.ID.Timestamp()), + }) + if err != nil { + return nil, status.New(codes.Internal, "unable to marshal metadata").Err() + } + return op, nil + default: + return nil, status.New(codes.NotFound, "unknown operation type").Err() + } +} diff --git a/backend/internal/server/server.go b/backend/internal/server/server.go new file mode 100644 index 0000000..c525cfd --- /dev/null +++ b/backend/internal/server/server.go @@ -0,0 +1,77 @@ +package server + +import ( + "context" + "fmt" + "log/slog" + "net" + + "cloud.google.com/go/longrunning/autogen/longrunningpb" + spb "github.com/ansg191/ibd-trader-backend/api/gen/idb/stock/v1" + upb "github.com/ansg191/ibd-trader-backend/api/gen/idb/user/v1" + "github.com/ansg191/ibd-trader-backend/internal/database" + "github.com/ansg191/ibd-trader-backend/internal/ibd" + "github.com/ansg191/ibd-trader-backend/internal/keys" + "github.com/ansg191/ibd-trader-backend/internal/leader/manager/ibd/scrape" + "github.com/ansg191/ibd-trader-backend/internal/redis/taskqueue" + "github.com/ansg191/ibd-trader-backend/internal/server/idb/stock/v1" + "github.com/ansg191/ibd-trader-backend/internal/server/idb/user/v1" + "github.com/redis/go-redis/v9" + "google.golang.org/grpc" + "google.golang.org/grpc/reflection" +) + +//go:generate make -C ../../api/ generate + +type Server struct { + s *grpc.Server + port uint16 +} + +func New( + ctx context.Context, + port uint16, + db database.TransactionExecutor, + rClient *redis.Client, + client *ibd.Client, + kms keys.KeyManagementService, + keyName string, +) (*Server, error) { + scrapeQueue, err := taskqueue.New( + ctx, + rClient, + scrape.Queue, + "grpc-server", + taskqueue.WithEncoding[scrape.TaskInfo](scrape.QueueEncoding)) + if err != nil { + return nil, err + } + + s := grpc.NewServer() + upb.RegisterUserServiceServer(s, user.New(db, kms, keyName, client)) + spb.RegisterStockServiceServer(s, stock.New(db, scrapeQueue)) + longrunningpb.RegisterOperationsServer(s, newOperationServer(scrapeQueue)) + reflection.Register(s) + return &Server{s, port}, nil +} + +func (s *Server) Serve(ctx context.Context) error { + lis, err := net.Listen("tcp", fmt.Sprintf(":%d", s.port)) + if err != nil { + return err + } + + // Graceful shutdown + go func() { + <-ctx.Done() + slog.ErrorContext(ctx, + "Shutting down server", + "err", ctx.Err(), + "cause", context.Cause(ctx), + ) + s.s.GracefulStop() + }() + + slog.InfoContext(ctx, "Starting gRPC server", "port", s.port) + return s.s.Serve(lis) +} |