diff options
Diffstat (limited to 'backend/internal/analyzer/openai/openai.go')
-rw-r--r-- | backend/internal/analyzer/openai/openai.go | 126 |
1 files changed, 126 insertions, 0 deletions
diff --git a/backend/internal/analyzer/openai/openai.go b/backend/internal/analyzer/openai/openai.go new file mode 100644 index 0000000..0419c57 --- /dev/null +++ b/backend/internal/analyzer/openai/openai.go @@ -0,0 +1,126 @@ +package openai + +import ( + "context" + "encoding/json" + "fmt" + "math" + "strconv" + "strings" + "time" + + "github.com/ansg191/ibd-trader-backend/internal/analyzer" + "github.com/ansg191/ibd-trader-backend/internal/utils" + + "github.com/Rhymond/go-money" + "github.com/sashabaranov/go-openai" +) + +type Client interface { + CreateChatCompletion( + ctx context.Context, + request openai.ChatCompletionRequest, + ) (response openai.ChatCompletionResponse, err error) +} + +type Analyzer struct { + client Client + model string + systemMsg string + temperature float32 +} + +func NewAnalyzer(opts ...Option) *Analyzer { + a := &Analyzer{ + client: nil, + model: defaultModel, + systemMsg: defaultSystemMsg, + temperature: defaultTemperature, + } + for _, option := range opts { + option(a) + } + if a.client == nil { + panic("client is required") + } + + return a +} + +func (a *Analyzer) Analyze( + ctx context.Context, + symbol string, + price *money.Money, + rawAnalysis string, +) (*analyzer.Analysis, error) { + usrMsg := fmt.Sprintf( + "%s\n%s\n%s\n%s\n", + time.Now().Format(time.RFC3339), + symbol, + price.Display(), + rawAnalysis, + ) + res, err := a.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{ + Model: a.model, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleSystem, + Content: a.systemMsg, + }, + { + Role: openai.ChatMessageRoleUser, + Content: usrMsg, + }, + }, + MaxTokens: 0, + Temperature: a.temperature, + ResponseFormat: &openai.ChatCompletionResponseFormat{Type: openai.ChatCompletionResponseFormatTypeJSONObject}, + }) + if err != nil { + return nil, err + } + + var resp response + if err = json.Unmarshal([]byte(res.Choices[0].Message.Content), &resp); err != nil { + return nil, fmt.Errorf("failed to unmarshal gpt response: %w", err) + } + + var action analyzer.ChartAction + switch strings.ToLower(resp.Action) { + case "buy": + action = analyzer.Buy + case "sell": + action = analyzer.Sell + case "hold": + action = analyzer.Hold + default: + action = analyzer.Unknown + } + + m, err := utils.ParseMoney(resp.Price) + if err != nil { + return nil, fmt.Errorf("failed to parse price: %w", err) + } + + confidence, err := strconv.ParseFloat(resp.Confidence, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse confidence: %w", err) + } + if confidence < 0 || confidence > 100 { + return nil, fmt.Errorf("confidence must be between 0 and 100, got %f", confidence) + } + + return &analyzer.Analysis{ + Action: action, + Price: m, + Reason: resp.Reason, + Confidence: uint8(math.Floor(confidence)), + }, nil +} + +type response struct { + Action string `json:"action"` + Price string `json:"price"` + Reason string `json:"reason"` + Confidence string `json:"confidence"` +} |