diff options
Diffstat (limited to 'backend')
-rw-r--r-- | backend/cmd/main.go | 20 | ||||
-rw-r--r-- | backend/go.mod | 1 | ||||
-rw-r--r-- | backend/go.sum | 4 | ||||
-rw-r--r-- | backend/internal/ibd/auth.go | 54 | ||||
-rw-r--r-- | backend/internal/ibd/auth_test.go | 84 | ||||
-rw-r--r-- | backend/internal/ibd/check_ibd_username.go | 19 | ||||
-rw-r--r-- | backend/internal/ibd/client.go | 125 | ||||
-rw-r--r-- | backend/internal/ibd/client_test.go | 154 | ||||
-rw-r--r-- | backend/internal/ibd/ibd50.go | 7 | ||||
-rw-r--r-- | backend/internal/ibd/search.go | 16 | ||||
-rw-r--r-- | backend/internal/ibd/search_test.go | 13 | ||||
-rw-r--r-- | backend/internal/ibd/stockinfo.go | 16 | ||||
-rw-r--r-- | backend/internal/ibd/transport/scrapfly/options.go (renamed from backend/internal/ibd/options.go) | 2 | ||||
-rw-r--r-- | backend/internal/ibd/transport/scrapfly/scraper_types.go (renamed from backend/internal/ibd/scraper_types.go) | 28 | ||||
-rw-r--r-- | backend/internal/ibd/transport/scrapfly/scrapfly.go | 95 | ||||
-rw-r--r-- | backend/internal/ibd/transport/transport.go | 12 | ||||
-rw-r--r-- | backend/internal/ibd/userinfo.go | 17 |
17 files changed, 336 insertions, 331 deletions
diff --git a/backend/cmd/main.go b/backend/cmd/main.go index c5104e8..16fc1ce 100644 --- a/backend/cmd/main.go +++ b/backend/cmd/main.go @@ -6,6 +6,7 @@ import ( "log" "log/slog" "net/http" + "net/url" "os" "os/signal" "time" @@ -15,6 +16,8 @@ import ( "github.com/ansg191/ibd-trader-backend/internal/config" "github.com/ansg191/ibd-trader-backend/internal/database" "github.com/ansg191/ibd-trader-backend/internal/ibd" + "github.com/ansg191/ibd-trader-backend/internal/ibd/transport" + "github.com/ansg191/ibd-trader-backend/internal/ibd/transport/scrapfly" "github.com/ansg191/ibd-trader-backend/internal/keys" "github.com/ansg191/ibd-trader-backend/internal/leader/election" "github.com/ansg191/ibd-trader-backend/internal/leader/manager" @@ -75,7 +78,7 @@ func main() { _ = auth // Setup IBD client - client, err := ibd.NewClient(http.DefaultClient, cfg.IBD.APIKey, db, cfg.IBD.ProxyURL) + client, err := setupIBDClient(cfg, db) if err != nil { log.Fatal("Unable to setup IBD client: ", err) } @@ -143,6 +146,21 @@ func main() { ) } +func setupIBDClient(cfg *config.Config, db database.Database) (*ibd.Client, error) { + pUrl, err := url.Parse(cfg.IBD.ProxyURL) + if err != nil { + return nil, fmt.Errorf("unable to parse proxy URL: %w", err) + } + t := http.DefaultTransport.(*http.Transport).Clone() + t.Proxy = http.ProxyURL(pUrl) + transports := []transport.Transport{ + &http.Client{Transport: t}, // Default proxied transport + scrapfly.New(http.DefaultClient, cfg.IBD.APIKey), // Scrapfly transport + } + client := ibd.NewClient(transports, db) + return client, nil +} + func connectDB(logger *slog.Logger, cfg *config.Config) (database.Database, error) { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() diff --git a/backend/go.mod b/backend/go.mod index a5bf7da..bd55e54 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -13,6 +13,7 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 github.com/coreos/go-oidc/v3 v3.11.0 github.com/golang-migrate/migrate/v4 v4.17.1 + github.com/jarcoal/httpmock v1.3.1 github.com/lib/pq v1.10.9 github.com/lmittmann/tint v1.0.5 github.com/mennanov/fmutils v0.3.0 diff --git a/backend/go.sum b/backend/go.sum index 930c6cc..3f5e220 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -584,6 +584,8 @@ github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= +github.com/jarcoal/httpmock v1.3.1 h1:iUx3whfZWVf3jT01hQTO/Eo5sAYtB2/rqaUuOtpInww= +github.com/jarcoal/httpmock v1.3.1/go.mod h1:3yb8rc4BI7TCBhFY8ng0gjuLKJNquuDNiPaZjnENuYg= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= @@ -601,6 +603,8 @@ github.com/lmittmann/tint v1.0.5 h1:NQclAutOfYsqs2F1Lenue6OoWCajs5wJcP3DfWVpePw= github.com/lmittmann/tint v1.0.5/go.mod h1:HIS3gSy7qNwGCj+5oRjAutErFBl4BzdQP6cJZ0NfMwE= github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= +github.com/maxatome/go-testdeep v1.12.0 h1:Ql7Go8Tg0C1D/uMMX59LAoYK7LffeJQ6X2T04nTH68g= +github.com/maxatome/go-testdeep v1.12.0/go.mod h1:lPZc/HAcJMP92l7yI6TRz1aZN5URwUBUAfUNvrclaNM= github.com/mennanov/fmutils v0.3.0 h1:2YSyrO8oOLQQwB/iKe+xDDGO6xCUHiIAj3gYhY7D4Ao= github.com/mennanov/fmutils v0.3.0/go.mod h1:ph1jsu8gV1gUgMURCmfIVbXKG3O2/O5o/UbPbbqu8zs= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= diff --git a/backend/internal/ibd/auth.go b/backend/internal/ibd/auth.go index 7dff3a7..f09f3f7 100644 --- a/backend/internal/ibd/auth.go +++ b/backend/internal/ibd/auth.go @@ -7,6 +7,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "net/http" "strings" @@ -51,16 +52,23 @@ func (c *Client) getLoginPage(ctx context.Context) (*authConfig, error) { if err != nil { return nil, err } - - if resp.Result.StatusCode != http.StatusOK { + defer func(Body io.ReadCloser) { + _ = Body.Close() + }(resp.Body) + + if resp.StatusCode != http.StatusOK { + content, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } return nil, fmt.Errorf( "unexpected status code %d: %s", - resp.Result.StatusCode, - resp.Result.Content, + resp.StatusCode, + string(content), ) } - node, err := html.Parse(strings.NewReader(resp.Result.Content)) + node, err := html.Parse(resp.Body) if err != nil { return nil, err } @@ -113,18 +121,25 @@ func (c *Client) sendAuthRequest(ctx context.Context, cfg *authConfig, username, if err != nil { return "", "", err } + defer func(Body io.ReadCloser) { + _ = Body.Close() + }(resp.Body) - if resp.Result.StatusCode == http.StatusUnauthorized { + if resp.StatusCode == http.StatusUnauthorized { return "", "", ErrBadCredentials - } else if resp.Result.StatusCode != http.StatusOK { + } else if resp.StatusCode != http.StatusOK { + content, err := io.ReadAll(resp.Body) + if err != nil { + return "", "", fmt.Errorf("failed to read response body: %w", err) + } return "", "", fmt.Errorf( "unexpected status code %d: %s", - resp.Result.StatusCode, - resp.Result.Content, + resp.StatusCode, + string(content), ) } - node, err := html.Parse(strings.NewReader(resp.Result.Content)) + node, err := html.Parse(resp.Body) if err != nil { return "", "", err } @@ -145,19 +160,26 @@ func (c *Client) sendPostAuth(ctx context.Context, token, params string) (*http. if err != nil { return nil, err } - - if resp.Result.StatusCode != http.StatusOK { + defer func(Body io.ReadCloser) { + _ = Body.Close() + }(resp.Body) + + if resp.StatusCode != http.StatusOK { + content, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } return nil, fmt.Errorf( "unexpected status code %d: %s", - resp.Result.StatusCode, - resp.Result.Content, + resp.StatusCode, + string(content), ) } // Extract cookie - for _, cookie := range resp.Result.Cookies { + for _, cookie := range resp.Cookies() { if cookie.Name == cookieName { - return cookie.ToHTTPCookie() + return cookie, nil } } diff --git a/backend/internal/ibd/auth_test.go b/backend/internal/ibd/auth_test.go index d28b33a..8a00d42 100644 --- a/backend/internal/ibd/auth_test.go +++ b/backend/internal/ibd/auth_test.go @@ -9,6 +9,8 @@ import ( "testing" "time" + "github.com/ansg191/ibd-trader-backend/internal/ibd/transport" + "github.com/jarcoal/httpmock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/net/html" @@ -134,44 +136,36 @@ func TestClient_Authenticate(t *testing.T) { expectedVal := "test-cookie" expectedExp := time.Now().Add(time.Hour).Round(time.Second).In(time.UTC) - server := newServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - uri := r.URL.String() - switch uri { - case signInUrl: - w.Header().Set("Content-Type", "text/html") - _, err := w.Write([]byte(extractAuthHtml)) - require.NoError(t, err) - return - case authenticateUrl: + tp := httpmock.NewMockTransport() + tp.RegisterResponder("GET", signInUrl, + httpmock.NewStringResponder(http.StatusOK, extractAuthHtml)) + tp.RegisterResponder("POST", authenticateUrl, + func(request *http.Request) (*http.Response, error) { var body authRequestBody - require.NoError(t, json.NewDecoder(r.Body).Decode(&body)) + require.NoError(t, json.NewDecoder(request.Body).Decode(&body)) assert.Equal(t, "abc", body.Username) assert.Equal(t, "xyz", body.Password) - w.Header().Set("Content-Type", "text/html") - _, err := w.Write([]byte(extractTokenParamsHtml)) - require.NoError(t, err) - return - case postAuthUrl: - require.NoError(t, r.ParseForm()) - assert.Equal(t, extractTokenExpectedToken, r.Form.Get("token")) + return httpmock.NewStringResponse(http.StatusOK, extractTokenParamsHtml), nil + }) + tp.RegisterResponder("POST", postAuthUrl, + func(request *http.Request) (*http.Response, error) { + require.NoError(t, request.ParseForm()) + assert.Equal(t, extractTokenExpectedToken, request.Form.Get("token")) params, err := url.QueryUnescape(extractTokenExpectedParams) require.NoError(t, err) - assert.Equal(t, params, r.Form.Get("params")) + assert.Equal(t, params, request.Form.Get("params")) - w.Header().Set("Content-Type", "text/html") - http.SetCookie(w, &http.Cookie{Name: cookieName, Value: expectedVal, Expires: expectedExp}) - _, err = w.Write([]byte("OK")) - require.NoError(t, err) - return - default: - t.Fatalf("unexpected URL: %s", uri) - } - })) + resp := httpmock.NewStringResponse(http.StatusOK, "OK") + cookie := &http.Cookie{Name: cookieName, Value: expectedVal, Expires: expectedExp} + resp.Header.Set("Set-Cookie", cookie.String()) + return resp, nil + }) - client, err := NewClient(http.DefaultClient, apiKey, nil, "", WithBaseURL(server.URL)) - require.NoError(t, err) + client := NewClient([]transport.Transport{ + &http.Client{Transport: tp}, + }, nil) cookie, err := client.Authenticate(context.Background(), "abc", "xyz") require.NoError(t, err) @@ -184,32 +178,22 @@ func TestClient_Authenticate(t *testing.T) { func TestClient_Authenticate_401(t *testing.T) { t.Parallel() - server := newServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - uri := r.URL.String() - switch uri { - case signInUrl: - w.Header().Set("Content-Type", "text/html") - _, err := w.Write([]byte(extractAuthHtml)) - require.NoError(t, err) - return - case authenticateUrl: + tp := httpmock.NewMockTransport() + tp.RegisterResponder("GET", signInUrl, + httpmock.NewStringResponder(http.StatusOK, extractAuthHtml)) + tp.RegisterResponder("POST", authenticateUrl, + func(request *http.Request) (*http.Response, error) { var body authRequestBody - require.NoError(t, json.NewDecoder(r.Body).Decode(&body)) + require.NoError(t, json.NewDecoder(request.Body).Decode(&body)) assert.Equal(t, "abc", body.Username) assert.Equal(t, "xyz", body.Password) - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusUnauthorized) - _, err := w.Write([]byte(`{"name":"ValidationError","code":"ERR016","message":"Wrong username or password","description":"Wrong username or password"}`)) - require.NoError(t, err) - return - default: - t.Fatalf("unexpected URL: %s", uri) - } - })) + return httpmock.NewStringResponse(http.StatusUnauthorized, `{"name":"ValidationError","code":"ERR016","message":"Wrong username or password","description":"Wrong username or password"}`), nil + }) - client, err := NewClient(http.DefaultClient, apiKey, nil, "", WithBaseURL(server.URL)) - require.NoError(t, err) + client := NewClient([]transport.Transport{ + &http.Client{Transport: tp}, + }, nil) cookie, err := client.Authenticate(context.Background(), "abc", "xyz") assert.Nil(t, cookie) diff --git a/backend/internal/ibd/check_ibd_username.go b/backend/internal/ibd/check_ibd_username.go index 03d2640..c03176e 100644 --- a/backend/internal/ibd/check_ibd_username.go +++ b/backend/internal/ibd/check_ibd_username.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "fmt" + "io" "net/http" ) @@ -41,18 +42,26 @@ func (c *Client) checkIBDUsername(ctx context.Context, cfg *authConfig, username req.Header.Set("X-REQUEST-EDITIONID", "IBD-EN_US") req.Header.Set("X-REQUEST-SCHEME", "https") - resp, err := c.Do(req) + resp, err := c.DoWithStatus(req, []int{http.StatusOK, http.StatusUnauthorized}) if err != nil { return false, err } + defer func(Body io.ReadCloser) { + _ = Body.Close() + }(resp.Body) - if resp.Result.StatusCode == http.StatusUnauthorized { + if resp.StatusCode == http.StatusUnauthorized { return false, nil - } else if resp.Result.StatusCode != http.StatusOK { + } else if resp.StatusCode != http.StatusOK { + contentBytes, err := io.ReadAll(resp.Body) + if err != nil { + return false, fmt.Errorf("failed to read response body: %w", err) + } + content := string(contentBytes) return false, fmt.Errorf( "unexpected status code %d: %s", - resp.Result.StatusCode, - resp.Result.Content, + resp.StatusCode, + content, ) } return true, nil diff --git a/backend/internal/ibd/client.go b/backend/internal/ibd/client.go index c1cbb8a..2b91268 100644 --- a/backend/internal/ibd/client.go +++ b/backend/internal/ibd/client.go @@ -2,56 +2,28 @@ package ibd import ( "context" - "encoding/json" "errors" - "fmt" - "io" + "log/slog" "net/http" - "net/url" - "strconv" + "slices" "github.com/ansg191/ibd-trader-backend/internal/database" + "github.com/ansg191/ibd-trader-backend/internal/ibd/transport" ) var ErrNoAvailableCookies = errors.New("no available cookies") +var ErrNoAvailableTransports = errors.New("no available transports") type Client struct { - // HTTP client used to make requests - client *http.Client - // Scrapfly API key - apiKey string - // Client-wide Scrape options - options ScrapeOptions - // Cookie source - cookies database.CookieSource - // Proxy URL for non-scrapfly requests - proxyUrl *url.URL + transports []transport.Transport + cookies database.CookieSource } func NewClient( - client *http.Client, - apiKey string, + transports []transport.Transport, cookies database.CookieSource, - proxyUrl string, - opts ...ScrapeOption, -) (*Client, error) { - options := defaultScrapeOptions - for _, opt := range opts { - opt(&options) - } - - pProxyUrl, err := url.Parse(proxyUrl) - if err != nil { - return nil, err - } - - return &Client{ - client: client, - options: options, - apiKey: apiKey, - cookies: cookies, - proxyUrl: pProxyUrl, - }, nil +) *Client { + return &Client{transports, cookies} } func (c *Client) getCookie(ctx context.Context, subject *string) (uint, *http.Cookie, error) { @@ -83,64 +55,35 @@ func (c *Client) getCookie(ctx context.Context, subject *string) (uint, *http.Co return cookie.ID, cookie.ToHTTPCookie(), nil } -func (c *Client) Do(req *http.Request, opts ...ScrapeOption) (*ScraperResponse, error) { - options := c.options - for _, opt := range opts { - opt(&options) - } - - // Construct scrape request URL - scrapeUrl, err := url.Parse(options.baseURL) - if err != nil { - panic(err) - } - scrapeUrl.RawQuery = c.constructRawQuery(options, req.URL, req.Header) - - // Construct scrape request - scrapeReq, err := http.NewRequestWithContext(req.Context(), req.Method, scrapeUrl.String(), req.Body) - if err != nil { - return nil, err - } - - // Send scrape request - resp, err := c.client.Do(scrapeReq) - if err != nil { - return nil, err - } - defer func(Body io.ReadCloser) { - _ = Body.Close() - }(resp.Body) - - // Parse scrape response - scraperResponse := new(ScraperResponse) - err = json.NewDecoder(resp.Body).Decode(scraperResponse) - if err != nil { - return nil, err - } - - return scraperResponse, nil +func (c *Client) Do(req *http.Request) (*http.Response, error) { + return c.DoWithStatus(req, []int{http.StatusOK}) } -func (c *Client) constructRawQuery(options ScrapeOptions, u *url.URL, headers http.Header) string { - params := url.Values{} - params.Set("key", c.apiKey) - params.Set("url", u.String()) - if options.country != nil { - params.Set("country", *options.country) - } - params.Set("asp", strconv.FormatBool(options.asp)) - params.Set("proxy_pool", options.proxyPool.String()) - params.Set("render_js", strconv.FormatBool(options.renderJS)) - params.Set("cache", strconv.FormatBool(options.cache)) - - for k, v := range headers { - for i, vv := range v { - params.Add( - fmt.Sprintf("headers[%s][%d]", k, i), - vv, +func (c *Client) DoWithStatus(req *http.Request, expectedStatus []int) (*http.Response, error) { + for i, tp := range c.transports { + resp, err := tp.Do(req) + if errors.Is(err, transport.ErrUnsupportedRequest) { + // Skip unsupported transport + continue + } + if err != nil { + slog.ErrorContext(req.Context(), "transport error", + "transport", i, + "error", err, + ) + continue + } + if slices.Contains(expectedStatus, resp.StatusCode) { + return resp, nil + } else { + slog.ErrorContext(req.Context(), "unexpected status code", + "transport", i, + "expected", expectedStatus, + "actual", resp.StatusCode, ) + continue } } - return params.Encode() + return nil, ErrNoAvailableTransports } diff --git a/backend/internal/ibd/client_test.go b/backend/internal/ibd/client_test.go index 048ad59..0a8fa98 100644 --- a/backend/internal/ibd/client_test.go +++ b/backend/internal/ibd/client_test.go @@ -1,177 +1,50 @@ package ibd import ( - "bytes" "context" - "encoding/json" - "net/http" - "net/http/httptest" - "strings" "testing" "time" "github.com/ansg191/ibd-trader-backend/internal/database" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -const apiKey = "test-api-key-123" - -func newServer(t *testing.T, handler http.Handler) *httptest.Server { - t.Helper() - - return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - r.Close = true - defer r.Body.Close() - req := reconstructReq(t, r) - - rw := newResponseWriter() - handler.ServeHTTP(rw, req) - require.NoError(t, rw.Done(w)) - })) -} - -func reconstructReq(t *testing.T, r *http.Request) *http.Request { - t.Helper() - - params := r.URL.Query() - require.Equal(t, apiKey, params.Get("key")) - - // Reconstruct the request from the query params - var key string - var url string - headers := make(http.Header) - for k, v := range params { - switch k { - case "key": - key = v[0] - case "url": - url = v[0] - default: - if strings.HasPrefix(k, "headers") { - var name string - // Get index of first [ - i := strings.Index(k, "[") - if i == -1 { - t.Fatalf("invalid header key: %s", k) - } - // Get index of first ] - j := strings.Index(k, "]") - if j == -1 { - t.Fatalf("invalid header key: %s", k) - } - - // Get the header name - name = k[i+1 : j] - headers.Set(name, v[0]) - } - } - } - require.Equal(t, apiKey, key) - require.NotEmpty(t, url) - - req, err := http.NewRequestWithContext(r.Context(), r.Method, url, r.Body) - require.NoError(t, err) - req.Header = headers - - return req -} - -type responsewriter struct { - ret ScraperResponse - body bytes.Buffer - headers http.Header -} - -func newResponseWriter() *responsewriter { - return &responsewriter{ - headers: make(http.Header), - } -} - -func (w *responsewriter) Header() http.Header { - return w.headers -} - -func (w *responsewriter) Write(bytes []byte) (int, error) { - if w.ret.Result.StatusCode == 0 { - w.ret.Result.StatusCode = http.StatusOK - } - return w.body.Write(bytes) -} - -func (w *responsewriter) WriteHeader(statusCode int) { - w.ret.Result.StatusCode = statusCode -} - -func (w *responsewriter) Done(rw http.ResponseWriter) error { - w.ret.Result.Content = w.body.String() - - w.ret.Result.ResponseHeaders = make(map[string]string) - for k, v := range w.headers { - if k == "Set-Cookie" { - continue - } - w.ret.Result.ResponseHeaders[k] = v[0] - } - - req := http.Response{Header: w.headers} - w.ret.Result.Cookies = make([]ScraperCookie, 0) - for _, c := range req.Cookies() { - var cookie ScraperCookie - cookie.FromHTTPCookie(c) - w.ret.Result.Cookies = append(w.ret.Result.Cookies, cookie) - } - - rw.WriteHeader(http.StatusOK) - return json.NewEncoder(rw).Encode(w.ret) -} - func TestClient_getCookie(t *testing.T) { t.Parallel() t.Run("no cookies", func(t *testing.T) { t.Parallel() - client, err := NewClient( - http.DefaultClient, - apiKey, + client := NewClient( + nil, new(emptyCookieSourceStub), - "", ) - require.NoError(t, err) - _, _, err = client.getCookie(context.Background(), nil) + _, _, err := client.getCookie(context.Background(), nil) assert.ErrorIs(t, err, ErrNoAvailableCookies) }) t.Run("no cookies by subject", func(t *testing.T) { t.Parallel() - client, err := NewClient( - http.DefaultClient, - apiKey, + client := NewClient( + nil, new(emptyCookieSourceStub), - "", ) - require.NoError(t, err) subject := "test" - _, _, err = client.getCookie(context.Background(), &subject) + _, _, err := client.getCookie(context.Background(), &subject) assert.ErrorIs(t, err, ErrNoAvailableCookies) }) t.Run("get any cookie", func(t *testing.T) { t.Parallel() - client, err := NewClient( - http.DefaultClient, - apiKey, - new(cookieSourceStub), - "", + client := NewClient( + nil, + new(emptyCookieSourceStub), ) - require.NoError(t, err) id, cookie, err := client.getCookie(context.Background(), nil) require.NoError(t, err) @@ -186,13 +59,10 @@ func TestClient_getCookie(t *testing.T) { t.Run("get cookie by subject", func(t *testing.T) { t.Parallel() - client, err := NewClient( - http.DefaultClient, - apiKey, - new(cookieSourceStub), - "", + client := NewClient( + nil, + new(emptyCookieSourceStub), ) - require.NoError(t, err) subject := "test" id, cookie, err := client.getCookie(context.Background(), &subject) diff --git a/backend/internal/ibd/ibd50.go b/backend/internal/ibd/ibd50.go index 93aa31d..ea02f82 100644 --- a/backend/internal/ibd/ibd50.go +++ b/backend/internal/ibd/ibd50.go @@ -47,12 +47,7 @@ func (c *Client) GetIBD50(ctx context.Context) ([]*Stock, error) { req.Header.Add("x-newrelic-id", "VwUOV1dTDhABV1FRBgQOVVUF") req.Header.Add("x-requested-with", "XMLHttpRequest") - // Clone client to add proxy - client := *(c.client) - transport := http.DefaultTransport.(*http.Transport).Clone() - transport.Proxy = http.ProxyURL(c.proxyUrl) - - resp, err := client.Do(req) + resp, err := c.Do(req) if err != nil { return nil, err } diff --git a/backend/internal/ibd/search.go b/backend/internal/ibd/search.go index 23ef08b..341b14b 100644 --- a/backend/internal/ibd/search.go +++ b/backend/internal/ibd/search.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "io" "net/http" "net/url" "time" @@ -37,17 +38,24 @@ func (c *Client) Search(ctx context.Context, symbol string) (database.Stock, err if err != nil { return database.Stock{}, err } + defer func(Body io.ReadCloser) { + _ = Body.Close() + }(resp.Body) - if resp.Result.StatusCode != http.StatusOK { + if resp.StatusCode != http.StatusOK { + content, err := io.ReadAll(resp.Body) + if err != nil { + return database.Stock{}, fmt.Errorf("failed to read response body: %w", err) + } return database.Stock{}, fmt.Errorf( "unexpected status code %d: %s", - resp.Result.StatusCode, - resp.Result.Content, + resp.StatusCode, + string(content), ) } var sr searchResponse - if err = json.Unmarshal([]byte(resp.Result.Content), &sr); err != nil { + if err = json.NewDecoder(resp.Body).Decode(&sr); err != nil { return database.Stock{}, err } diff --git a/backend/internal/ibd/search_test.go b/backend/internal/ibd/search_test.go index ac0f578..f291033 100644 --- a/backend/internal/ibd/search_test.go +++ b/backend/internal/ibd/search_test.go @@ -5,6 +5,8 @@ import ( "net/http" "testing" + "github.com/ansg191/ibd-trader-backend/internal/ibd/transport" + "github.com/jarcoal/httpmock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -190,13 +192,12 @@ func TestClient_Search(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - server := newServer(t, http.HandlerFunc(func(writer http.ResponseWriter, r *http.Request) { - _, _ = writer.Write([]byte(tt.response)) - })) - defer server.Close() + tp := httpmock.NewMockTransport() + tp.RegisterResponder("GET", searchUrl, httpmock.NewStringResponder(200, tt.response)) - client, err := NewClient(http.DefaultClient, apiKey, new(cookieSourceStub), "", WithBaseURL(server.URL)) - require.NoError(t, err) + client := NewClient([]transport.Transport{ + &http.Client{Transport: tp}, + }, new(cookieSourceStub)) tt.f(t, client) }) diff --git a/backend/internal/ibd/stockinfo.go b/backend/internal/ibd/stockinfo.go index e278872..9caa956 100644 --- a/backend/internal/ibd/stockinfo.go +++ b/backend/internal/ibd/stockinfo.go @@ -3,6 +3,7 @@ package ibd import ( "context" "fmt" + "io" "net/http" "net/url" "strconv" @@ -37,16 +38,23 @@ func (c *Client) StockInfo(ctx context.Context, uri string) (*database.StockInfo if err != nil { return nil, err } + defer func(Body io.ReadCloser) { + _ = Body.Close() + }(resp.Body) - if resp.Result.StatusCode != http.StatusOK { + if resp.StatusCode != http.StatusOK { + content, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } return nil, fmt.Errorf( "unexpected status code %d: %s", - resp.Result.StatusCode, - resp.Result.Content, + resp.StatusCode, + string(content), ) } - node, err := html.Parse(strings.NewReader(resp.Result.Content)) + node, err := html.Parse(resp.Body) if err != nil { return nil, err } diff --git a/backend/internal/ibd/options.go b/backend/internal/ibd/transport/scrapfly/options.go index a07241e..f16a4b0 100644 --- a/backend/internal/ibd/options.go +++ b/backend/internal/ibd/transport/scrapfly/options.go @@ -1,4 +1,4 @@ -package ibd +package scrapfly const BaseURL = "https://api.scrapfly.io/scrape" diff --git a/backend/internal/ibd/scraper_types.go b/backend/internal/ibd/transport/scrapfly/scraper_types.go index c21ed1c..f3cf651 100644 --- a/backend/internal/ibd/scraper_types.go +++ b/backend/internal/ibd/transport/scrapfly/scraper_types.go @@ -1,8 +1,10 @@ -package ibd +package scrapfly import ( "fmt" + "io" "net/http" + "strings" "time" ) @@ -225,3 +227,27 @@ func (c *ScraperCookie) FromHTTPCookie(cookie *http.Cookie) { Version: "", } } + +func (r *ScraperResponse) ToHTTPResponse() (*http.Response, error) { + resp := &http.Response{ + StatusCode: r.Result.StatusCode, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(r.Result.Content)), + ContentLength: int64(len(r.Result.Content)), + Close: true, + } + + for k, v := range r.Result.ResponseHeaders { + resp.Header.Set(k, v) + } + + for _, c := range r.Result.Cookies { + cookie, err := c.ToHTTPCookie() + if err != nil { + return nil, err + } + resp.Header.Add("Set-Cookie", cookie.String()) + } + + return resp, nil +} diff --git a/backend/internal/ibd/transport/scrapfly/scrapfly.go b/backend/internal/ibd/transport/scrapfly/scrapfly.go new file mode 100644 index 0000000..f34f3aa --- /dev/null +++ b/backend/internal/ibd/transport/scrapfly/scrapfly.go @@ -0,0 +1,95 @@ +package scrapfly + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strconv" + + "github.com/ansg191/ibd-trader-backend/internal/ibd/transport" +) + +type ScrapflyTransport struct { + client *http.Client + apiKey string + options ScrapeOptions +} + +func New(client *http.Client, apiKey string, opts ...ScrapeOption) *ScrapflyTransport { + options := defaultScrapeOptions + for _, opt := range opts { + opt(&options) + } + + return &ScrapflyTransport{ + client: client, + apiKey: apiKey, + options: options, + } +} + +func (s *ScrapflyTransport) Do(req *http.Request) (*http.Response, error) { + // Construct scrape request URL + scrapeUrl, err := url.Parse(s.options.baseURL) + if err != nil { + panic(err) + } + scrapeUrl.RawQuery = s.constructRawQuery(req.URL, req.Header) + + // We can't handle `Content-Type` header on GET requests + // Wierd quirk of the Scrapfly API + if req.Method == http.MethodGet && req.Header.Get("Content-Type") != "" { + return nil, transport.ErrUnsupportedRequest + } + + // Construct scrape request + scrapeReq, err := http.NewRequestWithContext(req.Context(), req.Method, scrapeUrl.String(), req.Body) + if err != nil { + return nil, err + } + + // Send scrape request + resp, err := s.client.Do(scrapeReq) + if err != nil { + return nil, err + } + defer func(Body io.ReadCloser) { + _ = Body.Close() + }(resp.Body) + + // Parse scrape response + scraperResponse := new(ScraperResponse) + err = json.NewDecoder(resp.Body).Decode(scraperResponse) + if err != nil { + return nil, err + } + + // Convert scraper response to http.Response + return scraperResponse.ToHTTPResponse() +} + +func (s *ScrapflyTransport) constructRawQuery(u *url.URL, headers http.Header) string { + params := url.Values{} + params.Set("key", s.apiKey) + params.Set("url", u.String()) + if s.options.country != nil { + params.Set("country", *s.options.country) + } + params.Set("asp", strconv.FormatBool(s.options.asp)) + params.Set("proxy_pool", s.options.proxyPool.String()) + params.Set("render_js", strconv.FormatBool(s.options.renderJS)) + params.Set("cache", strconv.FormatBool(s.options.cache)) + + for k, v := range headers { + for i, vv := range v { + params.Add( + fmt.Sprintf("headers[%s][%d]", k, i), + vv, + ) + } + } + + return params.Encode() +} diff --git a/backend/internal/ibd/transport/transport.go b/backend/internal/ibd/transport/transport.go new file mode 100644 index 0000000..7918f4f --- /dev/null +++ b/backend/internal/ibd/transport/transport.go @@ -0,0 +1,12 @@ +package transport + +import ( + "errors" + "net/http" +) + +var ErrUnsupportedRequest = errors.New("unsupported request") + +type Transport interface { + Do(req *http.Request) (*http.Response, error) +} diff --git a/backend/internal/ibd/userinfo.go b/backend/internal/ibd/userinfo.go index ba7a5b5..ed61497 100644 --- a/backend/internal/ibd/userinfo.go +++ b/backend/internal/ibd/userinfo.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "io" "log/slog" "net/http" ) @@ -24,17 +25,25 @@ func (c *Client) UserInfo(ctx context.Context, cookie *http.Cookie) (*UserProfil if err != nil { return nil, err } + defer func(Body io.ReadCloser) { + _ = Body.Close() + }(resp.Body) - if resp.Result.StatusCode != http.StatusOK { + content, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf( "unexpected status code %d: %s", - resp.Result.StatusCode, - resp.Result.Content, + resp.StatusCode, + string(content), ) } up := new(UserProfile) - if err = up.UnmarshalJSON([]byte(resp.Result.Content)); err != nil { + if err = up.UnmarshalJSON(content); err != nil { return nil, err } |