aboutsummaryrefslogtreecommitdiff
path: root/backend/internal
diff options
context:
space:
mode:
authorGravatar Anshul Gupta <ansg191@anshulg.com> 2024-08-06 17:16:35 -0700
committerGravatar Anshul Gupta <ansg191@anshulg.com> 2024-08-06 17:16:35 -0700
commit961f9e0a76c3cfe9ae92ca8da0531790e0610b69 (patch)
treef6de4ed36c3f48ee94ecd524dedeb0d7c84b72e5 /backend/internal
parent641c81198d7fed7138bb482f226e54bd703094ab (diff)
downloadibd-trader-961f9e0a76c3cfe9ae92ca8da0531790e0610b69.tar.gz
ibd-trader-961f9e0a76c3cfe9ae92ca8da0531790e0610b69.tar.zst
ibd-trader-961f9e0a76c3cfe9ae92ca8da0531790e0610b69.zip
Modify IBD to accept various transport backends
This allows IBD to try using faster and cheaper transports first with fallback to more reliable and expensive transports later.
Diffstat (limited to 'backend/internal')
-rw-r--r--backend/internal/ibd/auth.go54
-rw-r--r--backend/internal/ibd/auth_test.go84
-rw-r--r--backend/internal/ibd/check_ibd_username.go19
-rw-r--r--backend/internal/ibd/client.go125
-rw-r--r--backend/internal/ibd/client_test.go154
-rw-r--r--backend/internal/ibd/ibd50.go7
-rw-r--r--backend/internal/ibd/search.go16
-rw-r--r--backend/internal/ibd/search_test.go13
-rw-r--r--backend/internal/ibd/stockinfo.go16
-rw-r--r--backend/internal/ibd/transport/scrapfly/options.go (renamed from backend/internal/ibd/options.go)2
-rw-r--r--backend/internal/ibd/transport/scrapfly/scraper_types.go (renamed from backend/internal/ibd/scraper_types.go)28
-rw-r--r--backend/internal/ibd/transport/scrapfly/scrapfly.go95
-rw-r--r--backend/internal/ibd/transport/transport.go12
-rw-r--r--backend/internal/ibd/userinfo.go17
14 files changed, 312 insertions, 330 deletions
diff --git a/backend/internal/ibd/auth.go b/backend/internal/ibd/auth.go
index 7dff3a7..f09f3f7 100644
--- a/backend/internal/ibd/auth.go
+++ b/backend/internal/ibd/auth.go
@@ -7,6 +7,7 @@ import (
"encoding/json"
"errors"
"fmt"
+ "io"
"net/http"
"strings"
@@ -51,16 +52,23 @@ func (c *Client) getLoginPage(ctx context.Context) (*authConfig, error) {
if err != nil {
return nil, err
}
-
- if resp.Result.StatusCode != http.StatusOK {
+ defer func(Body io.ReadCloser) {
+ _ = Body.Close()
+ }(resp.Body)
+
+ if resp.StatusCode != http.StatusOK {
+ content, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read response body: %w", err)
+ }
return nil, fmt.Errorf(
"unexpected status code %d: %s",
- resp.Result.StatusCode,
- resp.Result.Content,
+ resp.StatusCode,
+ string(content),
)
}
- node, err := html.Parse(strings.NewReader(resp.Result.Content))
+ node, err := html.Parse(resp.Body)
if err != nil {
return nil, err
}
@@ -113,18 +121,25 @@ func (c *Client) sendAuthRequest(ctx context.Context, cfg *authConfig, username,
if err != nil {
return "", "", err
}
+ defer func(Body io.ReadCloser) {
+ _ = Body.Close()
+ }(resp.Body)
- if resp.Result.StatusCode == http.StatusUnauthorized {
+ if resp.StatusCode == http.StatusUnauthorized {
return "", "", ErrBadCredentials
- } else if resp.Result.StatusCode != http.StatusOK {
+ } else if resp.StatusCode != http.StatusOK {
+ content, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return "", "", fmt.Errorf("failed to read response body: %w", err)
+ }
return "", "", fmt.Errorf(
"unexpected status code %d: %s",
- resp.Result.StatusCode,
- resp.Result.Content,
+ resp.StatusCode,
+ string(content),
)
}
- node, err := html.Parse(strings.NewReader(resp.Result.Content))
+ node, err := html.Parse(resp.Body)
if err != nil {
return "", "", err
}
@@ -145,19 +160,26 @@ func (c *Client) sendPostAuth(ctx context.Context, token, params string) (*http.
if err != nil {
return nil, err
}
-
- if resp.Result.StatusCode != http.StatusOK {
+ defer func(Body io.ReadCloser) {
+ _ = Body.Close()
+ }(resp.Body)
+
+ if resp.StatusCode != http.StatusOK {
+ content, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read response body: %w", err)
+ }
return nil, fmt.Errorf(
"unexpected status code %d: %s",
- resp.Result.StatusCode,
- resp.Result.Content,
+ resp.StatusCode,
+ string(content),
)
}
// Extract cookie
- for _, cookie := range resp.Result.Cookies {
+ for _, cookie := range resp.Cookies() {
if cookie.Name == cookieName {
- return cookie.ToHTTPCookie()
+ return cookie, nil
}
}
diff --git a/backend/internal/ibd/auth_test.go b/backend/internal/ibd/auth_test.go
index d28b33a..8a00d42 100644
--- a/backend/internal/ibd/auth_test.go
+++ b/backend/internal/ibd/auth_test.go
@@ -9,6 +9,8 @@ import (
"testing"
"time"
+ "github.com/ansg191/ibd-trader-backend/internal/ibd/transport"
+ "github.com/jarcoal/httpmock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/html"
@@ -134,44 +136,36 @@ func TestClient_Authenticate(t *testing.T) {
expectedVal := "test-cookie"
expectedExp := time.Now().Add(time.Hour).Round(time.Second).In(time.UTC)
- server := newServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- uri := r.URL.String()
- switch uri {
- case signInUrl:
- w.Header().Set("Content-Type", "text/html")
- _, err := w.Write([]byte(extractAuthHtml))
- require.NoError(t, err)
- return
- case authenticateUrl:
+ tp := httpmock.NewMockTransport()
+ tp.RegisterResponder("GET", signInUrl,
+ httpmock.NewStringResponder(http.StatusOK, extractAuthHtml))
+ tp.RegisterResponder("POST", authenticateUrl,
+ func(request *http.Request) (*http.Response, error) {
var body authRequestBody
- require.NoError(t, json.NewDecoder(r.Body).Decode(&body))
+ require.NoError(t, json.NewDecoder(request.Body).Decode(&body))
assert.Equal(t, "abc", body.Username)
assert.Equal(t, "xyz", body.Password)
- w.Header().Set("Content-Type", "text/html")
- _, err := w.Write([]byte(extractTokenParamsHtml))
- require.NoError(t, err)
- return
- case postAuthUrl:
- require.NoError(t, r.ParseForm())
- assert.Equal(t, extractTokenExpectedToken, r.Form.Get("token"))
+ return httpmock.NewStringResponse(http.StatusOK, extractTokenParamsHtml), nil
+ })
+ tp.RegisterResponder("POST", postAuthUrl,
+ func(request *http.Request) (*http.Response, error) {
+ require.NoError(t, request.ParseForm())
+ assert.Equal(t, extractTokenExpectedToken, request.Form.Get("token"))
params, err := url.QueryUnescape(extractTokenExpectedParams)
require.NoError(t, err)
- assert.Equal(t, params, r.Form.Get("params"))
+ assert.Equal(t, params, request.Form.Get("params"))
- w.Header().Set("Content-Type", "text/html")
- http.SetCookie(w, &http.Cookie{Name: cookieName, Value: expectedVal, Expires: expectedExp})
- _, err = w.Write([]byte("OK"))
- require.NoError(t, err)
- return
- default:
- t.Fatalf("unexpected URL: %s", uri)
- }
- }))
+ resp := httpmock.NewStringResponse(http.StatusOK, "OK")
+ cookie := &http.Cookie{Name: cookieName, Value: expectedVal, Expires: expectedExp}
+ resp.Header.Set("Set-Cookie", cookie.String())
+ return resp, nil
+ })
- client, err := NewClient(http.DefaultClient, apiKey, nil, "", WithBaseURL(server.URL))
- require.NoError(t, err)
+ client := NewClient([]transport.Transport{
+ &http.Client{Transport: tp},
+ }, nil)
cookie, err := client.Authenticate(context.Background(), "abc", "xyz")
require.NoError(t, err)
@@ -184,32 +178,22 @@ func TestClient_Authenticate(t *testing.T) {
func TestClient_Authenticate_401(t *testing.T) {
t.Parallel()
- server := newServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- uri := r.URL.String()
- switch uri {
- case signInUrl:
- w.Header().Set("Content-Type", "text/html")
- _, err := w.Write([]byte(extractAuthHtml))
- require.NoError(t, err)
- return
- case authenticateUrl:
+ tp := httpmock.NewMockTransport()
+ tp.RegisterResponder("GET", signInUrl,
+ httpmock.NewStringResponder(http.StatusOK, extractAuthHtml))
+ tp.RegisterResponder("POST", authenticateUrl,
+ func(request *http.Request) (*http.Response, error) {
var body authRequestBody
- require.NoError(t, json.NewDecoder(r.Body).Decode(&body))
+ require.NoError(t, json.NewDecoder(request.Body).Decode(&body))
assert.Equal(t, "abc", body.Username)
assert.Equal(t, "xyz", body.Password)
- w.Header().Set("Content-Type", "application/json")
- w.WriteHeader(http.StatusUnauthorized)
- _, err := w.Write([]byte(`{"name":"ValidationError","code":"ERR016","message":"Wrong username or password","description":"Wrong username or password"}`))
- require.NoError(t, err)
- return
- default:
- t.Fatalf("unexpected URL: %s", uri)
- }
- }))
+ return httpmock.NewStringResponse(http.StatusUnauthorized, `{"name":"ValidationError","code":"ERR016","message":"Wrong username or password","description":"Wrong username or password"}`), nil
+ })
- client, err := NewClient(http.DefaultClient, apiKey, nil, "", WithBaseURL(server.URL))
- require.NoError(t, err)
+ client := NewClient([]transport.Transport{
+ &http.Client{Transport: tp},
+ }, nil)
cookie, err := client.Authenticate(context.Background(), "abc", "xyz")
assert.Nil(t, cookie)
diff --git a/backend/internal/ibd/check_ibd_username.go b/backend/internal/ibd/check_ibd_username.go
index 03d2640..c03176e 100644
--- a/backend/internal/ibd/check_ibd_username.go
+++ b/backend/internal/ibd/check_ibd_username.go
@@ -5,6 +5,7 @@ import (
"context"
"encoding/json"
"fmt"
+ "io"
"net/http"
)
@@ -41,18 +42,26 @@ func (c *Client) checkIBDUsername(ctx context.Context, cfg *authConfig, username
req.Header.Set("X-REQUEST-EDITIONID", "IBD-EN_US")
req.Header.Set("X-REQUEST-SCHEME", "https")
- resp, err := c.Do(req)
+ resp, err := c.DoWithStatus(req, []int{http.StatusOK, http.StatusUnauthorized})
if err != nil {
return false, err
}
+ defer func(Body io.ReadCloser) {
+ _ = Body.Close()
+ }(resp.Body)
- if resp.Result.StatusCode == http.StatusUnauthorized {
+ if resp.StatusCode == http.StatusUnauthorized {
return false, nil
- } else if resp.Result.StatusCode != http.StatusOK {
+ } else if resp.StatusCode != http.StatusOK {
+ contentBytes, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return false, fmt.Errorf("failed to read response body: %w", err)
+ }
+ content := string(contentBytes)
return false, fmt.Errorf(
"unexpected status code %d: %s",
- resp.Result.StatusCode,
- resp.Result.Content,
+ resp.StatusCode,
+ content,
)
}
return true, nil
diff --git a/backend/internal/ibd/client.go b/backend/internal/ibd/client.go
index c1cbb8a..2b91268 100644
--- a/backend/internal/ibd/client.go
+++ b/backend/internal/ibd/client.go
@@ -2,56 +2,28 @@ package ibd
import (
"context"
- "encoding/json"
"errors"
- "fmt"
- "io"
+ "log/slog"
"net/http"
- "net/url"
- "strconv"
+ "slices"
"github.com/ansg191/ibd-trader-backend/internal/database"
+ "github.com/ansg191/ibd-trader-backend/internal/ibd/transport"
)
var ErrNoAvailableCookies = errors.New("no available cookies")
+var ErrNoAvailableTransports = errors.New("no available transports")
type Client struct {
- // HTTP client used to make requests
- client *http.Client
- // Scrapfly API key
- apiKey string
- // Client-wide Scrape options
- options ScrapeOptions
- // Cookie source
- cookies database.CookieSource
- // Proxy URL for non-scrapfly requests
- proxyUrl *url.URL
+ transports []transport.Transport
+ cookies database.CookieSource
}
func NewClient(
- client *http.Client,
- apiKey string,
+ transports []transport.Transport,
cookies database.CookieSource,
- proxyUrl string,
- opts ...ScrapeOption,
-) (*Client, error) {
- options := defaultScrapeOptions
- for _, opt := range opts {
- opt(&options)
- }
-
- pProxyUrl, err := url.Parse(proxyUrl)
- if err != nil {
- return nil, err
- }
-
- return &Client{
- client: client,
- options: options,
- apiKey: apiKey,
- cookies: cookies,
- proxyUrl: pProxyUrl,
- }, nil
+) *Client {
+ return &Client{transports, cookies}
}
func (c *Client) getCookie(ctx context.Context, subject *string) (uint, *http.Cookie, error) {
@@ -83,64 +55,35 @@ func (c *Client) getCookie(ctx context.Context, subject *string) (uint, *http.Co
return cookie.ID, cookie.ToHTTPCookie(), nil
}
-func (c *Client) Do(req *http.Request, opts ...ScrapeOption) (*ScraperResponse, error) {
- options := c.options
- for _, opt := range opts {
- opt(&options)
- }
-
- // Construct scrape request URL
- scrapeUrl, err := url.Parse(options.baseURL)
- if err != nil {
- panic(err)
- }
- scrapeUrl.RawQuery = c.constructRawQuery(options, req.URL, req.Header)
-
- // Construct scrape request
- scrapeReq, err := http.NewRequestWithContext(req.Context(), req.Method, scrapeUrl.String(), req.Body)
- if err != nil {
- return nil, err
- }
-
- // Send scrape request
- resp, err := c.client.Do(scrapeReq)
- if err != nil {
- return nil, err
- }
- defer func(Body io.ReadCloser) {
- _ = Body.Close()
- }(resp.Body)
-
- // Parse scrape response
- scraperResponse := new(ScraperResponse)
- err = json.NewDecoder(resp.Body).Decode(scraperResponse)
- if err != nil {
- return nil, err
- }
-
- return scraperResponse, nil
+func (c *Client) Do(req *http.Request) (*http.Response, error) {
+ return c.DoWithStatus(req, []int{http.StatusOK})
}
-func (c *Client) constructRawQuery(options ScrapeOptions, u *url.URL, headers http.Header) string {
- params := url.Values{}
- params.Set("key", c.apiKey)
- params.Set("url", u.String())
- if options.country != nil {
- params.Set("country", *options.country)
- }
- params.Set("asp", strconv.FormatBool(options.asp))
- params.Set("proxy_pool", options.proxyPool.String())
- params.Set("render_js", strconv.FormatBool(options.renderJS))
- params.Set("cache", strconv.FormatBool(options.cache))
-
- for k, v := range headers {
- for i, vv := range v {
- params.Add(
- fmt.Sprintf("headers[%s][%d]", k, i),
- vv,
+func (c *Client) DoWithStatus(req *http.Request, expectedStatus []int) (*http.Response, error) {
+ for i, tp := range c.transports {
+ resp, err := tp.Do(req)
+ if errors.Is(err, transport.ErrUnsupportedRequest) {
+ // Skip unsupported transport
+ continue
+ }
+ if err != nil {
+ slog.ErrorContext(req.Context(), "transport error",
+ "transport", i,
+ "error", err,
+ )
+ continue
+ }
+ if slices.Contains(expectedStatus, resp.StatusCode) {
+ return resp, nil
+ } else {
+ slog.ErrorContext(req.Context(), "unexpected status code",
+ "transport", i,
+ "expected", expectedStatus,
+ "actual", resp.StatusCode,
)
+ continue
}
}
- return params.Encode()
+ return nil, ErrNoAvailableTransports
}
diff --git a/backend/internal/ibd/client_test.go b/backend/internal/ibd/client_test.go
index 048ad59..0a8fa98 100644
--- a/backend/internal/ibd/client_test.go
+++ b/backend/internal/ibd/client_test.go
@@ -1,177 +1,50 @@
package ibd
import (
- "bytes"
"context"
- "encoding/json"
- "net/http"
- "net/http/httptest"
- "strings"
"testing"
"time"
"github.com/ansg191/ibd-trader-backend/internal/database"
-
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
-const apiKey = "test-api-key-123"
-
-func newServer(t *testing.T, handler http.Handler) *httptest.Server {
- t.Helper()
-
- return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- r.Close = true
- defer r.Body.Close()
- req := reconstructReq(t, r)
-
- rw := newResponseWriter()
- handler.ServeHTTP(rw, req)
- require.NoError(t, rw.Done(w))
- }))
-}
-
-func reconstructReq(t *testing.T, r *http.Request) *http.Request {
- t.Helper()
-
- params := r.URL.Query()
- require.Equal(t, apiKey, params.Get("key"))
-
- // Reconstruct the request from the query params
- var key string
- var url string
- headers := make(http.Header)
- for k, v := range params {
- switch k {
- case "key":
- key = v[0]
- case "url":
- url = v[0]
- default:
- if strings.HasPrefix(k, "headers") {
- var name string
- // Get index of first [
- i := strings.Index(k, "[")
- if i == -1 {
- t.Fatalf("invalid header key: %s", k)
- }
- // Get index of first ]
- j := strings.Index(k, "]")
- if j == -1 {
- t.Fatalf("invalid header key: %s", k)
- }
-
- // Get the header name
- name = k[i+1 : j]
- headers.Set(name, v[0])
- }
- }
- }
- require.Equal(t, apiKey, key)
- require.NotEmpty(t, url)
-
- req, err := http.NewRequestWithContext(r.Context(), r.Method, url, r.Body)
- require.NoError(t, err)
- req.Header = headers
-
- return req
-}
-
-type responsewriter struct {
- ret ScraperResponse
- body bytes.Buffer
- headers http.Header
-}
-
-func newResponseWriter() *responsewriter {
- return &responsewriter{
- headers: make(http.Header),
- }
-}
-
-func (w *responsewriter) Header() http.Header {
- return w.headers
-}
-
-func (w *responsewriter) Write(bytes []byte) (int, error) {
- if w.ret.Result.StatusCode == 0 {
- w.ret.Result.StatusCode = http.StatusOK
- }
- return w.body.Write(bytes)
-}
-
-func (w *responsewriter) WriteHeader(statusCode int) {
- w.ret.Result.StatusCode = statusCode
-}
-
-func (w *responsewriter) Done(rw http.ResponseWriter) error {
- w.ret.Result.Content = w.body.String()
-
- w.ret.Result.ResponseHeaders = make(map[string]string)
- for k, v := range w.headers {
- if k == "Set-Cookie" {
- continue
- }
- w.ret.Result.ResponseHeaders[k] = v[0]
- }
-
- req := http.Response{Header: w.headers}
- w.ret.Result.Cookies = make([]ScraperCookie, 0)
- for _, c := range req.Cookies() {
- var cookie ScraperCookie
- cookie.FromHTTPCookie(c)
- w.ret.Result.Cookies = append(w.ret.Result.Cookies, cookie)
- }
-
- rw.WriteHeader(http.StatusOK)
- return json.NewEncoder(rw).Encode(w.ret)
-}
-
func TestClient_getCookie(t *testing.T) {
t.Parallel()
t.Run("no cookies", func(t *testing.T) {
t.Parallel()
- client, err := NewClient(
- http.DefaultClient,
- apiKey,
+ client := NewClient(
+ nil,
new(emptyCookieSourceStub),
- "",
)
- require.NoError(t, err)
- _, _, err = client.getCookie(context.Background(), nil)
+ _, _, err := client.getCookie(context.Background(), nil)
assert.ErrorIs(t, err, ErrNoAvailableCookies)
})
t.Run("no cookies by subject", func(t *testing.T) {
t.Parallel()
- client, err := NewClient(
- http.DefaultClient,
- apiKey,
+ client := NewClient(
+ nil,
new(emptyCookieSourceStub),
- "",
)
- require.NoError(t, err)
subject := "test"
- _, _, err = client.getCookie(context.Background(), &subject)
+ _, _, err := client.getCookie(context.Background(), &subject)
assert.ErrorIs(t, err, ErrNoAvailableCookies)
})
t.Run("get any cookie", func(t *testing.T) {
t.Parallel()
- client, err := NewClient(
- http.DefaultClient,
- apiKey,
- new(cookieSourceStub),
- "",
+ client := NewClient(
+ nil,
+ new(emptyCookieSourceStub),
)
- require.NoError(t, err)
id, cookie, err := client.getCookie(context.Background(), nil)
require.NoError(t, err)
@@ -186,13 +59,10 @@ func TestClient_getCookie(t *testing.T) {
t.Run("get cookie by subject", func(t *testing.T) {
t.Parallel()
- client, err := NewClient(
- http.DefaultClient,
- apiKey,
- new(cookieSourceStub),
- "",
+ client := NewClient(
+ nil,
+ new(emptyCookieSourceStub),
)
- require.NoError(t, err)
subject := "test"
id, cookie, err := client.getCookie(context.Background(), &subject)
diff --git a/backend/internal/ibd/ibd50.go b/backend/internal/ibd/ibd50.go
index 93aa31d..ea02f82 100644
--- a/backend/internal/ibd/ibd50.go
+++ b/backend/internal/ibd/ibd50.go
@@ -47,12 +47,7 @@ func (c *Client) GetIBD50(ctx context.Context) ([]*Stock, error) {
req.Header.Add("x-newrelic-id", "VwUOV1dTDhABV1FRBgQOVVUF")
req.Header.Add("x-requested-with", "XMLHttpRequest")
- // Clone client to add proxy
- client := *(c.client)
- transport := http.DefaultTransport.(*http.Transport).Clone()
- transport.Proxy = http.ProxyURL(c.proxyUrl)
-
- resp, err := client.Do(req)
+ resp, err := c.Do(req)
if err != nil {
return nil, err
}
diff --git a/backend/internal/ibd/search.go b/backend/internal/ibd/search.go
index 23ef08b..341b14b 100644
--- a/backend/internal/ibd/search.go
+++ b/backend/internal/ibd/search.go
@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
+ "io"
"net/http"
"net/url"
"time"
@@ -37,17 +38,24 @@ func (c *Client) Search(ctx context.Context, symbol string) (database.Stock, err
if err != nil {
return database.Stock{}, err
}
+ defer func(Body io.ReadCloser) {
+ _ = Body.Close()
+ }(resp.Body)
- if resp.Result.StatusCode != http.StatusOK {
+ if resp.StatusCode != http.StatusOK {
+ content, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return database.Stock{}, fmt.Errorf("failed to read response body: %w", err)
+ }
return database.Stock{}, fmt.Errorf(
"unexpected status code %d: %s",
- resp.Result.StatusCode,
- resp.Result.Content,
+ resp.StatusCode,
+ string(content),
)
}
var sr searchResponse
- if err = json.Unmarshal([]byte(resp.Result.Content), &sr); err != nil {
+ if err = json.NewDecoder(resp.Body).Decode(&sr); err != nil {
return database.Stock{}, err
}
diff --git a/backend/internal/ibd/search_test.go b/backend/internal/ibd/search_test.go
index ac0f578..f291033 100644
--- a/backend/internal/ibd/search_test.go
+++ b/backend/internal/ibd/search_test.go
@@ -5,6 +5,8 @@ import (
"net/http"
"testing"
+ "github.com/ansg191/ibd-trader-backend/internal/ibd/transport"
+ "github.com/jarcoal/httpmock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -190,13 +192,12 @@ func TestClient_Search(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- server := newServer(t, http.HandlerFunc(func(writer http.ResponseWriter, r *http.Request) {
- _, _ = writer.Write([]byte(tt.response))
- }))
- defer server.Close()
+ tp := httpmock.NewMockTransport()
+ tp.RegisterResponder("GET", searchUrl, httpmock.NewStringResponder(200, tt.response))
- client, err := NewClient(http.DefaultClient, apiKey, new(cookieSourceStub), "", WithBaseURL(server.URL))
- require.NoError(t, err)
+ client := NewClient([]transport.Transport{
+ &http.Client{Transport: tp},
+ }, new(cookieSourceStub))
tt.f(t, client)
})
diff --git a/backend/internal/ibd/stockinfo.go b/backend/internal/ibd/stockinfo.go
index e278872..9caa956 100644
--- a/backend/internal/ibd/stockinfo.go
+++ b/backend/internal/ibd/stockinfo.go
@@ -3,6 +3,7 @@ package ibd
import (
"context"
"fmt"
+ "io"
"net/http"
"net/url"
"strconv"
@@ -37,16 +38,23 @@ func (c *Client) StockInfo(ctx context.Context, uri string) (*database.StockInfo
if err != nil {
return nil, err
}
+ defer func(Body io.ReadCloser) {
+ _ = Body.Close()
+ }(resp.Body)
- if resp.Result.StatusCode != http.StatusOK {
+ if resp.StatusCode != http.StatusOK {
+ content, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read response body: %w", err)
+ }
return nil, fmt.Errorf(
"unexpected status code %d: %s",
- resp.Result.StatusCode,
- resp.Result.Content,
+ resp.StatusCode,
+ string(content),
)
}
- node, err := html.Parse(strings.NewReader(resp.Result.Content))
+ node, err := html.Parse(resp.Body)
if err != nil {
return nil, err
}
diff --git a/backend/internal/ibd/options.go b/backend/internal/ibd/transport/scrapfly/options.go
index a07241e..f16a4b0 100644
--- a/backend/internal/ibd/options.go
+++ b/backend/internal/ibd/transport/scrapfly/options.go
@@ -1,4 +1,4 @@
-package ibd
+package scrapfly
const BaseURL = "https://api.scrapfly.io/scrape"
diff --git a/backend/internal/ibd/scraper_types.go b/backend/internal/ibd/transport/scrapfly/scraper_types.go
index c21ed1c..f3cf651 100644
--- a/backend/internal/ibd/scraper_types.go
+++ b/backend/internal/ibd/transport/scrapfly/scraper_types.go
@@ -1,8 +1,10 @@
-package ibd
+package scrapfly
import (
"fmt"
+ "io"
"net/http"
+ "strings"
"time"
)
@@ -225,3 +227,27 @@ func (c *ScraperCookie) FromHTTPCookie(cookie *http.Cookie) {
Version: "",
}
}
+
+func (r *ScraperResponse) ToHTTPResponse() (*http.Response, error) {
+ resp := &http.Response{
+ StatusCode: r.Result.StatusCode,
+ Header: make(http.Header),
+ Body: io.NopCloser(strings.NewReader(r.Result.Content)),
+ ContentLength: int64(len(r.Result.Content)),
+ Close: true,
+ }
+
+ for k, v := range r.Result.ResponseHeaders {
+ resp.Header.Set(k, v)
+ }
+
+ for _, c := range r.Result.Cookies {
+ cookie, err := c.ToHTTPCookie()
+ if err != nil {
+ return nil, err
+ }
+ resp.Header.Add("Set-Cookie", cookie.String())
+ }
+
+ return resp, nil
+}
diff --git a/backend/internal/ibd/transport/scrapfly/scrapfly.go b/backend/internal/ibd/transport/scrapfly/scrapfly.go
new file mode 100644
index 0000000..f34f3aa
--- /dev/null
+++ b/backend/internal/ibd/transport/scrapfly/scrapfly.go
@@ -0,0 +1,95 @@
+package scrapfly
+
+import (
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "strconv"
+
+ "github.com/ansg191/ibd-trader-backend/internal/ibd/transport"
+)
+
+type ScrapflyTransport struct {
+ client *http.Client
+ apiKey string
+ options ScrapeOptions
+}
+
+func New(client *http.Client, apiKey string, opts ...ScrapeOption) *ScrapflyTransport {
+ options := defaultScrapeOptions
+ for _, opt := range opts {
+ opt(&options)
+ }
+
+ return &ScrapflyTransport{
+ client: client,
+ apiKey: apiKey,
+ options: options,
+ }
+}
+
+func (s *ScrapflyTransport) Do(req *http.Request) (*http.Response, error) {
+ // Construct scrape request URL
+ scrapeUrl, err := url.Parse(s.options.baseURL)
+ if err != nil {
+ panic(err)
+ }
+ scrapeUrl.RawQuery = s.constructRawQuery(req.URL, req.Header)
+
+ // We can't handle `Content-Type` header on GET requests
+ // Wierd quirk of the Scrapfly API
+ if req.Method == http.MethodGet && req.Header.Get("Content-Type") != "" {
+ return nil, transport.ErrUnsupportedRequest
+ }
+
+ // Construct scrape request
+ scrapeReq, err := http.NewRequestWithContext(req.Context(), req.Method, scrapeUrl.String(), req.Body)
+ if err != nil {
+ return nil, err
+ }
+
+ // Send scrape request
+ resp, err := s.client.Do(scrapeReq)
+ if err != nil {
+ return nil, err
+ }
+ defer func(Body io.ReadCloser) {
+ _ = Body.Close()
+ }(resp.Body)
+
+ // Parse scrape response
+ scraperResponse := new(ScraperResponse)
+ err = json.NewDecoder(resp.Body).Decode(scraperResponse)
+ if err != nil {
+ return nil, err
+ }
+
+ // Convert scraper response to http.Response
+ return scraperResponse.ToHTTPResponse()
+}
+
+func (s *ScrapflyTransport) constructRawQuery(u *url.URL, headers http.Header) string {
+ params := url.Values{}
+ params.Set("key", s.apiKey)
+ params.Set("url", u.String())
+ if s.options.country != nil {
+ params.Set("country", *s.options.country)
+ }
+ params.Set("asp", strconv.FormatBool(s.options.asp))
+ params.Set("proxy_pool", s.options.proxyPool.String())
+ params.Set("render_js", strconv.FormatBool(s.options.renderJS))
+ params.Set("cache", strconv.FormatBool(s.options.cache))
+
+ for k, v := range headers {
+ for i, vv := range v {
+ params.Add(
+ fmt.Sprintf("headers[%s][%d]", k, i),
+ vv,
+ )
+ }
+ }
+
+ return params.Encode()
+}
diff --git a/backend/internal/ibd/transport/transport.go b/backend/internal/ibd/transport/transport.go
new file mode 100644
index 0000000..7918f4f
--- /dev/null
+++ b/backend/internal/ibd/transport/transport.go
@@ -0,0 +1,12 @@
+package transport
+
+import (
+ "errors"
+ "net/http"
+)
+
+var ErrUnsupportedRequest = errors.New("unsupported request")
+
+type Transport interface {
+ Do(req *http.Request) (*http.Response, error)
+}
diff --git a/backend/internal/ibd/userinfo.go b/backend/internal/ibd/userinfo.go
index ba7a5b5..ed61497 100644
--- a/backend/internal/ibd/userinfo.go
+++ b/backend/internal/ibd/userinfo.go
@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
+ "io"
"log/slog"
"net/http"
)
@@ -24,17 +25,25 @@ func (c *Client) UserInfo(ctx context.Context, cookie *http.Cookie) (*UserProfil
if err != nil {
return nil, err
}
+ defer func(Body io.ReadCloser) {
+ _ = Body.Close()
+ }(resp.Body)
- if resp.Result.StatusCode != http.StatusOK {
+ content, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read response body: %w", err)
+ }
+
+ if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf(
"unexpected status code %d: %s",
- resp.Result.StatusCode,
- resp.Result.Content,
+ resp.StatusCode,
+ string(content),
)
}
up := new(UserProfile)
- if err = up.UnmarshalJSON([]byte(resp.Result.Content)); err != nil {
+ if err = up.UnmarshalJSON(content); err != nil {
return nil, err
}