aboutsummaryrefslogtreecommitdiff
path: root/backend/internal/ibd/auth.go
diff options
context:
space:
mode:
Diffstat (limited to 'backend/internal/ibd/auth.go')
-rw-r--r--backend/internal/ibd/auth.go333
1 files changed, 333 insertions, 0 deletions
diff --git a/backend/internal/ibd/auth.go b/backend/internal/ibd/auth.go
new file mode 100644
index 0000000..7b82057
--- /dev/null
+++ b/backend/internal/ibd/auth.go
@@ -0,0 +1,333 @@
+package ibd
+
+import (
+ "bytes"
+ "context"
+ "encoding/base64"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "strings"
+
+ "github.com/ansg191/ibd-trader-backend/internal/ibd/transport"
+ "golang.org/x/net/html"
+)
+
+const (
+ signInUrl = "https://myibd.investors.com/secure/signin.aspx?eurl=https%3A%2F%2Fwww.investors.com"
+ authenticateUrl = "https://sso.accounts.dowjones.com/authenticate"
+ postAuthUrl = "https://sso.accounts.dowjones.com/postauth/handler"
+ cookieName = ".ASPXAUTH"
+)
+
+var ErrAuthCookieNotFound = errors.New("cookie not found")
+var ErrBadCredentials = errors.New("bad credentials")
+
+func (c *Client) Authenticate(
+ ctx context.Context,
+ username,
+ password string,
+) (*http.Cookie, error) {
+ cfg, err := c.getLoginPage(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ token, params, err := c.sendAuthRequest(ctx, cfg, username, password)
+ if err != nil {
+ return nil, err
+ }
+
+ return c.sendPostAuth(ctx, token, params)
+}
+
+func (c *Client) getLoginPage(ctx context.Context) (*authConfig, error) {
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, signInUrl, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ resp, err := c.Do(req, withRequiredProps(transport.PropertiesReliable))
+ if err != nil {
+ return nil, err
+ }
+ defer func(Body io.ReadCloser) {
+ _ = Body.Close()
+ }(resp.Body)
+
+ if resp.StatusCode != http.StatusOK {
+ content, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read response body: %w", err)
+ }
+ return nil, fmt.Errorf(
+ "unexpected status code %d: %s",
+ resp.StatusCode,
+ string(content),
+ )
+ }
+
+ node, err := html.Parse(resp.Body)
+ if err != nil {
+ return nil, err
+ }
+
+ cfg, err := extractAuthConfig(node)
+ if err != nil {
+ return nil, fmt.Errorf("failed to extract auth config: %w", err)
+ }
+
+ return cfg, nil
+}
+
+func (c *Client) sendAuthRequest(ctx context.Context, cfg *authConfig, username, password string) (string, string, error) {
+ body := authRequestBody{
+ ClientId: cfg.ClientID,
+ RedirectUri: cfg.CallbackURL,
+ Tenant: "sso",
+ ResponseType: cfg.ExtraParams.ResponseType,
+ Username: username,
+ Password: password,
+ Scope: cfg.ExtraParams.Scope,
+ State: cfg.ExtraParams.State,
+ Headers: struct {
+ XRemoteUser string `json:"x-_remote-_user"`
+ }(struct{ XRemoteUser string }{
+ XRemoteUser: username,
+ }),
+ XOidcProvider: "localop",
+ Protocol: cfg.ExtraParams.Protocol,
+ Nonce: cfg.ExtraParams.Nonce,
+ UiLocales: cfg.ExtraParams.UiLocales,
+ Csrf: cfg.ExtraParams.Csrf,
+ Intstate: cfg.ExtraParams.Intstate,
+ Connection: "DJldap",
+ }
+ bodyJson, err := json.Marshal(body)
+ if err != nil {
+ return "", "", fmt.Errorf("failed to marshal auth request body: %w", err)
+ }
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, authenticateUrl, bytes.NewReader(bodyJson))
+ if err != nil {
+ return "", "", err
+ }
+
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Auth0-Client", "eyJuYW1lIjoiYXV0aDAuanMtdWxwIiwidmVyc2lvbiI6IjkuMjQuMSJ9")
+
+ resp, err := c.Do(req,
+ withRequiredProps(transport.PropertiesReliable),
+ withExpectedStatuses(http.StatusOK, http.StatusUnauthorized))
+ if err != nil {
+ return "", "", err
+ }
+ defer func(Body io.ReadCloser) {
+ _ = Body.Close()
+ }(resp.Body)
+
+ if resp.StatusCode == http.StatusUnauthorized {
+ return "", "", ErrBadCredentials
+ } else if resp.StatusCode != http.StatusOK {
+ content, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return "", "", fmt.Errorf("failed to read response body: %w", err)
+ }
+ return "", "", fmt.Errorf(
+ "unexpected status code %d: %s",
+ resp.StatusCode,
+ string(content),
+ )
+ }
+
+ node, err := html.Parse(resp.Body)
+ if err != nil {
+ return "", "", err
+ }
+
+ return extractTokenParams(node)
+}
+
+func (c *Client) sendPostAuth(ctx context.Context, token, params string) (*http.Cookie, error) {
+ body := fmt.Sprintf("token=%s&params=%s", token, params)
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, postAuthUrl, strings.NewReader(body))
+ if err != nil {
+ return nil, err
+ }
+
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+
+ resp, err := c.Do(req, withRequiredProps(transport.PropertiesReliable))
+ if err != nil {
+ return nil, err
+ }
+ defer func(Body io.ReadCloser) {
+ _ = Body.Close()
+ }(resp.Body)
+
+ if resp.StatusCode != http.StatusOK {
+ content, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read response body: %w", err)
+ }
+ return nil, fmt.Errorf(
+ "unexpected status code %d: %s",
+ resp.StatusCode,
+ string(content),
+ )
+ }
+
+ // Extract cookie
+ for _, cookie := range resp.Cookies() {
+ if cookie.Name == cookieName {
+ return cookie, nil
+ }
+ }
+
+ return nil, ErrAuthCookieNotFound
+}
+
+func extractAuthConfig(node *html.Node) (*authConfig, error) {
+ // Find `root` element
+ root := findId(node, "root")
+ if root == nil {
+ return nil, fmt.Errorf("root element not found")
+ }
+
+ // Get adjacent script element
+ var script *html.Node
+ for s := root.NextSibling; s != nil; s = s.NextSibling {
+ if s.Type == html.ElementNode && s.Data == "script" {
+ script = s
+ break
+ }
+ }
+
+ if script == nil {
+ return nil, fmt.Errorf("script element not found")
+ }
+
+ // Get script content
+ content := extractText(script)
+
+ // Find `AUTH_CONFIG` variable
+ const authConfigVar = "const AUTH_CONFIG = '"
+ i := strings.Index(content, authConfigVar)
+ if i == -1 {
+ return nil, fmt.Errorf("AUTH_CONFIG not found")
+ }
+
+ // Find end of `AUTH_CONFIG` variable
+ j := strings.Index(content[i+len(authConfigVar):], "'")
+
+ // Extract `AUTH_CONFIG` value
+ authConfigJSONB64 := content[i+len(authConfigVar) : i+len(authConfigVar)+j]
+
+ // Decode `AUTH_CONFIG` value
+ authConfigJSON, err := base64.StdEncoding.DecodeString(authConfigJSONB64)
+ if err != nil {
+ return nil, fmt.Errorf("failed to decode AUTH_CONFIG: %w", err)
+ }
+
+ // Unmarshal `AUTH_CONFIG` value
+ var cfg authConfig
+ if err = json.Unmarshal(authConfigJSON, &cfg); err != nil {
+ return nil, fmt.Errorf("failed to unmarshal AUTH_CONFIG: %w", err)
+ }
+
+ return &cfg, nil
+}
+
+type authConfig struct {
+ Auth0Domain string `json:"auth0Domain"`
+ CallbackURL string `json:"callbackURL"`
+ ClientID string `json:"clientID"`
+ ExtraParams struct {
+ Protocol string `json:"protocol"`
+ Scope string `json:"scope"`
+ ResponseType string `json:"response_type"`
+ Nonce string `json:"nonce"`
+ UiLocales string `json:"ui_locales"`
+ Csrf string `json:"_csrf"`
+ Intstate string `json:"_intstate"`
+ State string `json:"state"`
+ } `json:"extraParams"`
+ InternalOptions struct {
+ ResponseType string `json:"response_type"`
+ ClientId string `json:"client_id"`
+ Scope string `json:"scope"`
+ RedirectUri string `json:"redirect_uri"`
+ UiLocales string `json:"ui_locales"`
+ Eurl string `json:"eurl"`
+ Nonce string `json:"nonce"`
+ State string `json:"state"`
+ Resource string `json:"resource"`
+ Protocol string `json:"protocol"`
+ Client string `json:"client"`
+ } `json:"internalOptions"`
+ IsThirdPartyClient bool `json:"isThirdPartyClient"`
+ AuthorizationServer struct {
+ Url string `json:"url"`
+ Issuer string `json:"issuer"`
+ } `json:"authorizationServer"`
+}
+
+func extractTokenParams(node *html.Node) (token string, params string, err error) {
+ inputs := findChildrenRecursive(node, func(node *html.Node) bool {
+ return node.Type == html.ElementNode && node.Data == "input"
+ })
+
+ var tokenNode, paramsNode *html.Node
+ for _, input := range inputs {
+ for _, attr := range input.Attr {
+ if attr.Key == "name" && attr.Val == "token" {
+ tokenNode = input
+ } else if attr.Key == "name" && attr.Val == "params" {
+ paramsNode = input
+ }
+ }
+ }
+
+ if tokenNode == nil {
+ return "", "", fmt.Errorf("token input not found")
+ }
+ if paramsNode == nil {
+ return "", "", fmt.Errorf("params input not found")
+ }
+
+ for _, attr := range tokenNode.Attr {
+ if attr.Key == "value" {
+ token = attr.Val
+ }
+ }
+ for _, attr := range paramsNode.Attr {
+ if attr.Key == "value" {
+ params = attr.Val
+ }
+ }
+
+ return
+}
+
+type authRequestBody struct {
+ ClientId string `json:"client_id"`
+ RedirectUri string `json:"redirect_uri"`
+ Tenant string `json:"tenant"`
+ ResponseType string `json:"response_type"`
+ Username string `json:"username"`
+ Password string `json:"password"`
+ Scope string `json:"scope"`
+ State string `json:"state"`
+ Headers struct {
+ XRemoteUser string `json:"x-_remote-_user"`
+ } `json:"headers"`
+ XOidcProvider string `json:"x-_oidc-_provider"`
+ Protocol string `json:"protocol"`
+ Nonce string `json:"nonce"`
+ UiLocales string `json:"ui_locales"`
+ Csrf string `json:"_csrf"`
+ Intstate string `json:"_intstate"`
+ Connection string `json:"connection"`
+}