Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refresh expiring tokens before making requests #1594

Merged
merged 6 commits into from
Mar 28, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 39 additions & 22 deletions api/cloudcontroller/wrapper/uaa_authentication.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
package wrapper

import (
"strings"
"time"

"github.com/SermoDigital/jose/jws"

"code.cloudfoundry.org/cli/api/cloudcontroller"
"code.cloudfoundry.org/cli/api/cloudcontroller/ccerror"
"code.cloudfoundry.org/cli/api/uaa"
)

//go:generate counterfeiter . UAAClient

const accessTokenExpirationMargin = time.Minute

// UAAClient is the interface for getting a valid access token
type UAAClient interface {
RefreshAccessToken(refreshToken string) (uaa.RefreshedTokens, error)
Expand Down Expand Up @@ -48,32 +54,17 @@ func (t *UAAAuthentication) Make(request *cloudcontroller.Request, passedRespons
return t.connection.Make(request, passedResponse)
}

request.Header.Set("Authorization", t.cache.AccessToken())

requestErr := t.connection.Make(request, passedResponse)
if _, ok := requestErr.(ccerror.InvalidAuthTokenError); ok {
tokens, err := t.client.RefreshAccessToken(t.cache.RefreshToken())
if err != nil {
if t.cache.AccessToken() != "" || t.cache.RefreshToken() != "" {
// assert a valid access token for authenticated requests
err := t.refreshToken()
if nil != err {
return err
}

t.cache.SetAccessToken(tokens.AuthorizationToken())
t.cache.SetRefreshToken(tokens.RefreshToken)

if request.Body != nil {
err = request.ResetBody()
if err != nil {
if _, ok := err.(ccerror.PipeSeekError); ok {
return ccerror.PipeSeekError{Err: requestErr}
}
return err
}
}
request.Header.Set("Authorization", t.cache.AccessToken())
requestErr = t.connection.Make(request, passedResponse)
}

return requestErr
err := t.connection.Make(request, passedResponse)
return err
}

// SetClient sets the UAA client that the wrapper will use.
Expand All @@ -86,3 +77,29 @@ func (t *UAAAuthentication) Wrap(innerconnection cloudcontroller.Connection) clo
t.connection = innerconnection
return t
}

// refreshToken refreshes the JWT access token if it is expired or about to expire.
// If the access token is not yet expired, no action is performed.
func (t *UAAAuthentication) refreshToken() error {
var expiresIn time.Duration

tokenStr := strings.TrimPrefix(t.cache.AccessToken(), "bearer ")
token, err := jws.ParseJWT([]byte(tokenStr))
if err != nil {
// if the JWT could not be parsed, force a refresh
expiresIn = 0
} else {
expiration, _ := token.Claims().Expiration()
expiresIn = expiration.Sub(time.Now())
}

if expiresIn < accessTokenExpirationMargin {
tokens, err := t.client.RefreshAccessToken(t.cache.RefreshToken())
if err != nil {
return err
}
t.cache.SetAccessToken(tokens.AuthorizationToken())
t.cache.SetRefreshToken(tokens.RefreshToken)
}
return nil
}
123 changes: 72 additions & 51 deletions api/cloudcontroller/wrapper/uaa_authentication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,19 @@ import (
"io/ioutil"
"net/http"
"strings"
"time"

"code.cloudfoundry.org/cli/api/uaa"

"github.com/SermoDigital/jose/crypto"
"github.com/SermoDigital/jose/jws"

"code.cloudfoundry.org/cli/api/cloudcontroller"
"code.cloudfoundry.org/cli/api/cloudcontroller/ccerror"

"code.cloudfoundry.org/cli/api/cloudcontroller"
"code.cloudfoundry.org/cli/api/cloudcontroller/cloudcontrollerfakes"
. "code.cloudfoundry.org/cli/api/cloudcontroller/wrapper"
"code.cloudfoundry.org/cli/api/cloudcontroller/wrapper/wrapperfakes"
"code.cloudfoundry.org/cli/api/uaa"
"code.cloudfoundry.org/cli/api/uaa/wrapper/util"

. "github.com/onsi/ginkgo"
Expand All @@ -33,8 +39,6 @@ var _ = Describe("UAA Authentication", func() {
fakeConnection = new(cloudcontrollerfakes.FakeConnection)
fakeClient = new(wrapperfakes.FakeUAAClient)
inMemoryCache = util.NewInMemoryTokenCache()
inMemoryCache.SetAccessToken("a-ok")

inner = NewUAAAuthentication(fakeClient, inMemoryCache)
wrapper = inner.Wrap(fakeConnection)

Expand Down Expand Up @@ -62,15 +66,53 @@ var _ = Describe("UAA Authentication", func() {
})
})

When("the token is valid", func() {
When("no tokens are set", func() {
BeforeEach(func() {
inMemoryCache.SetAccessToken("")
inMemoryCache.SetRefreshToken("")
})

It("does not attempt to refresh the token", func() {
Expect(fakeClient.RefreshAccessTokenCallCount()).To(Equal(0))
})
})

When("the access token is invalid", func() {
var (
executeErr error
)
BeforeEach(func() {
inMemoryCache.SetAccessToken("Bearer some.invalid.token")
inMemoryCache.SetRefreshToken("some refresh token")
executeErr = wrapper.Make(request, nil)
})

It("should refresh the token", func() {
Expect(executeErr).ToNot(HaveOccurred())
Expect(fakeClient.RefreshAccessTokenCallCount()).To(Equal(1))
})
})

When("the access token is valid", func() {
var (
accessToken = buildTokenString(time.Now().AddDate(0, 0, 1))
)
BeforeEach(func() {
inMemoryCache.SetAccessToken(accessToken)
})

It("does not refresh the token", func() {
Expect(fakeClient.RefreshAccessTokenCallCount()).To(Equal(0))
})

It("adds authentication headers", func() {
err := wrapper.Make(request, nil)
Expect(err).ToNot(HaveOccurred())

Expect(fakeConnection.MakeCallCount()).To(Equal(1))
authenticatedRequest, _ := fakeConnection.MakeArgsForCall(0)
headers := authenticatedRequest.Header
Expect(headers["Authorization"]).To(ConsistOf([]string{"a-ok"}))
Expect(headers["Authorization"]).To(ConsistOf([]string{accessToken}))
})

When("the request already has headers", func() {
Expand Down Expand Up @@ -106,13 +148,17 @@ var _ = Describe("UAA Authentication", func() {
})
})

When("the token is invalid", func() {
When("the access token is expired", func() {
var (
expectedBody string
request *cloudcontroller.Request
executeErr error
)

invalidAccessToken := buildTokenString(time.Time{})
newAccessToken := buildTokenString(time.Now().AddDate(0, 1, 1))
newRefreshToken := "newRefreshToken"

BeforeEach(func() {
expectedBody = "this body content should be preserved"
body := strings.NewReader(expectedBody)
Expand All @@ -121,26 +167,12 @@ var _ = Describe("UAA Authentication", func() {
Body: ioutil.NopCloser(body),
}, body)

makeCount := 0
fakeConnection.MakeStub = func(request *cloudcontroller.Request, response *cloudcontroller.Response) error {
body, err := ioutil.ReadAll(request.Body)
Expect(err).NotTo(HaveOccurred())
Expect(string(body)).To(Equal(expectedBody))

if makeCount == 0 {
makeCount++
return ccerror.InvalidAuthTokenError{}
} else {
return nil
}
}

inMemoryCache.SetAccessToken("what")
inMemoryCache.SetAccessToken(invalidAccessToken)

fakeClient.RefreshAccessTokenReturns(
uaa.RefreshedTokens{
AccessToken: "foobar-2",
RefreshToken: "bananananananana",
AccessToken: newAccessToken,
RefreshToken: newRefreshToken,
Type: "bearer",
},
nil,
Expand All @@ -156,40 +188,29 @@ var _ = Describe("UAA Authentication", func() {
Expect(fakeClient.RefreshAccessTokenCallCount()).To(Equal(1))
})

It("should resend the request", func() {
Expect(executeErr).ToNot(HaveOccurred())
Expect(fakeConnection.MakeCallCount()).To(Equal(2))

requestArg, _ := fakeConnection.MakeArgsForCall(1)
Expect(requestArg.Header.Get("Authorization")).To(Equal("bearer foobar-2"))
})

It("should save the refresh token", func() {
Expect(executeErr).ToNot(HaveOccurred())
Expect(inMemoryCache.RefreshToken()).To(Equal("bananananananana"))
Expect(inMemoryCache.RefreshToken()).To(Equal(newRefreshToken))
Expect(inMemoryCache.AccessToken()).To(ContainSubstring(newAccessToken))
})

When("a PipeSeekError is returned from ResetBody", func() {
BeforeEach(func() {
body, writer := cloudcontroller.NewPipeBomb()
req, err := http.NewRequest(http.MethodGet, "https://foo.bar.com/banana", body)
Expect(err).NotTo(HaveOccurred())
request = cloudcontroller.NewRequest(req, body)

go func() {
defer GinkgoRecover()

_, err := writer.Write([]byte(expectedBody))
Expect(err).NotTo(HaveOccurred())
err = writer.Close()
Expect(err).NotTo(HaveOccurred())
}()
When("token cannot be refreshed", func() {
JustBeforeEach(func() {
fakeConnection.MakeReturns(ccerror.InvalidAuthTokenError{})
})

It("set the err on PipeSeekError", func() {
Expect(executeErr).To(MatchError(ccerror.PipeSeekError{Err: ccerror.InvalidAuthTokenError{}}))
It("should not re-try the initial request", func() {
Expect(fakeConnection.MakeCallCount()).To(Equal(1))
})
})

})
})
})

func buildTokenString(expiration time.Time) string {
c := jws.Claims{}
c.SetExpiration(expiration)
token := jws.NewJWT(c, crypto.Unsecured)
tokenBytes, _ := token.Serialize(nil)
return string(tokenBytes)
}
19 changes: 19 additions & 0 deletions cf/api/authentication/authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"strings"
"time"

"github.com/SermoDigital/jose/jws"

"code.cloudfoundry.org/cli/cf/configuration/coreconfig"
"code.cloudfoundry.org/cli/cf/errors"
. "code.cloudfoundry.org/cli/cf/i18n"
Expand All @@ -17,6 +19,8 @@ import (

//go:generate counterfeiter . TokenRefresher

const accessTokenExpirationMargin = time.Minute

type TokenRefresher interface {
RefreshAuthToken() (updatedToken string, apiErr error)
}
Expand All @@ -27,6 +31,7 @@ type Repository interface {
net.RequestDumperInterface

RefreshAuthToken() (updatedToken string, apiErr error)
RefreshToken(token string) (updatedToken string, apiErr error)
Authenticate(credentials map[string]string) (apiErr error)
Authorize(token string) (string, error)
GetLoginPromptsAndSaveUAAServerURL() (map[string]coreconfig.AuthPrompt, error)
Expand Down Expand Up @@ -181,6 +186,20 @@ func (uaa UAARepository) GetLoginPromptsAndSaveUAAServerURL() (prompts map[strin
}

func (uaa UAARepository) RefreshAuthToken() (string, error) {
return uaa.RefreshToken(uaa.config.AccessToken())
}

func (uaa UAARepository) RefreshToken(t string) (string, error) {
tokenStr := strings.TrimPrefix(t, "bearer ")
token, err := jws.ParseJWT([]byte(tokenStr))
if err != nil {
return "", err
}
expiration, _ := token.Claims().Expiration()
if expiration.Sub(time.Now()) > accessTokenExpirationMargin {
return t, nil
}

data := url.Values{}

switch uaa.config.UAAGrantType() {
Expand Down
Loading