aboutsummaryrefslogtreecommitdiff
path: root/backend/internal/analyzer/openai/openai.go
diff options
context:
space:
mode:
Diffstat (limited to 'backend/internal/analyzer/openai/openai.go')
-rw-r--r--backend/internal/analyzer/openai/openai.go126
1 files changed, 126 insertions, 0 deletions
diff --git a/backend/internal/analyzer/openai/openai.go b/backend/internal/analyzer/openai/openai.go
new file mode 100644
index 0000000..0419c57
--- /dev/null
+++ b/backend/internal/analyzer/openai/openai.go
@@ -0,0 +1,126 @@
+package openai
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "math"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/ansg191/ibd-trader-backend/internal/analyzer"
+ "github.com/ansg191/ibd-trader-backend/internal/utils"
+
+ "github.com/Rhymond/go-money"
+ "github.com/sashabaranov/go-openai"
+)
+
+type Client interface {
+ CreateChatCompletion(
+ ctx context.Context,
+ request openai.ChatCompletionRequest,
+ ) (response openai.ChatCompletionResponse, err error)
+}
+
+type Analyzer struct {
+ client Client
+ model string
+ systemMsg string
+ temperature float32
+}
+
+func NewAnalyzer(opts ...Option) *Analyzer {
+ a := &Analyzer{
+ client: nil,
+ model: defaultModel,
+ systemMsg: defaultSystemMsg,
+ temperature: defaultTemperature,
+ }
+ for _, option := range opts {
+ option(a)
+ }
+ if a.client == nil {
+ panic("client is required")
+ }
+
+ return a
+}
+
+func (a *Analyzer) Analyze(
+ ctx context.Context,
+ symbol string,
+ price *money.Money,
+ rawAnalysis string,
+) (*analyzer.Analysis, error) {
+ usrMsg := fmt.Sprintf(
+ "%s\n%s\n%s\n%s\n",
+ time.Now().Format(time.RFC3339),
+ symbol,
+ price.Display(),
+ rawAnalysis,
+ )
+ res, err := a.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
+ Model: a.model,
+ Messages: []openai.ChatCompletionMessage{
+ {
+ Role: openai.ChatMessageRoleSystem,
+ Content: a.systemMsg,
+ },
+ {
+ Role: openai.ChatMessageRoleUser,
+ Content: usrMsg,
+ },
+ },
+ MaxTokens: 0,
+ Temperature: a.temperature,
+ ResponseFormat: &openai.ChatCompletionResponseFormat{Type: openai.ChatCompletionResponseFormatTypeJSONObject},
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ var resp response
+ if err = json.Unmarshal([]byte(res.Choices[0].Message.Content), &resp); err != nil {
+ return nil, fmt.Errorf("failed to unmarshal gpt response: %w", err)
+ }
+
+ var action analyzer.ChartAction
+ switch strings.ToLower(resp.Action) {
+ case "buy":
+ action = analyzer.Buy
+ case "sell":
+ action = analyzer.Sell
+ case "hold":
+ action = analyzer.Hold
+ default:
+ action = analyzer.Unknown
+ }
+
+ m, err := utils.ParseMoney(resp.Price)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse price: %w", err)
+ }
+
+ confidence, err := strconv.ParseFloat(resp.Confidence, 64)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse confidence: %w", err)
+ }
+ if confidence < 0 || confidence > 100 {
+ return nil, fmt.Errorf("confidence must be between 0 and 100, got %f", confidence)
+ }
+
+ return &analyzer.Analysis{
+ Action: action,
+ Price: m,
+ Reason: resp.Reason,
+ Confidence: uint8(math.Floor(confidence)),
+ }, nil
+}
+
+type response struct {
+ Action string `json:"action"`
+ Price string `json:"price"`
+ Reason string `json:"reason"`
+ Confidence string `json:"confidence"`
+}