From 9d66008099ac9bbed5fbd269218955414ae5563f Mon Sep 17 00:00:00 2001 From: Bill Robbins <23531640+bill-robbins-ss@users.noreply.github.com> Date: Wed, 27 Oct 2021 17:24:25 -0600 Subject: [PATCH] feat: graceful refresh token rotation This patch adds a configuration flag which enables graceful refresh token rotation. Previously, refresh tokens could only be used once. On reuse, all tokens of that chain would be revoked. This is particularly challenging in environments, where it's difficult to make guarantees on synchronization. This could lead to refresh tokens being sent twice due to some parallel execution. To resolve this, refresh tokens can now be graceful by changing `oauth2.grant.refresh_token.grace_period=10s` (example value). During this time, a refresh token can be used multiple times to generate new refresh, ID, and access tokens. All tokens will correctly be invalidated, when the refresh token is re-used after the grace period expires, or when the delete consent endpoint is used. Closes #1831 #3770 --- driver/config/provider.go | 9 ++ driver/config/provider_test.go | 7 + internal/config/config.yaml | 13 ++ oauth2/fosite_store_helpers.go | 63 ++++++++ oauth2/oauth2_auth_code_test.go | 149 +++++++++++++++++- persistence/sql/migratest/migration_test.go | 20 +-- ...efresh_token_in_grace_period_flag.down.sql | 1 + ..._refresh_token_in_grace_period_flag.up.sql | 1 + persistence/sql/persister.go | 1 + persistence/sql/persister_oauth2.go | 74 ++++++++- spec/config.json | 19 ++- x/fosite_storer.go | 5 +- 12 files changed, 341 insertions(+), 21 deletions(-) create mode 100644 persistence/sql/migrations/20241014121000000000_add_refresh_token_in_grace_period_flag.down.sql create mode 100644 persistence/sql/migrations/20241014121000000000_add_refresh_token_in_grace_period_flag.up.sql diff --git a/driver/config/provider.go b/driver/config/provider.go index a6abecaad0e..40a4a5206a8 100644 --- a/driver/config/provider.go +++ b/driver/config/provider.go @@ -102,6 +102,7 @@ const ( KeyExcludeNotBeforeClaim = "oauth2.exclude_not_before_claim" KeyAllowedTopLevelClaims = "oauth2.allowed_top_level_claims" KeyMirrorTopLevelClaims = "oauth2.mirror_top_level_claims" + KeyRefreshTokenRotationGracePeriod = "oauth2.grant.refresh_token.rotation_grace_period" // #nosec G101 KeyOAuth2GrantJWTIDOptional = "oauth2.grant.jwt.jti_optional" KeyOAuth2GrantJWTIssuedDateOptional = "oauth2.grant.jwt.iat_optional" KeyOAuth2GrantJWTMaxDuration = "oauth2.grant.jwt.max_ttl" @@ -669,3 +670,11 @@ func (p *DefaultProvider) cookieSuffix(ctx context.Context, key string) string { return p.getProvider(ctx).String(key) + suffix } + +func (p *DefaultProvider) RefreshTokenRotationGracePeriod(ctx context.Context) time.Duration { + var duration = p.getProvider(ctx).DurationF(KeyRefreshTokenRotationGracePeriod, 0) + if duration > time.Hour { + return time.Hour + } + return p.getProvider(ctx).DurationF(KeyRefreshTokenRotationGracePeriod, 0) +} diff --git a/driver/config/provider_test.go b/driver/config/provider_test.go index 8e5c44a9e2e..82424aa8feb 100644 --- a/driver/config/provider_test.go +++ b/driver/config/provider_test.go @@ -291,6 +291,13 @@ func TestViperProviderValidates(t *testing.T) { assert.Equal(t, "random_salt", c.SubjectIdentifierAlgorithmSalt(ctx)) assert.Equal(t, []string{"whatever"}, c.DefaultClientScope(ctx)) + // refresh + assert.Equal(t, time.Duration(0), c.RefreshTokenRotationGracePeriod()) + require.NoError(t, c.Set(ctx, KeyRefreshTokenRotationGracePeriod, "1s")) + assert.Equal(t, time.Second, c.RefreshTokenRotationGracePeriod()) + require.NoError(t, c.Set(ctx, KeyRefreshTokenRotationGracePeriod, "2h")) + assert.Equal(t, time.Hour, c.RefreshTokenRotationGracePeriod()) + // urls assert.Equal(t, urlx.ParseOrPanic("https://issuer"), c.IssuerURL(ctx)) assert.Equal(t, urlx.ParseOrPanic("https://public/"), c.PublicURL(ctx)) diff --git a/internal/config/config.yaml b/internal/config/config.yaml index f3e8bff399c..ad3e1ba74d6 100644 --- a/internal/config/config.yaml +++ b/internal/config/config.yaml @@ -402,6 +402,19 @@ oauth2: session: # store encrypted data in database, default true encrypt_at_rest: true + ## refresh_token_rotation + # By default Refresh Tokens are rotated and invalidated with each use. See https://datatracker.ietf.org/doc/html/draft-ietf-oauth-security-topics#section-4.13.2 for more details + refresh_token_rotation: + # + ## grace_period + # + # Set the grace period for a refresh token to allow it to be used for the duration of this configuration after its + # first use. New refresh tokens will continue to be issued. + # + # Examples: + # - 5s + # - 1m + grace_period: 0s # The secrets section configures secrets used for encryption and signing of several systems. All secrets can be rotated, # for more information on this topic navigate to: diff --git a/oauth2/fosite_store_helpers.go b/oauth2/fosite_store_helpers.go index 6b64d61991e..fb7162650c8 100644 --- a/oauth2/fosite_store_helpers.go +++ b/oauth2/fosite_store_helpers.go @@ -225,6 +225,7 @@ func TestHelperRunner(t *testing.T, store InternalRegistry, k string) { t.Run(fmt.Sprintf("case=testHelperDeleteAccessTokens/db=%s", k), testHelperDeleteAccessTokens(store)) t.Run(fmt.Sprintf("case=testHelperRevokeAccessToken/db=%s", k), testHelperRevokeAccessToken(store)) t.Run(fmt.Sprintf("case=testFositeJWTBearerGrantStorage/db=%s", k), testFositeJWTBearerGrantStorage(store)) + t.Run(fmt.Sprintf("case=testHelperRevokeRefreshTokenMaybeGracePeriod/db=%s", k), testHelperRevokeRefreshTokenMaybeGracePeriod(store)) } func testHelperRequestIDMultiples(m InternalRegistry, _ string) func(t *testing.T) { @@ -553,6 +554,68 @@ func testHelperRevokeAccessToken(x InternalRegistry) func(t *testing.T) { } } +func testHelperRevokeRefreshTokenMaybeGracePeriod(x InternalRegistry) func(t *testing.T) { + + return func(t *testing.T) { + t.Run("Revokes refresh token when grace period not configured", func(t *testing.T) { + // SETUP + m := x.OAuth2Storage() + ctx := context.Background() + + refreshTokenSession := fmt.Sprintf("refresh_token_%d", time.Now().Unix()) + err := m.CreateRefreshTokenSession(ctx, refreshTokenSession, &defaultRequest) + assert.NoError(t, err, "precondition failed: could not create refresh token session") + + // ACT + err = m.RevokeRefreshTokenMaybeGracePeriod(ctx, defaultRequest.GetID(), refreshTokenSession) + assert.NoError(t, err) + + tmpSession := new(fosite.Session) + _, err = m.GetRefreshTokenSession(ctx, refreshTokenSession, *tmpSession) + + // ASSERT + // a revoked refresh token returns an error when getting the token again + assert.Error(t, err) + assert.True(t, errors.Is(err, fosite.ErrInactiveToken)) + }) + + t.Run("refresh token enters grace period when configured,", func(t *testing.T) { + ctx := context.Background() + + // SETUP + x.Config().MustSet(ctx, "oauth2.refresh_token_rotation.grace_period", "1m") + + // always reset back to the default + t.Cleanup(func() { + x.Config().MustSet(ctx, "oauth2.refresh_token_rotation.grace_period", "0m") + }) + + m := x.OAuth2Storage() + + refreshTokenSession := fmt.Sprintf("refresh_token_%d_with_grace_period", time.Now().Unix()) + err := m.CreateRefreshTokenSession(ctx, refreshTokenSession, &defaultRequest) + assert.NoError(t, err, "precondition failed: could not create refresh token session") + + // ACT + assert.NoError(t, m.RevokeRefreshTokenMaybeGracePeriod(ctx, defaultRequest.GetID(), refreshTokenSession)) + assert.NoError(t, m.RevokeRefreshTokenMaybeGracePeriod(ctx, defaultRequest.GetID(), refreshTokenSession)) + assert.NoError(t, m.RevokeRefreshTokenMaybeGracePeriod(ctx, defaultRequest.GetID(), refreshTokenSession)) + + tmpSession := new(fosite.Session) + req, err := m.GetRefreshTokenSession(ctx, refreshTokenSession, *tmpSession) + assert.NoError(t, err) + + // ASSERT + // when grace period is configured the refresh token can be obtained within + // the grace period without error + assert.NoError(t, err) + + assert.Equal(t, defaultRequest.GetID(), req.GetID()) + }) + } + +} + func testHelperCreateGetDeletePKCERequestSession(x InternalRegistry) func(t *testing.T) { return func(t *testing.T) { m := x.OAuth2Storage() diff --git a/oauth2/oauth2_auth_code_test.go b/oauth2/oauth2_auth_code_test.go index aa6934062ed..971dd2a65a7 100644 --- a/oauth2/oauth2_auth_code_test.go +++ b/oauth2/oauth2_auth_code_test.go @@ -176,8 +176,9 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { } assertRefreshToken := func(t *testing.T, token *oauth2.Token, c *oauth2.Config, expectedExp time.Time) { - actualExp, err := strconv.ParseInt(testhelpers.IntrospectToken(t, c, token.RefreshToken, adminTS).Get("exp").String(), 10, 64) - require.NoError(t, err) + introspect := testhelpers.IntrospectToken(t, c, token.RefreshToken, adminTS) + actualExp, err := strconv.ParseInt(introspect.Get("exp").String(), 10, 64) + require.NoError(t, err, "%s", introspect) requirex.EqualTime(t, expectedExp, time.Unix(actualExp, 0), time.Second) } @@ -330,6 +331,150 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { }) }) + t.Run("case=graceful token rotation", func(t *testing.T) { + run := func(t *testing.T, strategy string) { + reg.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, "5s") + t.Cleanup(func() { + reg.Config().MustSet(ctx, config.KeyRefreshTokenRotationGracePeriod, nil) + }) + + c, conf := newOAuth2Client(t, reg, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + testhelpers.NewLoginConsentUI(t, reg.Config(), + acceptLoginHandler(t, c, subject, nil), + acceptConsentHandler(t, c, subject, nil), + ) + + issueTokens := func(t *testing.T) *oauth2.Token { + code, _ := getAuthorizeCode(t, conf, nil, oauth2.SetAuthURLParam("nonce", nonce)) + require.NotEmpty(t, code) + token, err := conf.Exchange(context.Background(), code) + iat := time.Now() + require.NoError(t, err) + + introspectAccessToken(t, conf, token, subject) + assertJWTAccessToken(t, strategy, conf, token, subject, iat.Add(reg.Config().GetAccessTokenLifespan(ctx)), `["hydra","offline","openid"]`) + assertIDToken(t, token, conf, subject, nonce, iat.Add(reg.Config().GetIDTokenLifespan(ctx))) + assertRefreshToken(t, token, conf, iat.Add(reg.Config().GetRefreshTokenLifespan(ctx))) + return token + } + + refreshTokens := func(t *testing.T, token *oauth2.Token) *oauth2.Token { + require.NotEmpty(t, token.RefreshToken) + token.Expiry = token.Expiry.Add(-time.Hour * 24) + iat := time.Now() + refreshedToken, err := conf.TokenSource(context.Background(), token).Token() + require.NoError(t, err) + + require.NotEqual(t, token.AccessToken, refreshedToken.AccessToken) + require.NotEqual(t, token.RefreshToken, refreshedToken.RefreshToken) + require.NotEqual(t, token.Extra("id_token"), refreshedToken.Extra("id_token")) + + introspectAccessToken(t, conf, refreshedToken, subject) + assertJWTAccessToken(t, strategy, conf, refreshedToken, subject, iat.Add(reg.Config().GetAccessTokenLifespan(ctx)), `["hydra","offline","openid"]`) + assertIDToken(t, refreshedToken, conf, subject, nonce, iat.Add(reg.Config().GetIDTokenLifespan(ctx))) + assertRefreshToken(t, refreshedToken, conf, iat.Add(reg.Config().GetRefreshTokenLifespan(ctx))) + return refreshedToken + } + + t.Run("followup=successfully perform refresh token flow", func(t *testing.T) { + start := time.Now() + + token := issueTokens(t) + var first, second *oauth2.Token + t.Run("followup=first refresh", func(t *testing.T) { + first = refreshTokens(t, token) + }) + + t.Run("followup=second refresh", func(t *testing.T) { + second = refreshTokens(t, token) + }) + + // Sleep until the grace period is over + time.Sleep(time.Until(start.Add(5*time.Second + time.Millisecond*10))) + t.Run("followup=refresh failure invalidates all tokens", func(t *testing.T) { + _, err := conf.TokenSource(context.Background(), token).Token() + assert.Error(t, err) + + i := testhelpers.IntrospectToken(t, conf, first.AccessToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + + i = testhelpers.IntrospectToken(t, conf, second.AccessToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + + i = testhelpers.IntrospectToken(t, conf, first.RefreshToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + + i = testhelpers.IntrospectToken(t, conf, second.RefreshToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + }) + }) + + t.Run("followup=graceful refresh tokens are all refreshed", func(t *testing.T) { + start := time.Now() + token := issueTokens(t) + var a1Refresh, b1Refresh, a2RefreshA, a2RefreshB, b2RefreshA, b2RefreshB *oauth2.Token + t.Run("followup=first refresh", func(t *testing.T) { + a1Refresh = refreshTokens(t, token) + }) + + t.Run("followup=second refresh", func(t *testing.T) { + b1Refresh = refreshTokens(t, token) + }) + + t.Run("followup=first refresh from first refresh", func(t *testing.T) { + a2RefreshA = refreshTokens(t, a1Refresh) + }) + + t.Run("followup=second refresh from first refresh", func(t *testing.T) { + a2RefreshB = refreshTokens(t, a1Refresh) + }) + + t.Run("followup=first refresh from second refresh", func(t *testing.T) { + b2RefreshA = refreshTokens(t, b1Refresh) + }) + + t.Run("followup=second refresh from second refresh", func(t *testing.T) { + b2RefreshB = refreshTokens(t, b1Refresh) + }) + + // Sleep until the grace period is over + time.Sleep(time.Until(start.Add(5*time.Second + time.Millisecond*10))) + t.Run("followup=refresh failure invalidates all tokens", func(t *testing.T) { + _, err := conf.TokenSource(context.Background(), token).Token() + assert.Error(t, err) + + for k, token := range []*oauth2.Token{ + a1Refresh, b1Refresh, a2RefreshA, a2RefreshB, b2RefreshA, b2RefreshB, + } { + t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { + i := testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + + i = testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + + i = testhelpers.IntrospectToken(t, conf, token.RefreshToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + + i = testhelpers.IntrospectToken(t, conf, token.RefreshToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + }) + } + }) + }) + } + + t.Run("strategy=jwt", func(t *testing.T) { + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "jwt") + run(t, "jwt") + }) + + t.Run("strategy=opaque", func(t *testing.T) { + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") + run(t, "opaque") + }) + }) + t.Run("case=perform authorize code flow with verifable credentials", func(t *testing.T) { // Make sure we test against all crypto suites that we advertise. cfg, _, err := publicClient.OidcAPI.DiscoverOidcConfiguration(ctx).Execute() diff --git a/persistence/sql/migratest/migration_test.go b/persistence/sql/migratest/migration_test.go index 1fa0ce3836d..8564cfab969 100644 --- a/persistence/sql/migratest/migration_test.go +++ b/persistence/sql/migratest/migration_test.go @@ -144,7 +144,7 @@ func TestMigrations(t *testing.T) { t.Run("case=hydra_oauth2_authentication_session", func(t *testing.T) { ss := []flow.LoginSession{} - c.All(&ss) + require.NoError(t, c.All(&ss)) require.Equal(t, 17, len(ss)) for _, s := range ss { @@ -157,7 +157,7 @@ func TestMigrations(t *testing.T) { t.Run("case=hydra_oauth2_obfuscated_authentication_session", func(t *testing.T) { ss := []consent.ForcedObfuscatedLoginSession{} - c.All(&ss) + require.NoError(t, c.All(&ss)) require.Equal(t, 13, len(ss)) for _, s := range ss { @@ -169,7 +169,7 @@ func TestMigrations(t *testing.T) { t.Run("case=hydra_oauth2_logout_request", func(t *testing.T) { lrs := []flow.LogoutRequest{} - c.All(&lrs) + require.NoError(t, c.All(&lrs)) require.Equal(t, 7, len(lrs)) for _, s := range lrs { @@ -182,7 +182,7 @@ func TestMigrations(t *testing.T) { t.Run("case=hydra_oauth2_jti_blacklist", func(t *testing.T) { bjtis := []oauth2.BlacklistedJTI{} - c.All(&bjtis) + require.NoError(t, c.All(&bjtis)) require.Equal(t, 1, len(bjtis)) for _, bjti := range bjtis { testhelpersuuid.AssertUUID(t, bjti.NID) @@ -194,7 +194,7 @@ func TestMigrations(t *testing.T) { t.Run("case=hydra_oauth2_access", func(t *testing.T) { as := []sql.OAuth2RequestSQL{} - c.RawQuery("SELECT * FROM hydra_oauth2_access").All(&as) + require.NoError(t, c.RawQuery("SELECT * FROM hydra_oauth2_access").All(&as)) require.Equal(t, 13, len(as)) for _, a := range as { @@ -210,7 +210,7 @@ func TestMigrations(t *testing.T) { t.Run("case=hydra_oauth2_refresh", func(t *testing.T) { rs := []sql.OAuth2RequestSQL{} - c.RawQuery("SELECT * FROM hydra_oauth2_refresh").All(&rs) + require.NoError(t, c.RawQuery(`SELECT signature, nid, request_id, challenge_id, requested_at, client_id, scope, granted_scope, requested_audience, granted_audience, form_data, subject, active, session_data, expires_at FROM hydra_oauth2_refresh`).All(&rs)) require.Equal(t, 13, len(rs)) for _, r := range rs { @@ -226,7 +226,7 @@ func TestMigrations(t *testing.T) { t.Run("case=hydra_oauth2_code", func(t *testing.T) { cs := []sql.OAuth2RequestSQL{} - c.RawQuery("SELECT * FROM hydra_oauth2_code").All(&cs) + require.NoError(t, c.RawQuery("SELECT * FROM hydra_oauth2_code").All(&cs)) require.Equal(t, 13, len(cs)) for _, c := range cs { @@ -242,7 +242,7 @@ func TestMigrations(t *testing.T) { t.Run("case=hydra_oauth2_oidc", func(t *testing.T) { os := []sql.OAuth2RequestSQL{} - c.RawQuery("SELECT * FROM hydra_oauth2_oidc").All(&os) + require.NoError(t, c.RawQuery("SELECT * FROM hydra_oauth2_oidc").All(&os)) require.Equal(t, 13, len(os)) for _, o := range os { @@ -258,7 +258,7 @@ func TestMigrations(t *testing.T) { t.Run("case=hydra_oauth2_pkce", func(t *testing.T) { ps := []sql.OAuth2RequestSQL{} - c.RawQuery("SELECT * FROM hydra_oauth2_pkce").All(&ps) + require.NoError(t, c.RawQuery("SELECT * FROM hydra_oauth2_pkce").All(&ps)) require.Equal(t, 11, len(ps)) for _, p := range ps { @@ -274,7 +274,7 @@ func TestMigrations(t *testing.T) { t.Run("case=networks", func(t *testing.T) { ns := []networkx.Network{} - c.RawQuery("SELECT * FROM networks").All(&ns) + require.NoError(t, c.RawQuery("SELECT * FROM networks").All(&ns)) require.Equal(t, 1, len(ns)) for _, n := range ns { testhelpersuuid.AssertUUID(t, n.ID) diff --git a/persistence/sql/migrations/20241014121000000000_add_refresh_token_in_grace_period_flag.down.sql b/persistence/sql/migrations/20241014121000000000_add_refresh_token_in_grace_period_flag.down.sql new file mode 100644 index 00000000000..a30a127e902 --- /dev/null +++ b/persistence/sql/migrations/20241014121000000000_add_refresh_token_in_grace_period_flag.down.sql @@ -0,0 +1 @@ +ALTER TABLE hydra_oauth2_refresh DROP COLUMN first_used_at; diff --git a/persistence/sql/migrations/20241014121000000000_add_refresh_token_in_grace_period_flag.up.sql b/persistence/sql/migrations/20241014121000000000_add_refresh_token_in_grace_period_flag.up.sql new file mode 100644 index 00000000000..8ae823047f7 --- /dev/null +++ b/persistence/sql/migrations/20241014121000000000_add_refresh_token_in_grace_period_flag.up.sql @@ -0,0 +1 @@ +ALTER TABLE hydra_oauth2_refresh ADD first_used_at TIMESTAMP DEFAULT NULL; diff --git a/persistence/sql/persister.go b/persistence/sql/persister.go index 93649fc46ef..98161c55cf6 100644 --- a/persistence/sql/persister.go +++ b/persistence/sql/persister.go @@ -60,6 +60,7 @@ type ( contextx.Provider x.RegistryLogger x.TracingProvider + config.Provider } ) diff --git a/persistence/sql/persister_oauth2.go b/persistence/sql/persister_oauth2.go index 6e1336b80de..0adce30cfb6 100644 --- a/persistence/sql/persister_oauth2.go +++ b/persistence/sql/persister_oauth2.go @@ -58,6 +58,10 @@ type ( // InternalExpiresAt denormalizes the expiry from the session to additionally store it as a row. InternalExpiresAt sqlxx.NullTime `db:"expires_at" json:"-"` } + OAuth2RefreshTable struct { + OAuth2RequestSQL + FirstUsedAt sql.NullTime `db:"first_used_at"` + } ) const ( @@ -72,6 +76,10 @@ func (r OAuth2RequestSQL) TableName() string { return "hydra_oauth2_" + string(r.Table) } +func (r OAuth2RefreshTable) TableName() string { + return "hydra_oauth2_refresh" +} + func (p *Persister) sqlSchemaFromRequest(ctx context.Context, signature string, r fosite.Requester, table tableName, expiresAt time.Time) (*OAuth2RequestSQL, error) { subject := "" if r.GetSession() == nil { @@ -122,6 +130,24 @@ func (p *Persister) sqlSchemaFromRequest(ctx context.Context, signature string, }, nil } +func (p *Persister) marshalSession(ctx context.Context, session fosite.Session) ([]byte, error) { + sessionBytes, err := json.Marshal(session) + if err != nil { + return nil, err + } + + if !p.config.EncryptSessionData(ctx) { + return sessionBytes, nil + } + + ciphertext, err := p.r.KeyCipher().Encrypt(ctx, sessionBytes, nil) + if err != nil { + return nil, err + } + + return []byte(ciphertext), nil +} + func (r *OAuth2RequestSQL) toRequest(ctx context.Context, session fosite.Session, p *Persister) (_ *fosite.Request, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.toRequest") defer otelx.End(span, &err) @@ -429,7 +455,34 @@ func (p *Persister) CreateRefreshTokenSession(ctx context.Context, signature str func (p *Persister) GetRefreshTokenSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetRefreshTokenSession") defer otelx.End(span, &err) - return p.findSessionBySignature(ctx, signature, session, sqlTableRefresh) + + r := OAuth2RefreshTable{OAuth2RequestSQL: OAuth2RequestSQL{Table: sqlTableRefresh}} + err = p.QueryWithNetwork(ctx).Where("signature = ?", signature).First(&r) + if errors.Is(err, sql.ErrNoRows) { + return nil, errorsx.WithStack(fosite.ErrNotFound) + } else if err != nil { + return nil, sqlcon.HandleError(err) + } + + fositeRequest, err := r.toRequest(ctx, session, p) + if err != nil { + return nil, err + } + + if r.Active { + return fositeRequest, nil + } + + if gracePeriod := p.r.Config().RefreshTokenRotationGracePeriod(ctx); gracePeriod > 0 && r.FirstUsedAt.Valid { + if r.FirstUsedAt.Time.Add(gracePeriod).Before(time.Now()) { + return fositeRequest, errorsx.WithStack(fosite.ErrInactiveToken) + } + + r.Active = true // We set active to true because we are in the grace period. + return r.toRequest(ctx, session, p) // And re-generate the request + } + + return fositeRequest, errorsx.WithStack(fosite.ErrInactiveToken) } func (p *Persister) DeleteRefreshTokenSession(ctx context.Context, signature string) (err error) { @@ -483,10 +536,25 @@ func (p *Persister) RevokeRefreshToken(ctx context.Context, id string) (err erro return p.deactivateSessionByRequestID(ctx, id, sqlTableRefresh) } -func (p *Persister) RevokeRefreshTokenMaybeGracePeriod(ctx context.Context, id string, _ string) (err error) { +func (p *Persister) RevokeRefreshTokenMaybeGracePeriod(ctx context.Context, id string, signature string) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeRefreshTokenMaybeGracePeriod") defer otelx.End(span, &err) - return p.deactivateSessionByRequestID(ctx, id, sqlTableRefresh) + + gracePeriod := p.config.RefreshTokenRotationGracePeriod(ctx) + if gracePeriod <= 0 { + return p.deactivateSessionByRequestID(ctx, id, sqlTableRefresh) + } + + /* #nosec G201 table is static */ + return sqlcon.HandleError( + p.Connection(ctx). + RawQuery( + fmt.Sprintf("UPDATE %s SET active=false, first_used_at = CURRENT_TIMESTAMP WHERE request_id=? AND nid = ? AND active=true", OAuth2RequestSQL{Table: sqlTableRefresh}.TableName()), + id, + p.NetworkID(ctx), + ). + Exec(), + ) } func (p *Persister) RevokeAccessToken(ctx context.Context, id string) (err error) { diff --git a/spec/config.json b/spec/config.json index 2445cbc6a24..72f81534c66 100644 --- a/spec/config.json +++ b/spec/config.json @@ -1068,6 +1068,21 @@ "type": "object", "additionalProperties": false, "properties": { + "refresh_token": { + "type": "object", + "properties": { + "grace_period": { + "title": "Refresh Token Rotation Grace Period", + "description": "Configures how long a Refresh Token remains valid after it has been used. The maximum value is one hour.", + "default": "0s", + "allOf": [ + { + "$ref": "#/definitions/duration" + } + ] + } + } + }, "jwt": { "type": "object", "additionalProperties": false, @@ -1122,8 +1137,8 @@ } ] } - } - }, + } + }, "secrets": { "type": "object", "additionalProperties": false, diff --git a/x/fosite_storer.go b/x/fosite_storer.go index 23654c519b9..546cfc98870 100644 --- a/x/fosite_storer.go +++ b/x/fosite_storer.go @@ -18,16 +18,13 @@ import ( type FositeStorer interface { fosite.Storage oauth2.CoreStorage + oauth2.TokenRevocationStorage openid.OpenIDConnectRequestStorage pkce.PKCERequestStorage rfc7523.RFC7523KeyStorage verifiable.NonceManager oauth2.ResourceOwnerPasswordCredentialsGrantStorage - RevokeRefreshToken(ctx context.Context, requestID string) error - - RevokeAccessToken(ctx context.Context, requestID string) error - // flush the access token requests from the database. // no data will be deleted after the 'notAfter' timeframe. FlushInactiveAccessTokens(ctx context.Context, notAfter time.Time, limit int, batchSize int) error