diff options
Diffstat (limited to 'backend/internal/ibd/auth.go')
-rw-r--r-- | backend/internal/ibd/auth.go | 333 |
1 files changed, 333 insertions, 0 deletions
diff --git a/backend/internal/ibd/auth.go b/backend/internal/ibd/auth.go new file mode 100644 index 0000000..7b82057 --- /dev/null +++ b/backend/internal/ibd/auth.go @@ -0,0 +1,333 @@ +package ibd + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + + "github.com/ansg191/ibd-trader-backend/internal/ibd/transport" + "golang.org/x/net/html" +) + +const ( + signInUrl = "https://myibd.investors.com/secure/signin.aspx?eurl=https%3A%2F%2Fwww.investors.com" + authenticateUrl = "https://sso.accounts.dowjones.com/authenticate" + postAuthUrl = "https://sso.accounts.dowjones.com/postauth/handler" + cookieName = ".ASPXAUTH" +) + +var ErrAuthCookieNotFound = errors.New("cookie not found") +var ErrBadCredentials = errors.New("bad credentials") + +func (c *Client) Authenticate( + ctx context.Context, + username, + password string, +) (*http.Cookie, error) { + cfg, err := c.getLoginPage(ctx) + if err != nil { + return nil, err + } + + token, params, err := c.sendAuthRequest(ctx, cfg, username, password) + if err != nil { + return nil, err + } + + return c.sendPostAuth(ctx, token, params) +} + +func (c *Client) getLoginPage(ctx context.Context) (*authConfig, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, signInUrl, nil) + if err != nil { + return nil, err + } + + resp, err := c.Do(req, withRequiredProps(transport.PropertiesReliable)) + if err != nil { + return nil, err + } + defer func(Body io.ReadCloser) { + _ = Body.Close() + }(resp.Body) + + if resp.StatusCode != http.StatusOK { + content, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return nil, fmt.Errorf( + "unexpected status code %d: %s", + resp.StatusCode, + string(content), + ) + } + + node, err := html.Parse(resp.Body) + if err != nil { + return nil, err + } + + cfg, err := extractAuthConfig(node) + if err != nil { + return nil, fmt.Errorf("failed to extract auth config: %w", err) + } + + return cfg, nil +} + +func (c *Client) sendAuthRequest(ctx context.Context, cfg *authConfig, username, password string) (string, string, error) { + body := authRequestBody{ + ClientId: cfg.ClientID, + RedirectUri: cfg.CallbackURL, + Tenant: "sso", + ResponseType: cfg.ExtraParams.ResponseType, + Username: username, + Password: password, + Scope: cfg.ExtraParams.Scope, + State: cfg.ExtraParams.State, + Headers: struct { + XRemoteUser string `json:"x-_remote-_user"` + }(struct{ XRemoteUser string }{ + XRemoteUser: username, + }), + XOidcProvider: "localop", + Protocol: cfg.ExtraParams.Protocol, + Nonce: cfg.ExtraParams.Nonce, + UiLocales: cfg.ExtraParams.UiLocales, + Csrf: cfg.ExtraParams.Csrf, + Intstate: cfg.ExtraParams.Intstate, + Connection: "DJldap", + } + bodyJson, err := json.Marshal(body) + if err != nil { + return "", "", fmt.Errorf("failed to marshal auth request body: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, authenticateUrl, bytes.NewReader(bodyJson)) + if err != nil { + return "", "", err + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Auth0-Client", "eyJuYW1lIjoiYXV0aDAuanMtdWxwIiwidmVyc2lvbiI6IjkuMjQuMSJ9") + + resp, err := c.Do(req, + withRequiredProps(transport.PropertiesReliable), + withExpectedStatuses(http.StatusOK, http.StatusUnauthorized)) + if err != nil { + return "", "", err + } + defer func(Body io.ReadCloser) { + _ = Body.Close() + }(resp.Body) + + if resp.StatusCode == http.StatusUnauthorized { + return "", "", ErrBadCredentials + } else if resp.StatusCode != http.StatusOK { + content, err := io.ReadAll(resp.Body) + if err != nil { + return "", "", fmt.Errorf("failed to read response body: %w", err) + } + return "", "", fmt.Errorf( + "unexpected status code %d: %s", + resp.StatusCode, + string(content), + ) + } + + node, err := html.Parse(resp.Body) + if err != nil { + return "", "", err + } + + return extractTokenParams(node) +} + +func (c *Client) sendPostAuth(ctx context.Context, token, params string) (*http.Cookie, error) { + body := fmt.Sprintf("token=%s¶ms=%s", token, params) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, postAuthUrl, strings.NewReader(body)) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := c.Do(req, withRequiredProps(transport.PropertiesReliable)) + if err != nil { + return nil, err + } + defer func(Body io.ReadCloser) { + _ = Body.Close() + }(resp.Body) + + if resp.StatusCode != http.StatusOK { + content, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return nil, fmt.Errorf( + "unexpected status code %d: %s", + resp.StatusCode, + string(content), + ) + } + + // Extract cookie + for _, cookie := range resp.Cookies() { + if cookie.Name == cookieName { + return cookie, nil + } + } + + return nil, ErrAuthCookieNotFound +} + +func extractAuthConfig(node *html.Node) (*authConfig, error) { + // Find `root` element + root := findId(node, "root") + if root == nil { + return nil, fmt.Errorf("root element not found") + } + + // Get adjacent script element + var script *html.Node + for s := root.NextSibling; s != nil; s = s.NextSibling { + if s.Type == html.ElementNode && s.Data == "script" { + script = s + break + } + } + + if script == nil { + return nil, fmt.Errorf("script element not found") + } + + // Get script content + content := extractText(script) + + // Find `AUTH_CONFIG` variable + const authConfigVar = "const AUTH_CONFIG = '" + i := strings.Index(content, authConfigVar) + if i == -1 { + return nil, fmt.Errorf("AUTH_CONFIG not found") + } + + // Find end of `AUTH_CONFIG` variable + j := strings.Index(content[i+len(authConfigVar):], "'") + + // Extract `AUTH_CONFIG` value + authConfigJSONB64 := content[i+len(authConfigVar) : i+len(authConfigVar)+j] + + // Decode `AUTH_CONFIG` value + authConfigJSON, err := base64.StdEncoding.DecodeString(authConfigJSONB64) + if err != nil { + return nil, fmt.Errorf("failed to decode AUTH_CONFIG: %w", err) + } + + // Unmarshal `AUTH_CONFIG` value + var cfg authConfig + if err = json.Unmarshal(authConfigJSON, &cfg); err != nil { + return nil, fmt.Errorf("failed to unmarshal AUTH_CONFIG: %w", err) + } + + return &cfg, nil +} + +type authConfig struct { + Auth0Domain string `json:"auth0Domain"` + CallbackURL string `json:"callbackURL"` + ClientID string `json:"clientID"` + ExtraParams struct { + Protocol string `json:"protocol"` + Scope string `json:"scope"` + ResponseType string `json:"response_type"` + Nonce string `json:"nonce"` + UiLocales string `json:"ui_locales"` + Csrf string `json:"_csrf"` + Intstate string `json:"_intstate"` + State string `json:"state"` + } `json:"extraParams"` + InternalOptions struct { + ResponseType string `json:"response_type"` + ClientId string `json:"client_id"` + Scope string `json:"scope"` + RedirectUri string `json:"redirect_uri"` + UiLocales string `json:"ui_locales"` + Eurl string `json:"eurl"` + Nonce string `json:"nonce"` + State string `json:"state"` + Resource string `json:"resource"` + Protocol string `json:"protocol"` + Client string `json:"client"` + } `json:"internalOptions"` + IsThirdPartyClient bool `json:"isThirdPartyClient"` + AuthorizationServer struct { + Url string `json:"url"` + Issuer string `json:"issuer"` + } `json:"authorizationServer"` +} + +func extractTokenParams(node *html.Node) (token string, params string, err error) { + inputs := findChildrenRecursive(node, func(node *html.Node) bool { + return node.Type == html.ElementNode && node.Data == "input" + }) + + var tokenNode, paramsNode *html.Node + for _, input := range inputs { + for _, attr := range input.Attr { + if attr.Key == "name" && attr.Val == "token" { + tokenNode = input + } else if attr.Key == "name" && attr.Val == "params" { + paramsNode = input + } + } + } + + if tokenNode == nil { + return "", "", fmt.Errorf("token input not found") + } + if paramsNode == nil { + return "", "", fmt.Errorf("params input not found") + } + + for _, attr := range tokenNode.Attr { + if attr.Key == "value" { + token = attr.Val + } + } + for _, attr := range paramsNode.Attr { + if attr.Key == "value" { + params = attr.Val + } + } + + return +} + +type authRequestBody struct { + ClientId string `json:"client_id"` + RedirectUri string `json:"redirect_uri"` + Tenant string `json:"tenant"` + ResponseType string `json:"response_type"` + Username string `json:"username"` + Password string `json:"password"` + Scope string `json:"scope"` + State string `json:"state"` + Headers struct { + XRemoteUser string `json:"x-_remote-_user"` + } `json:"headers"` + XOidcProvider string `json:"x-_oidc-_provider"` + Protocol string `json:"protocol"` + Nonce string `json:"nonce"` + UiLocales string `json:"ui_locales"` + Csrf string `json:"_csrf"` + Intstate string `json:"_intstate"` + Connection string `json:"connection"` +} |