Skip to content

Commit

Permalink
REC-110: refactor to remove Fallback
Browse files Browse the repository at this point in the history
Moves the Fallback implementation into appState.

This code is very tightly coupled with the app's behavior
and should not be held in a separate abstracted type.
  • Loading branch information
jayconrod committed Nov 26, 2024
1 parent 412f736 commit 5eef9d2
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 348 deletions.
58 changes: 23 additions & 35 deletions cmd/engflow_auth/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,13 @@ const (
type appState struct {
// These vars are initialized by `build()` if and only if they are not pre-populated;
// they should be pre-populated in tests and left nil otherwise.
userConfigDir string
browserOpener browser.Opener
authenticator oauthdevice.Authenticator
tokenStore oauthtoken.LoadStorer
stderr io.Writer
userConfigDir string
browserOpener browser.Opener
authenticator oauthdevice.Authenticator
fileStore oauthtoken.LoadStorer
keyringStore oauthtoken.LoadStorer
writeFileStore bool
stderr io.Writer
}

type ExportedToken struct {
Expand All @@ -81,42 +83,28 @@ func (r *appState) build(cliCtx *cli.Context) error {
if r.browserOpener == nil {
r.browserOpener = &browser.StderrPrint{}
}
if r.tokenStore == nil {
keyring, err := oauthtoken.NewKeyring()
if err != nil {
return autherr.CodedErrorf(autherr.CodeTokenStoreFailure, "failed to open keyring-based token store: %w", err)
}

if r.fileStore == nil {
tokensDir := filepath.Join(r.userConfigDir, "engflow_auth", "tokens")
fileStore, err := oauthtoken.NewFileTokenStore(tokensDir)
if err != nil {
return autherr.CodedErrorf(autherr.CodeTokenStoreFailure, "failed to open file-based token store: %w", err)
}

errorStore := &oauthtoken.FakeTokenStore{
StoreErr: fmt.Errorf("subcommand attempted invalid write to token storage"),
}

var writeStore oauthtoken.LoadStorer
switch writeStoreName := cliCtx.String("store"); writeStoreName {
case "keyring":
writeStore = keyring
case "file":
writeStore = fileStore
case "":
// Subcommands that don't have this flag defined will cause the flag
// value fetch to return empty (as opposed to a sane default value).
// These commands shouldn't write to token storage (else they should
// define the flag) so the corresponding token storage object errors
// on writes.
writeStore = errorStore
default:
return autherr.CodedErrorf(autherr.CodeBadParams, "unknown token store type %q", writeStoreName)
r.fileStore = fileStore
}
if r.keyringStore == nil {
keyringStore, err := oauthtoken.NewKeyring()
if err != nil {
return autherr.CodedErrorf(autherr.CodeTokenStoreFailure, "failed to open keyring-based token store: %w", err)
}

r.tokenStore = oauthtoken.NewFallback(
/* gets Store() operations */ writeStore,
/* gets Load() operations */ keyring, fileStore)
r.keyringStore = keyringStore
}
switch writeStoreName := cliCtx.String("store"); writeStoreName {
case "", "keyring":
r.writeFileStore = false
case "file":
r.writeFileStore = true
default:
return autherr.CodedErrorf(autherr.CodeBadParams, "unknown token store type %q", writeStoreName)
}
r.stderr = cliCtx.App.ErrWriter
return nil
Expand Down
77 changes: 57 additions & 20 deletions cmd/engflow_auth/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ func TestRun(t *testing.T) {

machineInput io.Reader
authenticator oauthdevice.Authenticator
tokenStore oauthtoken.LoadStorer
keyringStore oauthtoken.LoadStorer
fileStore oauthtoken.LoadStorer
browserOpener browser.Opener

wantCode int
Expand Down Expand Up @@ -157,7 +158,7 @@ func TestRun(t *testing.T) {
desc: "get propagates token store error",
args: []string{"get"},
machineInput: strings.NewReader(`{"uri": "https://cluster.example.com"}`),
tokenStore: &oauthtoken.FakeTokenStore{
keyringStore: &oauthtoken.FakeTokenStore{
LoadErr: errors.New("token_load_error"),
},
wantCode: autherr.CodeReauthRequired,
Expand All @@ -167,7 +168,7 @@ func TestRun(t *testing.T) {
desc: "get with URL expired",
args: []string{"get"},
machineInput: strings.NewReader(`{"uri": "https://cluster.example.com"}`),
tokenStore: &oauthtoken.FakeTokenStore{
keyringStore: &oauthtoken.FakeTokenStore{
Tokens: map[string]*oauth2.Token{
"cluster.example.com": {
AccessToken: "access_token",
Expand All @@ -182,7 +183,7 @@ func TestRun(t *testing.T) {
desc: "get with URL not expired",
args: []string{"get"},
machineInput: strings.NewReader(`{"uri": "https://cluster.example.com"}`),
tokenStore: &oauthtoken.FakeTokenStore{
keyringStore: &oauthtoken.FakeTokenStore{
Tokens: map[string]*oauth2.Token{
"cluster.example.com": {
AccessToken: "access_token",
Expand All @@ -195,6 +196,36 @@ func TestRun(t *testing.T) {
`"x-engflow-auth-token":["access_token"]},"expires":` + "\"" + expiresInFuture.Format(expiryFormat) + "\"",
},
},
{
desc: "get from keyring",
args: []string{"get"},
machineInput: strings.NewReader(`{"uri": "https://cluster.example.com"}`),
keyringStore: oauthtoken.NewFakeTokenStore().WithTokenForSubject("cluster.example.com", "alice"),
wantStdoutContaining: []string{`"x-engflow-auth-token"`},
},
{
desc: "get from file",
args: []string{"get"},
machineInput: strings.NewReader(`{"uri": "https://cluster.example.com"}`),
fileStore: oauthtoken.NewFakeTokenStore().WithTokenForSubject("cluster.example.com", "alice"),
wantStdoutContaining: []string{`"x-engflow-auth-token"`},
},
{
desc: "get from file with keyring error",
args: []string{"get"},
machineInput: strings.NewReader(`{"uri": "https://cluster.example.com"}`),
keyringStore: oauthtoken.NewFakeTokenStore().WithLoadErr(errors.New("fake error")),
fileStore: oauthtoken.NewFakeTokenStore().WithTokenForSubject("cluster.example.com", "alice"),
wantStdoutContaining: []string{`"x-engflow-auth-token"`},
},
{
desc: "get from keyring with file error",
args: []string{"get"},
machineInput: strings.NewReader(`{"uri": "https://cluster.example.com"}`),
keyringStore: oauthtoken.NewFakeTokenStore().WithTokenForSubject("cluster.example.com", "alice"),
fileStore: oauthtoken.NewFakeTokenStore().WithLoadErr(errors.New("fake error")),
wantStdoutContaining: []string{`"x-engflow-auth-token"`},
},
{
desc: "version prints build metadata",
args: []string{"version"},
Expand Down Expand Up @@ -234,7 +265,6 @@ func TestRun(t *testing.T) {
VerificationURIComplete: "https://cluster.example.com/with/auth/code",
},
},
tokenStore: oauthtoken.NewFakeTokenStore(),
wantStored: []string{
"cluster.example.com",
"cluster.local.example.com",
Expand All @@ -248,7 +278,7 @@ func TestRun(t *testing.T) {
VerificationURIComplete: "https://cluster.example.com/with/auth/code",
},
},
tokenStore: &oauthtoken.FakeTokenStore{
keyringStore: &oauthtoken.FakeTokenStore{
StoreErr: errors.New("token_store_fail"),
},
wantCode: autherr.CodeTokenStoreFailure,
Expand Down Expand Up @@ -339,7 +369,7 @@ func TestRun(t *testing.T) {
VerificationURIComplete: "https://cluster.example.com/with/auth/code",
},
},
tokenStore: &oauthtoken.FakeTokenStore{
keyringStore: &oauthtoken.FakeTokenStore{
StoreErr: errors.New("token_store_fail"),
},
wantCode: autherr.CodeTokenStoreFailure,
Expand Down Expand Up @@ -378,7 +408,7 @@ func TestRun(t *testing.T) {
{
desc: "login with changed subject",
args: []string{"login", "cluster.example.com"},
tokenStore: oauthtoken.NewFakeTokenStore().WithTokenForSubject(
keyringStore: oauthtoken.NewFakeTokenStore().WithTokenForSubject(
"cluster.example.com", "alice"),
authenticator: &fakeAuth{
deviceResponse: &oauth2.DeviceAuthResponse{
Expand All @@ -403,7 +433,7 @@ func TestRun(t *testing.T) {
{
desc: "logout with cluster",
args: []string{"logout", "cluster.example.com"},
tokenStore: &oauthtoken.FakeTokenStore{
keyringStore: &oauthtoken.FakeTokenStore{
Tokens: map[string]*oauth2.Token{
"cluster.example.com": {},
},
Expand All @@ -412,7 +442,7 @@ func TestRun(t *testing.T) {
{
desc: "logout with error",
args: []string{"logout", "cluster.example.com"},
tokenStore: &oauthtoken.FakeTokenStore{
keyringStore: &oauthtoken.FakeTokenStore{
DeleteErr: errors.New("token_delete_error"),
},
wantCode: autherr.CodeTokenStoreFailure,
Expand All @@ -433,7 +463,7 @@ func TestRun(t *testing.T) {
{
desc: "export when token not found",
args: []string{"export", "https://cluster.example.com"},
tokenStore: &oauthtoken.FakeTokenStore{
keyringStore: &oauthtoken.FakeTokenStore{
LoadErr: autherr.ReauthRequired("https://cluster.example.com"),
},
wantCode: autherr.CodeReauthRequired,
Expand All @@ -442,7 +472,7 @@ func TestRun(t *testing.T) {
{
desc: "export when token store fails",
args: []string{"export", "https://cluster.example.com"},
tokenStore: &oauthtoken.FakeTokenStore{
keyringStore: &oauthtoken.FakeTokenStore{
LoadErr: fmt.Errorf("token_load_error"),
},
wantCode: autherr.CodeTokenStoreFailure,
Expand All @@ -451,7 +481,7 @@ func TestRun(t *testing.T) {
{
desc: "export when token expired",
args: []string{"export", "https://cluster.example.com"},
tokenStore: &oauthtoken.FakeTokenStore{
keyringStore: &oauthtoken.FakeTokenStore{
Tokens: map[string]*oauth2.Token{
"cluster.example.com": {
AccessToken: "access_token",
Expand All @@ -465,7 +495,7 @@ func TestRun(t *testing.T) {
{
desc: "export token",
args: []string{"export", "https://cluster.example.com"},
tokenStore: &oauthtoken.FakeTokenStore{
keyringStore: &oauthtoken.FakeTokenStore{
Tokens: map[string]*oauth2.Token{
"cluster.example.com": {
AccessToken: "token_data",
Expand All @@ -482,7 +512,7 @@ func TestRun(t *testing.T) {
{
desc: "export token with alias",
args: []string{"export", "--alias", "cluster.local.example.com:8080", "https://cluster.example.com"},
tokenStore: &oauthtoken.FakeTokenStore{
keyringStore: &oauthtoken.FakeTokenStore{
Tokens: map[string]*oauth2.Token{
"cluster.example.com": {
AccessToken: "token_data",
Expand All @@ -508,7 +538,7 @@ func TestRun(t *testing.T) {
desc: "import with valid data",
args: []string{"import"},
machineInput: strings.NewReader(`{"token":{"access_token":"token_data"},"cluster_host":"cluster.example.com"}`),
tokenStore: oauthtoken.NewFakeTokenStore(),
keyringStore: oauthtoken.NewFakeTokenStore(),
wantStored: []string{
"cluster.example.com",
},
Expand All @@ -517,7 +547,7 @@ func TestRun(t *testing.T) {
desc: "import with alias",
args: []string{"import"},
machineInput: strings.NewReader(`{"token":{"access_token":"token_data"},"cluster_host":"cluster.example.com","aliases":["cluster.local.example.com"]}`),
tokenStore: oauthtoken.NewFakeTokenStore(),
keyringStore: oauthtoken.NewFakeTokenStore(),
wantStored: []string{
"cluster.example.com",
"cluster.local.example.com",
Expand All @@ -527,7 +557,7 @@ func TestRun(t *testing.T) {
desc: "import with store error",
args: []string{"import"},
machineInput: strings.NewReader(`{"token":{"access_token":"token_data"},"cluster_host":"cluster.example.com"}`),
tokenStore: &oauthtoken.FakeTokenStore{
keyringStore: &oauthtoken.FakeTokenStore{
StoreErr: errors.New("token_store_fail"),
},
wantCode: autherr.CodeTokenStoreFailure,
Expand All @@ -551,14 +581,21 @@ func TestRun(t *testing.T) {
userConfigDir: t.TempDir(),
browserOpener: tc.browserOpener,
authenticator: tc.authenticator,
tokenStore: tc.tokenStore,
keyringStore: tc.keyringStore,
fileStore: tc.fileStore,
}
if root.browserOpener == nil {
root.browserOpener = &fakeBrowser{}
}
if root.authenticator == nil {
root.authenticator = &fakeAuth{}
}
if root.keyringStore == nil {
root.keyringStore = oauthtoken.NewFakeTokenStore()
}
if root.fileStore == nil {
root.fileStore = oauthtoken.NewFakeTokenStore()
}

app := makeApp(root)

Expand Down Expand Up @@ -596,7 +633,7 @@ func TestRun(t *testing.T) {
t.Logf("\n====== BEGIN APP STDERR ======\n%s\n====== END APP STDERR ======\n\n", stderr.String())
}
}
if tokenStore, ok := tc.tokenStore.(*oauthtoken.FakeTokenStore); ok {
if tokenStore, ok := tc.keyringStore.(*oauthtoken.FakeTokenStore); ok {
assert.Subset(t, tokenStore.Tokens, tc.wantStored)
}
})
Expand Down
55 changes: 52 additions & 3 deletions cmd/engflow_auth/tokens.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
package main

import (
"errors"
"fmt"
"io/fs"

"github.com/EngFlow/auth/internal/oauthtoken"
"github.com/golang-jwt/jwt/v5"
"golang.org/x/oauth2"
)
Expand All @@ -27,7 +30,16 @@ import (
// loadToken may contain logic specific to this app and should be called
// by commands instead of calling LoadStorer.Load directly.
func (r *appState) loadToken(cluster string) (*oauth2.Token, error) {
return r.tokenStore.Load(cluster)
var errs []error
backends := []oauthtoken.LoadStorer{r.keyringStore, r.fileStore}
for _, backend := range backends {
token, err := backend.Load(cluster)
if err == nil {
return token, nil
}
errs = append(errs, err)
}
return nil, fmt.Errorf("failed to load token from %d backend(s): %w", len(backends), errors.Join(errs...))
}

// storeToken stores a token for the given cluster in one of the backends.
Expand All @@ -40,7 +52,12 @@ func (r *appState) storeToken(cluster string, token *oauth2.Token) error {
if err == nil {
r.warnIfSubjectChanged(cluster, oldToken, token)
}
return r.tokenStore.Store(cluster, token)

if r.writeFileStore {
return r.fileStore.Store(cluster, token)
} else {
return r.keyringStore.Store(cluster, token)
}
}

// warnIfSubjectChanged prints a warning on stderr if the new token belongs to
Expand Down Expand Up @@ -70,5 +87,37 @@ func (r *appState) warnIfSubjectChanged(cluster string, oldToken, newToken *oaut
// deleteToken may contain logic specific to this app and should be called
// by commands instead of calling LoadStorer.Delete directly.
func (r *appState) deleteToken(cluster string) error {
return r.tokenStore.Delete(cluster)
var errs []error
// Don't bother to delete from storeBackend, which should also be present in
// loadBackends
backends := []oauthtoken.LoadStorer{r.keyringStore, r.fileStore}
for _, backend := range backends {
errs = append(errs, backend.Delete(cluster))
}

var nonNotFoundErrs []error
for _, err := range errs {
if err == nil {
return nil
}
if !errors.Is(err, fs.ErrNotExist) {
nonNotFoundErrs = append(nonNotFoundErrs, err)
}
}
if err := errors.Join(nonNotFoundErrs...); err != nil {
return fmt.Errorf("failed to delete token from %d backend(s): %w", len(backends), err)
}
return &multiBackendNotFoundError{backendsCount: len(backends)}
}

type multiBackendNotFoundError struct {
backendsCount int
}

func (m *multiBackendNotFoundError) Error() string {
return fmt.Sprintf("token for cluster not found after trying %d token storage backends", m.backendsCount)
}

func (m *multiBackendNotFoundError) Is(err error) bool {
return err == fs.ErrNotExist
}
Loading

0 comments on commit 5eef9d2

Please sign in to comment.