diff options
author | 2024-08-06 18:53:22 -0700 | |
---|---|---|
committer | 2024-08-06 18:53:22 -0700 | |
commit | 825ba9d21d15e1f9b34c60bac68e42ee1fb125f9 (patch) | |
tree | c466380d15d672a4619a7e1c15f058d52123dbb4 /backend/internal | |
parent | 961f9e0a76c3cfe9ae92ca8da0531790e0610b69 (diff) | |
download | ibd-trader-825ba9d21d15e1f9b34c60bac68e42ee1fb125f9.tar.gz ibd-trader-825ba9d21d15e1f9b34c60bac68e42ee1fb125f9.tar.zst ibd-trader-825ba9d21d15e1f9b34c60bac68e42ee1fb125f9.zip |
Improve selection of IBD transports
Diffstat (limited to 'backend/internal')
-rw-r--r-- | backend/internal/ibd/auth.go | 9 | ||||
-rw-r--r-- | backend/internal/ibd/auth_test.go | 26 | ||||
-rw-r--r-- | backend/internal/ibd/check_ibd_username.go | 2 | ||||
-rw-r--r-- | backend/internal/ibd/client.go | 25 | ||||
-rw-r--r-- | backend/internal/ibd/client_test.go | 20 | ||||
-rw-r--r-- | backend/internal/ibd/options.go | 26 | ||||
-rw-r--r-- | backend/internal/ibd/search_test.go | 4 | ||||
-rw-r--r-- | backend/internal/ibd/transport/scrapfly/scrapfly.go | 8 | ||||
-rw-r--r-- | backend/internal/ibd/transport/standard.go | 41 | ||||
-rw-r--r-- | backend/internal/ibd/transport/transport.go | 54 | ||||
-rw-r--r-- | backend/internal/server/idb/user/v1/user.go | 57 | ||||
-rw-r--r-- | backend/internal/server/server.go | 2 |
12 files changed, 227 insertions, 47 deletions
diff --git a/backend/internal/ibd/auth.go b/backend/internal/ibd/auth.go index f09f3f7..7b82057 100644 --- a/backend/internal/ibd/auth.go +++ b/backend/internal/ibd/auth.go @@ -11,6 +11,7 @@ import ( "net/http" "strings" + "github.com/ansg191/ibd-trader-backend/internal/ibd/transport" "golang.org/x/net/html" ) @@ -48,7 +49,7 @@ func (c *Client) getLoginPage(ctx context.Context) (*authConfig, error) { return nil, err } - resp, err := c.Do(req) + resp, err := c.Do(req, withRequiredProps(transport.PropertiesReliable)) if err != nil { return nil, err } @@ -117,7 +118,9 @@ func (c *Client) sendAuthRequest(ctx context.Context, cfg *authConfig, username, req.Header.Set("Content-Type", "application/json") req.Header.Set("Auth0-Client", "eyJuYW1lIjoiYXV0aDAuanMtdWxwIiwidmVyc2lvbiI6IjkuMjQuMSJ9") - resp, err := c.Do(req) + resp, err := c.Do(req, + withRequiredProps(transport.PropertiesReliable), + withExpectedStatuses(http.StatusOK, http.StatusUnauthorized)) if err != nil { return "", "", err } @@ -156,7 +159,7 @@ func (c *Client) sendPostAuth(ctx context.Context, token, params string) (*http. req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - resp, err := c.Do(req) + resp, err := c.Do(req, withRequiredProps(transport.PropertiesReliable)) if err != nil { return nil, err } diff --git a/backend/internal/ibd/auth_test.go b/backend/internal/ibd/auth_test.go index 8a00d42..54ea98a 100644 --- a/backend/internal/ibd/auth_test.go +++ b/backend/internal/ibd/auth_test.go @@ -163,9 +163,7 @@ func TestClient_Authenticate(t *testing.T) { return resp, nil }) - client := NewClient([]transport.Transport{ - &http.Client{Transport: tp}, - }, nil) + client := NewClient(nil, newTransport(tp)) cookie, err := client.Authenticate(context.Background(), "abc", "xyz") require.NoError(t, err) @@ -191,11 +189,27 @@ func TestClient_Authenticate_401(t *testing.T) { return httpmock.NewStringResponse(http.StatusUnauthorized, `{"name":"ValidationError","code":"ERR016","message":"Wrong username or password","description":"Wrong username or password"}`), nil }) - client := NewClient([]transport.Transport{ - &http.Client{Transport: tp}, - }, nil) + client := NewClient(nil, newTransport(tp)) cookie, err := client.Authenticate(context.Background(), "abc", "xyz") assert.Nil(t, cookie) assert.ErrorIs(t, err, ErrBadCredentials) } + +type testReliableTransport http.Client + +func newTransport(tp *httpmock.MockTransport) *testReliableTransport { + return (*testReliableTransport)(&http.Client{Transport: tp}) +} + +func (t *testReliableTransport) String() string { + return "testReliableTransport" +} + +func (t *testReliableTransport) Do(req *http.Request) (*http.Response, error) { + return (*http.Client)(t).Do(req) +} + +func (t *testReliableTransport) Properties() transport.Properties { + return transport.PropertiesFree | transport.PropertiesReliable +} diff --git a/backend/internal/ibd/check_ibd_username.go b/backend/internal/ibd/check_ibd_username.go index c03176e..b026151 100644 --- a/backend/internal/ibd/check_ibd_username.go +++ b/backend/internal/ibd/check_ibd_username.go @@ -42,7 +42,7 @@ func (c *Client) checkIBDUsername(ctx context.Context, cfg *authConfig, username req.Header.Set("X-REQUEST-EDITIONID", "IBD-EN_US") req.Header.Set("X-REQUEST-SCHEME", "https") - resp, err := c.DoWithStatus(req, []int{http.StatusOK, http.StatusUnauthorized}) + resp, err := c.Do(req, withExpectedStatuses(http.StatusOK, http.StatusUnauthorized)) if err != nil { return false, err } diff --git a/backend/internal/ibd/client.go b/backend/internal/ibd/client.go index 2b91268..25c5173 100644 --- a/backend/internal/ibd/client.go +++ b/backend/internal/ibd/client.go @@ -20,8 +20,8 @@ type Client struct { } func NewClient( - transports []transport.Transport, cookies database.CookieSource, + transports ...transport.Transport, ) *Client { return &Client{transports, cookies} } @@ -55,12 +55,17 @@ func (c *Client) getCookie(ctx context.Context, subject *string) (uint, *http.Co return cookie.ID, cookie.ToHTTPCookie(), nil } -func (c *Client) Do(req *http.Request) (*http.Response, error) { - return c.DoWithStatus(req, []int{http.StatusOK}) -} +func (c *Client) Do(req *http.Request, opts ...optionFunc) (*http.Response, error) { + o := defaultOptions + for _, opt := range opts { + opt(&o) + } + + // Sort and filter transports by properties + transports := transport.FilterTransports(c.transports, o.requiredProps) + transport.SortTransports(transports) -func (c *Client) DoWithStatus(req *http.Request, expectedStatus []int) (*http.Response, error) { - for i, tp := range c.transports { + for _, tp := range transports { resp, err := tp.Do(req) if errors.Is(err, transport.ErrUnsupportedRequest) { // Skip unsupported transport @@ -68,17 +73,17 @@ func (c *Client) DoWithStatus(req *http.Request, expectedStatus []int) (*http.Re } if err != nil { slog.ErrorContext(req.Context(), "transport error", - "transport", i, + "transport", tp.String(), "error", err, ) continue } - if slices.Contains(expectedStatus, resp.StatusCode) { + if slices.Contains(o.expectedStatuses, resp.StatusCode) { return resp, nil } else { slog.ErrorContext(req.Context(), "unexpected status code", - "transport", i, - "expected", expectedStatus, + "transport", tp.String(), + "expected", o.expectedStatuses, "actual", resp.StatusCode, ) continue diff --git a/backend/internal/ibd/client_test.go b/backend/internal/ibd/client_test.go index 0a8fa98..d2dc1b2 100644 --- a/backend/internal/ibd/client_test.go +++ b/backend/internal/ibd/client_test.go @@ -16,10 +16,7 @@ func TestClient_getCookie(t *testing.T) { t.Run("no cookies", func(t *testing.T) { t.Parallel() - client := NewClient( - nil, - new(emptyCookieSourceStub), - ) + client := NewClient(new(emptyCookieSourceStub)) _, _, err := client.getCookie(context.Background(), nil) assert.ErrorIs(t, err, ErrNoAvailableCookies) @@ -28,10 +25,7 @@ func TestClient_getCookie(t *testing.T) { t.Run("no cookies by subject", func(t *testing.T) { t.Parallel() - client := NewClient( - nil, - new(emptyCookieSourceStub), - ) + client := NewClient(new(emptyCookieSourceStub)) subject := "test" _, _, err := client.getCookie(context.Background(), &subject) @@ -41,10 +35,7 @@ func TestClient_getCookie(t *testing.T) { t.Run("get any cookie", func(t *testing.T) { t.Parallel() - client := NewClient( - nil, - new(emptyCookieSourceStub), - ) + client := NewClient(new(cookieSourceStub)) id, cookie, err := client.getCookie(context.Background(), nil) require.NoError(t, err) @@ -59,10 +50,7 @@ func TestClient_getCookie(t *testing.T) { t.Run("get cookie by subject", func(t *testing.T) { t.Parallel() - client := NewClient( - nil, - new(emptyCookieSourceStub), - ) + client := NewClient(new(cookieSourceStub)) subject := "test" id, cookie, err := client.getCookie(context.Background(), &subject) diff --git a/backend/internal/ibd/options.go b/backend/internal/ibd/options.go new file mode 100644 index 0000000..5c378d5 --- /dev/null +++ b/backend/internal/ibd/options.go @@ -0,0 +1,26 @@ +package ibd + +import "github.com/ansg191/ibd-trader-backend/internal/ibd/transport" + +type optionFunc func(*options) + +var defaultOptions = options{ + expectedStatuses: []int{200}, +} + +type options struct { + expectedStatuses []int + requiredProps transport.Properties +} + +func withExpectedStatuses(statuses ...int) optionFunc { + return func(o *options) { + o.expectedStatuses = append(o.expectedStatuses, statuses...) + } +} + +func withRequiredProps(props transport.Properties) optionFunc { + return func(o *options) { + o.requiredProps = props + } +} diff --git a/backend/internal/ibd/search_test.go b/backend/internal/ibd/search_test.go index f291033..99157cf 100644 --- a/backend/internal/ibd/search_test.go +++ b/backend/internal/ibd/search_test.go @@ -195,9 +195,7 @@ func TestClient_Search(t *testing.T) { tp := httpmock.NewMockTransport() tp.RegisterResponder("GET", searchUrl, httpmock.NewStringResponder(200, tt.response)) - client := NewClient([]transport.Transport{ - &http.Client{Transport: tp}, - }, new(cookieSourceStub)) + client := NewClient(new(cookieSourceStub), transport.NewStandardTransport(&http.Client{Transport: tp})) tt.f(t, client) }) diff --git a/backend/internal/ibd/transport/scrapfly/scrapfly.go b/backend/internal/ibd/transport/scrapfly/scrapfly.go index f34f3aa..3b414de 100644 --- a/backend/internal/ibd/transport/scrapfly/scrapfly.go +++ b/backend/internal/ibd/transport/scrapfly/scrapfly.go @@ -30,6 +30,10 @@ func New(client *http.Client, apiKey string, opts ...ScrapeOption) *ScrapflyTran } } +func (s *ScrapflyTransport) String() string { + return "scrapfly" +} + func (s *ScrapflyTransport) Do(req *http.Request) (*http.Response, error) { // Construct scrape request URL scrapeUrl, err := url.Parse(s.options.baseURL) @@ -70,6 +74,10 @@ func (s *ScrapflyTransport) Do(req *http.Request) (*http.Response, error) { return scraperResponse.ToHTTPResponse() } +func (s *ScrapflyTransport) Properties() transport.Properties { + return transport.PropertiesReliable +} + func (s *ScrapflyTransport) constructRawQuery(u *url.URL, headers http.Header) string { params := url.Values{} params.Set("key", s.apiKey) diff --git a/backend/internal/ibd/transport/standard.go b/backend/internal/ibd/transport/standard.go new file mode 100644 index 0000000..9fa9ff9 --- /dev/null +++ b/backend/internal/ibd/transport/standard.go @@ -0,0 +1,41 @@ +package transport + +import ( + "net/http" + + "github.com/EDDYCJY/fake-useragent" +) + +type StandardTransport http.Client + +func NewStandardTransport(client *http.Client) *StandardTransport { + return (*StandardTransport)(client) +} + +func (t *StandardTransport) Do(req *http.Request) (*http.Response, error) { + addFakeHeaders(req) + return (*http.Client)(t).Do(req) +} + +func (t *StandardTransport) String() string { + return "standard" +} + +func (t *StandardTransport) Properties() Properties { + return PropertiesFree +} + +func addFakeHeaders(req *http.Request) { + req.Header.Set("User-Agent", browser.Linux()) + req.Header.Set("Sec-CH-UA", `"Not)A;Brand";v="99", "Google Chrome";v="127", "Chromium";v="127"`) + req.Header.Set("Sec-CH-UA-Mobile", "?0") + req.Header.Set("Sec-CH-UA-Platform", "Linux") + req.Header.Set("Upgrade-Insecure-Requests", "1") + req.Header.Set("Priority", "u=0, i") + req.Header.Set("Sec-Fetch-Site", "none") + req.Header.Set("Sec-Fetch-Mode", "navigate") + req.Header.Set("Sec-Fetch-Dest", "document") + req.Header.Set("Sec-Fetch-User", "?1") + req.Header.Set("Accept-Language", "en-US,en;q=0.9") + req.Header.Set("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7") +} diff --git a/backend/internal/ibd/transport/transport.go b/backend/internal/ibd/transport/transport.go index 7918f4f..95e9ef3 100644 --- a/backend/internal/ibd/transport/transport.go +++ b/backend/internal/ibd/transport/transport.go @@ -1,12 +1,66 @@ package transport import ( + "cmp" "errors" + "fmt" "net/http" + "slices" ) var ErrUnsupportedRequest = errors.New("unsupported request") +type Properties uint8 + +const ( + // PropertiesFree indicates that the transport is free. + // This means that requests made with this transport don't cost any money. + PropertiesFree Properties = 1 << iota + // PropertiesReliable indicates that the transport is reliable. + // This means that requests made with this transport are guaranteed to be + // successful if the server is reachable. + PropertiesReliable +) + +func (p Properties) IsReliable() bool { + return p&PropertiesReliable != 0 +} + +func (p Properties) IsFree() bool { + return p&PropertiesFree != 0 +} + type Transport interface { + fmt.Stringer + Do(req *http.Request) (*http.Response, error) + Properties() Properties +} + +// SortTransports sorts the transports by their properties. +// +// The transports are sorted in the following order: +// 1. Free transports +// 2. Reliable transports +func SortTransports(transports []Transport) { + priorities := map[Properties]int{ + PropertiesFree | PropertiesReliable: 0, + PropertiesFree: 1, + PropertiesReliable: 2, + } + slices.SortStableFunc(transports, func(a, b Transport) int { + iPriority := priorities[a.Properties()] + jPriority := priorities[b.Properties()] + return cmp.Compare(iPriority, jPriority) + }) +} + +func FilterTransports(transport []Transport, props Properties) []Transport { + var filtered []Transport + for _, tp := range transport { + if tp.Properties()&props == props { + filtered = append(filtered, tp) + } + } + return filtered } diff --git a/backend/internal/server/idb/user/v1/user.go b/backend/internal/server/idb/user/v1/user.go index 8e5f16a..c100465 100644 --- a/backend/internal/server/idb/user/v1/user.go +++ b/backend/internal/server/idb/user/v1/user.go @@ -17,21 +17,26 @@ import ( type Server struct { pb.UnimplementedUserServiceServer - db database.UserStore + user database.UserStore + cookie database.CookieSource client *ibd.Client } -func New(db database.UserStore, client *ibd.Client) *Server { - return &Server{db: db, client: client} +func New(userStore database.UserStore, cookieStore database.CookieStore, client *ibd.Client) *Server { + return &Server{ + user: userStore, + cookie: cookieStore, + client: client, + } } func (u *Server) CreateUser(ctx context.Context, request *pb.CreateUserRequest) (*pb.CreateUserResponse, error) { - err := u.db.AddUser(ctx, request.Subject) + err := u.user.AddUser(ctx, request.Subject) if err != nil { return nil, status.Errorf(codes.Internal, "unable to create user: %v", err) } - user, err := u.db.GetUser(ctx, request.Subject) + user, err := u.user.GetUser(ctx, request.Subject) if err != nil { return nil, status.Errorf(codes.Internal, "unable to get user: %v", err) } @@ -46,7 +51,7 @@ func (u *Server) CreateUser(ctx context.Context, request *pb.CreateUserRequest) } func (u *Server) GetUser(ctx context.Context, request *pb.GetUserRequest) (*pb.GetUserResponse, error) { - user, err := u.db.GetUser(ctx, request.Subject) + user, err := u.user.GetUser(ctx, request.Subject) if errors.Is(err, database.ErrUserNotFound) { return nil, status.New(codes.NotFound, "user not found").Err() } @@ -83,7 +88,7 @@ func (u *Server) UpdateUser(ctx context.Context, request *pb.UpdateUserRequest) (newUser.IbdPassword != existingUser.IbdPassword || newUser.IbdUsername != existingUser.IbdUsername) { // Update IBD creds - err = u.db.AddIBDCreds(ctx, newUser.Subject, *newUser.IbdUsername, *newUser.IbdPassword) + err = u.user.AddIBDCreds(ctx, newUser.Subject, *newUser.IbdUsername, *newUser.IbdPassword) if err != nil { return nil, status.Errorf(codes.Internal, "unable to update user: %v", err) } @@ -111,3 +116,41 @@ func (u *Server) CheckIBDUsername(ctx context.Context, req *pb.CheckIBDUsernameR Exists: exists, }, nil } + +func (u *Server) AuthenticateUser(ctx context.Context, req *pb.AuthenticateUserRequest) (*pb.AuthenticateUserResponse, error) { + // Check if user has cookies + cookies, err := u.cookie.GetCookies(ctx, req.Subject, false) + if err != nil { + return nil, status.Errorf(codes.Internal, "unable to get cookies: %v", err) + } + if len(cookies) > 0 { + return &pb.AuthenticateUserResponse{ + Authenticated: true, + }, nil + } + + // Authenticate user + // Get IBD creds + username, password, err := u.user.GetIBDCreds(ctx, req.Subject) + if errors.Is(err, database.ErrIBDCredsNotFound) { + return nil, status.New(codes.NotFound, "User has no IDB creds").Err() + } + if err != nil { + return nil, status.Errorf(codes.Internal, "unable to get IBD creds: %v", err) + } + + // Authenticate user + cookie, err := u.client.Authenticate(ctx, username, password) + if errors.Is(err, ibd.ErrBadCredentials) { + return &pb.AuthenticateUserResponse{ + Authenticated: false, + }, nil + } + if err != nil { + return nil, status.Errorf(codes.Internal, "unable to authenticate user: %v", err) + } + + return &pb.AuthenticateUserResponse{ + Authenticated: cookie != nil, + }, nil +} diff --git a/backend/internal/server/server.go b/backend/internal/server/server.go index c46a629..186d581 100644 --- a/backend/internal/server/server.go +++ b/backend/internal/server/server.go @@ -45,7 +45,7 @@ func New( } s := grpc.NewServer() - upb.RegisterUserServiceServer(s, user.New(db, client)) + upb.RegisterUserServiceServer(s, user.New(db, db, client)) spb.RegisterStockServiceServer(s, stock.New(db, scrapeQueue)) longrunningpb.RegisterOperationsServer(s, newOperationServer(scrapeQueue)) reflection.Register(s) |