aboutsummaryrefslogtreecommitdiff
path: root/backend/internal/ibd/client_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'backend/internal/ibd/client_test.go')
-rw-r--r--backend/internal/ibd/client_test.go241
1 files changed, 241 insertions, 0 deletions
diff --git a/backend/internal/ibd/client_test.go b/backend/internal/ibd/client_test.go
new file mode 100644
index 0000000..577987d
--- /dev/null
+++ b/backend/internal/ibd/client_test.go
@@ -0,0 +1,241 @@
+package ibd
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+ "time"
+
+ "ibd-trader/internal/database"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+const apiKey = "test-api-key-123"
+
+func newServer(t *testing.T, handler http.Handler) *httptest.Server {
+ t.Helper()
+
+ return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ r.Close = true
+ defer r.Body.Close()
+ req := reconstructReq(t, r)
+
+ rw := newResponseWriter()
+ handler.ServeHTTP(rw, req)
+ require.NoError(t, rw.Done(w))
+ }))
+}
+
+func reconstructReq(t *testing.T, r *http.Request) *http.Request {
+ t.Helper()
+
+ params := r.URL.Query()
+ require.Equal(t, apiKey, params.Get("key"))
+
+ // Reconstruct the request from the query params
+ var key string
+ var url string
+ headers := make(http.Header)
+ for k, v := range params {
+ switch k {
+ case "key":
+ key = v[0]
+ case "url":
+ url = v[0]
+ default:
+ if strings.HasPrefix(k, "headers") {
+ var name string
+ // Get index of first [
+ i := strings.Index(k, "[")
+ if i == -1 {
+ t.Fatalf("invalid header key: %s", k)
+ }
+ // Get index of first ]
+ j := strings.Index(k, "]")
+ if j == -1 {
+ t.Fatalf("invalid header key: %s", k)
+ }
+
+ // Get the header name
+ name = k[i+1 : j]
+ headers.Set(name, v[0])
+ }
+ }
+ }
+ require.Equal(t, apiKey, key)
+ require.NotEmpty(t, url)
+
+ req, err := http.NewRequestWithContext(r.Context(), r.Method, url, r.Body)
+ require.NoError(t, err)
+ req.Header = headers
+
+ return req
+}
+
+type responsewriter struct {
+ ret ScraperResponse
+ body bytes.Buffer
+ headers http.Header
+}
+
+func newResponseWriter() *responsewriter {
+ return &responsewriter{
+ headers: make(http.Header),
+ }
+}
+
+func (w *responsewriter) Header() http.Header {
+ return w.headers
+}
+
+func (w *responsewriter) Write(bytes []byte) (int, error) {
+ if w.ret.Result.StatusCode == 0 {
+ w.ret.Result.StatusCode = http.StatusOK
+ }
+ return w.body.Write(bytes)
+}
+
+func (w *responsewriter) WriteHeader(statusCode int) {
+ w.ret.Result.StatusCode = statusCode
+}
+
+func (w *responsewriter) Done(rw http.ResponseWriter) error {
+ w.ret.Result.Content = w.body.String()
+
+ w.ret.Result.ResponseHeaders = make(map[string]string)
+ for k, v := range w.headers {
+ if k == "Set-Cookie" {
+ continue
+ }
+ w.ret.Result.ResponseHeaders[k] = v[0]
+ }
+
+ req := http.Response{Header: w.headers}
+ w.ret.Result.Cookies = make([]ScraperCookie, 0)
+ for _, c := range req.Cookies() {
+ var cookie ScraperCookie
+ cookie.FromHTTPCookie(c)
+ w.ret.Result.Cookies = append(w.ret.Result.Cookies, cookie)
+ }
+
+ rw.WriteHeader(http.StatusOK)
+ return json.NewEncoder(rw).Encode(w.ret)
+}
+
+func TestClient_getCookie(t *testing.T) {
+ t.Parallel()
+
+ t.Run("no cookies", func(t *testing.T) {
+ t.Parallel()
+
+ client, err := NewClient(
+ http.DefaultClient,
+ apiKey,
+ new(emptyCookieSourceStub),
+ "",
+ )
+ require.NoError(t, err)
+
+ _, _, err = client.getCookie(context.Background(), nil)
+ assert.ErrorIs(t, err, ErrNoAvailableCookies)
+ })
+
+ t.Run("no cookies by subject", func(t *testing.T) {
+ t.Parallel()
+
+ client, err := NewClient(
+ http.DefaultClient,
+ apiKey,
+ new(emptyCookieSourceStub),
+ "",
+ )
+ require.NoError(t, err)
+
+ subject := "test"
+ _, _, err = client.getCookie(context.Background(), &subject)
+ assert.ErrorIs(t, err, ErrNoAvailableCookies)
+ })
+
+ t.Run("get any cookie", func(t *testing.T) {
+ t.Parallel()
+
+ client, err := NewClient(
+ http.DefaultClient,
+ apiKey,
+ new(cookieSourceStub),
+ "",
+ )
+ require.NoError(t, err)
+
+ id, cookie, err := client.getCookie(context.Background(), nil)
+ require.NoError(t, err)
+ assert.Equal(t, uint(42), id)
+ assert.Equal(t, cookieName, cookie.Name)
+ assert.Equal(t, "test-token", cookie.Value)
+ assert.Equal(t, "/", cookie.Path)
+ assert.Equal(t, time.Unix(0, 0), cookie.Expires)
+ assert.Equal(t, "investors.com", cookie.Domain)
+ })
+
+ t.Run("get cookie by subject", func(t *testing.T) {
+ t.Parallel()
+
+ client, err := NewClient(
+ http.DefaultClient,
+ apiKey,
+ new(cookieSourceStub),
+ "",
+ )
+ require.NoError(t, err)
+
+ subject := "test"
+ id, cookie, err := client.getCookie(context.Background(), &subject)
+ require.NoError(t, err)
+ assert.Equal(t, uint(42), id)
+ assert.Equal(t, cookieName, cookie.Name)
+ assert.Equal(t, "test-token", cookie.Value)
+ assert.Equal(t, "/", cookie.Path)
+ assert.Equal(t, time.Unix(0, 0), cookie.Expires)
+ assert.Equal(t, "investors.com", cookie.Domain)
+ })
+}
+
+type emptyCookieSourceStub struct{}
+
+func (c *emptyCookieSourceStub) GetAnyCookie(_ context.Context) (*database.IBDCookie, error) {
+ return nil, nil
+}
+
+func (c *emptyCookieSourceStub) GetCookies(_ context.Context, _ string, _ bool) ([]database.IBDCookie, error) {
+ return nil, nil
+}
+
+func (c *emptyCookieSourceStub) ReportCookieFailure(_ context.Context, _ uint) error {
+ return nil
+}
+
+var testCookie = database.IBDCookie{
+ ID: 42,
+ Token: "test-token",
+ Expiry: time.Unix(0, 0),
+}
+
+type cookieSourceStub struct{}
+
+func (c *cookieSourceStub) GetAnyCookie(_ context.Context) (*database.IBDCookie, error) {
+ return &testCookie, nil
+}
+
+func (c *cookieSourceStub) GetCookies(_ context.Context, _ string, _ bool) ([]database.IBDCookie, error) {
+ return []database.IBDCookie{testCookie}, nil
+}
+
+func (c *cookieSourceStub) ReportCookieFailure(_ context.Context, _ uint) error {
+ return nil
+}