diff --git a/cmd/engflow_auth/BUILD b/cmd/engflow_auth/BUILD index 58071d5..0d2b5ff 100644 --- a/cmd/engflow_auth/BUILD +++ b/cmd/engflow_auth/BUILD @@ -7,10 +7,10 @@ go_library( importpath = "github.com/EngFlow/auth/cmd/engflow_auth", visibility = ["//visibility:private"], deps = [ + "//internal/auth", "//internal/autherr", "//internal/browser", "//internal/buildstamp", - "//internal/oauthdevice", "//internal/oauthtoken", "@com_github_engflow_credential_helper_go//:credential-helper-go", "@com_github_urfave_cli_v2//:cli", @@ -29,9 +29,8 @@ go_test( srcs = ["main_test.go"], embed = [":engflow_auth_lib"], deps = [ + "//internal/auth", "//internal/autherr", - "//internal/browser", - "//internal/oauthdevice", "//internal/oauthtoken", "@com_github_stretchr_testify//assert", "@org_golang_x_oauth2//:oauth2", diff --git a/cmd/engflow_auth/main.go b/cmd/engflow_auth/main.go index 61f310e..6271c5a 100644 --- a/cmd/engflow_auth/main.go +++ b/cmd/engflow_auth/main.go @@ -26,10 +26,10 @@ import ( "strings" "time" + "github.com/EngFlow/auth/internal/auth" "github.com/EngFlow/auth/internal/autherr" "github.com/EngFlow/auth/internal/browser" "github.com/EngFlow/auth/internal/buildstamp" - "github.com/EngFlow/auth/internal/oauthdevice" "github.com/EngFlow/auth/internal/oauthtoken" "github.com/urfave/cli/v2" @@ -44,8 +44,7 @@ const ( ) type appState struct { - browserOpener browser.Opener - authenticator oauthdevice.Authenticator + authenticator auth.Backend tokenStore oauthtoken.LoadStorer } @@ -194,7 +193,7 @@ func (r *appState) login(cliCtx *cli.Context) error { } } - authRes, err := r.authenticator.FetchCode(ctx, oauthURL) + token, err := r.authenticator.Authenticate(ctx, clusterURL) if err != nil { if oauthErr := (*oauth2.RetrieveError)(nil); errors.Is(err, autherr.UnexpectedHTML) || errors.As(err, &oauthErr) { return autherr.CodedErrorf( @@ -209,24 +208,6 @@ Visit %s for help.`, } return autherr.CodedErrorf(autherr.CodeAuthFailure, "failed to generate device code: %w", err) } - // The "complete" URI that includes the device code pre-populated is ideal, - // but technically optional. Prefer it, but fall back to the required URL in - // the response if necessary. - verificationURLStr := authRes.VerificationURIComplete - if verificationURLStr == "" { - verificationURLStr = authRes.VerificationURI - } - verificationURL, err := url.Parse(verificationURLStr) - if err != nil { - return autherr.CodedErrorf(autherr.CodeAuthFailure, "failed to parse authentication URL: %w", err) - } - if err := r.browserOpener.Open(verificationURL); err != nil { - return autherr.CodedErrorf(autherr.CodeAuthFailure, "failed to open browser to perform authentication: %w", err) - } - token, err := r.authenticator.FetchToken(ctx, authRes) - if err != nil { - return autherr.CodedErrorf(autherr.CodeAuthFailure, "failed to obtain auth token: %w", err) - } var storeErrs []error for _, storeURL := range storeURLs { @@ -356,14 +337,12 @@ func main() { ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) defer cancel() - deviceAuth := oauthdevice.NewAuth(cliClientID, nil) - browserOpener := &browser.StderrPrint{} + deviceAuth := auth.NewDeviceCode(&browser.StderrPrint{}, cliClientID, nil) tokenStore, err := oauthtoken.NewKeyring() if err != nil { exitOnError(autherr.CodedErrorf(autherr.CodeTokenStoreFailure, "failed to open token store: %w", err)) } root := &appState{ - browserOpener: browserOpener, authenticator: deviceAuth, tokenStore: oauthtoken.NewCacheAlert(tokenStore, os.Stderr), } diff --git a/cmd/engflow_auth/main_test.go b/cmd/engflow_auth/main_test.go index cff190b..3b74c64 100644 --- a/cmd/engflow_auth/main_test.go +++ b/cmd/engflow_auth/main_test.go @@ -25,9 +25,8 @@ import ( "testing" "time" + "github.com/EngFlow/auth/internal/auth" "github.com/EngFlow/auth/internal/autherr" - "github.com/EngFlow/auth/internal/browser" - "github.com/EngFlow/auth/internal/oauthdevice" "github.com/EngFlow/auth/internal/oauthtoken" "github.com/stretchr/testify/assert" "golang.org/x/oauth2" @@ -78,27 +77,13 @@ func codedErrorContains(t *testing.T, gotErr error, code int, wantMsg string) bo } type fakeAuth struct { - res *oauth2.DeviceAuthResponse - fetchCodeErr error fetchTokenErr error } -func (f *fakeAuth) FetchCode(ctx context.Context, authEndpint *oauth2.Endpoint) (*oauth2.DeviceAuthResponse, error) { - return f.res, f.fetchCodeErr -} - -func (f *fakeAuth) FetchToken(ctx context.Context, authRes *oauth2.DeviceAuthResponse) (*oauth2.Token, error) { +func (f *fakeAuth) Authenticate(ctx context.Context, host *url.URL) (*oauth2.Token, error) { return nil, f.fetchTokenErr } -type fakeBrowser struct { - openErr error -} - -func (f *fakeBrowser) Open(u *url.URL) error { - return f.openErr -} - func TestRun(t *testing.T) { expiresInFuture := time.Now().AddDate(0, 0, 7).UTC() expiryFormat := "2006-01-02T15:04:05Z" @@ -108,9 +93,8 @@ func TestRun(t *testing.T) { args []string machineInput io.Reader - authenticator oauthdevice.Authenticator + authenticator auth.Backend tokenStore oauthtoken.LoadStorer - browserOpener browser.Opener wantCode int wantErr string @@ -211,20 +195,10 @@ func TestRun(t *testing.T) { { desc: "login happy path", args: []string{"login", "cluster.example.com"}, - authenticator: &fakeAuth{ - res: &oauth2.DeviceAuthResponse{ - VerificationURIComplete: "https://cluster.example.com/with/auth/code", - }, - }, }, { - desc: "login with alias", - args: []string{"login", "--alias", "cluster.local.example.com", "cluster.example.com"}, - authenticator: &fakeAuth{ - res: &oauth2.DeviceAuthResponse{ - VerificationURIComplete: "https://cluster.example.com/with/auth/code", - }, - }, + desc: "login with alias", + args: []string{"login", "--alias", "cluster.local.example.com", "cluster.example.com"}, tokenStore: oauthtoken.NewFakeTokenStore(), wantStored: []string{ "cluster.example.com", @@ -234,11 +208,6 @@ func TestRun(t *testing.T) { { desc: "login with alias with store errors", args: []string{"login", "--alias", "cluster.local.example.com", "cluster.example.com"}, - authenticator: &fakeAuth{ - res: &oauth2.DeviceAuthResponse{ - VerificationURIComplete: "https://cluster.example.com/with/auth/code", - }, - }, tokenStore: &oauthtoken.FakeTokenStore{ StoreErr: errors.New("token_store_fail"), }, @@ -248,11 +217,6 @@ func TestRun(t *testing.T) { { desc: "login with host and port", args: []string{"login", "cluster.example.com:8080"}, - authenticator: &fakeAuth{ - res: &oauth2.DeviceAuthResponse{ - VerificationURIComplete: "https://cluster.example.com:8080/with/auth/code", - }, - }, }, { desc: "login with invalid scheme", @@ -260,26 +224,11 @@ func TestRun(t *testing.T) { wantCode: autherr.CodeBadParams, wantErr: "illegal scheme", }, - { - desc: "login code fetch failure", - args: []string{"login", "cluster.example.com"}, - authenticator: &fakeAuth{ - res: &oauth2.DeviceAuthResponse{ - VerificationURIComplete: "https://cluster.example.com/with/auth/code", - }, - fetchCodeErr: errors.New("fetch_code_fail"), - }, - wantCode: autherr.CodeAuthFailure, - wantErr: "fetch_code_fail", - }, { desc: "login code fetch RetrieveError", args: []string{"login", "cluster.example.com"}, authenticator: &fakeAuth{ - res: &oauth2.DeviceAuthResponse{ - VerificationURIComplete: "https://cluster.example.com/with/auth/code", - }, - fetchCodeErr: &oauth2.RetrieveError{}, + fetchTokenErr: &oauth2.RetrieveError{}, }, wantCode: autherr.CodeAuthFailure, wantErr: "This cluster may not support 'engflow_auth login'.\nVisit https://cluster.example.com/gettingstarted for help.", @@ -288,35 +237,15 @@ func TestRun(t *testing.T) { desc: "login code fetch unexpected HTML", args: []string{"login", "cluster.example.com"}, authenticator: &fakeAuth{ - res: &oauth2.DeviceAuthResponse{ - VerificationURIComplete: "https://cluster.example.com/with/auth/code", - }, - fetchCodeErr: autherr.UnexpectedHTML, + fetchTokenErr: autherr.UnexpectedHTML, }, wantCode: autherr.CodeAuthFailure, wantErr: "This cluster may not support 'engflow_auth login'.\nVisit https://cluster.example.com/gettingstarted for help.", }, - { - desc: "login browser open failure", - args: []string{"login", "cluster.example.com"}, - authenticator: &fakeAuth{ - res: &oauth2.DeviceAuthResponse{ - VerificationURIComplete: "https://cluster.example.com/with/auth/code", - }, - }, - browserOpener: &fakeBrowser{ - openErr: errors.New("browser_open_fail"), - }, - wantCode: autherr.CodeAuthFailure, - wantErr: "browser_open_fail", - }, { desc: "login token fetch failure", args: []string{"login", "cluster.example.com"}, authenticator: &fakeAuth{ - res: &oauth2.DeviceAuthResponse{ - VerificationURIComplete: "https://cluster.example.com/with/auth/code", - }, fetchTokenErr: errors.New("fetch_token_fail"), }, wantCode: autherr.CodeAuthFailure, @@ -325,11 +254,6 @@ func TestRun(t *testing.T) { { desc: "login token store failure", args: []string{"login", "cluster.example.com"}, - authenticator: &fakeAuth{ - res: &oauth2.DeviceAuthResponse{ - VerificationURIComplete: "https://cluster.example.com/with/auth/code", - }, - }, tokenStore: &oauthtoken.FakeTokenStore{ StoreErr: errors.New("token_store_fail"), }, @@ -494,13 +418,9 @@ func TestRun(t *testing.T) { stderr := bytes.NewBuffer(nil) root := &appState{ - browserOpener: tc.browserOpener, authenticator: tc.authenticator, tokenStore: tc.tokenStore, } - if root.browserOpener == nil { - root.browserOpener = &fakeBrowser{} - } if root.authenticator == nil { root.authenticator = &fakeAuth{} } diff --git a/go.mod b/go.mod index 9721f69..d7a7fba 100644 --- a/go.mod +++ b/go.mod @@ -20,6 +20,7 @@ require ( github.com/godbus/dbus/v5 v5.1.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect + github.com/stretchr/objx v0.5.2 // indirect github.com/xrash/smetrics v0.0.0-20240312152122-5f08fbb34913 // indirect golang.org/x/sys v0.8.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/internal/auth/BUILD b/internal/auth/BUILD new file mode 100644 index 0000000..0ee4c63 --- /dev/null +++ b/internal/auth/BUILD @@ -0,0 +1,25 @@ +load("@rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "auth", + srcs = ["authenticator.go"], + importpath = "github.com/EngFlow/auth/internal/auth", + visibility = ["//:__subpackages__"], + deps = [ + "//internal/autherr", + "//internal/browser", + "@org_golang_x_oauth2//:oauth2", + ], +) + +go_test( + name = "auth_test", + srcs = ["authenticator_test.go"], + embed = [":auth"], + deps = [ + "//internal/browser", + "@com_github_stretchr_testify//assert", + "@com_github_stretchr_testify//mock", + "@org_golang_x_oauth2//:oauth2", + ], +) diff --git a/internal/auth/authenticator.go b/internal/auth/authenticator.go new file mode 100644 index 0000000..c402f92 --- /dev/null +++ b/internal/auth/authenticator.go @@ -0,0 +1,125 @@ +// Copyright 2024 EngFlow Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package auth + +import ( + "context" + "errors" + "net/http" + "net/url" + "strings" + + "golang.org/x/oauth2" + + "github.com/EngFlow/auth/internal/autherr" + "github.com/EngFlow/auth/internal/browser" +) + +var errUnexpectedHTML = errors.New("request to JSON API returned HTML unexpectedly") + +// Backend implementations specify the entire authentication flow to obtain an +// OAuth2 token given a backend endpoint. +type Backend interface { + // Authenticate returns a valid token for the specified endpoint. + Authenticate(ctx context.Context, host *url.URL) (*oauth2.Token, error) +} + +// DeviceCode implements Backend via the "OAuth2 device flow", which involves: +// * obtaining an ephemeral device code +// * instructing the user to log in via a browser, and enter the device code +// * polling the server for completion, and fetching the generated token +type DeviceCode struct { + browserOpener browser.Opener + clientID string + scopes []string + // httpTransport sets the behavior of HTTP calls under unit tests; it is + // intentionally unexported so that only package-level tests can hook HTTP + // calls. + httpTransport http.RoundTripper +} + +func NewDeviceCode(browserOpener browser.Opener, clientID string, scopes []string) *DeviceCode { + return &DeviceCode{ + browserOpener: browserOpener, + clientID: clientID, + scopes: scopes, + // Explicitly leave this unset, to ensure the default HTTP transport is + // used in non-test usecases. + httpTransport: nil, + } +} + +func (d *DeviceCode) Authenticate(ctx context.Context, host *url.URL) (*oauth2.Token, error) { + // Under tests, the HTTP transport might be set in order to stub out network + // calls. If this is the case, ensure the oauth2 library is using it; the + // library API around this is that it discovers a client via the context, or + // uses some default if none is set (currently http.DefaultTransport). + if d.httpTransport != nil { + client := &http.Client{ + Transport: d.httpTransport, + } + ctx = context.WithValue(ctx, oauth2.HTTPClient, client) + } + + config := &oauth2.Config{ + ClientID: d.clientID, + Scopes: d.scopes, + Endpoint: oauth2.Endpoint{ + DeviceAuthURL: urlWithPath(host, "api/v1/oauth2/device").String(), + TokenURL: urlWithPath(host, "api/v1/oauth2/token").String(), + AuthStyle: oauth2.AuthStyleInParams, + }, + } + res, err := config.DeviceAuth(ctx) + if err != nil { + if oauthErr := (*oauth2.RetrieveError)(nil); errors.As(err, &oauthErr) { + return nil, err + } + // BUG(CUS-320): Older versions of engflow backends sometimes respond to + // requests with HTTP 200 and an HTML body, which confuses the oauth2 + // library. This shouldn't happen anymore (newer versions return a 404 + // when oauth2 is not supported) but we still guard against it until no + // old versions of backends are running anywhere. + if strings.Contains(err.Error(), "invalid character '<'") { + return nil, errUnexpectedHTML + } + // Default error handling + return nil, err + } + + // The "complete" URI that includes the device code pre-populated is ideal, + // but technically optional. Prefer it, but fall back to the required URL in + // the response if necessary. + verificationURLStr := res.VerificationURIComplete + if verificationURLStr == "" { + verificationURLStr = res.VerificationURI + } + verificationURL, err := url.Parse(verificationURLStr) + if err != nil { + return nil, autherr.CodedErrorf(autherr.CodeAuthFailure, "failed to parse authentication URL: %w", err) + } + if err := d.browserOpener.Open(verificationURL); err != nil { + return nil, autherr.CodedErrorf(autherr.CodeAuthFailure, "failed to open browser to perform authentication: %w", err) + } + + return config.DeviceAccessToken(ctx, res) +} + +func urlWithPath(u *url.URL, path string) *url.URL { + newURL := &url.URL{} + *newURL = *u + newURL.Path = path + return newURL +} diff --git a/internal/auth/authenticator_test.go b/internal/auth/authenticator_test.go new file mode 100644 index 0000000..4cc996f --- /dev/null +++ b/internal/auth/authenticator_test.go @@ -0,0 +1,158 @@ +// Copyright 2024 EngFlow Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package auth + +import ( + "context" + "errors" + "io" + "net/http" + "net/url" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "golang.org/x/oauth2" +) + +func assertErrorContains(t *testing.T, got error, want string) { + t.Helper() + if want == "" { + assert.NoError(t, got) + } else { + assert.ErrorContains(t, got, want) + } +} + +type mockTransport struct { + mock.Mock +} + +func (m *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) { + args := m.Called(req) + return args.Get(0).(*http.Response), args.Error(1) +} + +func requestTargetMatches(url string) any { + return mock.MatchedBy(func(req *http.Request) bool { + return req.URL.String() == url + }) +} + +func httpResponse(code int, body string) *http.Response { + return &http.Response{ + StatusCode: code, + Body: io.NopCloser(strings.NewReader(body)), + } +} + +type mockOpener struct { + mock.Mock +} + +func (m *mockOpener) Open(u *url.URL) error { + args := m.Called(u) + return args.Error(0) +} + +func TestDeviceCode(t *testing.T) { + testCases := []struct { + desc string + browserOpenErr error + codeFetchResponse *http.Response + codeFetchErr error + tokenFetchResponse *http.Response + tokenFetchErr error + + wantToken *oauth2.Token + wantErr string + }{ + { + desc: "successful auth", + codeFetchResponse: httpResponse(200, `{"device_code":"75ba408f-fdf0-469a-a56e-b9a3a698f8b3","verification_uri":"https://oauth2.example.com/login?deviceCode\u003d75ba408f-fdf0-469a-a56e-b9a3a698f8b3","verification_uri_complete":"https://oauth2.example.com/login?deviceCode\u003d75ba408f-fdf0-469a-a56e-b9a3a698f8b3\u0026userCode\u003dKLJQ-OQGG","user_code":"KLJQ-OQGG","expires_in":300,"interval":1}`), + tokenFetchResponse: httpResponse(200, `{"access_token":"yippeekiyay","expires_in":7776000}`), + wantToken: &oauth2.Token{ + AccessToken: "yippeekiyay", + }, + }, + { + desc: "code fetch failure", + codeFetchErr: errors.New("code_fetch_failure"), + + wantErr: "code_fetch_failure", + }, + { + desc: "code fetch http error", + codeFetchResponse: httpResponse(403, ``), + + wantErr: "oauth2: cannot fetch token", + }, + { + desc: "browser open error", + codeFetchResponse: httpResponse(200, `{"device_code":"75ba408f-fdf0-469a-a56e-b9a3a698f8b3","verification_uri":"https://oauth2.example.com/login?deviceCode\u003d75ba408f-fdf0-469a-a56e-b9a3a698f8b3","verification_uri_complete":"https://oauth2.example.com/login?deviceCode\u003d75ba408f-fdf0-469a-a56e-b9a3a698f8b3\u0026userCode\u003dKLJQ-OQGG","user_code":"KLJQ-OQGG","expires_in":300,"interval":1}`), + browserOpenErr: errors.New("browser_open_failure"), + + wantErr: "browser_open_failure", + }, + { + desc: "token fetch failure", + codeFetchResponse: httpResponse(200, `{"device_code":"75ba408f-fdf0-469a-a56e-b9a3a698f8b3","verification_uri":"https://oauth2.example.com/login?deviceCode\u003d75ba408f-fdf0-469a-a56e-b9a3a698f8b3","verification_uri_complete":"https://oauth2.example.com/login?deviceCode\u003d75ba408f-fdf0-469a-a56e-b9a3a698f8b3\u0026userCode\u003dKLJQ-OQGG","user_code":"KLJQ-OQGG","expires_in":300,"interval":1}`), + tokenFetchErr: errors.New("token_fetch_failure"), + + wantErr: "token_fetch_failure", + }, + { + desc: "token fetch http error", + codeFetchResponse: httpResponse(200, `{"device_code":"75ba408f-fdf0-469a-a56e-b9a3a698f8b3","verification_uri":"https://oauth2.example.com/login?deviceCode\u003d75ba408f-fdf0-469a-a56e-b9a3a698f8b3","verification_uri_complete":"https://oauth2.example.com/login?deviceCode\u003d75ba408f-fdf0-469a-a56e-b9a3a698f8b3\u0026userCode\u003dKLJQ-OQGG","user_code":"KLJQ-OQGG","expires_in":300,"interval":1}`), + tokenFetchResponse: httpResponse(500, `internal server error`), + + wantErr: "oauth2: cannot fetch token", + }, + } + for _, tc := range testCases { + testHost := &url.URL{Scheme: "https", Host: "oauth2.example.com"} + deviceAuthEndpoint := "https://oauth2.example.com/api/v1/oauth2/device" + tokenEndpoint := "https://oauth2.example.com/api/v1/oauth2/token" + t.Run(tc.desc, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + opener := &mockOpener{} + opener.On("Open", mock.Anything).Return(tc.browserOpenErr) + + transport := &mockTransport{} + transport.On("RoundTrip", requestTargetMatches(deviceAuthEndpoint)).Return(tc.codeFetchResponse, tc.codeFetchErr) + transport.On("RoundTrip", requestTargetMatches(tokenEndpoint)).Return(tc.tokenFetchResponse, tc.tokenFetchErr) + + deviceCode := &DeviceCode{ + browserOpener: opener, + clientID: "john_mcclane", + scopes: []string{"nypd", "lapd"}, + httpTransport: transport, + } + got, gotErr := deviceCode.Authenticate(ctx, testHost) + + assertErrorContains(t, gotErr, tc.wantErr) + if gotErr != nil { + return + } + + got.Expiry = time.Time{} + assert.EqualExportedValues(t, tc.wantToken, got) + }) + } +} diff --git a/internal/oauthdevice/BUILD b/internal/oauthdevice/BUILD deleted file mode 100644 index 3bf1df2..0000000 --- a/internal/oauthdevice/BUILD +++ /dev/null @@ -1,9 +0,0 @@ -load("@rules_go//go:def.bzl", "go_library") - -go_library( - name = "oauthdevice", - srcs = ["authenticator.go"], - importpath = "github.com/EngFlow/auth/internal/oauthdevice", - visibility = ["//:__subpackages__"], - deps = ["@org_golang_x_oauth2//:oauth2"], -) diff --git a/internal/oauthdevice/authenticator.go b/internal/oauthdevice/authenticator.go deleted file mode 100644 index 0537bb0..0000000 --- a/internal/oauthdevice/authenticator.go +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright 2024 EngFlow Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package oauthdevice - -import ( - "context" - "errors" - "strings" - - "golang.org/x/oauth2" -) - -var errUnexpectedHTML = errors.New("request to JSON API returned HTML unexpectedly") - -type Authenticator interface { - FetchCode(context.Context, *oauth2.Endpoint) (*oauth2.DeviceAuthResponse, error) - FetchToken(context.Context, *oauth2.DeviceAuthResponse) (*oauth2.Token, error) -} - -type Auth struct { - config *oauth2.Config -} - -func NewAuth(clientID string, scopes []string) *Auth { - return &Auth{ - config: &oauth2.Config{ - ClientID: clientID, - Scopes: scopes, - }, - } -} - -func (a *Auth) FetchCode(ctx context.Context, authEndpoint *oauth2.Endpoint) (*oauth2.DeviceAuthResponse, error) { - a.config.Endpoint = *authEndpoint - res, err := a.config.DeviceAuth(ctx) - if err != nil { - if oauthErr := (*oauth2.RetrieveError)(nil); errors.As(err, &oauthErr) { - return nil, err - } - // BUG(CUS-320): Clusters that are not oauth-aware will return HTML with - // a 2xx code, confusing the oauth library. Detect and alias those - // errors here. - if strings.Contains(err.Error(), "invalid character '<'") { - return res, errUnexpectedHTML - } - // Default error handling - return nil, err - } - return res, err -} - -func (a *Auth) FetchToken(ctx context.Context, authRes *oauth2.DeviceAuthResponse) (*oauth2.Token, error) { - return a.config.DeviceAccessToken(ctx, authRes) -}