diff options
Diffstat (limited to 'backend/internal/ibd/client_test.go')
-rw-r--r-- | backend/internal/ibd/client_test.go | 241 |
1 files changed, 241 insertions, 0 deletions
diff --git a/backend/internal/ibd/client_test.go b/backend/internal/ibd/client_test.go new file mode 100644 index 0000000..577987d --- /dev/null +++ b/backend/internal/ibd/client_test.go @@ -0,0 +1,241 @@ +package ibd + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "ibd-trader/internal/database" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const apiKey = "test-api-key-123" + +func newServer(t *testing.T, handler http.Handler) *httptest.Server { + t.Helper() + + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.Close = true + defer r.Body.Close() + req := reconstructReq(t, r) + + rw := newResponseWriter() + handler.ServeHTTP(rw, req) + require.NoError(t, rw.Done(w)) + })) +} + +func reconstructReq(t *testing.T, r *http.Request) *http.Request { + t.Helper() + + params := r.URL.Query() + require.Equal(t, apiKey, params.Get("key")) + + // Reconstruct the request from the query params + var key string + var url string + headers := make(http.Header) + for k, v := range params { + switch k { + case "key": + key = v[0] + case "url": + url = v[0] + default: + if strings.HasPrefix(k, "headers") { + var name string + // Get index of first [ + i := strings.Index(k, "[") + if i == -1 { + t.Fatalf("invalid header key: %s", k) + } + // Get index of first ] + j := strings.Index(k, "]") + if j == -1 { + t.Fatalf("invalid header key: %s", k) + } + + // Get the header name + name = k[i+1 : j] + headers.Set(name, v[0]) + } + } + } + require.Equal(t, apiKey, key) + require.NotEmpty(t, url) + + req, err := http.NewRequestWithContext(r.Context(), r.Method, url, r.Body) + require.NoError(t, err) + req.Header = headers + + return req +} + +type responsewriter struct { + ret ScraperResponse + body bytes.Buffer + headers http.Header +} + +func newResponseWriter() *responsewriter { + return &responsewriter{ + headers: make(http.Header), + } +} + +func (w *responsewriter) Header() http.Header { + return w.headers +} + +func (w *responsewriter) Write(bytes []byte) (int, error) { + if w.ret.Result.StatusCode == 0 { + w.ret.Result.StatusCode = http.StatusOK + } + return w.body.Write(bytes) +} + +func (w *responsewriter) WriteHeader(statusCode int) { + w.ret.Result.StatusCode = statusCode +} + +func (w *responsewriter) Done(rw http.ResponseWriter) error { + w.ret.Result.Content = w.body.String() + + w.ret.Result.ResponseHeaders = make(map[string]string) + for k, v := range w.headers { + if k == "Set-Cookie" { + continue + } + w.ret.Result.ResponseHeaders[k] = v[0] + } + + req := http.Response{Header: w.headers} + w.ret.Result.Cookies = make([]ScraperCookie, 0) + for _, c := range req.Cookies() { + var cookie ScraperCookie + cookie.FromHTTPCookie(c) + w.ret.Result.Cookies = append(w.ret.Result.Cookies, cookie) + } + + rw.WriteHeader(http.StatusOK) + return json.NewEncoder(rw).Encode(w.ret) +} + +func TestClient_getCookie(t *testing.T) { + t.Parallel() + + t.Run("no cookies", func(t *testing.T) { + t.Parallel() + + client, err := NewClient( + http.DefaultClient, + apiKey, + new(emptyCookieSourceStub), + "", + ) + require.NoError(t, err) + + _, _, err = client.getCookie(context.Background(), nil) + assert.ErrorIs(t, err, ErrNoAvailableCookies) + }) + + t.Run("no cookies by subject", func(t *testing.T) { + t.Parallel() + + client, err := NewClient( + http.DefaultClient, + apiKey, + new(emptyCookieSourceStub), + "", + ) + require.NoError(t, err) + + subject := "test" + _, _, err = client.getCookie(context.Background(), &subject) + assert.ErrorIs(t, err, ErrNoAvailableCookies) + }) + + t.Run("get any cookie", func(t *testing.T) { + t.Parallel() + + client, err := NewClient( + http.DefaultClient, + apiKey, + new(cookieSourceStub), + "", + ) + require.NoError(t, err) + + id, cookie, err := client.getCookie(context.Background(), nil) + require.NoError(t, err) + assert.Equal(t, uint(42), id) + assert.Equal(t, cookieName, cookie.Name) + assert.Equal(t, "test-token", cookie.Value) + assert.Equal(t, "/", cookie.Path) + assert.Equal(t, time.Unix(0, 0), cookie.Expires) + assert.Equal(t, "investors.com", cookie.Domain) + }) + + t.Run("get cookie by subject", func(t *testing.T) { + t.Parallel() + + client, err := NewClient( + http.DefaultClient, + apiKey, + new(cookieSourceStub), + "", + ) + require.NoError(t, err) + + subject := "test" + id, cookie, err := client.getCookie(context.Background(), &subject) + require.NoError(t, err) + assert.Equal(t, uint(42), id) + assert.Equal(t, cookieName, cookie.Name) + assert.Equal(t, "test-token", cookie.Value) + assert.Equal(t, "/", cookie.Path) + assert.Equal(t, time.Unix(0, 0), cookie.Expires) + assert.Equal(t, "investors.com", cookie.Domain) + }) +} + +type emptyCookieSourceStub struct{} + +func (c *emptyCookieSourceStub) GetAnyCookie(_ context.Context) (*database.IBDCookie, error) { + return nil, nil +} + +func (c *emptyCookieSourceStub) GetCookies(_ context.Context, _ string, _ bool) ([]database.IBDCookie, error) { + return nil, nil +} + +func (c *emptyCookieSourceStub) ReportCookieFailure(_ context.Context, _ uint) error { + return nil +} + +var testCookie = database.IBDCookie{ + ID: 42, + Token: "test-token", + Expiry: time.Unix(0, 0), +} + +type cookieSourceStub struct{} + +func (c *cookieSourceStub) GetAnyCookie(_ context.Context) (*database.IBDCookie, error) { + return &testCookie, nil +} + +func (c *cookieSourceStub) GetCookies(_ context.Context, _ string, _ bool) ([]database.IBDCookie, error) { + return []database.IBDCookie{testCookie}, nil +} + +func (c *cookieSourceStub) ReportCookieFailure(_ context.Context, _ uint) error { + return nil +} |