diff --git a/component/wallet-cli/pkg/walletrunner/consent/cognito.go b/component/wallet-cli/pkg/walletrunner/consent/cognito.go new file mode 100644 index 000000000..7a2a6213d --- /dev/null +++ b/component/wallet-cli/pkg/walletrunner/consent/cognito.go @@ -0,0 +1,97 @@ +/* +Copyright Avast Software. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 +*/ + +package consent + +import ( + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" +) + +type Cognito struct { + httpClient httpClient + url string + password string + login string + existingCookies map[string]*http.Cookie +} + +func NewCognito(client httpClient, cookies []*http.Cookie, url string, login string, password string) *Cognito { + existing := map[string]*http.Cookie{} + for _, c := range cookies { + existing[c.Name] = c + } + + return &Cognito{ + url: url, + existingCookies: existing, + httpClient: client, + login: login, + password: password, + } +} + +func (c *Cognito) Execute() error { + getReq, err := http.NewRequest(http.MethodGet, c.url, nil) + if err != nil { + return err + } + + getResp, err := c.httpClient.Do(getReq) + if err != nil { + return err + } + + data := url.Values{} + data.Set("username", c.login) + data.Set("password", c.password) + data.Add("signInSubmitButton", "Sign in") + + for _, cookie := range getResp.Cookies() { + c.existingCookies[cookie.Name] = cookie + } + + xsrf, ok := c.existingCookies["XSRF-TOKEN"] + if !ok { + return errors.New("XSRF-TOKEN cookie not found") + } + + data.Add("_csrf", xsrf.Value) + + postReq, err := http.NewRequest(http.MethodPost, c.url, strings.NewReader(data.Encode())) + if err != nil { + return err + } + for _, cookie := range c.existingCookies { + postReq.AddCookie(cookie) + } + + postReq.Header.Add("Content-Type", "application/x-www-form-urlencoded") + + postResp, postErr := c.httpClient.Do(postReq) + if postErr != nil { + return postErr + } + + if postResp.StatusCode != http.StatusFound { + var body []byte + if postResp.Body != nil { + body, _ = io.ReadAll(postResp.Body) + defer func() { + _ = postResp.Body.Close() + }() + } + + return fmt.Errorf("unexpected status code from post cognito. %v with body %s", + postResp.StatusCode, body) + } + + return nil +} diff --git a/component/wallet-cli/pkg/walletrunner/consent/cognito_test.go b/component/wallet-cli/pkg/walletrunner/consent/cognito_test.go new file mode 100644 index 000000000..30c432af9 --- /dev/null +++ b/component/wallet-cli/pkg/walletrunner/consent/cognito_test.go @@ -0,0 +1,191 @@ +/* +Copyright Avast Software. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 +*/ + +package consent_test + +import ( + "errors" + "io" + "net/http" + "strings" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + + "github.com/trustbloc/vcs/component/wallet-cli/pkg/walletrunner/consent" +) + +func TestCognitoConsent(t *testing.T) { + t.Run("success", func(t *testing.T) { + cl := NewMockhttpClient(gomock.NewController(t)) + + targetURL := "https://example.auth.us-east-2.amazoncognito.com/login?client_id=example&redirect_uri=" + + "https%3A%2F%2Fexample-redirect.com%2Fvcs%2Foidc%2Fredirect&response_type=code&" + + "state=9bc93ec1-7bdd-4084-8948-299ef35adab8" + + ct := consent.NewCognito( + cl, + []*http.Cookie{ + { + Name: "XSRF-TOKEN", + Value: "abcd", + }, + }, + targetURL, + "some-login", + "some-password", + ) + + cl.EXPECT().Do(gomock.Any()).DoAndReturn(func(request *http.Request) (*http.Response, error) { + assert.Equal(t, http.MethodGet, request.Method) + assert.Equal(t, targetURL, request.URL.String()) + + return &http.Response{ + Header: map[string][]string{ + "Set-Cookie": { + "XSRF-TOKEN=8f6cafbe-34c3-4c96-b53b-47c798297e79; Path=/; Secure; HttpOnly; SameSite=Lax", + }, + }, + }, nil + }) + + cl.EXPECT().Do(gomock.Any()).DoAndReturn(func(request *http.Request) (*http.Response, error) { + assert.Equal(t, http.MethodPost, request.Method) + assert.Equal(t, targetURL, request.URL.String()) + assert.Equal(t, "XSRF-TOKEN", request.Cookies()[0].Name) + assert.Equal(t, "8f6cafbe-34c3-4c96-b53b-47c798297e79", request.Cookies()[0].Value) + assert.NoError(t, request.ParseForm()) + + assert.Equal(t, "some-login", request.Form.Get("username")) + assert.Equal(t, "Sign in", request.Form.Get("signInSubmitButton")) + assert.Equal(t, "some-password", request.Form.Get("password")) + assert.Equal(t, "8f6cafbe-34c3-4c96-b53b-47c798297e79", request.Form.Get("_csrf")) + return &http.Response{ + StatusCode: http.StatusFound, + }, nil + }) + + assert.NoError(t, ct.Execute()) + }) + + t.Run("fail get", func(t *testing.T) { + cl := NewMockhttpClient(gomock.NewController(t)) + + targetURL := "https://example.auth.us-east-2.amazoncognito.com/login?client_id=example&redirect_uri=https%3A%" + + "2F%2Fexample-redirect.com%2Fvcs%2Foidc%2Fredirect&response_type=code&" + + "state=9bc93ec1-7bdd-4084-8948-299ef35adab8" + + ct := consent.NewCognito( + cl, + []*http.Cookie{ + { + Name: "XSRF-TOKEN", + Value: "abcd", + }, + }, + targetURL, + "some-login", + "some-password", + ) + + cl.EXPECT().Do(gomock.Any()).DoAndReturn(func(request *http.Request) (*http.Response, error) { + assert.Equal(t, http.MethodGet, request.Method) + assert.Equal(t, targetURL, request.URL.String()) + + return nil, errors.New("some error") + }) + + assert.ErrorContains(t, ct.Execute(), "some error") + }) + + t.Run("missing csrf", func(t *testing.T) { + cl := NewMockhttpClient(gomock.NewController(t)) + + targetURL := "https://example.auth.us-east-2.amazoncognito.com/login?client_id=example&redirect_uri=https%3A%" + + "2F%2Fexample-redirect.com%2Fvcs%2Foidc%2Fredirect&response_type=code&" + + "state=9bc93ec1-7bdd-4084-8948-299ef35adab8" + + ct := consent.NewCognito( + cl, + nil, + targetURL, + "some-login", + "some-password", + ) + + cl.EXPECT().Do(gomock.Any()).DoAndReturn(func(request *http.Request) (*http.Response, error) { + return &http.Response{}, nil + }) + + assert.ErrorContains(t, ct.Execute(), "XSRF-TOKEN cookie not found") + }) + + t.Run("fail post", func(t *testing.T) { + cl := NewMockhttpClient(gomock.NewController(t)) + + targetURL := "https://example.auth.us-east-2.amazoncognito.com/login?client_id=example&redirect_uri=https%3A%" + + "2F%2Fexample-redirect.com%2Fvcs%2Foidc%2Fredirect&response_type=code&" + + "state=9bc93ec1-7bdd-4084-8948-299ef35adab8" + + ct := consent.NewCognito( + cl, + []*http.Cookie{ + { + Name: "XSRF-TOKEN", + Value: "abcd", + }, + }, + targetURL, + "some-login", + "some-password", + ) + + cl.EXPECT().Do(gomock.Any()).DoAndReturn(func(request *http.Request) (*http.Response, error) { + return &http.Response{}, nil + }) + + cl.EXPECT().Do(gomock.Any()).DoAndReturn(func(request *http.Request) (*http.Response, error) { + return nil, errors.New("post failed") + }) + + assert.ErrorContains(t, ct.Execute(), "post failed") + }) + + t.Run("fail invalid status code", func(t *testing.T) { + cl := NewMockhttpClient(gomock.NewController(t)) + + targetURL := "https://example.auth.us-east-2.amazoncognito.com/login?client_id=example&redirect_uri=https%3A%" + + "2F%2Fexample-redirect.com%2Fvcs%2Foidc%2Fredirect&response_type=code&" + + "state=9bc93ec1-7bdd-4084-8948-299ef35adab8" + + ct := consent.NewCognito( + cl, + []*http.Cookie{ + { + Name: "XSRF-TOKEN", + Value: "abcd", + }, + }, + targetURL, + "some-login", + "some-password", + ) + + cl.EXPECT().Do(gomock.Any()).DoAndReturn(func(request *http.Request) (*http.Response, error) { + return &http.Response{}, nil + }) + + cl.EXPECT().Do(gomock.Any()).DoAndReturn(func(request *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusMultiStatus, + Body: io.NopCloser(strings.NewReader(`some random text`)), + }, nil + }) + + assert.ErrorContains(t, ct.Execute(), "unexpected status code from post cognito. 207 with body some random text") + }) +} diff --git a/component/wallet-cli/pkg/walletrunner/consent/interfaces.go b/component/wallet-cli/pkg/walletrunner/consent/interfaces.go new file mode 100644 index 000000000..9f256b66a --- /dev/null +++ b/component/wallet-cli/pkg/walletrunner/consent/interfaces.go @@ -0,0 +1,14 @@ +/* +Copyright Avast Software. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 +*/ + +package consent + +import "net/http" + +//go:generate mockgen -destination interfaces_mocks_test.go -package consent_test -source=interfaces.go +type httpClient interface { + Do(req *http.Request) (*http.Response, error) +} diff --git a/component/wallet-cli/pkg/walletrunner/wallet_runner_oidc4ci.go b/component/wallet-cli/pkg/walletrunner/wallet_runner_oidc4ci.go index ef2a77c3a..426c77492 100644 --- a/component/wallet-cli/pkg/walletrunner/wallet_runner_oidc4ci.go +++ b/component/wallet-cli/pkg/walletrunner/wallet_runner_oidc4ci.go @@ -30,6 +30,7 @@ import ( "golang.org/x/oauth2" "github.com/trustbloc/vcs/component/wallet-cli/pkg/credentialoffer" + "github.com/trustbloc/vcs/component/wallet-cli/pkg/walletrunner/consent" "github.com/trustbloc/vcs/pkg/kms/signer" "github.com/trustbloc/vcs/pkg/restapi/v1/common" issuerv1 "github.com/trustbloc/vcs/pkg/restapi/v1/issuer" @@ -47,11 +48,16 @@ type OIDC4CIConfig struct { Password string } -func (s *Service) RunOIDC4CI(config *OIDC4CIConfig) error { +func (s *Service) RunOIDC4CI( + config *OIDC4CIConfig, +) error { log.Println("Starting OIDC4VCI authorized code flow") log.Printf("Initiate issuance URL:\n\n\t%s\n\n", config.InitiateIssuanceURL) - offerResponse, err := credentialoffer.ParseInitiateIssuanceUrl(config.InitiateIssuanceURL, s.httpClient) + offerResponse, err := credentialoffer.ParseInitiateIssuanceUrl( + config.InitiateIssuanceURL, + s.httpClient, + ) if err != nil { return fmt.Errorf("parse initiate issuance url: %w", err) } @@ -62,7 +68,9 @@ func (s *Service) RunOIDC4CI(config *OIDC4CIConfig) error { return fmt.Errorf("get issuer oidc config: %w", err) } - oidcIssuerCredentialConfig, err := s.getIssuerCredentialsOIDCConfig(offerResponse.CredentialIssuer) + oidcIssuerCredentialConfig, err := s.getIssuerCredentialsOIDCConfig( + offerResponse.CredentialIssuer, + ) if err != nil { return fmt.Errorf("get issuer oidc issuer config: %w", err) } @@ -80,7 +88,11 @@ func (s *Service) RunOIDC4CI(config *OIDC4CIConfig) error { return fmt.Errorf("listen: %w", err) } - redirectURL.Host = fmt.Sprintf("%s:%d", redirectURL.Hostname(), listener.Addr().(*net.TCPAddr).Port) + redirectURL.Host = fmt.Sprintf( + "%s:%d", + redirectURL.Hostname(), + listener.Addr().(*net.TCPAddr).Port, + ) } s.oauthClient = &oauth2.Config{ @@ -152,7 +164,11 @@ func (s *Service) RunOIDC4CI(config *OIDC4CIConfig) error { } s.print("Getting credential") - vc, _, err := s.getCredential(oidcIssuerCredentialConfig.CredentialEndpoint, config.CredentialType, config.CredentialFormat) + vc, _, err := s.getCredential( + oidcIssuerCredentialConfig.CredentialEndpoint, + config.CredentialType, + config.CredentialFormat, + ) if err != nil { return fmt.Errorf("get credential: %w", err) } @@ -175,7 +191,11 @@ func (s *Service) RunOIDC4CI(config *OIDC4CIConfig) error { return fmt.Errorf("parse vc: %w", err) } - log.Printf("Credential with ID [%s] and type [%v] added successfully", vcParsed.ID, config.CredentialType) + log.Printf( + "Credential with ID [%s] and type [%v] added successfully", + vcParsed.ID, + config.CredentialType, + ) if !s.keepWalletOpen { s.wallet.Close() @@ -184,7 +204,9 @@ func (s *Service) RunOIDC4CI(config *OIDC4CIConfig) error { return nil } -func (s *Service) getIssuerOIDCConfig(issuerURL string) (*issuerv1.WellKnownOpenIDConfiguration, error) { +func (s *Service) getIssuerOIDCConfig( + issuerURL string, +) (*issuerv1.WellKnownOpenIDConfiguration, error) { // GET /issuer/{profileID}/{profileVersion}/.well-known/openid-configuration resp, err := s.httpClient.Get(issuerURL + "/.well-known/openid-configuration") if err != nil { @@ -206,7 +228,9 @@ func (s *Service) getIssuerOIDCConfig(issuerURL string) (*issuerv1.WellKnownOpen return &oidcConfig, nil } -func (s *Service) getIssuerCredentialsOIDCConfig(issuerURL string) (*issuerv1.WellKnownOpenIDIssuerConfiguration, error) { +func (s *Service) getIssuerCredentialsOIDCConfig( + issuerURL string, +) (*issuerv1.WellKnownOpenIDIssuerConfiguration, error) { // GET /issuer/{profileID}/.well-known/openid-credential-issuer resp, err := s.httpClient.Get(issuerURL + "/.well-known/openid-credential-issuer") if err != nil { @@ -228,23 +252,37 @@ func (s *Service) getIssuerCredentialsOIDCConfig(issuerURL string) (*issuerv1.We return &oidcConfig, nil } -func (s *Service) getAuthCode(config *OIDC4CIConfig, authCodeURL string) (string, error) { +func (s *Service) getAuthCode( + config *OIDC4CIConfig, + authCodeURL string, +) (string, error) { //var loginURL, consentURL *url.URL var authCode string httpClient := &http.Client{ Jar: s.httpClient.Jar, Transport: s.httpClient.Transport, - CheckRedirect: func(req *http.Request, via []*http.Request) error { - // intercept client auth code - if strings.HasPrefix(req.URL.String(), config.RedirectURI) { - authCode = req.URL.Query().Get("code") + } - return http.ErrUseLastResponse - } + httpClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { + if strings.Contains(req.URL.String(), ".amazoncognito.com/login") { + s.print("got cognito consent screen") + return consent.NewCognito( + httpClient, + httpClient.Jar.Cookies(req.URL), + req.URL.String(), + config.Login, config.Password, + ).Execute() + } - return nil - }, + // intercept client auth code + if strings.HasPrefix(req.URL.String(), config.RedirectURI) { + authCode = req.URL.Query().Get("code") + + return http.ErrUseLastResponse + } + + return nil } s.print("Getting authorization code") @@ -257,7 +295,10 @@ func (s *Service) getAuthCode(config *OIDC4CIConfig, authCodeURL string) (string return authCode, nil } -func (s *Service) getAuthCodeFromBrowser(listener net.Listener, authCodeURL string) (string, error) { +func (s *Service) getAuthCodeFromBrowser( + listener net.Listener, + authCodeURL string, +) (string, error) { server := &callbackServer{ listener: listener, codeChan: make(chan string, 1), @@ -267,7 +308,10 @@ func (s *Service) getAuthCodeFromBrowser(listener net.Listener, authCodeURL stri http.Serve(listener, server) }() - log.Printf("Log in with a browser:\n\n%s\n\nor press [Enter] to open link in your default browser\n", authCodeURL) + log.Printf( + "Log in with a browser:\n\n%s\n\nor press [Enter] to open link in your default browser\n", + authCodeURL, + ) done := make(chan struct{}) @@ -298,7 +342,13 @@ func (s *Service) getCredential( didKeyID := s.vcProviderConf.WalletParams.DidKeyID[0] - kmsSigner, err := signer.NewKMSSigner(km, cr, strings.Split(didKeyID, "#")[1], s.vcProviderConf.WalletParams.SignType, nil) + kmsSigner, err := signer.NewKMSSigner( + km, + cr, + strings.Split(didKeyID, "#")[1], + s.vcProviderConf.WalletParams.SignType, + nil, + ) if err != nil { return nil, 0, fmt.Errorf("create kms signer: %w", err) } @@ -364,7 +414,11 @@ func (s *Service) getCredential( if resp.StatusCode != http.StatusOK { b, _ := io.ReadAll(resp.Body) - return nil, finalDuration, fmt.Errorf("get credential: status %s and body %s", resp.Status, string(b)) + return nil, finalDuration, fmt.Errorf( + "get credential: status %s and body %s", + resp.Status, + string(b), + ) } var credentialResp CredentialResponse @@ -376,7 +430,9 @@ func (s *Service) getCredential( return credentialResp.Credential, finalDuration, nil } -func (s *Service) print(msg string) { +func (s *Service) print( + msg string, +) { if s.debug { fmt.Println() } @@ -384,7 +440,9 @@ func (s *Service) print(msg string) { log.Printf("%s\n\n", msg) } -func waitForEnter(done chan<- struct{}) { +func waitForEnter( + done chan<- struct{}, +) { _, _ = fmt.Scanln() done <- struct{}{} } @@ -394,7 +452,10 @@ type callbackServer struct { codeChan chan string } -func (s *callbackServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { +func (s *callbackServer) ServeHTTP( + w http.ResponseWriter, + r *http.Request, +) { if r.URL.Path != "/callback" { http.NotFound(w, r)