Skip to content

Commit

Permalink
feat: graceful refresh token rotation
Browse files Browse the repository at this point in the history
This patch adds a configuration flag which enables graceful refresh token rotation. Previously, refresh tokens could only be used once. On reuse, all tokens of that chain would be revoked.

This is particularly challenging in environments, where it's difficult to make guarantees on synchronization. This could lead to refresh tokens being sent twice due to some parallel execution.

To resolve this, refresh tokens can now be graceful by changing `oauth2.grant.refresh_token.grace_period=10s` (example value). During this time, a refresh token can be used multiple times to generate new refresh, ID, and access tokens.

All tokens will correctly be invalidated, when the refresh token is re-used after the grace period expires, or when the delete consent endpoint is used.

Closes #1831 #3770
  • Loading branch information
bill-robbins-ss authored and aeneasr committed Oct 15, 2024
1 parent 0cd00dc commit 3d7414e
Show file tree
Hide file tree
Showing 12 changed files with 341 additions and 21 deletions.
9 changes: 9 additions & 0 deletions driver/config/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ const (
KeyExcludeNotBeforeClaim = "oauth2.exclude_not_before_claim"
KeyAllowedTopLevelClaims = "oauth2.allowed_top_level_claims"
KeyMirrorTopLevelClaims = "oauth2.mirror_top_level_claims"
KeyRefreshTokenRotationGracePeriod = "oauth2.grant.refresh_token.rotation_grace_period" // #nosec G101
KeyOAuth2GrantJWTIDOptional = "oauth2.grant.jwt.jti_optional"
KeyOAuth2GrantJWTIssuedDateOptional = "oauth2.grant.jwt.iat_optional"
KeyOAuth2GrantJWTMaxDuration = "oauth2.grant.jwt.max_ttl"
Expand Down Expand Up @@ -669,3 +670,11 @@ func (p *DefaultProvider) cookieSuffix(ctx context.Context, key string) string {

return p.getProvider(ctx).String(key) + suffix
}

func (p *DefaultProvider) RefreshTokenRotationGracePeriod(ctx context.Context) time.Duration {
var duration = p.getProvider(ctx).DurationF(KeyRefreshTokenRotationGracePeriod, 0)
if duration > time.Hour {
return time.Hour
}
return p.getProvider(ctx).DurationF(KeyRefreshTokenRotationGracePeriod, 0)
}
7 changes: 7 additions & 0 deletions driver/config/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,13 @@ func TestViperProviderValidates(t *testing.T) {
assert.Equal(t, "random_salt", c.SubjectIdentifierAlgorithmSalt(ctx))
assert.Equal(t, []string{"whatever"}, c.DefaultClientScope(ctx))

// refresh
assert.Equal(t, time.Duration(0), c.RefreshTokenRotationGracePeriod(ctx))
require.NoError(t, c.Set(ctx, KeyRefreshTokenRotationGracePeriod, "1s"))
assert.Equal(t, time.Second, c.RefreshTokenRotationGracePeriod(ctx))
require.NoError(t, c.Set(ctx, KeyRefreshTokenRotationGracePeriod, "2h"))
assert.Equal(t, time.Hour, c.RefreshTokenRotationGracePeriod(ctx))

// urls
assert.Equal(t, urlx.ParseOrPanic("https://issuer"), c.IssuerURL(ctx))
assert.Equal(t, urlx.ParseOrPanic("https://public/"), c.PublicURL(ctx))
Expand Down
13 changes: 13 additions & 0 deletions internal/config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,19 @@ oauth2:
session:
# store encrypted data in database, default true
encrypt_at_rest: true
## refresh_token_rotation
# By default Refresh Tokens are rotated and invalidated with each use. See https://datatracker.ietf.org/doc/html/draft-ietf-oauth-security-topics#section-4.13.2 for more details
refresh_token_rotation:
#
## grace_period
#
# Set the grace period for a refresh token to allow it to be used for the duration of this configuration after its
# first use. New refresh tokens will continue to be issued.
#
# Examples:
# - 5s
# - 1m
grace_period: 0s

# The secrets section configures secrets used for encryption and signing of several systems. All secrets can be rotated,
# for more information on this topic navigate to:
Expand Down
63 changes: 63 additions & 0 deletions oauth2/fosite_store_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ func TestHelperRunner(t *testing.T, store InternalRegistry, k string) {
t.Run(fmt.Sprintf("case=testHelperDeleteAccessTokens/db=%s", k), testHelperDeleteAccessTokens(store))
t.Run(fmt.Sprintf("case=testHelperRevokeAccessToken/db=%s", k), testHelperRevokeAccessToken(store))
t.Run(fmt.Sprintf("case=testFositeJWTBearerGrantStorage/db=%s", k), testFositeJWTBearerGrantStorage(store))
t.Run(fmt.Sprintf("case=testHelperRevokeRefreshTokenMaybeGracePeriod/db=%s", k), testHelperRevokeRefreshTokenMaybeGracePeriod(store))
}

func testHelperRequestIDMultiples(m InternalRegistry, _ string) func(t *testing.T) {
Expand Down Expand Up @@ -553,6 +554,68 @@ func testHelperRevokeAccessToken(x InternalRegistry) func(t *testing.T) {
}
}

func testHelperRevokeRefreshTokenMaybeGracePeriod(x InternalRegistry) func(t *testing.T) {

return func(t *testing.T) {
t.Run("Revokes refresh token when grace period not configured", func(t *testing.T) {
// SETUP
m := x.OAuth2Storage()
ctx := context.Background()

refreshTokenSession := fmt.Sprintf("refresh_token_%d", time.Now().Unix())
err := m.CreateRefreshTokenSession(ctx, refreshTokenSession, &defaultRequest)
assert.NoError(t, err, "precondition failed: could not create refresh token session")

// ACT
err = m.RevokeRefreshTokenMaybeGracePeriod(ctx, defaultRequest.GetID(), refreshTokenSession)
assert.NoError(t, err)

tmpSession := new(fosite.Session)
_, err = m.GetRefreshTokenSession(ctx, refreshTokenSession, *tmpSession)

// ASSERT
// a revoked refresh token returns an error when getting the token again
assert.Error(t, err)
assert.True(t, errors.Is(err, fosite.ErrInactiveToken))
})

t.Run("refresh token enters grace period when configured,", func(t *testing.T) {
ctx := context.Background()

// SETUP
x.Config().MustSet(ctx, "oauth2.refresh_token_rotation.grace_period", "1m")

// always reset back to the default
t.Cleanup(func() {
x.Config().MustSet(ctx, "oauth2.refresh_token_rotation.grace_period", "0m")
})

m := x.OAuth2Storage()

refreshTokenSession := fmt.Sprintf("refresh_token_%d_with_grace_period", time.Now().Unix())
err := m.CreateRefreshTokenSession(ctx, refreshTokenSession, &defaultRequest)
assert.NoError(t, err, "precondition failed: could not create refresh token session")

// ACT
assert.NoError(t, m.RevokeRefreshTokenMaybeGracePeriod(ctx, defaultRequest.GetID(), refreshTokenSession))
assert.NoError(t, m.RevokeRefreshTokenMaybeGracePeriod(ctx, defaultRequest.GetID(), refreshTokenSession))
assert.NoError(t, m.RevokeRefreshTokenMaybeGracePeriod(ctx, defaultRequest.GetID(), refreshTokenSession))

tmpSession := new(fosite.Session)
req, err := m.GetRefreshTokenSession(ctx, refreshTokenSession, *tmpSession)
assert.NoError(t, err)

// ASSERT
// when grace period is configured the refresh token can be obtained within
// the grace period without error
assert.NoError(t, err)

assert.Equal(t, defaultRequest.GetID(), req.GetID())
})
}

}

func testHelperCreateGetDeletePKCERequestSession(x InternalRegistry) func(t *testing.T) {
return func(t *testing.T) {
m := x.OAuth2Storage()
Expand Down
149 changes: 147 additions & 2 deletions oauth2/oauth2_auth_code_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,9 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) {
}

assertRefreshToken := func(t *testing.T, token *oauth2.Token, c *oauth2.Config, expectedExp time.Time) {
actualExp, err := strconv.ParseInt(testhelpers.IntrospectToken(t, c, token.RefreshToken, adminTS).Get("exp").String(), 10, 64)
require.NoError(t, err)
introspect := testhelpers.IntrospectToken(t, c, token.RefreshToken, adminTS)
actualExp, err := strconv.ParseInt(introspect.Get("exp").String(), 10, 64)
require.NoError(t, err, "%s", introspect)
requirex.EqualTime(t, expectedExp, time.Unix(actualExp, 0), time.Second)
}

Expand Down Expand Up @@ -330,6 +331,150 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) {
})
})

t.Run("case=graceful token rotation", func(t *testing.T) {
run := func(t *testing.T, strategy string) {
reg.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "5s")
t.Cleanup(func() {
reg.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, nil)
})

c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler))
testhelpers.NewLoginConsentUI(t, reg.Config(),
acceptLoginHandler(t, c, subject, nil),
acceptConsentHandler(t, c, subject, nil),
)

issueTokens := func(t *testing.T) *oauth2.Token {
code, _ := getAuthorizeCode(t, conf, nil, oauth2.SetAuthURLParam("nonce", nonce))
require.NotEmpty(t, code)
token, err := conf.Exchange(context.Background(), code)
iat := time.Now()
require.NoError(t, err)

introspectAccessToken(t, conf, token, subject)
assertJWTAccessToken(t, strategy, conf, token, subject, iat.Add(reg.Config().GetAccessTokenLifespan(ctx)), `["hydra","offline","openid"]`)
assertIDToken(t, token, conf, subject, nonce, iat.Add(reg.Config().GetIDTokenLifespan(ctx)))
assertRefreshToken(t, token, conf, iat.Add(reg.Config().GetRefreshTokenLifespan(ctx)))
return token
}

refreshTokens := func(t *testing.T, token *oauth2.Token) *oauth2.Token {
require.NotEmpty(t, token.RefreshToken)
token.Expiry = token.Expiry.Add(-time.Hour * 24)
iat := time.Now()
refreshedToken, err := conf.TokenSource(context.Background(), token).Token()
require.NoError(t, err)

require.NotEqual(t, token.AccessToken, refreshedToken.AccessToken)
require.NotEqual(t, token.RefreshToken, refreshedToken.RefreshToken)
require.NotEqual(t, token.Extra("id_token"), refreshedToken.Extra("id_token"))

introspectAccessToken(t, conf, refreshedToken, subject)
assertJWTAccessToken(t, strategy, conf, refreshedToken, subject, iat.Add(reg.Config().GetAccessTokenLifespan(ctx)), `["hydra","offline","openid"]`)
assertIDToken(t, refreshedToken, conf, subject, nonce, iat.Add(reg.Config().GetIDTokenLifespan(ctx)))
assertRefreshToken(t, refreshedToken, conf, iat.Add(reg.Config().GetRefreshTokenLifespan(ctx)))
return refreshedToken
}

t.Run("followup=successfully perform refresh token flow", func(t *testing.T) {
start := time.Now()

token := issueTokens(t)
var first, second *oauth2.Token
t.Run("followup=first refresh", func(t *testing.T) {
first = refreshTokens(t, token)
})

t.Run("followup=second refresh", func(t *testing.T) {
second = refreshTokens(t, token)
})

// Sleep until the grace period is over
time.Sleep(time.Until(start.Add(5*time.Second + time.Millisecond*10)))
t.Run("followup=refresh failure invalidates all tokens", func(t *testing.T) {
_, err := conf.TokenSource(context.Background(), token).Token()
assert.Error(t, err)

i := testhelpers.IntrospectToken(t, conf, first.AccessToken, adminTS)
assert.False(t, i.Get("active").Bool(), "%s", i)

i = testhelpers.IntrospectToken(t, conf, second.AccessToken, adminTS)
assert.False(t, i.Get("active").Bool(), "%s", i)

i = testhelpers.IntrospectToken(t, conf, first.RefreshToken, adminTS)
assert.False(t, i.Get("active").Bool(), "%s", i)

i = testhelpers.IntrospectToken(t, conf, second.RefreshToken, adminTS)
assert.False(t, i.Get("active").Bool(), "%s", i)
})
})

t.Run("followup=graceful refresh tokens are all refreshed", func(t *testing.T) {
start := time.Now()
token := issueTokens(t)
var a1Refresh, b1Refresh, a2RefreshA, a2RefreshB, b2RefreshA, b2RefreshB *oauth2.Token
t.Run("followup=first refresh", func(t *testing.T) {
a1Refresh = refreshTokens(t, token)
})

t.Run("followup=second refresh", func(t *testing.T) {
b1Refresh = refreshTokens(t, token)
})

t.Run("followup=first refresh from first refresh", func(t *testing.T) {
a2RefreshA = refreshTokens(t, a1Refresh)
})

t.Run("followup=second refresh from first refresh", func(t *testing.T) {
a2RefreshB = refreshTokens(t, a1Refresh)
})

t.Run("followup=first refresh from second refresh", func(t *testing.T) {
b2RefreshA = refreshTokens(t, b1Refresh)
})

t.Run("followup=second refresh from second refresh", func(t *testing.T) {
b2RefreshB = refreshTokens(t, b1Refresh)
})

// Sleep until the grace period is over
time.Sleep(time.Until(start.Add(5*time.Second + time.Millisecond*10)))
t.Run("followup=refresh failure invalidates all tokens", func(t *testing.T) {
_, err := conf.TokenSource(context.Background(), token).Token()
assert.Error(t, err)

for k, token := range []*oauth2.Token{
a1Refresh, b1Refresh, a2RefreshA, a2RefreshB, b2RefreshA, b2RefreshB,
} {
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
i := testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS)
assert.False(t, i.Get("active").Bool(), "%s", i)

i = testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS)
assert.False(t, i.Get("active").Bool(), "%s", i)

i = testhelpers.IntrospectToken(t, conf, token.RefreshToken, adminTS)
assert.False(t, i.Get("active").Bool(), "%s", i)

i = testhelpers.IntrospectToken(t, conf, token.RefreshToken, adminTS)
assert.False(t, i.Get("active").Bool(), "%s", i)
})
}
})
})
}

t.Run("strategy=jwt", func(t *testing.T) {
reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "jwt")
run(t, "jwt")
})

t.Run("strategy=opaque", func(t *testing.T) {
reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque")
run(t, "opaque")
})
})

t.Run("case=perform authorize code flow with verifable credentials", func(t *testing.T) {
// Make sure we test against all crypto suites that we advertise.
cfg, _, err := publicClient.OidcAPI.DiscoverOidcConfiguration(ctx).Execute()
Expand Down
20 changes: 10 additions & 10 deletions persistence/sql/migratest/migration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ func TestMigrations(t *testing.T) {

t.Run("case=hydra_oauth2_authentication_session", func(t *testing.T) {
ss := []flow.LoginSession{}
c.All(&ss)
require.NoError(t, c.All(&ss))
require.Equal(t, 17, len(ss))

for _, s := range ss {
Expand All @@ -157,7 +157,7 @@ func TestMigrations(t *testing.T) {

t.Run("case=hydra_oauth2_obfuscated_authentication_session", func(t *testing.T) {
ss := []consent.ForcedObfuscatedLoginSession{}
c.All(&ss)
require.NoError(t, c.All(&ss))
require.Equal(t, 13, len(ss))

for _, s := range ss {
Expand All @@ -169,7 +169,7 @@ func TestMigrations(t *testing.T) {

t.Run("case=hydra_oauth2_logout_request", func(t *testing.T) {
lrs := []flow.LogoutRequest{}
c.All(&lrs)
require.NoError(t, c.All(&lrs))
require.Equal(t, 7, len(lrs))

for _, s := range lrs {
Expand All @@ -182,7 +182,7 @@ func TestMigrations(t *testing.T) {

t.Run("case=hydra_oauth2_jti_blacklist", func(t *testing.T) {
bjtis := []oauth2.BlacklistedJTI{}
c.All(&bjtis)
require.NoError(t, c.All(&bjtis))
require.Equal(t, 1, len(bjtis))
for _, bjti := range bjtis {
testhelpersuuid.AssertUUID(t, bjti.NID)
Expand All @@ -194,7 +194,7 @@ func TestMigrations(t *testing.T) {

t.Run("case=hydra_oauth2_access", func(t *testing.T) {
as := []sql.OAuth2RequestSQL{}
c.RawQuery("SELECT * FROM hydra_oauth2_access").All(&as)
require.NoError(t, c.RawQuery("SELECT * FROM hydra_oauth2_access").All(&as))
require.Equal(t, 13, len(as))

for _, a := range as {
Expand All @@ -210,7 +210,7 @@ func TestMigrations(t *testing.T) {

t.Run("case=hydra_oauth2_refresh", func(t *testing.T) {
rs := []sql.OAuth2RequestSQL{}
c.RawQuery("SELECT * FROM hydra_oauth2_refresh").All(&rs)
require.NoError(t, c.RawQuery(`SELECT signature, nid, request_id, challenge_id, requested_at, client_id, scope, granted_scope, requested_audience, granted_audience, form_data, subject, active, session_data, expires_at FROM hydra_oauth2_refresh`).All(&rs))
require.Equal(t, 13, len(rs))

for _, r := range rs {
Expand All @@ -226,7 +226,7 @@ func TestMigrations(t *testing.T) {

t.Run("case=hydra_oauth2_code", func(t *testing.T) {
cs := []sql.OAuth2RequestSQL{}
c.RawQuery("SELECT * FROM hydra_oauth2_code").All(&cs)
require.NoError(t, c.RawQuery("SELECT * FROM hydra_oauth2_code").All(&cs))
require.Equal(t, 13, len(cs))

for _, c := range cs {
Expand All @@ -242,7 +242,7 @@ func TestMigrations(t *testing.T) {

t.Run("case=hydra_oauth2_oidc", func(t *testing.T) {
os := []sql.OAuth2RequestSQL{}
c.RawQuery("SELECT * FROM hydra_oauth2_oidc").All(&os)
require.NoError(t, c.RawQuery("SELECT * FROM hydra_oauth2_oidc").All(&os))
require.Equal(t, 13, len(os))

for _, o := range os {
Expand All @@ -258,7 +258,7 @@ func TestMigrations(t *testing.T) {

t.Run("case=hydra_oauth2_pkce", func(t *testing.T) {
ps := []sql.OAuth2RequestSQL{}
c.RawQuery("SELECT * FROM hydra_oauth2_pkce").All(&ps)
require.NoError(t, c.RawQuery("SELECT * FROM hydra_oauth2_pkce").All(&ps))
require.Equal(t, 11, len(ps))

for _, p := range ps {
Expand All @@ -274,7 +274,7 @@ func TestMigrations(t *testing.T) {

t.Run("case=networks", func(t *testing.T) {
ns := []networkx.Network{}
c.RawQuery("SELECT * FROM networks").All(&ns)
require.NoError(t, c.RawQuery("SELECT * FROM networks").All(&ns))
require.Equal(t, 1, len(ns))
for _, n := range ns {
testhelpersuuid.AssertUUID(t, n.ID)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE hydra_oauth2_refresh DROP COLUMN first_used_at;
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE hydra_oauth2_refresh ADD first_used_at TIMESTAMP DEFAULT NULL;
1 change: 1 addition & 0 deletions persistence/sql/persister.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ type (
contextx.Provider
x.RegistryLogger
x.TracingProvider
config.Provider
}
)

Expand Down
Loading

0 comments on commit 3d7414e

Please sign in to comment.