aboutsummaryrefslogtreecommitdiff
path: root/backend/internal
diff options
context:
space:
mode:
authorGravatar Anshul Gupta <ansg191@anshulg.com> 2024-08-06 18:53:22 -0700
committerGravatar Anshul Gupta <ansg191@anshulg.com> 2024-08-06 18:53:22 -0700
commit825ba9d21d15e1f9b34c60bac68e42ee1fb125f9 (patch)
treec466380d15d672a4619a7e1c15f058d52123dbb4 /backend/internal
parent961f9e0a76c3cfe9ae92ca8da0531790e0610b69 (diff)
downloadibd-trader-825ba9d21d15e1f9b34c60bac68e42ee1fb125f9.tar.gz
ibd-trader-825ba9d21d15e1f9b34c60bac68e42ee1fb125f9.tar.zst
ibd-trader-825ba9d21d15e1f9b34c60bac68e42ee1fb125f9.zip
Improve selection of IBD transports
Diffstat (limited to 'backend/internal')
-rw-r--r--backend/internal/ibd/auth.go9
-rw-r--r--backend/internal/ibd/auth_test.go26
-rw-r--r--backend/internal/ibd/check_ibd_username.go2
-rw-r--r--backend/internal/ibd/client.go25
-rw-r--r--backend/internal/ibd/client_test.go20
-rw-r--r--backend/internal/ibd/options.go26
-rw-r--r--backend/internal/ibd/search_test.go4
-rw-r--r--backend/internal/ibd/transport/scrapfly/scrapfly.go8
-rw-r--r--backend/internal/ibd/transport/standard.go41
-rw-r--r--backend/internal/ibd/transport/transport.go54
-rw-r--r--backend/internal/server/idb/user/v1/user.go57
-rw-r--r--backend/internal/server/server.go2
12 files changed, 227 insertions, 47 deletions
diff --git a/backend/internal/ibd/auth.go b/backend/internal/ibd/auth.go
index f09f3f7..7b82057 100644
--- a/backend/internal/ibd/auth.go
+++ b/backend/internal/ibd/auth.go
@@ -11,6 +11,7 @@ import (
"net/http"
"strings"
+ "github.com/ansg191/ibd-trader-backend/internal/ibd/transport"
"golang.org/x/net/html"
)
@@ -48,7 +49,7 @@ func (c *Client) getLoginPage(ctx context.Context) (*authConfig, error) {
return nil, err
}
- resp, err := c.Do(req)
+ resp, err := c.Do(req, withRequiredProps(transport.PropertiesReliable))
if err != nil {
return nil, err
}
@@ -117,7 +118,9 @@ func (c *Client) sendAuthRequest(ctx context.Context, cfg *authConfig, username,
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Auth0-Client", "eyJuYW1lIjoiYXV0aDAuanMtdWxwIiwidmVyc2lvbiI6IjkuMjQuMSJ9")
- resp, err := c.Do(req)
+ resp, err := c.Do(req,
+ withRequiredProps(transport.PropertiesReliable),
+ withExpectedStatuses(http.StatusOK, http.StatusUnauthorized))
if err != nil {
return "", "", err
}
@@ -156,7 +159,7 @@ func (c *Client) sendPostAuth(ctx context.Context, token, params string) (*http.
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
- resp, err := c.Do(req)
+ resp, err := c.Do(req, withRequiredProps(transport.PropertiesReliable))
if err != nil {
return nil, err
}
diff --git a/backend/internal/ibd/auth_test.go b/backend/internal/ibd/auth_test.go
index 8a00d42..54ea98a 100644
--- a/backend/internal/ibd/auth_test.go
+++ b/backend/internal/ibd/auth_test.go
@@ -163,9 +163,7 @@ func TestClient_Authenticate(t *testing.T) {
return resp, nil
})
- client := NewClient([]transport.Transport{
- &http.Client{Transport: tp},
- }, nil)
+ client := NewClient(nil, newTransport(tp))
cookie, err := client.Authenticate(context.Background(), "abc", "xyz")
require.NoError(t, err)
@@ -191,11 +189,27 @@ func TestClient_Authenticate_401(t *testing.T) {
return httpmock.NewStringResponse(http.StatusUnauthorized, `{"name":"ValidationError","code":"ERR016","message":"Wrong username or password","description":"Wrong username or password"}`), nil
})
- client := NewClient([]transport.Transport{
- &http.Client{Transport: tp},
- }, nil)
+ client := NewClient(nil, newTransport(tp))
cookie, err := client.Authenticate(context.Background(), "abc", "xyz")
assert.Nil(t, cookie)
assert.ErrorIs(t, err, ErrBadCredentials)
}
+
+type testReliableTransport http.Client
+
+func newTransport(tp *httpmock.MockTransport) *testReliableTransport {
+ return (*testReliableTransport)(&http.Client{Transport: tp})
+}
+
+func (t *testReliableTransport) String() string {
+ return "testReliableTransport"
+}
+
+func (t *testReliableTransport) Do(req *http.Request) (*http.Response, error) {
+ return (*http.Client)(t).Do(req)
+}
+
+func (t *testReliableTransport) Properties() transport.Properties {
+ return transport.PropertiesFree | transport.PropertiesReliable
+}
diff --git a/backend/internal/ibd/check_ibd_username.go b/backend/internal/ibd/check_ibd_username.go
index c03176e..b026151 100644
--- a/backend/internal/ibd/check_ibd_username.go
+++ b/backend/internal/ibd/check_ibd_username.go
@@ -42,7 +42,7 @@ func (c *Client) checkIBDUsername(ctx context.Context, cfg *authConfig, username
req.Header.Set("X-REQUEST-EDITIONID", "IBD-EN_US")
req.Header.Set("X-REQUEST-SCHEME", "https")
- resp, err := c.DoWithStatus(req, []int{http.StatusOK, http.StatusUnauthorized})
+ resp, err := c.Do(req, withExpectedStatuses(http.StatusOK, http.StatusUnauthorized))
if err != nil {
return false, err
}
diff --git a/backend/internal/ibd/client.go b/backend/internal/ibd/client.go
index 2b91268..25c5173 100644
--- a/backend/internal/ibd/client.go
+++ b/backend/internal/ibd/client.go
@@ -20,8 +20,8 @@ type Client struct {
}
func NewClient(
- transports []transport.Transport,
cookies database.CookieSource,
+ transports ...transport.Transport,
) *Client {
return &Client{transports, cookies}
}
@@ -55,12 +55,17 @@ func (c *Client) getCookie(ctx context.Context, subject *string) (uint, *http.Co
return cookie.ID, cookie.ToHTTPCookie(), nil
}
-func (c *Client) Do(req *http.Request) (*http.Response, error) {
- return c.DoWithStatus(req, []int{http.StatusOK})
-}
+func (c *Client) Do(req *http.Request, opts ...optionFunc) (*http.Response, error) {
+ o := defaultOptions
+ for _, opt := range opts {
+ opt(&o)
+ }
+
+ // Sort and filter transports by properties
+ transports := transport.FilterTransports(c.transports, o.requiredProps)
+ transport.SortTransports(transports)
-func (c *Client) DoWithStatus(req *http.Request, expectedStatus []int) (*http.Response, error) {
- for i, tp := range c.transports {
+ for _, tp := range transports {
resp, err := tp.Do(req)
if errors.Is(err, transport.ErrUnsupportedRequest) {
// Skip unsupported transport
@@ -68,17 +73,17 @@ func (c *Client) DoWithStatus(req *http.Request, expectedStatus []int) (*http.Re
}
if err != nil {
slog.ErrorContext(req.Context(), "transport error",
- "transport", i,
+ "transport", tp.String(),
"error", err,
)
continue
}
- if slices.Contains(expectedStatus, resp.StatusCode) {
+ if slices.Contains(o.expectedStatuses, resp.StatusCode) {
return resp, nil
} else {
slog.ErrorContext(req.Context(), "unexpected status code",
- "transport", i,
- "expected", expectedStatus,
+ "transport", tp.String(),
+ "expected", o.expectedStatuses,
"actual", resp.StatusCode,
)
continue
diff --git a/backend/internal/ibd/client_test.go b/backend/internal/ibd/client_test.go
index 0a8fa98..d2dc1b2 100644
--- a/backend/internal/ibd/client_test.go
+++ b/backend/internal/ibd/client_test.go
@@ -16,10 +16,7 @@ func TestClient_getCookie(t *testing.T) {
t.Run("no cookies", func(t *testing.T) {
t.Parallel()
- client := NewClient(
- nil,
- new(emptyCookieSourceStub),
- )
+ client := NewClient(new(emptyCookieSourceStub))
_, _, err := client.getCookie(context.Background(), nil)
assert.ErrorIs(t, err, ErrNoAvailableCookies)
@@ -28,10 +25,7 @@ func TestClient_getCookie(t *testing.T) {
t.Run("no cookies by subject", func(t *testing.T) {
t.Parallel()
- client := NewClient(
- nil,
- new(emptyCookieSourceStub),
- )
+ client := NewClient(new(emptyCookieSourceStub))
subject := "test"
_, _, err := client.getCookie(context.Background(), &subject)
@@ -41,10 +35,7 @@ func TestClient_getCookie(t *testing.T) {
t.Run("get any cookie", func(t *testing.T) {
t.Parallel()
- client := NewClient(
- nil,
- new(emptyCookieSourceStub),
- )
+ client := NewClient(new(cookieSourceStub))
id, cookie, err := client.getCookie(context.Background(), nil)
require.NoError(t, err)
@@ -59,10 +50,7 @@ func TestClient_getCookie(t *testing.T) {
t.Run("get cookie by subject", func(t *testing.T) {
t.Parallel()
- client := NewClient(
- nil,
- new(emptyCookieSourceStub),
- )
+ client := NewClient(new(cookieSourceStub))
subject := "test"
id, cookie, err := client.getCookie(context.Background(), &subject)
diff --git a/backend/internal/ibd/options.go b/backend/internal/ibd/options.go
new file mode 100644
index 0000000..5c378d5
--- /dev/null
+++ b/backend/internal/ibd/options.go
@@ -0,0 +1,26 @@
+package ibd
+
+import "github.com/ansg191/ibd-trader-backend/internal/ibd/transport"
+
+type optionFunc func(*options)
+
+var defaultOptions = options{
+ expectedStatuses: []int{200},
+}
+
+type options struct {
+ expectedStatuses []int
+ requiredProps transport.Properties
+}
+
+func withExpectedStatuses(statuses ...int) optionFunc {
+ return func(o *options) {
+ o.expectedStatuses = append(o.expectedStatuses, statuses...)
+ }
+}
+
+func withRequiredProps(props transport.Properties) optionFunc {
+ return func(o *options) {
+ o.requiredProps = props
+ }
+}
diff --git a/backend/internal/ibd/search_test.go b/backend/internal/ibd/search_test.go
index f291033..99157cf 100644
--- a/backend/internal/ibd/search_test.go
+++ b/backend/internal/ibd/search_test.go
@@ -195,9 +195,7 @@ func TestClient_Search(t *testing.T) {
tp := httpmock.NewMockTransport()
tp.RegisterResponder("GET", searchUrl, httpmock.NewStringResponder(200, tt.response))
- client := NewClient([]transport.Transport{
- &http.Client{Transport: tp},
- }, new(cookieSourceStub))
+ client := NewClient(new(cookieSourceStub), transport.NewStandardTransport(&http.Client{Transport: tp}))
tt.f(t, client)
})
diff --git a/backend/internal/ibd/transport/scrapfly/scrapfly.go b/backend/internal/ibd/transport/scrapfly/scrapfly.go
index f34f3aa..3b414de 100644
--- a/backend/internal/ibd/transport/scrapfly/scrapfly.go
+++ b/backend/internal/ibd/transport/scrapfly/scrapfly.go
@@ -30,6 +30,10 @@ func New(client *http.Client, apiKey string, opts ...ScrapeOption) *ScrapflyTran
}
}
+func (s *ScrapflyTransport) String() string {
+ return "scrapfly"
+}
+
func (s *ScrapflyTransport) Do(req *http.Request) (*http.Response, error) {
// Construct scrape request URL
scrapeUrl, err := url.Parse(s.options.baseURL)
@@ -70,6 +74,10 @@ func (s *ScrapflyTransport) Do(req *http.Request) (*http.Response, error) {
return scraperResponse.ToHTTPResponse()
}
+func (s *ScrapflyTransport) Properties() transport.Properties {
+ return transport.PropertiesReliable
+}
+
func (s *ScrapflyTransport) constructRawQuery(u *url.URL, headers http.Header) string {
params := url.Values{}
params.Set("key", s.apiKey)
diff --git a/backend/internal/ibd/transport/standard.go b/backend/internal/ibd/transport/standard.go
new file mode 100644
index 0000000..9fa9ff9
--- /dev/null
+++ b/backend/internal/ibd/transport/standard.go
@@ -0,0 +1,41 @@
+package transport
+
+import (
+ "net/http"
+
+ "github.com/EDDYCJY/fake-useragent"
+)
+
+type StandardTransport http.Client
+
+func NewStandardTransport(client *http.Client) *StandardTransport {
+ return (*StandardTransport)(client)
+}
+
+func (t *StandardTransport) Do(req *http.Request) (*http.Response, error) {
+ addFakeHeaders(req)
+ return (*http.Client)(t).Do(req)
+}
+
+func (t *StandardTransport) String() string {
+ return "standard"
+}
+
+func (t *StandardTransport) Properties() Properties {
+ return PropertiesFree
+}
+
+func addFakeHeaders(req *http.Request) {
+ req.Header.Set("User-Agent", browser.Linux())
+ req.Header.Set("Sec-CH-UA", `"Not)A;Brand";v="99", "Google Chrome";v="127", "Chromium";v="127"`)
+ req.Header.Set("Sec-CH-UA-Mobile", "?0")
+ req.Header.Set("Sec-CH-UA-Platform", "Linux")
+ req.Header.Set("Upgrade-Insecure-Requests", "1")
+ req.Header.Set("Priority", "u=0, i")
+ req.Header.Set("Sec-Fetch-Site", "none")
+ req.Header.Set("Sec-Fetch-Mode", "navigate")
+ req.Header.Set("Sec-Fetch-Dest", "document")
+ req.Header.Set("Sec-Fetch-User", "?1")
+ req.Header.Set("Accept-Language", "en-US,en;q=0.9")
+ req.Header.Set("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7")
+}
diff --git a/backend/internal/ibd/transport/transport.go b/backend/internal/ibd/transport/transport.go
index 7918f4f..95e9ef3 100644
--- a/backend/internal/ibd/transport/transport.go
+++ b/backend/internal/ibd/transport/transport.go
@@ -1,12 +1,66 @@
package transport
import (
+ "cmp"
"errors"
+ "fmt"
"net/http"
+ "slices"
)
var ErrUnsupportedRequest = errors.New("unsupported request")
+type Properties uint8
+
+const (
+ // PropertiesFree indicates that the transport is free.
+ // This means that requests made with this transport don't cost any money.
+ PropertiesFree Properties = 1 << iota
+ // PropertiesReliable indicates that the transport is reliable.
+ // This means that requests made with this transport are guaranteed to be
+ // successful if the server is reachable.
+ PropertiesReliable
+)
+
+func (p Properties) IsReliable() bool {
+ return p&PropertiesReliable != 0
+}
+
+func (p Properties) IsFree() bool {
+ return p&PropertiesFree != 0
+}
+
type Transport interface {
+ fmt.Stringer
+
Do(req *http.Request) (*http.Response, error)
+ Properties() Properties
+}
+
+// SortTransports sorts the transports by their properties.
+//
+// The transports are sorted in the following order:
+// 1. Free transports
+// 2. Reliable transports
+func SortTransports(transports []Transport) {
+ priorities := map[Properties]int{
+ PropertiesFree | PropertiesReliable: 0,
+ PropertiesFree: 1,
+ PropertiesReliable: 2,
+ }
+ slices.SortStableFunc(transports, func(a, b Transport) int {
+ iPriority := priorities[a.Properties()]
+ jPriority := priorities[b.Properties()]
+ return cmp.Compare(iPriority, jPriority)
+ })
+}
+
+func FilterTransports(transport []Transport, props Properties) []Transport {
+ var filtered []Transport
+ for _, tp := range transport {
+ if tp.Properties()&props == props {
+ filtered = append(filtered, tp)
+ }
+ }
+ return filtered
}
diff --git a/backend/internal/server/idb/user/v1/user.go b/backend/internal/server/idb/user/v1/user.go
index 8e5f16a..c100465 100644
--- a/backend/internal/server/idb/user/v1/user.go
+++ b/backend/internal/server/idb/user/v1/user.go
@@ -17,21 +17,26 @@ import (
type Server struct {
pb.UnimplementedUserServiceServer
- db database.UserStore
+ user database.UserStore
+ cookie database.CookieSource
client *ibd.Client
}
-func New(db database.UserStore, client *ibd.Client) *Server {
- return &Server{db: db, client: client}
+func New(userStore database.UserStore, cookieStore database.CookieStore, client *ibd.Client) *Server {
+ return &Server{
+ user: userStore,
+ cookie: cookieStore,
+ client: client,
+ }
}
func (u *Server) CreateUser(ctx context.Context, request *pb.CreateUserRequest) (*pb.CreateUserResponse, error) {
- err := u.db.AddUser(ctx, request.Subject)
+ err := u.user.AddUser(ctx, request.Subject)
if err != nil {
return nil, status.Errorf(codes.Internal, "unable to create user: %v", err)
}
- user, err := u.db.GetUser(ctx, request.Subject)
+ user, err := u.user.GetUser(ctx, request.Subject)
if err != nil {
return nil, status.Errorf(codes.Internal, "unable to get user: %v", err)
}
@@ -46,7 +51,7 @@ func (u *Server) CreateUser(ctx context.Context, request *pb.CreateUserRequest)
}
func (u *Server) GetUser(ctx context.Context, request *pb.GetUserRequest) (*pb.GetUserResponse, error) {
- user, err := u.db.GetUser(ctx, request.Subject)
+ user, err := u.user.GetUser(ctx, request.Subject)
if errors.Is(err, database.ErrUserNotFound) {
return nil, status.New(codes.NotFound, "user not found").Err()
}
@@ -83,7 +88,7 @@ func (u *Server) UpdateUser(ctx context.Context, request *pb.UpdateUserRequest)
(newUser.IbdPassword != existingUser.IbdPassword ||
newUser.IbdUsername != existingUser.IbdUsername) {
// Update IBD creds
- err = u.db.AddIBDCreds(ctx, newUser.Subject, *newUser.IbdUsername, *newUser.IbdPassword)
+ err = u.user.AddIBDCreds(ctx, newUser.Subject, *newUser.IbdUsername, *newUser.IbdPassword)
if err != nil {
return nil, status.Errorf(codes.Internal, "unable to update user: %v", err)
}
@@ -111,3 +116,41 @@ func (u *Server) CheckIBDUsername(ctx context.Context, req *pb.CheckIBDUsernameR
Exists: exists,
}, nil
}
+
+func (u *Server) AuthenticateUser(ctx context.Context, req *pb.AuthenticateUserRequest) (*pb.AuthenticateUserResponse, error) {
+ // Check if user has cookies
+ cookies, err := u.cookie.GetCookies(ctx, req.Subject, false)
+ if err != nil {
+ return nil, status.Errorf(codes.Internal, "unable to get cookies: %v", err)
+ }
+ if len(cookies) > 0 {
+ return &pb.AuthenticateUserResponse{
+ Authenticated: true,
+ }, nil
+ }
+
+ // Authenticate user
+ // Get IBD creds
+ username, password, err := u.user.GetIBDCreds(ctx, req.Subject)
+ if errors.Is(err, database.ErrIBDCredsNotFound) {
+ return nil, status.New(codes.NotFound, "User has no IDB creds").Err()
+ }
+ if err != nil {
+ return nil, status.Errorf(codes.Internal, "unable to get IBD creds: %v", err)
+ }
+
+ // Authenticate user
+ cookie, err := u.client.Authenticate(ctx, username, password)
+ if errors.Is(err, ibd.ErrBadCredentials) {
+ return &pb.AuthenticateUserResponse{
+ Authenticated: false,
+ }, nil
+ }
+ if err != nil {
+ return nil, status.Errorf(codes.Internal, "unable to authenticate user: %v", err)
+ }
+
+ return &pb.AuthenticateUserResponse{
+ Authenticated: cookie != nil,
+ }, nil
+}
diff --git a/backend/internal/server/server.go b/backend/internal/server/server.go
index c46a629..186d581 100644
--- a/backend/internal/server/server.go
+++ b/backend/internal/server/server.go
@@ -45,7 +45,7 @@ func New(
}
s := grpc.NewServer()
- upb.RegisterUserServiceServer(s, user.New(db, client))
+ upb.RegisterUserServiceServer(s, user.New(db, db, client))
spb.RegisterStockServiceServer(s, stock.New(db, scrapeQueue))
longrunningpb.RegisterOperationsServer(s, newOperationServer(scrapeQueue))
reflection.Register(s)