package ibd import ( "bytes" "context" "encoding/json" "net/http" "net/http/httptest" "strings" "testing" "time" "github.com/ansg191/ibd-trader-backend/internal/database" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) const apiKey = "test-api-key-123" func newServer(t *testing.T, handler http.Handler) *httptest.Server { t.Helper() return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { r.Close = true defer r.Body.Close() req := reconstructReq(t, r) rw := newResponseWriter() handler.ServeHTTP(rw, req) require.NoError(t, rw.Done(w)) })) } func reconstructReq(t *testing.T, r *http.Request) *http.Request { t.Helper() params := r.URL.Query() require.Equal(t, apiKey, params.Get("key")) // Reconstruct the request from the query params var key string var url string headers := make(http.Header) for k, v := range params { switch k { case "key": key = v[0] case "url": url = v[0] default: if strings.HasPrefix(k, "headers") { var name string // Get index of first [ i := strings.Index(k, "[") if i == -1 { t.Fatalf("invalid header key: %s", k) } // Get index of first ] j := strings.Index(k, "]") if j == -1 { t.Fatalf("invalid header key: %s", k) } // Get the header name name = k[i+1 : j] headers.Set(name, v[0]) } } } require.Equal(t, apiKey, key) require.NotEmpty(t, url) req, err := http.NewRequestWithContext(r.Context(), r.Method, url, r.Body) require.NoError(t, err) req.Header = headers return req } type responsewriter struct { ret ScraperResponse body bytes.Buffer headers http.Header } func newResponseWriter() *responsewriter { return &responsewriter{ headers: make(http.Header), } } func (w *responsewriter) Header() http.Header { return w.headers } func (w *responsewriter) Write(bytes []byte) (int, error) { if w.ret.Result.StatusCode == 0 { w.ret.Result.StatusCode = http.StatusOK } return w.body.Write(bytes) } func (w *responsewriter) WriteHeader(statusCode int) { w.ret.Result.StatusCode = statusCode } func (w *responsewriter) Done(rw http.ResponseWriter) error { w.ret.Result.Content = w.body.String() w.ret.Result.ResponseHeaders = make(map[string]string) for k, v := range w.headers { if k == "Set-Cookie" { continue } w.ret.Result.ResponseHeaders[k] = v[0] } req := http.Response{Header: w.headers} w.ret.Result.Cookies = make([]ScraperCookie, 0) for _, c := range req.Cookies() { var cookie ScraperCookie cookie.FromHTTPCookie(c) w.ret.Result.Cookies = append(w.ret.Result.Cookies, cookie) } rw.WriteHeader(http.StatusOK) return json.NewEncoder(rw).Encode(w.ret) } func TestClient_getCookie(t *testing.T) { t.Parallel() t.Run("no cookies", func(t *testing.T) { t.Parallel() client, err := NewClient( http.DefaultClient, apiKey, new(emptyCookieSourceStub), "", ) require.NoError(t, err) _, _, err = client.getCookie(context.Background(), nil) assert.ErrorIs(t, err, ErrNoAvailableCookies) }) t.Run("no cookies by subject", func(t *testing.T) { t.Parallel() client, err := NewClient( http.DefaultClient, apiKey, new(emptyCookieSourceStub), "", ) require.NoError(t, err) subject := "test" _, _, err = client.getCookie(context.Background(), &subject) assert.ErrorIs(t, err, ErrNoAvailableCookies) }) t.Run("get any cookie", func(t *testing.T) { t.Parallel() client, err := NewClient( http.DefaultClient, apiKey, new(cookieSourceStub), "", ) require.NoError(t, err) id, cookie, err := client.getCookie(context.Background(), nil) require.NoError(t, err) assert.Equal(t, uint(42), id) assert.Equal(t, cookieName, cookie.Name) assert.Equal(t, "test-token", cookie.Value) assert.Equal(t, "/", cookie.Path) assert.Equal(t, time.Unix(0, 0), cookie.Expires) assert.Equal(t, "investors.com", cookie.Domain) }) t.Run("get cookie by subject", func(t *testing.T) { t.Parallel() client, err := NewClient( http.DefaultClient, apiKey, new(cookieSourceStub), "", ) require.NoError(t, err) subject := "test" id, cookie, err := client.getCookie(context.Background(), &subject) require.NoError(t, err) assert.Equal(t, uint(42), id) assert.Equal(t, cookieName, cookie.Name) assert.Equal(t, "test-token", cookie.Value) assert.Equal(t, "/", cookie.Path) assert.Equal(t, time.Unix(0, 0), cookie.Expires) assert.Equal(t, "investors.com", cookie.Domain) }) } type emptyCookieSourceStub struct{} func (c *emptyCookieSourceStub) GetAnyCookie(_ context.Context) (*database.IBDCookie, error) { return nil, nil } func (c *emptyCookieSourceStub) GetCookies(_ context.Context, _ string, _ bool) ([]database.IBDCookie, error) { return nil, nil } func (c *emptyCookieSourceStub) ReportCookieFailure(_ context.Context, _ uint) error { return nil } var testCookie = database.IBDCookie{ ID: 42, Token: "test-token", Expiry: time.Unix(0, 0), } type cookieSourceStub struct{} func (c *cookieSourceStub) GetAnyCookie(_ context.Context) (*database.IBDCookie, error) { return &testCookie, nil } func (c *cookieSourceStub) GetCookies(_ context.Context, _ string, _ bool) ([]database.IBDCookie, error) { return []database.IBDCookie{testCookie}, nil } func (c *cookieSourceStub) ReportCookieFailure(_ context.Context, _ uint) error { return nil }