aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--backend/cmd/main.go20
-rw-r--r--backend/go.mod1
-rw-r--r--backend/go.sum4
-rw-r--r--backend/internal/ibd/auth.go54
-rw-r--r--backend/internal/ibd/auth_test.go84
-rw-r--r--backend/internal/ibd/check_ibd_username.go19
-rw-r--r--backend/internal/ibd/client.go125
-rw-r--r--backend/internal/ibd/client_test.go154
-rw-r--r--backend/internal/ibd/ibd50.go7
-rw-r--r--backend/internal/ibd/search.go16
-rw-r--r--backend/internal/ibd/search_test.go13
-rw-r--r--backend/internal/ibd/stockinfo.go16
-rw-r--r--backend/internal/ibd/transport/scrapfly/options.go (renamed from backend/internal/ibd/options.go)2
-rw-r--r--backend/internal/ibd/transport/scrapfly/scraper_types.go (renamed from backend/internal/ibd/scraper_types.go)28
-rw-r--r--backend/internal/ibd/transport/scrapfly/scrapfly.go95
-rw-r--r--backend/internal/ibd/transport/transport.go12
-rw-r--r--backend/internal/ibd/userinfo.go17
17 files changed, 336 insertions, 331 deletions
diff --git a/backend/cmd/main.go b/backend/cmd/main.go
index c5104e8..16fc1ce 100644
--- a/backend/cmd/main.go
+++ b/backend/cmd/main.go
@@ -6,6 +6,7 @@ import (
"log"
"log/slog"
"net/http"
+ "net/url"
"os"
"os/signal"
"time"
@@ -15,6 +16,8 @@ import (
"github.com/ansg191/ibd-trader-backend/internal/config"
"github.com/ansg191/ibd-trader-backend/internal/database"
"github.com/ansg191/ibd-trader-backend/internal/ibd"
+ "github.com/ansg191/ibd-trader-backend/internal/ibd/transport"
+ "github.com/ansg191/ibd-trader-backend/internal/ibd/transport/scrapfly"
"github.com/ansg191/ibd-trader-backend/internal/keys"
"github.com/ansg191/ibd-trader-backend/internal/leader/election"
"github.com/ansg191/ibd-trader-backend/internal/leader/manager"
@@ -75,7 +78,7 @@ func main() {
_ = auth
// Setup IBD client
- client, err := ibd.NewClient(http.DefaultClient, cfg.IBD.APIKey, db, cfg.IBD.ProxyURL)
+ client, err := setupIBDClient(cfg, db)
if err != nil {
log.Fatal("Unable to setup IBD client: ", err)
}
@@ -143,6 +146,21 @@ func main() {
)
}
+func setupIBDClient(cfg *config.Config, db database.Database) (*ibd.Client, error) {
+ pUrl, err := url.Parse(cfg.IBD.ProxyURL)
+ if err != nil {
+ return nil, fmt.Errorf("unable to parse proxy URL: %w", err)
+ }
+ t := http.DefaultTransport.(*http.Transport).Clone()
+ t.Proxy = http.ProxyURL(pUrl)
+ transports := []transport.Transport{
+ &http.Client{Transport: t}, // Default proxied transport
+ scrapfly.New(http.DefaultClient, cfg.IBD.APIKey), // Scrapfly transport
+ }
+ client := ibd.NewClient(transports, db)
+ return client, nil
+}
+
func connectDB(logger *slog.Logger, cfg *config.Config) (database.Database, error) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
diff --git a/backend/go.mod b/backend/go.mod
index a5bf7da..bd55e54 100644
--- a/backend/go.mod
+++ b/backend/go.mod
@@ -13,6 +13,7 @@ require (
github.com/cespare/xxhash/v2 v2.3.0
github.com/coreos/go-oidc/v3 v3.11.0
github.com/golang-migrate/migrate/v4 v4.17.1
+ github.com/jarcoal/httpmock v1.3.1
github.com/lib/pq v1.10.9
github.com/lmittmann/tint v1.0.5
github.com/mennanov/fmutils v0.3.0
diff --git a/backend/go.sum b/backend/go.sum
index 930c6cc..3f5e220 100644
--- a/backend/go.sum
+++ b/backend/go.sum
@@ -584,6 +584,8 @@ github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
+github.com/jarcoal/httpmock v1.3.1 h1:iUx3whfZWVf3jT01hQTO/Eo5sAYtB2/rqaUuOtpInww=
+github.com/jarcoal/httpmock v1.3.1/go.mod h1:3yb8rc4BI7TCBhFY8ng0gjuLKJNquuDNiPaZjnENuYg=
github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU=
github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
@@ -601,6 +603,8 @@ github.com/lmittmann/tint v1.0.5 h1:NQclAutOfYsqs2F1Lenue6OoWCajs5wJcP3DfWVpePw=
github.com/lmittmann/tint v1.0.5/go.mod h1:HIS3gSy7qNwGCj+5oRjAutErFBl4BzdQP6cJZ0NfMwE=
github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY=
github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
+github.com/maxatome/go-testdeep v1.12.0 h1:Ql7Go8Tg0C1D/uMMX59LAoYK7LffeJQ6X2T04nTH68g=
+github.com/maxatome/go-testdeep v1.12.0/go.mod h1:lPZc/HAcJMP92l7yI6TRz1aZN5URwUBUAfUNvrclaNM=
github.com/mennanov/fmutils v0.3.0 h1:2YSyrO8oOLQQwB/iKe+xDDGO6xCUHiIAj3gYhY7D4Ao=
github.com/mennanov/fmutils v0.3.0/go.mod h1:ph1jsu8gV1gUgMURCmfIVbXKG3O2/O5o/UbPbbqu8zs=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
diff --git a/backend/internal/ibd/auth.go b/backend/internal/ibd/auth.go
index 7dff3a7..f09f3f7 100644
--- a/backend/internal/ibd/auth.go
+++ b/backend/internal/ibd/auth.go
@@ -7,6 +7,7 @@ import (
"encoding/json"
"errors"
"fmt"
+ "io"
"net/http"
"strings"
@@ -51,16 +52,23 @@ func (c *Client) getLoginPage(ctx context.Context) (*authConfig, error) {
if err != nil {
return nil, err
}
-
- if resp.Result.StatusCode != http.StatusOK {
+ defer func(Body io.ReadCloser) {
+ _ = Body.Close()
+ }(resp.Body)
+
+ if resp.StatusCode != http.StatusOK {
+ content, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read response body: %w", err)
+ }
return nil, fmt.Errorf(
"unexpected status code %d: %s",
- resp.Result.StatusCode,
- resp.Result.Content,
+ resp.StatusCode,
+ string(content),
)
}
- node, err := html.Parse(strings.NewReader(resp.Result.Content))
+ node, err := html.Parse(resp.Body)
if err != nil {
return nil, err
}
@@ -113,18 +121,25 @@ func (c *Client) sendAuthRequest(ctx context.Context, cfg *authConfig, username,
if err != nil {
return "", "", err
}
+ defer func(Body io.ReadCloser) {
+ _ = Body.Close()
+ }(resp.Body)
- if resp.Result.StatusCode == http.StatusUnauthorized {
+ if resp.StatusCode == http.StatusUnauthorized {
return "", "", ErrBadCredentials
- } else if resp.Result.StatusCode != http.StatusOK {
+ } else if resp.StatusCode != http.StatusOK {
+ content, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return "", "", fmt.Errorf("failed to read response body: %w", err)
+ }
return "", "", fmt.Errorf(
"unexpected status code %d: %s",
- resp.Result.StatusCode,
- resp.Result.Content,
+ resp.StatusCode,
+ string(content),
)
}
- node, err := html.Parse(strings.NewReader(resp.Result.Content))
+ node, err := html.Parse(resp.Body)
if err != nil {
return "", "", err
}
@@ -145,19 +160,26 @@ func (c *Client) sendPostAuth(ctx context.Context, token, params string) (*http.
if err != nil {
return nil, err
}
-
- if resp.Result.StatusCode != http.StatusOK {
+ defer func(Body io.ReadCloser) {
+ _ = Body.Close()
+ }(resp.Body)
+
+ if resp.StatusCode != http.StatusOK {
+ content, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read response body: %w", err)
+ }
return nil, fmt.Errorf(
"unexpected status code %d: %s",
- resp.Result.StatusCode,
- resp.Result.Content,
+ resp.StatusCode,
+ string(content),
)
}
// Extract cookie
- for _, cookie := range resp.Result.Cookies {
+ for _, cookie := range resp.Cookies() {
if cookie.Name == cookieName {
- return cookie.ToHTTPCookie()
+ return cookie, nil
}
}
diff --git a/backend/internal/ibd/auth_test.go b/backend/internal/ibd/auth_test.go
index d28b33a..8a00d42 100644
--- a/backend/internal/ibd/auth_test.go
+++ b/backend/internal/ibd/auth_test.go
@@ -9,6 +9,8 @@ import (
"testing"
"time"
+ "github.com/ansg191/ibd-trader-backend/internal/ibd/transport"
+ "github.com/jarcoal/httpmock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/html"
@@ -134,44 +136,36 @@ func TestClient_Authenticate(t *testing.T) {
expectedVal := "test-cookie"
expectedExp := time.Now().Add(time.Hour).Round(time.Second).In(time.UTC)
- server := newServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- uri := r.URL.String()
- switch uri {
- case signInUrl:
- w.Header().Set("Content-Type", "text/html")
- _, err := w.Write([]byte(extractAuthHtml))
- require.NoError(t, err)
- return
- case authenticateUrl:
+ tp := httpmock.NewMockTransport()
+ tp.RegisterResponder("GET", signInUrl,
+ httpmock.NewStringResponder(http.StatusOK, extractAuthHtml))
+ tp.RegisterResponder("POST", authenticateUrl,
+ func(request *http.Request) (*http.Response, error) {
var body authRequestBody
- require.NoError(t, json.NewDecoder(r.Body).Decode(&body))
+ require.NoError(t, json.NewDecoder(request.Body).Decode(&body))
assert.Equal(t, "abc", body.Username)
assert.Equal(t, "xyz", body.Password)
- w.Header().Set("Content-Type", "text/html")
- _, err := w.Write([]byte(extractTokenParamsHtml))
- require.NoError(t, err)
- return
- case postAuthUrl:
- require.NoError(t, r.ParseForm())
- assert.Equal(t, extractTokenExpectedToken, r.Form.Get("token"))
+ return httpmock.NewStringResponse(http.StatusOK, extractTokenParamsHtml), nil
+ })
+ tp.RegisterResponder("POST", postAuthUrl,
+ func(request *http.Request) (*http.Response, error) {
+ require.NoError(t, request.ParseForm())
+ assert.Equal(t, extractTokenExpectedToken, request.Form.Get("token"))
params, err := url.QueryUnescape(extractTokenExpectedParams)
require.NoError(t, err)
- assert.Equal(t, params, r.Form.Get("params"))
+ assert.Equal(t, params, request.Form.Get("params"))
- w.Header().Set("Content-Type", "text/html")
- http.SetCookie(w, &http.Cookie{Name: cookieName, Value: expectedVal, Expires: expectedExp})
- _, err = w.Write([]byte("OK"))
- require.NoError(t, err)
- return
- default:
- t.Fatalf("unexpected URL: %s", uri)
- }
- }))
+ resp := httpmock.NewStringResponse(http.StatusOK, "OK")
+ cookie := &http.Cookie{Name: cookieName, Value: expectedVal, Expires: expectedExp}
+ resp.Header.Set("Set-Cookie", cookie.String())
+ return resp, nil
+ })
- client, err := NewClient(http.DefaultClient, apiKey, nil, "", WithBaseURL(server.URL))
- require.NoError(t, err)
+ client := NewClient([]transport.Transport{
+ &http.Client{Transport: tp},
+ }, nil)
cookie, err := client.Authenticate(context.Background(), "abc", "xyz")
require.NoError(t, err)
@@ -184,32 +178,22 @@ func TestClient_Authenticate(t *testing.T) {
func TestClient_Authenticate_401(t *testing.T) {
t.Parallel()
- server := newServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- uri := r.URL.String()
- switch uri {
- case signInUrl:
- w.Header().Set("Content-Type", "text/html")
- _, err := w.Write([]byte(extractAuthHtml))
- require.NoError(t, err)
- return
- case authenticateUrl:
+ tp := httpmock.NewMockTransport()
+ tp.RegisterResponder("GET", signInUrl,
+ httpmock.NewStringResponder(http.StatusOK, extractAuthHtml))
+ tp.RegisterResponder("POST", authenticateUrl,
+ func(request *http.Request) (*http.Response, error) {
var body authRequestBody
- require.NoError(t, json.NewDecoder(r.Body).Decode(&body))
+ require.NoError(t, json.NewDecoder(request.Body).Decode(&body))
assert.Equal(t, "abc", body.Username)
assert.Equal(t, "xyz", body.Password)
- w.Header().Set("Content-Type", "application/json")
- w.WriteHeader(http.StatusUnauthorized)
- _, err := w.Write([]byte(`{"name":"ValidationError","code":"ERR016","message":"Wrong username or password","description":"Wrong username or password"}`))
- require.NoError(t, err)
- return
- default:
- t.Fatalf("unexpected URL: %s", uri)
- }
- }))
+ return httpmock.NewStringResponse(http.StatusUnauthorized, `{"name":"ValidationError","code":"ERR016","message":"Wrong username or password","description":"Wrong username or password"}`), nil
+ })
- client, err := NewClient(http.DefaultClient, apiKey, nil, "", WithBaseURL(server.URL))
- require.NoError(t, err)
+ client := NewClient([]transport.Transport{
+ &http.Client{Transport: tp},
+ }, nil)
cookie, err := client.Authenticate(context.Background(), "abc", "xyz")
assert.Nil(t, cookie)
diff --git a/backend/internal/ibd/check_ibd_username.go b/backend/internal/ibd/check_ibd_username.go
index 03d2640..c03176e 100644
--- a/backend/internal/ibd/check_ibd_username.go
+++ b/backend/internal/ibd/check_ibd_username.go
@@ -5,6 +5,7 @@ import (
"context"
"encoding/json"
"fmt"
+ "io"
"net/http"
)
@@ -41,18 +42,26 @@ func (c *Client) checkIBDUsername(ctx context.Context, cfg *authConfig, username
req.Header.Set("X-REQUEST-EDITIONID", "IBD-EN_US")
req.Header.Set("X-REQUEST-SCHEME", "https")
- resp, err := c.Do(req)
+ resp, err := c.DoWithStatus(req, []int{http.StatusOK, http.StatusUnauthorized})
if err != nil {
return false, err
}
+ defer func(Body io.ReadCloser) {
+ _ = Body.Close()
+ }(resp.Body)
- if resp.Result.StatusCode == http.StatusUnauthorized {
+ if resp.StatusCode == http.StatusUnauthorized {
return false, nil
- } else if resp.Result.StatusCode != http.StatusOK {
+ } else if resp.StatusCode != http.StatusOK {
+ contentBytes, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return false, fmt.Errorf("failed to read response body: %w", err)
+ }
+ content := string(contentBytes)
return false, fmt.Errorf(
"unexpected status code %d: %s",
- resp.Result.StatusCode,
- resp.Result.Content,
+ resp.StatusCode,
+ content,
)
}
return true, nil
diff --git a/backend/internal/ibd/client.go b/backend/internal/ibd/client.go
index c1cbb8a..2b91268 100644
--- a/backend/internal/ibd/client.go
+++ b/backend/internal/ibd/client.go
@@ -2,56 +2,28 @@ package ibd
import (
"context"
- "encoding/json"
"errors"
- "fmt"
- "io"
+ "log/slog"
"net/http"
- "net/url"
- "strconv"
+ "slices"
"github.com/ansg191/ibd-trader-backend/internal/database"
+ "github.com/ansg191/ibd-trader-backend/internal/ibd/transport"
)
var ErrNoAvailableCookies = errors.New("no available cookies")
+var ErrNoAvailableTransports = errors.New("no available transports")
type Client struct {
- // HTTP client used to make requests
- client *http.Client
- // Scrapfly API key
- apiKey string
- // Client-wide Scrape options
- options ScrapeOptions
- // Cookie source
- cookies database.CookieSource
- // Proxy URL for non-scrapfly requests
- proxyUrl *url.URL
+ transports []transport.Transport
+ cookies database.CookieSource
}
func NewClient(
- client *http.Client,
- apiKey string,
+ transports []transport.Transport,
cookies database.CookieSource,
- proxyUrl string,
- opts ...ScrapeOption,
-) (*Client, error) {
- options := defaultScrapeOptions
- for _, opt := range opts {
- opt(&options)
- }
-
- pProxyUrl, err := url.Parse(proxyUrl)
- if err != nil {
- return nil, err
- }
-
- return &Client{
- client: client,
- options: options,
- apiKey: apiKey,
- cookies: cookies,
- proxyUrl: pProxyUrl,
- }, nil
+) *Client {
+ return &Client{transports, cookies}
}
func (c *Client) getCookie(ctx context.Context, subject *string) (uint, *http.Cookie, error) {
@@ -83,64 +55,35 @@ func (c *Client) getCookie(ctx context.Context, subject *string) (uint, *http.Co
return cookie.ID, cookie.ToHTTPCookie(), nil
}
-func (c *Client) Do(req *http.Request, opts ...ScrapeOption) (*ScraperResponse, error) {
- options := c.options
- for _, opt := range opts {
- opt(&options)
- }
-
- // Construct scrape request URL
- scrapeUrl, err := url.Parse(options.baseURL)
- if err != nil {
- panic(err)
- }
- scrapeUrl.RawQuery = c.constructRawQuery(options, req.URL, req.Header)
-
- // Construct scrape request
- scrapeReq, err := http.NewRequestWithContext(req.Context(), req.Method, scrapeUrl.String(), req.Body)
- if err != nil {
- return nil, err
- }
-
- // Send scrape request
- resp, err := c.client.Do(scrapeReq)
- if err != nil {
- return nil, err
- }
- defer func(Body io.ReadCloser) {
- _ = Body.Close()
- }(resp.Body)
-
- // Parse scrape response
- scraperResponse := new(ScraperResponse)
- err = json.NewDecoder(resp.Body).Decode(scraperResponse)
- if err != nil {
- return nil, err
- }
-
- return scraperResponse, nil
+func (c *Client) Do(req *http.Request) (*http.Response, error) {
+ return c.DoWithStatus(req, []int{http.StatusOK})
}
-func (c *Client) constructRawQuery(options ScrapeOptions, u *url.URL, headers http.Header) string {
- params := url.Values{}
- params.Set("key", c.apiKey)
- params.Set("url", u.String())
- if options.country != nil {
- params.Set("country", *options.country)
- }
- params.Set("asp", strconv.FormatBool(options.asp))
- params.Set("proxy_pool", options.proxyPool.String())
- params.Set("render_js", strconv.FormatBool(options.renderJS))
- params.Set("cache", strconv.FormatBool(options.cache))
-
- for k, v := range headers {
- for i, vv := range v {
- params.Add(
- fmt.Sprintf("headers[%s][%d]", k, i),
- vv,
+func (c *Client) DoWithStatus(req *http.Request, expectedStatus []int) (*http.Response, error) {
+ for i, tp := range c.transports {
+ resp, err := tp.Do(req)
+ if errors.Is(err, transport.ErrUnsupportedRequest) {
+ // Skip unsupported transport
+ continue
+ }
+ if err != nil {
+ slog.ErrorContext(req.Context(), "transport error",
+ "transport", i,
+ "error", err,
+ )
+ continue
+ }
+ if slices.Contains(expectedStatus, resp.StatusCode) {
+ return resp, nil
+ } else {
+ slog.ErrorContext(req.Context(), "unexpected status code",
+ "transport", i,
+ "expected", expectedStatus,
+ "actual", resp.StatusCode,
)
+ continue
}
}
- return params.Encode()
+ return nil, ErrNoAvailableTransports
}
diff --git a/backend/internal/ibd/client_test.go b/backend/internal/ibd/client_test.go
index 048ad59..0a8fa98 100644
--- a/backend/internal/ibd/client_test.go
+++ b/backend/internal/ibd/client_test.go
@@ -1,177 +1,50 @@
package ibd
import (
- "bytes"
"context"
- "encoding/json"
- "net/http"
- "net/http/httptest"
- "strings"
"testing"
"time"
"github.com/ansg191/ibd-trader-backend/internal/database"
-
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
-const apiKey = "test-api-key-123"
-
-func newServer(t *testing.T, handler http.Handler) *httptest.Server {
- t.Helper()
-
- return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- r.Close = true
- defer r.Body.Close()
- req := reconstructReq(t, r)
-
- rw := newResponseWriter()
- handler.ServeHTTP(rw, req)
- require.NoError(t, rw.Done(w))
- }))
-}
-
-func reconstructReq(t *testing.T, r *http.Request) *http.Request {
- t.Helper()
-
- params := r.URL.Query()
- require.Equal(t, apiKey, params.Get("key"))
-
- // Reconstruct the request from the query params
- var key string
- var url string
- headers := make(http.Header)
- for k, v := range params {
- switch k {
- case "key":
- key = v[0]
- case "url":
- url = v[0]
- default:
- if strings.HasPrefix(k, "headers") {
- var name string
- // Get index of first [
- i := strings.Index(k, "[")
- if i == -1 {
- t.Fatalf("invalid header key: %s", k)
- }
- // Get index of first ]
- j := strings.Index(k, "]")
- if j == -1 {
- t.Fatalf("invalid header key: %s", k)
- }
-
- // Get the header name
- name = k[i+1 : j]
- headers.Set(name, v[0])
- }
- }
- }
- require.Equal(t, apiKey, key)
- require.NotEmpty(t, url)
-
- req, err := http.NewRequestWithContext(r.Context(), r.Method, url, r.Body)
- require.NoError(t, err)
- req.Header = headers
-
- return req
-}
-
-type responsewriter struct {
- ret ScraperResponse
- body bytes.Buffer
- headers http.Header
-}
-
-func newResponseWriter() *responsewriter {
- return &responsewriter{
- headers: make(http.Header),
- }
-}
-
-func (w *responsewriter) Header() http.Header {
- return w.headers
-}
-
-func (w *responsewriter) Write(bytes []byte) (int, error) {
- if w.ret.Result.StatusCode == 0 {
- w.ret.Result.StatusCode = http.StatusOK
- }
- return w.body.Write(bytes)
-}
-
-func (w *responsewriter) WriteHeader(statusCode int) {
- w.ret.Result.StatusCode = statusCode
-}
-
-func (w *responsewriter) Done(rw http.ResponseWriter) error {
- w.ret.Result.Content = w.body.String()
-
- w.ret.Result.ResponseHeaders = make(map[string]string)
- for k, v := range w.headers {
- if k == "Set-Cookie" {
- continue
- }
- w.ret.Result.ResponseHeaders[k] = v[0]
- }
-
- req := http.Response{Header: w.headers}
- w.ret.Result.Cookies = make([]ScraperCookie, 0)
- for _, c := range req.Cookies() {
- var cookie ScraperCookie
- cookie.FromHTTPCookie(c)
- w.ret.Result.Cookies = append(w.ret.Result.Cookies, cookie)
- }
-
- rw.WriteHeader(http.StatusOK)
- return json.NewEncoder(rw).Encode(w.ret)
-}
-
func TestClient_getCookie(t *testing.T) {
t.Parallel()
t.Run("no cookies", func(t *testing.T) {
t.Parallel()
- client, err := NewClient(
- http.DefaultClient,
- apiKey,
+ client := NewClient(
+ nil,
new(emptyCookieSourceStub),
- "",
)
- require.NoError(t, err)
- _, _, err = client.getCookie(context.Background(), nil)
+ _, _, err := client.getCookie(context.Background(), nil)
assert.ErrorIs(t, err, ErrNoAvailableCookies)
})
t.Run("no cookies by subject", func(t *testing.T) {
t.Parallel()
- client, err := NewClient(
- http.DefaultClient,
- apiKey,
+ client := NewClient(
+ nil,
new(emptyCookieSourceStub),
- "",
)
- require.NoError(t, err)
subject := "test"
- _, _, err = client.getCookie(context.Background(), &subject)
+ _, _, err := client.getCookie(context.Background(), &subject)
assert.ErrorIs(t, err, ErrNoAvailableCookies)
})
t.Run("get any cookie", func(t *testing.T) {
t.Parallel()
- client, err := NewClient(
- http.DefaultClient,
- apiKey,
- new(cookieSourceStub),
- "",
+ client := NewClient(
+ nil,
+ new(emptyCookieSourceStub),
)
- require.NoError(t, err)
id, cookie, err := client.getCookie(context.Background(), nil)
require.NoError(t, err)
@@ -186,13 +59,10 @@ func TestClient_getCookie(t *testing.T) {
t.Run("get cookie by subject", func(t *testing.T) {
t.Parallel()
- client, err := NewClient(
- http.DefaultClient,
- apiKey,
- new(cookieSourceStub),
- "",
+ client := NewClient(
+ nil,
+ new(emptyCookieSourceStub),
)
- require.NoError(t, err)
subject := "test"
id, cookie, err := client.getCookie(context.Background(), &subject)
diff --git a/backend/internal/ibd/ibd50.go b/backend/internal/ibd/ibd50.go
index 93aa31d..ea02f82 100644
--- a/backend/internal/ibd/ibd50.go
+++ b/backend/internal/ibd/ibd50.go
@@ -47,12 +47,7 @@ func (c *Client) GetIBD50(ctx context.Context) ([]*Stock, error) {
req.Header.Add("x-newrelic-id", "VwUOV1dTDhABV1FRBgQOVVUF")
req.Header.Add("x-requested-with", "XMLHttpRequest")
- // Clone client to add proxy
- client := *(c.client)
- transport := http.DefaultTransport.(*http.Transport).Clone()
- transport.Proxy = http.ProxyURL(c.proxyUrl)
-
- resp, err := client.Do(req)
+ resp, err := c.Do(req)
if err != nil {
return nil, err
}
diff --git a/backend/internal/ibd/search.go b/backend/internal/ibd/search.go
index 23ef08b..341b14b 100644
--- a/backend/internal/ibd/search.go
+++ b/backend/internal/ibd/search.go
@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
+ "io"
"net/http"
"net/url"
"time"
@@ -37,17 +38,24 @@ func (c *Client) Search(ctx context.Context, symbol string) (database.Stock, err
if err != nil {
return database.Stock{}, err
}
+ defer func(Body io.ReadCloser) {
+ _ = Body.Close()
+ }(resp.Body)
- if resp.Result.StatusCode != http.StatusOK {
+ if resp.StatusCode != http.StatusOK {
+ content, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return database.Stock{}, fmt.Errorf("failed to read response body: %w", err)
+ }
return database.Stock{}, fmt.Errorf(
"unexpected status code %d: %s",
- resp.Result.StatusCode,
- resp.Result.Content,
+ resp.StatusCode,
+ string(content),
)
}
var sr searchResponse
- if err = json.Unmarshal([]byte(resp.Result.Content), &sr); err != nil {
+ if err = json.NewDecoder(resp.Body).Decode(&sr); err != nil {
return database.Stock{}, err
}
diff --git a/backend/internal/ibd/search_test.go b/backend/internal/ibd/search_test.go
index ac0f578..f291033 100644
--- a/backend/internal/ibd/search_test.go
+++ b/backend/internal/ibd/search_test.go
@@ -5,6 +5,8 @@ import (
"net/http"
"testing"
+ "github.com/ansg191/ibd-trader-backend/internal/ibd/transport"
+ "github.com/jarcoal/httpmock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -190,13 +192,12 @@ func TestClient_Search(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- server := newServer(t, http.HandlerFunc(func(writer http.ResponseWriter, r *http.Request) {
- _, _ = writer.Write([]byte(tt.response))
- }))
- defer server.Close()
+ tp := httpmock.NewMockTransport()
+ tp.RegisterResponder("GET", searchUrl, httpmock.NewStringResponder(200, tt.response))
- client, err := NewClient(http.DefaultClient, apiKey, new(cookieSourceStub), "", WithBaseURL(server.URL))
- require.NoError(t, err)
+ client := NewClient([]transport.Transport{
+ &http.Client{Transport: tp},
+ }, new(cookieSourceStub))
tt.f(t, client)
})
diff --git a/backend/internal/ibd/stockinfo.go b/backend/internal/ibd/stockinfo.go
index e278872..9caa956 100644
--- a/backend/internal/ibd/stockinfo.go
+++ b/backend/internal/ibd/stockinfo.go
@@ -3,6 +3,7 @@ package ibd
import (
"context"
"fmt"
+ "io"
"net/http"
"net/url"
"strconv"
@@ -37,16 +38,23 @@ func (c *Client) StockInfo(ctx context.Context, uri string) (*database.StockInfo
if err != nil {
return nil, err
}
+ defer func(Body io.ReadCloser) {
+ _ = Body.Close()
+ }(resp.Body)
- if resp.Result.StatusCode != http.StatusOK {
+ if resp.StatusCode != http.StatusOK {
+ content, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read response body: %w", err)
+ }
return nil, fmt.Errorf(
"unexpected status code %d: %s",
- resp.Result.StatusCode,
- resp.Result.Content,
+ resp.StatusCode,
+ string(content),
)
}
- node, err := html.Parse(strings.NewReader(resp.Result.Content))
+ node, err := html.Parse(resp.Body)
if err != nil {
return nil, err
}
diff --git a/backend/internal/ibd/options.go b/backend/internal/ibd/transport/scrapfly/options.go
index a07241e..f16a4b0 100644
--- a/backend/internal/ibd/options.go
+++ b/backend/internal/ibd/transport/scrapfly/options.go
@@ -1,4 +1,4 @@
-package ibd
+package scrapfly
const BaseURL = "https://api.scrapfly.io/scrape"
diff --git a/backend/internal/ibd/scraper_types.go b/backend/internal/ibd/transport/scrapfly/scraper_types.go
index c21ed1c..f3cf651 100644
--- a/backend/internal/ibd/scraper_types.go
+++ b/backend/internal/ibd/transport/scrapfly/scraper_types.go
@@ -1,8 +1,10 @@
-package ibd
+package scrapfly
import (
"fmt"
+ "io"
"net/http"
+ "strings"
"time"
)
@@ -225,3 +227,27 @@ func (c *ScraperCookie) FromHTTPCookie(cookie *http.Cookie) {
Version: "",
}
}
+
+func (r *ScraperResponse) ToHTTPResponse() (*http.Response, error) {
+ resp := &http.Response{
+ StatusCode: r.Result.StatusCode,
+ Header: make(http.Header),
+ Body: io.NopCloser(strings.NewReader(r.Result.Content)),
+ ContentLength: int64(len(r.Result.Content)),
+ Close: true,
+ }
+
+ for k, v := range r.Result.ResponseHeaders {
+ resp.Header.Set(k, v)
+ }
+
+ for _, c := range r.Result.Cookies {
+ cookie, err := c.ToHTTPCookie()
+ if err != nil {
+ return nil, err
+ }
+ resp.Header.Add("Set-Cookie", cookie.String())
+ }
+
+ return resp, nil
+}
diff --git a/backend/internal/ibd/transport/scrapfly/scrapfly.go b/backend/internal/ibd/transport/scrapfly/scrapfly.go
new file mode 100644
index 0000000..f34f3aa
--- /dev/null
+++ b/backend/internal/ibd/transport/scrapfly/scrapfly.go
@@ -0,0 +1,95 @@
+package scrapfly
+
+import (
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "strconv"
+
+ "github.com/ansg191/ibd-trader-backend/internal/ibd/transport"
+)
+
+type ScrapflyTransport struct {
+ client *http.Client
+ apiKey string
+ options ScrapeOptions
+}
+
+func New(client *http.Client, apiKey string, opts ...ScrapeOption) *ScrapflyTransport {
+ options := defaultScrapeOptions
+ for _, opt := range opts {
+ opt(&options)
+ }
+
+ return &ScrapflyTransport{
+ client: client,
+ apiKey: apiKey,
+ options: options,
+ }
+}
+
+func (s *ScrapflyTransport) Do(req *http.Request) (*http.Response, error) {
+ // Construct scrape request URL
+ scrapeUrl, err := url.Parse(s.options.baseURL)
+ if err != nil {
+ panic(err)
+ }
+ scrapeUrl.RawQuery = s.constructRawQuery(req.URL, req.Header)
+
+ // We can't handle `Content-Type` header on GET requests
+ // Wierd quirk of the Scrapfly API
+ if req.Method == http.MethodGet && req.Header.Get("Content-Type") != "" {
+ return nil, transport.ErrUnsupportedRequest
+ }
+
+ // Construct scrape request
+ scrapeReq, err := http.NewRequestWithContext(req.Context(), req.Method, scrapeUrl.String(), req.Body)
+ if err != nil {
+ return nil, err
+ }
+
+ // Send scrape request
+ resp, err := s.client.Do(scrapeReq)
+ if err != nil {
+ return nil, err
+ }
+ defer func(Body io.ReadCloser) {
+ _ = Body.Close()
+ }(resp.Body)
+
+ // Parse scrape response
+ scraperResponse := new(ScraperResponse)
+ err = json.NewDecoder(resp.Body).Decode(scraperResponse)
+ if err != nil {
+ return nil, err
+ }
+
+ // Convert scraper response to http.Response
+ return scraperResponse.ToHTTPResponse()
+}
+
+func (s *ScrapflyTransport) constructRawQuery(u *url.URL, headers http.Header) string {
+ params := url.Values{}
+ params.Set("key", s.apiKey)
+ params.Set("url", u.String())
+ if s.options.country != nil {
+ params.Set("country", *s.options.country)
+ }
+ params.Set("asp", strconv.FormatBool(s.options.asp))
+ params.Set("proxy_pool", s.options.proxyPool.String())
+ params.Set("render_js", strconv.FormatBool(s.options.renderJS))
+ params.Set("cache", strconv.FormatBool(s.options.cache))
+
+ for k, v := range headers {
+ for i, vv := range v {
+ params.Add(
+ fmt.Sprintf("headers[%s][%d]", k, i),
+ vv,
+ )
+ }
+ }
+
+ return params.Encode()
+}
diff --git a/backend/internal/ibd/transport/transport.go b/backend/internal/ibd/transport/transport.go
new file mode 100644
index 0000000..7918f4f
--- /dev/null
+++ b/backend/internal/ibd/transport/transport.go
@@ -0,0 +1,12 @@
+package transport
+
+import (
+ "errors"
+ "net/http"
+)
+
+var ErrUnsupportedRequest = errors.New("unsupported request")
+
+type Transport interface {
+ Do(req *http.Request) (*http.Response, error)
+}
diff --git a/backend/internal/ibd/userinfo.go b/backend/internal/ibd/userinfo.go
index ba7a5b5..ed61497 100644
--- a/backend/internal/ibd/userinfo.go
+++ b/backend/internal/ibd/userinfo.go
@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
+ "io"
"log/slog"
"net/http"
)
@@ -24,17 +25,25 @@ func (c *Client) UserInfo(ctx context.Context, cookie *http.Cookie) (*UserProfil
if err != nil {
return nil, err
}
+ defer func(Body io.ReadCloser) {
+ _ = Body.Close()
+ }(resp.Body)
- if resp.Result.StatusCode != http.StatusOK {
+ content, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read response body: %w", err)
+ }
+
+ if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf(
"unexpected status code %d: %s",
- resp.Result.StatusCode,
- resp.Result.Content,
+ resp.StatusCode,
+ string(content),
)
}
up := new(UserProfile)
- if err = up.UnmarshalJSON([]byte(resp.Result.Content)); err != nil {
+ if err = up.UnmarshalJSON(content); err != nil {
return nil, err
}