Skip to content

Commit

Permalink
feat: graceful refresh token rotation
Browse files Browse the repository at this point in the history
This patch adds a configuration flag which enables graceful refresh token rotation. Previously, refresh tokens could only be used once. On reuse, all tokens of that chain would be revoked.

This is particularly challenging in environments, where it's difficult to make guarantees on synchronization. This could lead to refresh tokens being sent twice due to some parallel execution.

To resolve this, refresh tokens can now be graceful by changing `oauth2.grant.refresh_token.grace_period=10s` (example value). During this time, a refresh token can be used multiple times to generate new refresh, ID, and access tokens.

All tokens will correctly be invalidated, when the refresh token is re-used after the grace period expires, or when the delete consent endpoint is used.

Closes #1831 #3770
  • Loading branch information
aeneasr committed Oct 14, 2024
1 parent 6b17bcc commit db1b1fd
Show file tree
Hide file tree
Showing 6 changed files with 224 additions and 101 deletions.
8 changes: 4 additions & 4 deletions driver/config/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ const (
KeyExcludeNotBeforeClaim = "oauth2.exclude_not_before_claim"
KeyAllowedTopLevelClaims = "oauth2.allowed_top_level_claims"
KeyMirrorTopLevelClaims = "oauth2.mirror_top_level_claims"
KeyRefreshTokenRotationGracePeriod = "oauth2.refresh_token_rotation.grace_period" // #nosec G101
KeyRefreshTokenRotationGracePeriod = "oauth2.grant.refresh_token.rotation_grace_period" // #nosec G101
KeyOAuth2GrantJWTIDOptional = "oauth2.grant.jwt.jti_optional"
KeyOAuth2GrantJWTIssuedDateOptional = "oauth2.grant.jwt.iat_optional"
KeyOAuth2GrantJWTMaxDuration = "oauth2.grant.jwt.max_ttl"
Expand Down Expand Up @@ -671,10 +671,10 @@ func (p *DefaultProvider) cookieSuffix(ctx context.Context, key string) string {
return p.getProvider(ctx).String(key) + suffix
}

func (p *DefaultProvider) RefreshTokenRotationGracePeriod() time.Duration {
var duration = p.p.DurationF(KeyRefreshTokenRotationGracePeriod, 0)
func (p *DefaultProvider) RefreshTokenRotationGracePeriod(ctx context.Context) time.Duration {
var duration = p.getProvider(ctx).DurationF(KeyRefreshTokenRotationGracePeriod, 0)
if duration > time.Hour {
return time.Hour
}
return p.p.DurationF(KeyRefreshTokenRotationGracePeriod, 0)
return p.getProvider(ctx).DurationF(KeyRefreshTokenRotationGracePeriod, 0)
}
149 changes: 147 additions & 2 deletions oauth2/oauth2_auth_code_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,9 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) {
}

assertRefreshToken := func(t *testing.T, token *oauth2.Token, c *oauth2.Config, expectedExp time.Time) {
actualExp, err := strconv.ParseInt(testhelpers.IntrospectToken(t, c, token.RefreshToken, adminTS).Get("exp").String(), 10, 64)
require.NoError(t, err)
introspect := testhelpers.IntrospectToken(t, c, token.RefreshToken, adminTS)
actualExp, err := strconv.ParseInt(introspect.Get("exp").String(), 10, 64)
require.NoError(t, err, "%s", introspect)
requirex.EqualTime(t, expectedExp, time.Unix(actualExp, 0), time.Second)
}

Expand Down Expand Up @@ -330,6 +331,150 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) {
})
})

t.Run("case=graceful token rotation", func(t *testing.T) {
run := func(t *testing.T, strategy string) {
reg.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "5s")
t.Cleanup(func() {
reg.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, nil)
})

c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler))
testhelpers.NewLoginConsentUI(t, reg.Config(),
acceptLoginHandler(t, c, subject, nil),
acceptConsentHandler(t, c, subject, nil),
)

issueTokens := func(t *testing.T) *oauth2.Token {
code, _ := getAuthorizeCode(t, conf, nil, oauth2.SetAuthURLParam("nonce", nonce))
require.NotEmpty(t, code)
token, err := conf.Exchange(context.Background(), code)
iat := time.Now()
require.NoError(t, err)

introspectAccessToken(t, conf, token, subject)
assertJWTAccessToken(t, strategy, conf, token, subject, iat.Add(reg.Config().GetAccessTokenLifespan(ctx)), `["hydra","offline","openid"]`)
assertIDToken(t, token, conf, subject, nonce, iat.Add(reg.Config().GetIDTokenLifespan(ctx)))
assertRefreshToken(t, token, conf, iat.Add(reg.Config().GetRefreshTokenLifespan(ctx)))
return token
}

refreshTokens := func(t *testing.T, token *oauth2.Token) *oauth2.Token {
require.NotEmpty(t, token.RefreshToken)
token.Expiry = token.Expiry.Add(-time.Hour * 24)
iat := time.Now()
refreshedToken, err := conf.TokenSource(context.Background(), token).Token()
require.NoError(t, err)

require.NotEqual(t, token.AccessToken, refreshedToken.AccessToken)
require.NotEqual(t, token.RefreshToken, refreshedToken.RefreshToken)
require.NotEqual(t, token.Extra("id_token"), refreshedToken.Extra("id_token"))

introspectAccessToken(t, conf, refreshedToken, subject)
assertJWTAccessToken(t, strategy, conf, refreshedToken, subject, iat.Add(reg.Config().GetAccessTokenLifespan(ctx)), `["hydra","offline","openid"]`)
assertIDToken(t, refreshedToken, conf, subject, nonce, iat.Add(reg.Config().GetIDTokenLifespan(ctx)))
assertRefreshToken(t, refreshedToken, conf, iat.Add(reg.Config().GetRefreshTokenLifespan(ctx)))
return refreshedToken
}

t.Run("followup=successfully perform refresh token flow", func(t *testing.T) {
start := time.Now()

token := issueTokens(t)
var first, second *oauth2.Token
t.Run("followup=first refresh", func(t *testing.T) {
first = refreshTokens(t, token)
})

t.Run("followup=second refresh", func(t *testing.T) {
second = refreshTokens(t, token)
})

// Sleep until the grace period is over
time.Sleep(time.Until(start.Add(5*time.Second + time.Millisecond*10)))
t.Run("followup=refresh failure invalidates all tokens", func(t *testing.T) {
_, err := conf.TokenSource(context.Background(), token).Token()
assert.Error(t, err)

i := testhelpers.IntrospectToken(t, conf, first.AccessToken, adminTS)
assert.False(t, i.Get("active").Bool(), "%s", i)

i = testhelpers.IntrospectToken(t, conf, second.AccessToken, adminTS)
assert.False(t, i.Get("active").Bool(), "%s", i)

i = testhelpers.IntrospectToken(t, conf, first.RefreshToken, adminTS)
assert.False(t, i.Get("active").Bool(), "%s", i)

i = testhelpers.IntrospectToken(t, conf, second.RefreshToken, adminTS)
assert.False(t, i.Get("active").Bool(), "%s", i)
})
})

t.Run("followup=graceful refresh tokens are all refreshed", func(t *testing.T) {
start := time.Now()
token := issueTokens(t)
var a1Refresh, b1Refresh, a2RefreshA, a2RefreshB, b2RefreshA, b2RefreshB *oauth2.Token
t.Run("followup=first refresh", func(t *testing.T) {
a1Refresh = refreshTokens(t, token)
})

t.Run("followup=second refresh", func(t *testing.T) {
b1Refresh = refreshTokens(t, token)
})

t.Run("followup=first refresh from first refresh", func(t *testing.T) {
a2RefreshA = refreshTokens(t, a1Refresh)
})

t.Run("followup=second refresh from first refresh", func(t *testing.T) {
a2RefreshB = refreshTokens(t, a1Refresh)
})

t.Run("followup=first refresh from second refresh", func(t *testing.T) {
b2RefreshA = refreshTokens(t, b1Refresh)
})

t.Run("followup=second refresh from second refresh", func(t *testing.T) {
b2RefreshB = refreshTokens(t, b1Refresh)
})

// Sleep until the grace period is over
time.Sleep(time.Until(start.Add(5*time.Second + time.Millisecond*10)))
t.Run("followup=refresh failure invalidates all tokens", func(t *testing.T) {
_, err := conf.TokenSource(context.Background(), token).Token()
assert.Error(t, err)

for k, token := range []*oauth2.Token{
a1Refresh, b1Refresh, a2RefreshA, a2RefreshB, b2RefreshA, b2RefreshB,
} {
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
i := testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS)
assert.False(t, i.Get("active").Bool(), "%s", i)

i = testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS)
assert.False(t, i.Get("active").Bool(), "%s", i)

i = testhelpers.IntrospectToken(t, conf, token.RefreshToken, adminTS)
assert.False(t, i.Get("active").Bool(), "%s", i)

i = testhelpers.IntrospectToken(t, conf, token.RefreshToken, adminTS)
assert.False(t, i.Get("active").Bool(), "%s", i)
})
}
})
})
}

t.Run("strategy=jwt", func(t *testing.T) {
reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "jwt")
run(t, "jwt")
})

t.Run("strategy=opaque", func(t *testing.T) {
reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque")
run(t, "opaque")
})
})

t.Run("case=perform authorize code flow with verifable credentials", func(t *testing.T) {
// Make sure we test against all crypto suites that we advertise.
cfg, _, err := publicClient.OidcAPI.DiscoverOidcConfiguration(ctx).Execute()
Expand Down
20 changes: 10 additions & 10 deletions persistence/sql/migratest/migration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ func TestMigrations(t *testing.T) {

t.Run("case=hydra_oauth2_authentication_session", func(t *testing.T) {
ss := []flow.LoginSession{}
c.All(&ss)
require.NoError(t, c.All(&ss))
require.Equal(t, 17, len(ss))

for _, s := range ss {
Expand All @@ -157,7 +157,7 @@ func TestMigrations(t *testing.T) {

t.Run("case=hydra_oauth2_obfuscated_authentication_session", func(t *testing.T) {
ss := []consent.ForcedObfuscatedLoginSession{}
c.All(&ss)
require.NoError(t, c.All(&ss))
require.Equal(t, 13, len(ss))

for _, s := range ss {
Expand All @@ -169,7 +169,7 @@ func TestMigrations(t *testing.T) {

t.Run("case=hydra_oauth2_logout_request", func(t *testing.T) {
lrs := []flow.LogoutRequest{}
c.All(&lrs)
require.NoError(t, c.All(&lrs))
require.Equal(t, 7, len(lrs))

for _, s := range lrs {
Expand All @@ -182,7 +182,7 @@ func TestMigrations(t *testing.T) {

t.Run("case=hydra_oauth2_jti_blacklist", func(t *testing.T) {
bjtis := []oauth2.BlacklistedJTI{}
c.All(&bjtis)
require.NoError(t, c.All(&bjtis))
require.Equal(t, 1, len(bjtis))
for _, bjti := range bjtis {
testhelpersuuid.AssertUUID(t, bjti.NID)
Expand All @@ -194,7 +194,7 @@ func TestMigrations(t *testing.T) {

t.Run("case=hydra_oauth2_access", func(t *testing.T) {
as := []sql.OAuth2RequestSQL{}
c.RawQuery("SELECT * FROM hydra_oauth2_access").All(&as)
require.NoError(t, c.RawQuery("SELECT * FROM hydra_oauth2_access").All(&as))
require.Equal(t, 13, len(as))

for _, a := range as {
Expand All @@ -210,7 +210,7 @@ func TestMigrations(t *testing.T) {

t.Run("case=hydra_oauth2_refresh", func(t *testing.T) {
rs := []sql.OAuth2RequestSQL{}
c.RawQuery("SELECT * FROM hydra_oauth2_refresh").All(&rs)
require.NoError(t, c.RawQuery(`SELECT signature, nid, request_id, challenge_id, requested_at, client_id, scope, granted_scope, requested_audience, granted_audience, form_data, subject, active, session_data, expires_at FROM hydra_oauth2_refresh`).All(&rs))
require.Equal(t, 13, len(rs))

for _, r := range rs {
Expand All @@ -226,7 +226,7 @@ func TestMigrations(t *testing.T) {

t.Run("case=hydra_oauth2_code", func(t *testing.T) {
cs := []sql.OAuth2RequestSQL{}
c.RawQuery("SELECT * FROM hydra_oauth2_code").All(&cs)
require.NoError(t, c.RawQuery("SELECT * FROM hydra_oauth2_code").All(&cs))
require.Equal(t, 13, len(cs))

for _, c := range cs {
Expand All @@ -242,7 +242,7 @@ func TestMigrations(t *testing.T) {

t.Run("case=hydra_oauth2_oidc", func(t *testing.T) {
os := []sql.OAuth2RequestSQL{}
c.RawQuery("SELECT * FROM hydra_oauth2_oidc").All(&os)
require.NoError(t, c.RawQuery("SELECT * FROM hydra_oauth2_oidc").All(&os))
require.Equal(t, 13, len(os))

for _, o := range os {
Expand All @@ -258,7 +258,7 @@ func TestMigrations(t *testing.T) {

t.Run("case=hydra_oauth2_pkce", func(t *testing.T) {
ps := []sql.OAuth2RequestSQL{}
c.RawQuery("SELECT * FROM hydra_oauth2_pkce").All(&ps)
require.NoError(t, c.RawQuery("SELECT * FROM hydra_oauth2_pkce").All(&ps))
require.Equal(t, 11, len(ps))

for _, p := range ps {
Expand All @@ -274,7 +274,7 @@ func TestMigrations(t *testing.T) {

t.Run("case=networks", func(t *testing.T) {
ns := []networkx.Network{}
c.RawQuery("SELECT * FROM networks").All(&ns)
require.NoError(t, c.RawQuery("SELECT * FROM networks").All(&ns))
require.Equal(t, 1, len(ns))
for _, n := range ns {
testhelpersuuid.AssertUUID(t, n.ID)
Expand Down
1 change: 1 addition & 0 deletions persistence/sql/persister.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ type (
contextx.Provider
x.RegistryLogger
x.TracingProvider
config.Provider
}
)

Expand Down
Loading

0 comments on commit db1b1fd

Please sign in to comment.