aboutsummaryrefslogtreecommitdiff
path: root/backend/internal/analyzer/openai/openai.go
blob: 0419c57e0aed3c3edf87e52b818ee8aee3d33e18 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
package openai

import (
	"context"
	"encoding/json"
	"fmt"
	"math"
	"strconv"
	"strings"
	"time"

	"github.com/ansg191/ibd-trader-backend/internal/analyzer"
	"github.com/ansg191/ibd-trader-backend/internal/utils"

	"github.com/Rhymond/go-money"
	"github.com/sashabaranov/go-openai"
)

type Client interface {
	CreateChatCompletion(
		ctx context.Context,
		request openai.ChatCompletionRequest,
	) (response openai.ChatCompletionResponse, err error)
}

type Analyzer struct {
	client      Client
	model       string
	systemMsg   string
	temperature float32
}

func NewAnalyzer(opts ...Option) *Analyzer {
	a := &Analyzer{
		client:      nil,
		model:       defaultModel,
		systemMsg:   defaultSystemMsg,
		temperature: defaultTemperature,
	}
	for _, option := range opts {
		option(a)
	}
	if a.client == nil {
		panic("client is required")
	}

	return a
}

func (a *Analyzer) Analyze(
	ctx context.Context,
	symbol string,
	price *money.Money,
	rawAnalysis string,
) (*analyzer.Analysis, error) {
	usrMsg := fmt.Sprintf(
		"%s\n%s\n%s\n%s\n",
		time.Now().Format(time.RFC3339),
		symbol,
		price.Display(),
		rawAnalysis,
	)
	res, err := a.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
		Model: a.model,
		Messages: []openai.ChatCompletionMessage{
			{
				Role:    openai.ChatMessageRoleSystem,
				Content: a.systemMsg,
			},
			{
				Role:    openai.ChatMessageRoleUser,
				Content: usrMsg,
			},
		},
		MaxTokens:      0,
		Temperature:    a.temperature,
		ResponseFormat: &openai.ChatCompletionResponseFormat{Type: openai.ChatCompletionResponseFormatTypeJSONObject},
	})
	if err != nil {
		return nil, err
	}

	var resp response
	if err = json.Unmarshal([]byte(res.Choices[0].Message.Content), &resp); err != nil {
		return nil, fmt.Errorf("failed to unmarshal gpt response: %w", err)
	}

	var action analyzer.ChartAction
	switch strings.ToLower(resp.Action) {
	case "buy":
		action = analyzer.Buy
	case "sell":
		action = analyzer.Sell
	case "hold":
		action = analyzer.Hold
	default:
		action = analyzer.Unknown
	}

	m, err := utils.ParseMoney(resp.Price)
	if err != nil {
		return nil, fmt.Errorf("failed to parse price: %w", err)
	}

	confidence, err := strconv.ParseFloat(resp.Confidence, 64)
	if err != nil {
		return nil, fmt.Errorf("failed to parse confidence: %w", err)
	}
	if confidence < 0 || confidence > 100 {
		return nil, fmt.Errorf("confidence must be between 0 and 100, got %f", confidence)
	}

	return &analyzer.Analysis{
		Action:     action,
		Price:      m,
		Reason:     resp.Reason,
		Confidence: uint8(math.Floor(confidence)),
	}, nil
}

type response struct {
	Action     string `json:"action"`
	Price      string `json:"price"`
	Reason     string `json:"reason"`
	Confidence string `json:"confidence"`
}