Skip to content

Commit

Permalink
REC-110: refactor to remove CacheAlert
Browse files Browse the repository at this point in the history
This is the beginning of a refactoring to move some LoadStorer
implementations into appState.

Although CacheAlert satisfied the LoadStorer interface type,
it was a decorator that didn't actually load or store any tokens
or provide useful or meaningful abstraction. It seems better to
squash this into appState and remove unnecessary abstraction.
This also makes the test a little more realistic.

Also refactored fake token generation to make testing easier.
  • Loading branch information
jayconrod committed Nov 26, 2024
1 parent 5dd80b5 commit 412f736
Show file tree
Hide file tree
Showing 8 changed files with 149 additions and 183 deletions.
6 changes: 5 additions & 1 deletion cmd/engflow_auth/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ load("//infra:visibility.bzl", "RELEASE_ARTIFACT")

go_library(
name = "engflow_auth_lib",
srcs = ["main.go"],
srcs = [
"main.go",
"tokens.go",
],
importpath = "github.com/EngFlow/auth/cmd/engflow_auth",
visibility = ["//visibility:private"],
deps = [
Expand All @@ -13,6 +16,7 @@ go_library(
"//internal/oauthdevice",
"//internal/oauthtoken",
"@com_github_engflow_credential_helper_go//:credential-helper-go",
"@com_github_golang_jwt_jwt_v5//:jwt",
"@com_github_urfave_cli_v2//:cli",
"@org_golang_x_oauth2//:oauth2",
],
Expand Down
24 changes: 11 additions & 13 deletions cmd/engflow_auth/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"net"
"net/url"
Expand Down Expand Up @@ -53,6 +54,7 @@ type appState struct {
browserOpener browser.Opener
authenticator oauthdevice.Authenticator
tokenStore oauthtoken.LoadStorer
stderr io.Writer
}

type ExportedToken struct {
Expand Down Expand Up @@ -112,15 +114,11 @@ func (r *appState) build(cliCtx *cli.Context) error {
return autherr.CodedErrorf(autherr.CodeBadParams, "unknown token store type %q", writeStoreName)
}

r.tokenStore =
oauthtoken.NewCacheAlert(
oauthtoken.NewFallback(
/* gets Store() operations */ writeStore,
/* gets Load() operations */ keyring, fileStore,
),
cliCtx.App.ErrWriter,
)
r.tokenStore = oauthtoken.NewFallback(
/* gets Store() operations */ writeStore,
/* gets Load() operations */ keyring, fileStore)
}
r.stderr = cliCtx.App.ErrWriter
return nil
}

Expand All @@ -136,7 +134,7 @@ func (r *appState) get(cliCtx *cli.Context) error {
if err != nil {
return autherr.CodedErrorf(autherr.CodeBadParams, "failed to parse cluster URL %q from request: %w", req.URI, err)
}
token, err := r.tokenStore.Load(clusterURL.Host)
token, err := r.loadToken(clusterURL.Host)
if err != nil {
return autherr.ReauthRequired(clusterURL.Host)
}
Expand Down Expand Up @@ -165,7 +163,7 @@ func (r *appState) export(cliCtx *cli.Context) error {
return autherr.CodedErrorf(autherr.CodeBadParams, "invalid cluster: %w", err)
}

token, err := r.tokenStore.Load(clusterURL.Host)
token, err := r.loadToken(clusterURL.Host)
if err != nil {
if reauthErr := (*autherr.CodedError)(nil); errors.As(err, &reauthErr) && reauthErr.Code == autherr.CodeReauthRequired {
return reauthErr
Expand Down Expand Up @@ -207,7 +205,7 @@ func (r *appState) import_(cliCtx *cli.Context) error {

var storeErrs []error
for _, storeURL := range storeURLs {
if err := r.tokenStore.Store(storeURL.Host, token.Token); err != nil {
if err := r.storeToken(storeURL.Host, token.Token); err != nil {
storeErrs = append(storeErrs, fmt.Errorf("failed to save token for host %q: %w", storeURL.Host, err))
}
}
Expand Down Expand Up @@ -289,7 +287,7 @@ Visit %s for help.`,

var storeErrs []error
for _, storeURL := range storeURLs {
if err := r.tokenStore.Store(storeURL.Host, token); err != nil {
if err := r.storeToken(storeURL.Host, token); err != nil {
storeErrs = append(storeErrs, fmt.Errorf("failed to save token for host %q: %w", storeURL.Host, err))
}
}
Expand Down Expand Up @@ -323,7 +321,7 @@ func (r *appState) logout(cliCtx *cli.Context) error {
return autherr.CodedErrorf(autherr.CodeBadParams, "invalid cluster: %w", err)
}

if err := r.tokenStore.Delete(clusterURL.Host); errors.Is(err, fs.ErrNotExist) {
if err := r.deleteToken(clusterURL.Host); errors.Is(err, fs.ErrNotExist) {
return &autherr.CodedError{Code: autherr.CodeBadParams, Err: fmt.Errorf("no credentials found for cluster %q", clusterURL.Host)}
} else if err != nil {
return &autherr.CodedError{Code: autherr.CodeTokenStoreFailure, Err: err}
Expand Down
48 changes: 31 additions & 17 deletions cmd/engflow_auth/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,18 @@ func codedErrorContains(t *testing.T, gotErr error, code int, wantMsg string) bo
}

type fakeAuth struct {
res *oauth2.DeviceAuthResponse
fetchCodeErr error
fetchTokenErr error
deviceResponse *oauth2.DeviceAuthResponse
token *oauth2.Token
fetchCodeErr error
fetchTokenErr error
}

func (f *fakeAuth) FetchCode(ctx context.Context, authEndpint *oauth2.Endpoint) (*oauth2.DeviceAuthResponse, error) {
return f.res, f.fetchCodeErr
return f.deviceResponse, f.fetchCodeErr
}

func (f *fakeAuth) FetchToken(ctx context.Context, authRes *oauth2.DeviceAuthResponse) (*oauth2.Token, error) {
return nil, f.fetchTokenErr
return f.token, f.fetchTokenErr
}

type fakeBrowser struct {
Expand Down Expand Up @@ -220,7 +221,7 @@ func TestRun(t *testing.T) {
desc: "login happy path",
args: []string{"login", "cluster.example.com"},
authenticator: &fakeAuth{
res: &oauth2.DeviceAuthResponse{
deviceResponse: &oauth2.DeviceAuthResponse{
VerificationURIComplete: "https://cluster.example.com/with/auth/code",
},
},
Expand All @@ -229,7 +230,7 @@ func TestRun(t *testing.T) {
desc: "login with alias",
args: []string{"login", "--alias", "cluster.local.example.com", "cluster.example.com"},
authenticator: &fakeAuth{
res: &oauth2.DeviceAuthResponse{
deviceResponse: &oauth2.DeviceAuthResponse{
VerificationURIComplete: "https://cluster.example.com/with/auth/code",
},
},
Expand All @@ -243,7 +244,7 @@ func TestRun(t *testing.T) {
desc: "login with alias with store errors",
args: []string{"login", "--alias", "cluster.local.example.com", "cluster.example.com"},
authenticator: &fakeAuth{
res: &oauth2.DeviceAuthResponse{
deviceResponse: &oauth2.DeviceAuthResponse{
VerificationURIComplete: "https://cluster.example.com/with/auth/code",
},
},
Expand All @@ -257,7 +258,7 @@ func TestRun(t *testing.T) {
desc: "login with host and port",
args: []string{"login", "cluster.example.com:8080"},
authenticator: &fakeAuth{
res: &oauth2.DeviceAuthResponse{
deviceResponse: &oauth2.DeviceAuthResponse{
VerificationURIComplete: "https://cluster.example.com:8080/with/auth/code",
},
},
Expand All @@ -272,7 +273,7 @@ func TestRun(t *testing.T) {
desc: "login code fetch failure",
args: []string{"login", "cluster.example.com"},
authenticator: &fakeAuth{
res: &oauth2.DeviceAuthResponse{
deviceResponse: &oauth2.DeviceAuthResponse{
VerificationURIComplete: "https://cluster.example.com/with/auth/code",
},
fetchCodeErr: errors.New("fetch_code_fail"),
Expand All @@ -284,7 +285,7 @@ func TestRun(t *testing.T) {
desc: "login code fetch RetrieveError",
args: []string{"login", "cluster.example.com"},
authenticator: &fakeAuth{
res: &oauth2.DeviceAuthResponse{
deviceResponse: &oauth2.DeviceAuthResponse{
VerificationURIComplete: "https://cluster.example.com/with/auth/code",
},
fetchCodeErr: &oauth2.RetrieveError{},
Expand All @@ -296,7 +297,7 @@ func TestRun(t *testing.T) {
desc: "login code fetch unexpected HTML",
args: []string{"login", "cluster.example.com"},
authenticator: &fakeAuth{
res: &oauth2.DeviceAuthResponse{
deviceResponse: &oauth2.DeviceAuthResponse{
VerificationURIComplete: "https://cluster.example.com/with/auth/code",
},
fetchCodeErr: autherr.UnexpectedHTML,
Expand All @@ -308,7 +309,7 @@ func TestRun(t *testing.T) {
desc: "login browser open failure",
args: []string{"login", "cluster.example.com"},
authenticator: &fakeAuth{
res: &oauth2.DeviceAuthResponse{
deviceResponse: &oauth2.DeviceAuthResponse{
VerificationURIComplete: "https://cluster.example.com/with/auth/code",
},
},
Expand All @@ -322,7 +323,7 @@ func TestRun(t *testing.T) {
desc: "login token fetch failure",
args: []string{"login", "cluster.example.com"},
authenticator: &fakeAuth{
res: &oauth2.DeviceAuthResponse{
deviceResponse: &oauth2.DeviceAuthResponse{
VerificationURIComplete: "https://cluster.example.com/with/auth/code",
},
fetchTokenErr: errors.New("fetch_token_fail"),
Expand All @@ -334,7 +335,7 @@ func TestRun(t *testing.T) {
desc: "login token store failure",
args: []string{"login", "cluster.example.com"},
authenticator: &fakeAuth{
res: &oauth2.DeviceAuthResponse{
deviceResponse: &oauth2.DeviceAuthResponse{
VerificationURIComplete: "https://cluster.example.com/with/auth/code",
},
},
Expand All @@ -348,7 +349,7 @@ func TestRun(t *testing.T) {
desc: "login with file-backed token storage",
args: []string{"login", "--store=file", "cluster.example.com"},
authenticator: &fakeAuth{
res: &oauth2.DeviceAuthResponse{
deviceResponse: &oauth2.DeviceAuthResponse{
VerificationURIComplete: "https://cluster.example.com/with/auth/code",
},
},
Expand All @@ -357,7 +358,7 @@ func TestRun(t *testing.T) {
desc: "login with keyring-backed token storage",
args: []string{"login", "--store=keyring", "cluster.example.com"},
authenticator: &fakeAuth{
res: &oauth2.DeviceAuthResponse{
deviceResponse: &oauth2.DeviceAuthResponse{
VerificationURIComplete: "https://cluster.example.com/with/auth/code",
},
},
Expand All @@ -374,6 +375,19 @@ func TestRun(t *testing.T) {
wantCode: autherr.CodeBadParams,
wantErr: "flag provided but not defined",
},
{
desc: "login with changed subject",
args: []string{"login", "cluster.example.com"},
tokenStore: oauthtoken.NewFakeTokenStore().WithTokenForSubject(
"cluster.example.com", "alice"),
authenticator: &fakeAuth{
deviceResponse: &oauth2.DeviceAuthResponse{
VerificationURIComplete: "https://cluster.example.com/with/auth/code",
},
token: oauthtoken.NewFakeTokenForSubject("bob"),
},
wantStderrContaining: []string{"Login identity has changed"},
},
{
desc: "logout without cluster",
args: []string{"logout"},
Expand Down
74 changes: 74 additions & 0 deletions cmd/engflow_auth/tokens.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright 2024 EngFlow Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package main

import (
"fmt"

"github.com/golang-jwt/jwt/v5"
"golang.org/x/oauth2"
)

// loadToken loads a token for the given cluster or returns an error equivalent
// to fs.ErrNotExist if the token is not found in any store.
//
// loadToken may contain logic specific to this app and should be called
// by commands instead of calling LoadStorer.Load directly.
func (r *appState) loadToken(cluster string) (*oauth2.Token, error) {
return r.tokenStore.Load(cluster)
}

// storeToken stores a token for the given cluster in one of the backends.
//
// storeToken may contain logic specific to this app and should be called
// by commands instead of calling LoadStorer.Store directly. For example,
// storeToken prints a message if the token's subject has changed.
func (r *appState) storeToken(cluster string, token *oauth2.Token) error {
oldToken, err := r.loadToken(cluster)
if err == nil {
r.warnIfSubjectChanged(cluster, oldToken, token)
}
return r.tokenStore.Store(cluster, token)
}

// warnIfSubjectChanged prints a warning on stderr if the new token belongs to
// a different user than the previously stored token. The user is reminded to
// shutdown Bazel since it caches tokens in memory to avoid running actions
// with the old credential, which is probably still valid.
func (r *appState) warnIfSubjectChanged(cluster string, oldToken, newToken *oauth2.Token) {
// Disable claims validation, since expired tokens should be allowed to
// parse.
parser := jwt.NewParser(jwt.WithoutClaimsValidation())
oldClaims, newClaims := &jwt.RegisteredClaims{}, &jwt.RegisteredClaims{}
// Unverified parsing, since issuing a warning vs. not is not a security
// concern.
if _, _, err := parser.ParseUnverified(oldToken.AccessToken, oldClaims); err != nil {
return
}
if _, _, err := parser.ParseUnverified(newToken.AccessToken, newClaims); err != nil {
return
}
if oldClaims.Subject != newClaims.Subject {
fmt.Fprintf(r.stderr, "WARNING: Login identity has changed since last login to %q.\nPlease run `bazel shutdown` in current workspaces in order to ensure bazel picks up new credentials.\n", cluster)
}
}

// deleteToken removes a token from all of the backends.
//
// deleteToken may contain logic specific to this app and should be called
// by commands instead of calling LoadStorer.Delete directly.
func (r *appState) deleteToken(cluster string) error {
return r.tokenStore.Delete(cluster)
}
3 changes: 0 additions & 3 deletions internal/oauthtoken/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ load("@rules_go//go:def.bzl", "go_library", "go_test")
go_library(
name = "oauthtoken",
srcs = [
"cache_alert.go",
"debug.go",
"fake.go",
"fallback.go",
Expand All @@ -23,14 +22,12 @@ go_library(
go_test(
name = "oauthtoken_test",
srcs = [
"cache_alert_test.go",
"fallback_test.go",
"keyring_test.go",
"load_storer_test.go",
],
embed = [":oauthtoken"],
deps = [
"@com_github_golang_jwt_jwt_v5//:jwt",
"@com_github_google_uuid//:uuid",
"@com_github_stretchr_testify//assert",
"@com_github_stretchr_testify//require",
Expand Down
Loading

0 comments on commit 412f736

Please sign in to comment.