diff options
author | 2024-08-06 17:16:35 -0700 | |
---|---|---|
committer | 2024-08-06 17:16:35 -0700 | |
commit | 961f9e0a76c3cfe9ae92ca8da0531790e0610b69 (patch) | |
tree | f6de4ed36c3f48ee94ecd524dedeb0d7c84b72e5 /backend | |
parent | 641c81198d7fed7138bb482f226e54bd703094ab (diff) | |
download | ibd-trader-961f9e0a76c3cfe9ae92ca8da0531790e0610b69.tar.gz ibd-trader-961f9e0a76c3cfe9ae92ca8da0531790e0610b69.tar.zst ibd-trader-961f9e0a76c3cfe9ae92ca8da0531790e0610b69.zip |
Modify IBD to accept various transport backends
This allows IBD to try using faster and cheaper transports first with
fallback to more reliable and expensive transports later.
Diffstat (limited to 'backend')
-rw-r--r-- | backend/cmd/main.go | 20 | ||||
-rw-r--r-- | backend/go.mod | 1 | ||||
-rw-r--r-- | backend/go.sum | 4 | ||||
-rw-r--r-- | backend/internal/ibd/auth.go | 54 | ||||
-rw-r--r-- | backend/internal/ibd/auth_test.go | 84 | ||||
-rw-r--r-- | backend/internal/ibd/check_ibd_username.go | 19 | ||||
-rw-r--r-- | backend/internal/ibd/client.go | 125 | ||||
-rw-r--r-- | backend/internal/ibd/client_test.go | 154 | ||||
-rw-r--r-- | backend/internal/ibd/ibd50.go | 7 | ||||
-rw-r--r-- | backend/internal/ibd/search.go | 16 | ||||
-rw-r--r-- | backend/internal/ibd/search_test.go | 13 | ||||
-rw-r--r-- | backend/internal/ibd/stockinfo.go | 16 | ||||
-rw-r--r-- | backend/internal/ibd/transport/scrapfly/options.go (renamed from backend/internal/ibd/options.go) | 2 | ||||
-rw-r--r-- | backend/internal/ibd/transport/scrapfly/scraper_types.go (renamed from backend/internal/ibd/scraper_types.go) | 28 | ||||
-rw-r--r-- | backend/internal/ibd/transport/scrapfly/scrapfly.go | 95 | ||||
-rw-r--r-- | backend/internal/ibd/transport/transport.go | 12 | ||||
-rw-r--r-- | backend/internal/ibd/userinfo.go | 17 |
17 files changed, 336 insertions, 331 deletions
diff --git a/backend/cmd/main.go b/backend/cmd/main.go index c5104e8..16fc1ce 100644 --- a/backend/cmd/main.go +++ b/backend/cmd/main.go @@ -6,6 +6,7 @@ import ( "log" "log/slog" "net/http" + "net/url" "os" "os/signal" "time" @@ -15,6 +16,8 @@ import ( "github.com/ansg191/ibd-trader-backend/internal/config" "github.com/ansg191/ibd-trader-backend/internal/database" "github.com/ansg191/ibd-trader-backend/internal/ibd" + "github.com/ansg191/ibd-trader-backend/internal/ibd/transport" + "github.com/ansg191/ibd-trader-backend/internal/ibd/transport/scrapfly" "github.com/ansg191/ibd-trader-backend/internal/keys" "github.com/ansg191/ibd-trader-backend/internal/leader/election" "github.com/ansg191/ibd-trader-backend/internal/leader/manager" @@ -75,7 +78,7 @@ func main() { _ = auth // Setup IBD client - client, err := ibd.NewClient(http.DefaultClient, cfg.IBD.APIKey, db, cfg.IBD.ProxyURL) + client, err := setupIBDClient(cfg, db) if err != nil { log.Fatal("Unable to setup IBD client: ", err) } @@ -143,6 +146,21 @@ func main() { ) } +func setupIBDClient(cfg *config.Config, db database.Database) (*ibd.Client, error) { + pUrl, err := url.Parse(cfg.IBD.ProxyURL) + if err != nil { + return nil, fmt.Errorf("unable to parse proxy URL: %w", err) + } + t := http.DefaultTransport.(*http.Transport).Clone() + t.Proxy = http.ProxyURL(pUrl) + transports := []transport.Transport{ + &http.Client{Transport: t}, // Default proxied transport + scrapfly.New(http.DefaultClient, cfg.IBD.APIKey), // Scrapfly transport + } + client := ibd.NewClient(transports, db) + return client, nil +} + func connectDB(logger *slog.Logger, cfg *config.Config) (database.Database, error) { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() diff --git a/backend/go.mod b/backend/go.mod index a5bf7da..bd55e54 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -13,6 +13,7 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 github.com/coreos/go-oidc/v3 v3.11.0 github.com/golang-migrate/migrate/v4 v4.17.1 + github.com/jarcoal/httpmock v1.3.1 github.com/lib/pq v1.10.9 github.com/lmittmann/tint v1.0.5 github.com/mennanov/fmutils v0.3.0 diff --git a/backend/go.sum b/backend/go.sum index 930c6cc..3f5e220 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -584,6 +584,8 @@ github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= +github.com/jarcoal/httpmock v1.3.1 h1:iUx3whfZWVf3jT01hQTO/Eo5sAYtB2/rqaUuOtpInww= +github.com/jarcoal/httpmock v1.3.1/go.mod h1:3yb8rc4BI7TCBhFY8ng0gjuLKJNquuDNiPaZjnENuYg= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= @@ -601,6 +603,8 @@ github.com/lmittmann/tint v1.0.5 h1:NQclAutOfYsqs2F1Lenue6OoWCajs5wJcP3DfWVpePw= github.com/lmittmann/tint v1.0.5/go.mod h1:HIS3gSy7qNwGCj+5oRjAutErFBl4BzdQP6cJZ0NfMwE= github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= +github.com/maxatome/go-testdeep v1.12.0 h1:Ql7Go8Tg0C1D/uMMX59LAoYK7LffeJQ6X2T04nTH68g= +github.com/maxatome/go-testdeep v1.12.0/go.mod h1:lPZc/HAcJMP92l7yI6TRz1aZN5URwUBUAfUNvrclaNM= github.com/mennanov/fmutils v0.3.0 h1:2YSyrO8oOLQQwB/iKe+xDDGO6xCUHiIAj3gYhY7D4Ao= github.com/mennanov/fmutils v0.3.0/go.mod h1:ph1jsu8gV1gUgMURCmfIVbXKG3O2/O5o/UbPbbqu8zs= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= diff --git a/backend/internal/ibd/auth.go b/backend/internal/ibd/auth.go index 7dff3a7..f09f3f7 100644 --- a/backend/internal/ibd/auth.go +++ b/backend/internal/ibd/auth.go @@ -7,6 +7,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "net/http" "strings" @@ -51,16 +52,23 @@ func (c *Client) getLoginPage(ctx context.Context) (*authConfig, error) { if err != nil { return nil, err } - - if resp.Result.StatusCode != http.StatusOK { + defer func(Body io.ReadCloser) { + _ = Body.Close() + }(resp.Body) + + if resp.StatusCode != http.StatusOK { + content, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } return nil, fmt.Errorf( "unexpected status code %d: %s", - resp.Result.StatusCode, - resp.Result.Content, + resp.StatusCode, + string(content), ) } - node, err := html.Parse(strings.NewReader(resp.Result.Content)) + node, err := html.Parse(resp.Body) if err != nil { return nil, err } @@ -113,18 +121,25 @@ func (c *Client) sendAuthRequest(ctx context.Context, cfg *authConfig, username, if err != nil { return "", "", err } + defer func(Body io.ReadCloser) { + _ = Body.Close() + }(resp.Body) - if resp.Result.StatusCode == http.StatusUnauthorized { + if resp.StatusCode == http.StatusUnauthorized { return "", "", ErrBadCredentials - } else if resp.Result.StatusCode != http.StatusOK { + } else if resp.StatusCode != http.StatusOK { + content, err := io.ReadAll(resp.Body) + if err != nil { + return "", "", fmt.Errorf("failed to read response body: %w", err) + } return "", "", fmt.Errorf( "unexpected status code %d: %s", - resp.Result.StatusCode, - resp.Result.Content, + resp.StatusCode, + string(content), ) } - node, err := html.Parse(strings.NewReader(resp.Result.Content)) + node, err := html.Parse(resp.Body) if err != nil { return "", "", err } @@ -145,19 +160,26 @@ func (c *Client) sendPostAuth(ctx context.Context, token, params string) (*http. if err != nil { return nil, err } - - if resp.Result.StatusCode != http.StatusOK { + defer func(Body io.ReadCloser) { + _ = Body.Close() + }(resp.Body) + + if resp.StatusCode != http.StatusOK { + content, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } return nil, fmt.Errorf( "unexpected status code %d: %s", - resp.Result.StatusCode, - resp.Result.Content, + resp.StatusCode, + string(content), ) } // Extract cookie - for _, cookie := range resp.Result.Cookies { + for _, cookie := range resp.Cookies() { if cookie.Name == cookieName { - return cookie.ToHTTPCookie() + return cookie, nil } } diff --git a/backend/internal/ibd/auth_test.go b/backend/internal/ibd/auth_test.go index d28b33a..8a00d42 100644 --- a/backend/internal/ibd/auth_test.go +++ b/backend/internal/ibd/auth_test.go @@ -9,6 +9,8 @@ import ( "testing" "time" + "github.com/ansg191/ibd-trader-backend/internal/ibd/transport" + "github.com/jarcoal/httpmock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/net/html" @@ -134,44 +136,36 @@ func TestClient_Authenticate(t *testing.T) { expectedVal := "test-cookie" expectedExp := time.Now().Add(time.Hour).Round(time.Second).In(time.UTC) - server := newServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - uri := r.URL.String() - switch uri { - case signInUrl: - w.Header().Set("Content-Type", "text/html") - _, err := w.Write([]byte(extractAuthHtml)) - require.NoError(t, err) - return - case authenticateUrl: + tp := httpmock.NewMockTransport() + tp.RegisterResponder("GET", signInUrl, + httpmock.NewStringResponder(http.StatusOK, extractAuthHtml)) + tp.RegisterResponder("POST", authenticateUrl, + func(request *http.Request) (*http.Response, error) { var body authRequestBody - require.NoError(t, json.NewDecoder(r.Body).Decode(&body)) + require.NoError(t, json.NewDecoder(request.Body).Decode(&body)) assert.Equal(t, "abc", body.Username) assert.Equal(t, "xyz", body.Password) - w.Header().Set("Content-Type", "text/html") - _, err := w.Write([]byte(extractTokenParamsHtml)) - require.NoError(t, err) - return - case postAuthUrl: - require.NoError(t, r.ParseForm()) - assert.Equal(t, extractTokenExpectedToken, r.Form.Get("token")) + return httpmock.NewStringResponse(http.StatusOK, extractTokenParamsHtml), nil + }) + tp.RegisterResponder("POST", postAuthUrl, + func(request *http.Request) (*http.Response, error) { + require.NoError(t, request.ParseForm()) + assert.Equal(t, extractTokenExpectedToken, request.Form.Get("token")) params, err := url.QueryUnescape(extractTokenExpectedParams) require.NoError(t, err) - assert.Equal(t, params, r.Form.Get("params")) + assert.Equal(t, params, request.Form.Get("params")) - w.Header().Set("Content-Type", "text/html") - http.SetCookie(w, &http.Cookie{Name: cookieName, Value: expectedVal, Expires: expectedExp}) - _, err = w.Write([]byte("OK")) - require.NoError(t, err) - return - default: - t.Fatalf("unexpected URL: %s", uri) - } - })) + resp := httpmock.NewStringResponse(http.StatusOK, "OK") + cookie := &http.Cookie{Name: cookieName, Value: expectedVal, Expires: expectedExp} + resp.Header.Set("Set-Cookie", cookie.String()) + return resp, nil + }) - client, err := NewClient(http.DefaultClient, apiKey, nil, "", WithBaseURL(server.URL)) - require.NoError(t, err) + client := NewClient([]transport.Transport{ + &http.Client{Transport: tp}, + }, nil) cookie, err := client.Authenticate(context.Background(), "abc", "xyz") require.NoError(t, err) @@ -184,32 +178,22 @@ func TestClient_Authenticate(t *testing.T) { func TestClient_Authenticate_401(t *testing.T) { t.Parallel() - server := newServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - uri := r.URL.String() - switch uri { - case signInUrl: - w.Header().Set("Content-Type", "text/html") - _, err := w.Write([]byte(extractAuthHtml)) - require.NoError(t, err) - return - case authenticateUrl: + tp := httpmock.NewMockTransport() + tp.RegisterResponder("GET", signInUrl, + httpmock.NewStringResponder(http.StatusOK, extractAuthHtml)) + tp.RegisterResponder("POST", authenticateUrl, + func(request *http.Request) (*http.Response, error) { var body authRequestBody - require.NoError(t, json.NewDecoder(r.Body).Decode(&body)) + require.NoError(t, json.NewDecoder(request.Body).Decode(&body)) assert.Equal(t, "abc", body.Username) assert.Equal(t, "xyz", body.Password) - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusUnauthorized) - _, err := w.Write([]byte(`{"name":"ValidationError","code":"ERR016","message":"Wrong username or password","description":"Wrong username or password"}`)) - require.NoError(t, err) - return - default: - t.Fatalf("unexpected URL: %s", uri) - } - })) + return httpmock.NewStringResponse(http.StatusUnauthorized, `{"name":"ValidationError","code":"ERR016","message":"Wrong username or password","description":"Wrong username or password"}`), nil + }) - client, err := NewClient(http.DefaultClient, apiKey, nil, "", WithBaseURL(server.URL)) - require.NoError(t, err) + client := NewClient([]transport.Transport{ + &http.Client{Transport: tp}, + }, nil) cookie, err := client.Authenticate(context.Background(), "abc", "xyz") assert.Nil(t, cookie) diff --git a/backend/internal/ibd/check_ibd_username.go b/backend/internal/ibd/check_ibd_username.go index 03d2640..c03176e 100644 --- a/backend/internal/ibd/check_ibd_username.go +++ b/backend/internal/ibd/check_ibd_username.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "fmt" + "io" "net/http" ) @@ -41,18 +42,26 @@ func (c *Client) checkIBDUsername(ctx context.Context, cfg *authConfig, username req.Header.Set("X-REQUEST-EDITIONID", "IBD-EN_US") req.Header.Set("X-REQUEST-SCHEME", "https") - resp, err := c.Do(req) + resp, err := c.DoWithStatus(req, []int{http.StatusOK, http.StatusUnauthorized}) if err != nil { return false, err } + defer func(Body io.ReadCloser) { + _ = Body.Close() + }(resp.Body) - if resp.Result.StatusCode == http.StatusUnauthorized { + if resp.StatusCode == http.StatusUnauthorized { return false, nil - } else if resp.Result.StatusCode != http.StatusOK { + } else if resp.StatusCode != http.StatusOK { + contentBytes, err := io.ReadAll(resp.Body) + if err != nil { + return false, fmt.Errorf("failed to read response body: %w", err) + } + content := string(contentBytes) return false, fmt.Errorf( "unexpected status code %d: %s", - resp.Result.StatusCode, - resp.Result.Content, + resp.StatusCode, + content, ) } return true, nil diff --git a/backend/internal/ibd/client.go b/backend/internal/ibd/client.go index c1cbb8a..2b91268 100644 --- a/backend/internal/ibd/client.go +++ b/backend/internal/ibd/client.go @@ -2,56 +2,28 @@ package ibd import ( "context" - "encoding/json" "errors" - "fmt" - "io" + "log/slog" "net/http" - "net/url" - "strconv" + "slices" "github.com/ansg191/ibd-trader-backend/internal/database" + "github.com/ansg191/ibd-trader-backend/internal/ibd/transport" ) var ErrNoAvailableCookies = errors.New("no available cookies") +var ErrNoAvailableTransports = errors.New("no available transports") type Client struct { - // HTTP client used to make requests - client *http.Client - // Scrapfly API key - apiKey string - // Client-wide Scrape options - options ScrapeOptions - // Cookie source - cookies database.CookieSource - // Proxy URL for non-scrapfly requests - proxyUrl *url.URL + transports []transport.Transport + cookies database.CookieSource } func NewClient( - client *http.Client, - apiKey string, + transports []transport.Transport, cookies database.CookieSource, - proxyUrl string, - opts ...ScrapeOption, -) (*Client, error) { - options := defaultScrapeOptions - for _, opt := range opts { - opt(&options) - } - - pProxyUrl, err := url.Parse(proxyUrl) - if err != nil { - return nil, err - } - - return &Client{ - client: client, - options: options, - apiKey: apiKey, - cookies: cookies, - proxyUrl: pProxyUrl, - }, nil +) *Client { + return &Client{transports, cookies} } func (c *Client) getCookie(ctx context.Context, subject *string) (uint, *http.Cookie, error) { @@ -83,64 +55,35 @@ func (c *Client) getCookie(ctx context.Context, subject *string) (uint, *http.Co return cookie.ID, cookie.ToHTTPCookie(), nil } -func (c *Client) Do(req *http.Request, opts ...ScrapeOption) (*ScraperResponse, error) { - options := c.options - for _, opt := range opts { - opt(&options) - } - - // Construct scrape request URL - scrapeUrl, err := url.Parse(options.baseURL) - if err != nil { - panic(err) - } - scrapeUrl.RawQuery = c.constructRawQuery(options, req.URL, req.Header) - - // Construct scrape request - scrapeReq, err := http.NewRequestWithContext(req.Context(), req.Method, scrapeUrl.String(), req.Body) - if err != nil { - return nil, err - } - - // Send scrape request - resp, err := c.client.Do(scrapeReq) - if err != nil { - return nil, err - } - defer func(Body io.ReadCloser) { - _ = Body.Close() - }(resp.Body) - - // Parse scrape response - scraperResponse := new(ScraperResponse) - err = json.NewDecoder(resp.Body).Decode(scraperResponse) - if err != nil { - return nil, err - } - - return scraperResponse, nil +func (c *Client) Do(req *http.Request) (*http.Response, error) { + return c.DoWithStatus(req, []int{http.StatusOK}) } -func (c *Client) constructRawQuery(options ScrapeOptions, u *url.URL, headers http.Header) string { - params := url.Values{} - params.Set("key", c.apiKey) - params.Set("url", u.String()) - if options.country != nil { - params.Set("country", *options.country) - } - params.Set("asp", strconv.FormatBool(options.asp)) - params.Set("proxy_pool", options.proxyPool.String()) - params.Set("render_js", strconv.FormatBool(options.renderJS)) - params.Set("cache", strconv.FormatBool(options.cache)) - - for k, v := range headers { - for i, vv := range v { - params.Add( - fmt.Sprintf("headers[%s][%d]", k, i), - vv, +func (c *Client) DoWithStatus(req *http.Request, expectedStatus []int) (*http.Response, error) { + for i, tp := range c.transports { + resp, err := tp.Do(req) + if errors.Is(err, transport.ErrUnsupportedRequest) { + // Skip unsupported transport + continue + } + if err != nil { + slog.ErrorContext(req.Context(), "transport error", + "transport", i, + "error", err, + ) + continue + } + if slices.Contains(expectedStatus, resp.StatusCode) { + return resp, nil + } else { + slog.ErrorContext(req.Context(), "unexpected status code", + "transport", i, + "expected", expectedStatus, + "actual", resp.StatusCode, ) + continue } } - return params.Encode() + return nil, ErrNoAvailableTransports } diff --git a/backend/internal/ibd/client_test.go b/backend/internal/ibd/client_test.go index 048ad59..0a8fa98 100644 --- a/backend/internal/ibd/client_test.go +++ b/backend/internal/ibd/client_test.go @@ -1,177 +1,50 @@ package ibd import ( - "bytes" "context" - "encoding/json" - "net/http" - "net/http/httptest" - "strings" "testing" "time" "github.com/ansg191/ibd-trader-backend/internal/database" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -const apiKey = "test-api-key-123" - -func newServer(t *testing.T, handler http.Handler) *httptest.Server { - t.Helper() - - return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - r.Close = true - defer r.Body.Close() - req := reconstructReq(t, r) - - rw := newResponseWriter() - handler.ServeHTTP(rw, req) - require.NoError(t, rw.Done(w)) - })) -} - -func reconstructReq(t *testing.T, r *http.Request) *http.Request { - t.Helper() - - params := r.URL.Query() - require.Equal(t, apiKey, params.Get("key")) - - // Reconstruct the request from the query params - var key string - var url string - headers := make(http.Header) - for k, v := range params { - switch k { - case "key": - key = v[0] - case "url": - url = v[0] - default: - if strings.HasPrefix(k, "headers") { - var name string - // Get index of first [ - i := strings.Index(k, "[") - if i == -1 { - t.Fatalf("invalid header key: %s", k) - } - // Get index of first ] - j := strings.Index(k, "]") - if j == -1 { - t.Fatalf("invalid header key: %s", k) - } - - // Get the header name - name = k[i+1 : j] - headers.Set(name, v[0]) - } - } - } - require.Equal(t, apiKey, key) - require.NotEmpty(t, url) - - req, err := http.NewRequestWithContext(r.Context(), r.Method, url, r.Body) - require.NoError(t, err) - req.Header = headers - - return req -} - -type responsewriter struct { - ret ScraperResponse - body bytes.Buffer - headers http.Header -} - -func newResponseWriter() *responsewriter { - return &responsewriter{ - headers: make(http.Header), - } -} - -func (w *responsewriter) Header() http.Header { - return w.headers -} - -func (w *responsewriter) Write(bytes []byte) (int, error) { - if w.ret.Result.StatusCode == 0 { - w.ret.Result.StatusCode = http.StatusOK - } - return w.body.Write(bytes) -} - -func (w *responsewriter) WriteHeader(statusCode int) { - w.ret.Result.StatusCode = statusCode -} - -func (w *responsewriter) Done(rw http.ResponseWriter) error { - w.ret.Result.Content = w.body.String() - - w.ret.Result.ResponseHeaders = make(map[string]string) - for k, v := range w.headers { - if k == "Set-Cookie" { - continue - } - w.ret.Result.ResponseHeaders[k] = v[0] - } - - req := http.Response{Header: w.headers} - w.ret.Result.Cookies = make([]ScraperCookie, 0) - for _, c := range req.Cookies() { - var cookie ScraperCookie - cookie.FromHTTPCookie(c) - w.ret.Result.Cookies = append(w.ret.Result.Cookies, cookie) - } - - rw.WriteHeader(http.StatusOK) - return json.NewEncoder(rw).Encode(w.ret) -} - func TestClient_getCookie(t *testing.T) { t.Parallel() t.Run("no cookies", func(t *testing.T) { t.Parallel() - client, err := NewClient( - http.DefaultClient, - apiKey, + client := NewClient( + nil, new(emptyCookieSourceStub), - "", ) - require.NoError(t, err) - _, _, err = client.getCookie(context.Background(), nil) + _, _, err := client.getCookie(context.Background(), nil) assert.ErrorIs(t, err, ErrNoAvailableCookies) }) t.Run("no cookies by subject", func(t *testing.T) { t.Parallel() - client, err := NewClient( - http.DefaultClient, - apiKey, + client := NewClient( + nil, new(emptyCookieSourceStub), - "", ) - require.NoError(t, err) subject := "test" - _, _, err = client.getCookie(context.Background(), &subject) + _, _, err := client.getCookie(context.Background(), &subject) assert.ErrorIs(t, err, ErrNoAvailableCookies) }) t.Run("get any cookie", func(t *testing.T) { t.Parallel() - client, err := NewClient( - http.DefaultClient, - apiKey, - new(cookieSourceStub), - "", + client := NewClient( + nil, + new(emptyCookieSourceStub), ) - require.NoError(t, err) id, cookie, err := client.getCookie(context.Background(), nil) require.NoError(t, err) @@ -186,13 +59,10 @@ func TestClient_getCookie(t *testing.T) { t.Run("get cookie by subject", func(t *testing.T) { t.Parallel() - client, err := NewClient( - http.DefaultClient, - apiKey, - new(cookieSourceStub), - "", + client := NewClient( + nil, + new(emptyCookieSourceStub), ) - require.NoError(t, err) subject := "test" id, cookie, err := client.getCookie(context.Background(), &subject) diff --git a/backend/internal/ibd/ibd50.go b/backend/internal/ibd/ibd50.go index 93aa31d..ea02f82 100644 --- a/backend/internal/ibd/ibd50.go +++ b/backend/internal/ibd/ibd50.go @@ -47,12 +47,7 @@ func (c *Client) GetIBD50(ctx context.Context) ([]*Stock, error) { req.Header.Add("x-newrelic-id", "VwUOV1dTDhABV1FRBgQOVVUF") req.Header.Add("x-requested-with", "XMLHttpRequest") - // Clone client to add proxy - client := *(c.client) - transport := http.DefaultTransport.(*http.Transport).Clone() - transport.Proxy = http.ProxyURL(c.proxyUrl) - - resp, err := client.Do(req) + resp, err := c.Do(req) if err != nil { return nil, err } diff --git a/backend/internal/ibd/search.go b/backend/internal/ibd/search.go index 23ef08b..341b14b 100644 --- a/backend/internal/ibd/search.go +++ b/backend/internal/ibd/search.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "io" "net/http" "net/url" "time" @@ -37,17 +38,24 @@ func (c *Client) Search(ctx context.Context, symbol string) (database.Stock, err if err != nil { return database.Stock{}, err } + defer func(Body io.ReadCloser) { + _ = Body.Close() + }(resp.Body) - if resp.Result.StatusCode != http.StatusOK { + if resp.StatusCode != http.StatusOK { + content, err := io.ReadAll(resp.Body) + if err != nil { + return database.Stock{}, fmt.Errorf("failed to read response body: %w", err) + } return database.Stock{}, fmt.Errorf( "unexpected status code %d: %s", - resp.Result.StatusCode, - resp.Result.Content, + resp.StatusCode, + string(content), ) } var sr searchResponse - if err = json.Unmarshal([]byte(resp.Result.Content), &sr); err != nil { + if err = json.NewDecoder(resp.Body).Decode(&sr); err != nil { return database.Stock{}, err } diff --git a/backend/internal/ibd/search_test.go b/backend/internal/ibd/search_test.go index ac0f578..f291033 100644 --- a/backend/internal/ibd/search_test.go +++ b/backend/internal/ibd/search_test.go @@ -5,6 +5,8 @@ import ( "net/http" "testing" + "github.com/ansg191/ibd-trader-backend/internal/ibd/transport" + "github.com/jarcoal/httpmock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -190,13 +192,12 @@ func TestClient_Search(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - server := newServer(t, http.HandlerFunc(func(writer http.ResponseWriter, r *http.Request) { - _, _ = writer.Write([]byte(tt.response)) - })) - defer server.Close() + tp := httpmock.NewMockTransport() + tp.RegisterResponder("GET", searchUrl, httpmock.NewStringResponder(200, tt.response)) - client, err := NewClient(http.DefaultClient, apiKey, new(cookieSourceStub), "", WithBaseURL(server.URL)) - require.NoError(t, err) + client := NewClient([]transport.Transport{ + &http.Client{Transport: tp}, + }, new(cookieSourceStub)) tt.f(t, client) }) diff --git a/backend/internal/ibd/stockinfo.go b/backend/internal/ibd/stockinfo.go index e278872..9caa956 100644 --- a/backend/internal/ibd/stockinfo.go +++ b/backend/internal/ibd/stockinfo.go @@ -3,6 +3,7 @@ package ibd import ( "context" "fmt" + "io" "net/http" "net/url" "strconv" @@ -37,16 +38,23 @@ func (c *Client) StockInfo(ctx context.Context, uri string) (*database.StockInfo if err != nil { return nil, err } + defer func(Body io.ReadCloser) { + _ = Body.Close() + }(resp.Body) - if resp.Result.StatusCode != http.StatusOK { + if resp.StatusCode != http.StatusOK { + content, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } return nil, fmt.Errorf( "unexpected status code %d: %s", - resp.Result.StatusCode, - resp.Result.Content, + resp.StatusCode, + string(content), ) } - node, err := html.Parse(strings.NewReader(resp.Result.Content)) + node, err := html.Parse(resp.Body) if err != nil { return nil, err } diff --git a/backend/internal/ibd/options.go b/backend/internal/ibd/transport/scrapfly/options.go index a07241e..f16a4b0 100644 --- a/backend/internal/ibd/options.go +++ b/backend/internal/ibd/transport/scrapfly/options.go @@ -1,4 +1,4 @@ -package ibd +package scrapfly const BaseURL = "https://api.scrapfly.io/scrape" diff --git a/backend/internal/ibd/scraper_types.go b/backend/internal/ibd/transport/scrapfly/scraper_types.go index c21ed1c..f3cf651 100644 --- a/backend/internal/ibd/scraper_types.go +++ b/backend/internal/ibd/transport/scrapfly/scraper_types.go @@ -1,8 +1,10 @@ -package ibd +package scrapfly import ( "fmt" + "io" "net/http" + "strings" "time" ) @@ -225,3 +227,27 @@ func (c *ScraperCookie) FromHTTPCookie(cookie *http.Cookie) { Version: "", } } + +func (r *ScraperResponse) ToHTTPResponse() (*http.Response, error) { + resp := &http.Response{ + StatusCode: r.Result.StatusCode, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(r.Result.Content)), + ContentLength: int64(len(r.Result.Content)), + Close: true, + } + + for k, v := range r.Result.ResponseHeaders { + resp.Header.Set(k, v) + } + + for _, c := range r.Result.Cookies { + cookie, err := c.ToHTTPCookie() + if err != nil { + return nil, err + } + resp.Header.Add("Set-Cookie", cookie.String()) + } + + return resp, nil +} diff --git a/backend/internal/ibd/transport/scrapfly/scrapfly.go b/backend/internal/ibd/transport/scrapfly/scrapfly.go new file mode 100644 index 0000000..f34f3aa --- /dev/null +++ b/backend/internal/ibd/transport/scrapfly/scrapfly.go @@ -0,0 +1,95 @@ +package scrapfly + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strconv" + + "github.com/ansg191/ibd-trader-backend/internal/ibd/transport" +) + +type ScrapflyTransport struct { + client *http.Client + apiKey string + options ScrapeOptions +} + +func New(client *http.Client, apiKey string, opts ...ScrapeOption) *ScrapflyTransport { + options := defaultScrapeOptions + for _, opt := range opts { + opt(&options) + } + + return &ScrapflyTransport{ + client: client, + apiKey: apiKey, + options: options, + } +} + +func (s *ScrapflyTransport) Do(req *http.Request) (*http.Response, error) { + // Construct scrape request URL + scrapeUrl, err := url.Parse(s.options.baseURL) + if err != nil { + panic(err) + } + scrapeUrl.RawQuery = s.constructRawQuery(req.URL, req.Header) + + // We can't handle `Content-Type` header on GET requests + // Wierd quirk of the Scrapfly API + if req.Method == http.MethodGet && req.Header.Get("Content-Type") != "" { + return nil, transport.ErrUnsupportedRequest + } + + // Construct scrape request + scrapeReq, err := http.NewRequestWithContext(req.Context(), req.Method, scrapeUrl.String(), req.Body) + if err != nil { + return nil, err + } + + // Send scrape request + resp, err := s.client.Do(scrapeReq) + if err != nil { + return nil, err + } + defer func(Body io.ReadCloser) { + _ = Body.Close() + }(resp.Body) + + // Parse scrape response + scraperResponse := new(ScraperResponse) + err = json.NewDecoder(resp.Body).Decode(scraperResponse) + if err != nil { + return nil, err + } + + // Convert scraper response to http.Response + return scraperResponse.ToHTTPResponse() +} + +func (s *ScrapflyTransport) constructRawQuery(u *url.URL, headers http.Header) string { + params := url.Values{} + params.Set("key", s.apiKey) + params.Set("url", u.String()) + if s.options.country != nil { + params.Set("country", *s.options.country) + } + params.Set("asp", strconv.FormatBool(s.options.asp)) + params.Set("proxy_pool", s.options.proxyPool.String()) + params.Set("render_js", strconv.FormatBool(s.options.renderJS)) + params.Set("cache", strconv.FormatBool(s.options.cache)) + + for k, v := range headers { + for i, vv := range v { + params.Add( + fmt.Sprintf("headers[%s][%d]", k, i), + vv, + ) + } + } + + return params.Encode() +} diff --git a/backend/internal/ibd/transport/transport.go b/backend/internal/ibd/transport/transport.go new file mode 100644 index 0000000..7918f4f --- /dev/null +++ b/backend/internal/ibd/transport/transport.go @@ -0,0 +1,12 @@ +package transport + +import ( + "errors" + "net/http" +) + +var ErrUnsupportedRequest = errors.New("unsupported request") + +type Transport interface { + Do(req *http.Request) (*http.Response, error) +} diff --git a/backend/internal/ibd/userinfo.go b/backend/internal/ibd/userinfo.go index ba7a5b5..ed61497 100644 --- a/backend/internal/ibd/userinfo.go +++ b/backend/internal/ibd/userinfo.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "io" "log/slog" "net/http" ) @@ -24,17 +25,25 @@ func (c *Client) UserInfo(ctx context.Context, cookie *http.Cookie) (*UserProfil if err != nil { return nil, err } + defer func(Body io.ReadCloser) { + _ = Body.Close() + }(resp.Body) - if resp.Result.StatusCode != http.StatusOK { + content, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf( "unexpected status code %d: %s", - resp.Result.StatusCode, - resp.Result.Content, + resp.StatusCode, + string(content), ) } up := new(UserProfile) - if err = up.UnmarshalJSON([]byte(resp.Result.Content)); err != nil { + if err = up.UnmarshalJSON(content); err != nil { return nil, err } |