diff --git a/security-jwt/src/main/java/io/micronaut/security/token/jwt/generator/DefaultAccessRefreshTokenGenerator.java b/security-jwt/src/main/java/io/micronaut/security/token/jwt/generator/DefaultAccessRefreshTokenGenerator.java index cdd20e7922..e77139d4c4 100644 --- a/security-jwt/src/main/java/io/micronaut/security/token/jwt/generator/DefaultAccessRefreshTokenGenerator.java +++ b/security-jwt/src/main/java/io/micronaut/security/token/jwt/generator/DefaultAccessRefreshTokenGenerator.java @@ -29,6 +29,7 @@ import io.micronaut.security.token.jwt.render.TokenRenderer; import io.micronaut.security.token.refresh.RefreshTokenPersistence; import io.micronaut.security.token.validator.RefreshTokenValidator; +import jakarta.inject.Inject; import jakarta.inject.Singleton; import java.util.Map; import java.util.Optional; @@ -49,13 +50,28 @@ public class DefaultAccessRefreshTokenGenerator implements AccessRefreshTokenGen private static final Logger LOG = LoggerFactory.getLogger(DefaultAccessRefreshTokenGenerator.class); + /** + * @deprecated Not used anymore. + */ + @Deprecated protected final BeanContext beanContext; + protected final RefreshTokenGenerator refreshTokenGenerator; + protected final RefreshTokenPersistence refreshTokenPersistence; + protected final RefreshTokenValidator refreshTokenValidator; protected final ClaimsGenerator claimsGenerator; protected final AccessTokenConfiguration accessTokenConfiguration; protected final TokenRenderer tokenRenderer; protected final TokenGenerator tokenGenerator; - protected final ApplicationEventPublisher eventPublisher; + + /** + * @deprecated Not used any more. + */ + @Deprecated + protected final ApplicationEventPublisher eventPublisher; + + protected final ApplicationEventPublisher refreshTokenGeneratedEventPublisher; + protected final ApplicationEventPublisher accessTokenGeneratedEventPublisher; /** * @@ -66,7 +82,9 @@ public class DefaultAccessRefreshTokenGenerator implements AccessRefreshTokenGen * @param refreshTokenGenerator The refresh token generator * @param claimsGenerator Claims generator * @param eventPublisher The Application event publisher + * @deprecated Use {@link DefaultAccessRefreshTokenGenerator(AccessTokenConfiguration, TokenRenderer, TokenGenerator, RefreshTokenGenerator, RefreshTokenPersistence, RefreshTokenValidator, ClaimsGenerator, ApplicationEventPublisher, ApplicationEventPublisher)} instead. */ + @Deprecated public DefaultAccessRefreshTokenGenerator(AccessTokenConfiguration accessTokenConfiguration, TokenRenderer tokenRenderer, TokenGenerator tokenGenerator, @@ -74,15 +92,53 @@ public DefaultAccessRefreshTokenGenerator(AccessTokenConfiguration accessTokenCo @Nullable RefreshTokenGenerator refreshTokenGenerator, ClaimsGenerator claimsGenerator, ApplicationEventPublisher eventPublisher) { + this(accessTokenConfiguration, + tokenRenderer, + tokenGenerator, + refreshTokenGenerator, + beanContext.findBean(RefreshTokenPersistence.class).orElse(null), + beanContext.findBean(RefreshTokenValidator.class).orElse(null), + claimsGenerator, + eventPublisher, + eventPublisher); + } + + /** + * + * @param accessTokenConfiguration The access token generator config + * @param tokenRenderer The token renderer + * @param tokenGenerator The token generator + * @param refreshTokenGenerator The refresh token generator + * @param refreshTokenPersistence Refresh Token Persistence + * @param refreshTokenValidator Refresh Token Validator + * @param claimsGenerator Claims generator + * @param refreshTokenGeneratedEventPublisher The Application event publisher for {@link RefreshTokenGeneratedEvent}. + * @param accessTokenGeneratedEventPublisher The Application event publisher for {@link AccessTokenGeneratedEvent}. + */ + @Inject + public DefaultAccessRefreshTokenGenerator(AccessTokenConfiguration accessTokenConfiguration, + TokenRenderer tokenRenderer, + TokenGenerator tokenGenerator, + @Nullable RefreshTokenGenerator refreshTokenGenerator, + @Nullable RefreshTokenPersistence refreshTokenPersistence, + @Nullable RefreshTokenValidator refreshTokenValidator, + ClaimsGenerator claimsGenerator, + ApplicationEventPublisher refreshTokenGeneratedEventPublisher, + ApplicationEventPublisher accessTokenGeneratedEventPublisher) { this.accessTokenConfiguration = accessTokenConfiguration; this.tokenRenderer = tokenRenderer; this.tokenGenerator = tokenGenerator; - this.beanContext = beanContext; this.refreshTokenGenerator = refreshTokenGenerator; + this.refreshTokenPersistence = refreshTokenPersistence; + this.refreshTokenValidator = refreshTokenValidator; this.claimsGenerator = claimsGenerator; - this.eventPublisher = eventPublisher; + this.refreshTokenGeneratedEventPublisher = refreshTokenGeneratedEventPublisher; + this.accessTokenGeneratedEventPublisher = accessTokenGeneratedEventPublisher; + this.beanContext = null; + this.eventPublisher = null; } + /** * Generate an {@link AccessRefreshToken} response for the given * user details. @@ -103,24 +159,22 @@ public Optional generate(@NonNull Authentication authenticat */ @NonNull public Optional generateRefreshToken(@NonNull Authentication authentication) { - Optional refreshToken = Optional.empty(); String msg = "Skipped refresh token generation because no {} implementation is present"; - if (beanContext.containsBean(RefreshTokenValidator.class)) { - if (beanContext.containsBean(RefreshTokenPersistence.class)) { - if (refreshTokenGenerator != null) { - String key = refreshTokenGenerator.createKey(authentication); - refreshToken = refreshTokenGenerator.generate(authentication, key); - refreshToken.ifPresent(t -> eventPublisher.publishEvent(new RefreshTokenGeneratedEvent(authentication, key))); - } else { - debug(LOG, msg, RefreshTokenGenerator.class.getName()); - } - } else { - debug(LOG, msg, RefreshTokenPersistence.class.getName()); - } - } else { + if (refreshTokenValidator == null) { debug(LOG, msg, RefreshTokenValidator.class.getName()); + return Optional.empty(); } - + if (refreshTokenPersistence == null) { + debug(LOG, msg, RefreshTokenPersistence.class.getName()); + return Optional.empty(); + } + if (refreshTokenGenerator == null) { + debug(LOG, msg, RefreshTokenGenerator.class.getName()); + return Optional.empty(); + } + String key = refreshTokenGenerator.createKey(authentication); + Optional refreshToken = refreshTokenGenerator.generate(authentication, key); + refreshToken.ifPresent(t -> refreshTokenGeneratedEventPublisher.publishEvent(new RefreshTokenGeneratedEvent(authentication, key))); return refreshToken; } @@ -145,7 +199,7 @@ public Optional generate(@Nullable String refreshToken, @Non return Optional.empty(); } String accessToken = optionalAccessToken.get(); - eventPublisher.publishEvent(new AccessTokenGeneratedEvent(accessToken)); + accessTokenGeneratedEventPublisher.publishEvent(new AccessTokenGeneratedEvent(accessToken)); return Optional.of(tokenRenderer.render(accessTokenExpiration(oldClaims), accessToken, refreshToken)); } @@ -166,7 +220,7 @@ public Optional generate(@Nullable String refreshToken, @Non } String accessToken = optionalAccessToken.get(); - eventPublisher.publishEvent(new AccessTokenGeneratedEvent(accessToken)); + accessTokenGeneratedEventPublisher.publishEvent(new AccessTokenGeneratedEvent(accessToken)); return Optional.of(tokenRenderer.render(authentication, accessTokenExpiration(authentication), accessToken, refreshToken)); } diff --git a/security-jwt/src/test/groovy/io/micronaut/security/token/jwt/signature/jwks/JwksCacheSpec.groovy b/security-jwt/src/test/groovy/io/micronaut/security/token/jwt/signature/jwks/JwksCacheSpec.groovy index 7f91e3d5a1..23de7a0138 100644 --- a/security-jwt/src/test/groovy/io/micronaut/security/token/jwt/signature/jwks/JwksCacheSpec.groovy +++ b/security-jwt/src/test/groovy/io/micronaut/security/token/jwt/signature/jwks/JwksCacheSpec.groovy @@ -1,7 +1,11 @@ package io.micronaut.security.token.jwt.signature.jwks import com.fasterxml.jackson.databind.ObjectMapper -import com.nimbusds.jose.* +import com.nimbusds.jose.JWSAlgorithm +import com.nimbusds.jose.JWSHeader +import com.nimbusds.jose.JWSObject +import com.nimbusds.jose.JWSSigner +import com.nimbusds.jose.Payload import com.nimbusds.jose.crypto.MACSigner import com.nimbusds.jose.jwk.JWK import com.nimbusds.jose.jwk.KeyUse @@ -42,8 +46,6 @@ import io.micronaut.security.token.jwt.signature.rsa.RSASignatureGeneratorConfig import jakarta.inject.Named import jakarta.inject.Singleton import org.reactivestreams.Publisher -import spock.lang.AutoCleanup -import spock.lang.Shared import spock.lang.Specification import java.security.SecureRandom @@ -52,47 +54,6 @@ import java.security.interfaces.RSAPublicKey class JwksCacheSpec extends Specification { - @Shared - Map authServerConfig = [ - 'micronaut.http.client.read-timeout': '30s', - 'micronaut.security.authentication': 'bearer', - ] - - @AutoCleanup - @Shared - EmbeddedServer googleEmbeddedServer = ApplicationContext.run(EmbeddedServer, authServerConfig + [ - 'spec.name': 'GoogleJwksCacheSpec', - 'endpoints.refresh.enabled': StringUtils.TRUE, - 'endpoints.refresh.sensitive': StringUtils.FALSE, - ]) - - @AutoCleanup - @Shared - EmbeddedServer cognitoEmbeddedServer = ApplicationContext.run(EmbeddedServer, authServerConfig + [ - 'spec.name': 'CognitoJwksCacheSpec', - 'endpoints.refresh.enabled': StringUtils.TRUE, - 'endpoints.refresh.sensitive': StringUtils.FALSE, - ]) - - @AutoCleanup - @Shared - EmbeddedServer appleEmbeddedServer = ApplicationContext.run(EmbeddedServer, authServerConfig + [ - 'spec.name': 'AppleJwksCacheSpec', - ]) - - @AutoCleanup - @Shared - EmbeddedServer embeddedServer = ApplicationContext.run(EmbeddedServer, [ - 'micronaut.http.client.read-timeout': '30s', - 'micronaut.security.token.jwt.signatures.jwks.apple.url': "http://localhost:${appleEmbeddedServer.port}/keys", - 'micronaut.security.token.jwt.signatures.jwks.apple.cache-expiration': 5, - 'micronaut.security.token.jwt.signatures.jwks.google.url': "http://localhost:${googleEmbeddedServer.port}/keys", - 'micronaut.security.token.jwt.signatures.jwks.google.cache-expiration': 5, - 'micronaut.security.token.jwt.signatures.jwks.cognito.url': "http://localhost:${cognitoEmbeddedServer.port}/keys", - 'micronaut.security.token.jwt.signatures.jwks.cognito.cache-expiration': 5, - 'spec.name': 'JwksCacheSpec' - ]) - private void hello(BlockingHttpClient client, String token, boolean doAssertion = true) { HttpRequest request = HttpRequest.GET('/hello').bearerAuth(token) if (doAssertion) { @@ -109,20 +70,55 @@ class JwksCacheSpec extends Specification { void "JWK are cached"() { given: - HttpClient googleHttpClient = embeddedServer.applicationContext.createBean(HttpClient, googleEmbeddedServer.URL) - BlockingHttpClient googleClient = googleHttpClient.toBlocking() + Map authServerConfig = [ + 'micronaut.http.client.read-timeout': '30s', + 'micronaut.security.authentication': 'bearer', + ] - HttpClient appleHttpClient = embeddedServer.applicationContext.createBean(HttpClient, appleEmbeddedServer.URL) - BlockingHttpClient appleClient = appleHttpClient.toBlocking() + EmbeddedServer googleEmbeddedServer = ApplicationContext.run(EmbeddedServer, authServerConfig + [ + 'spec.name': 'GoogleJwksCacheSpec', + 'endpoints.refresh.enabled': StringUtils.TRUE, + 'endpoints.refresh.sensitive': StringUtils.FALSE, + ]) + + and: + EmbeddedServer cognitoEmbeddedServer = ApplicationContext.run(EmbeddedServer, authServerConfig + [ + 'spec.name': 'CognitoJwksCacheSpec', + 'endpoints.refresh.enabled': StringUtils.TRUE, + 'endpoints.refresh.sensitive': StringUtils.FALSE, + ]) + + and: + EmbeddedServer appleEmbeddedServer = ApplicationContext.run(EmbeddedServer, authServerConfig + [ + 'spec.name': 'AppleJwksCacheSpec', + ]) + and: + EmbeddedServer embeddedServer = ApplicationContext.run(EmbeddedServer, [ + 'micronaut.http.client.read-timeout': '30s', + 'micronaut.security.token.jwt.signatures.jwks.apple.url': "http://localhost:${appleEmbeddedServer.port}/keys", + 'micronaut.security.token.jwt.signatures.jwks.apple.cache-expiration': 5, + 'micronaut.security.token.jwt.signatures.jwks.google.url': "http://localhost:${googleEmbeddedServer.port}/keys", + 'micronaut.security.token.jwt.signatures.jwks.google.cache-expiration': 5, + 'micronaut.security.token.jwt.signatures.jwks.cognito.url': "http://localhost:${cognitoEmbeddedServer.port}/keys", + 'micronaut.security.token.jwt.signatures.jwks.cognito.cache-expiration': 5, + 'spec.name': 'JwksCacheSpec' + ]) + HttpClient googleHttpClient = embeddedServer.applicationContext.createBean(HttpClient, googleEmbeddedServer.URL) + BlockingHttpClient googleClient = googleHttpClient.toBlocking() HttpClient cognitoHttpClient = embeddedServer.applicationContext.createBean(HttpClient, cognitoEmbeddedServer.URL) BlockingHttpClient cognitoClient = cognitoHttpClient.toBlocking() - + HttpClient appleHttpClient = embeddedServer.applicationContext.createBean(HttpClient, appleEmbeddedServer.URL) + BlockingHttpClient appleClient = appleHttpClient.toBlocking() HttpClient httpClient = embeddedServer.applicationContext.createBean(HttpClient, embeddedServer.URL) BlockingHttpClient client = httpClient.toBlocking() expect: - 0 == totalInvocations() + googleEmbeddedServer.isRunning() + cognitoEmbeddedServer.isRunning() + appleEmbeddedServer.isRunning() + embeddedServer.isRunning() + 0 == totalInvocations(googleEmbeddedServer, appleEmbeddedServer, cognitoEmbeddedServer) when: BearerAccessRefreshToken googleBearerAccessRefreshToken = login(googleClient) @@ -150,27 +146,27 @@ class JwksCacheSpec extends Specification { cognitoBearerAccessRefreshToken.accessToken and: - 0 == totalInvocations() + 0 == totalInvocations(googleEmbeddedServer, appleEmbeddedServer, cognitoEmbeddedServer) when: - int oldInvocations = totalInvocations() + int oldInvocations = totalInvocations(googleEmbeddedServer, appleEmbeddedServer, cognitoEmbeddedServer) hello(client, googleAccessToken) hello(client, appleAccessToken) hello(client, cognitoAccessToken) then: - totalInvocations() >= (oldInvocations + 3) + totalInvocations(googleEmbeddedServer, appleEmbeddedServer, cognitoEmbeddedServer) >= (oldInvocations + 3) when: 'when you invoke it again all the keys are cached' - oldInvocations = totalInvocations() + oldInvocations = totalInvocations(googleEmbeddedServer, appleEmbeddedServer, cognitoEmbeddedServer) hello(client, appleAccessToken) then: - totalInvocations() == oldInvocations + totalInvocations(googleEmbeddedServer, appleEmbeddedServer, cognitoEmbeddedServer) == oldInvocations when: "generate new keys for cognito but with same id, other JWK sets do not match the ID, for cognito the verification key fails and a new one is fetched from the server" - oldInvocations = totalInvocations() - int invocations = cognitoInvocations() + oldInvocations = totalInvocations(googleEmbeddedServer, appleEmbeddedServer, cognitoEmbeddedServer) + int invocations = cognitoInvocations(cognitoEmbeddedServer) refresh(cognitoClient) cognitoEmbeddedServer.applicationContext.getBean(CognitoKeysController).invocations = invocations cognitoAccessToken = loginAccessToken(cognitoClient) @@ -178,12 +174,12 @@ class JwksCacheSpec extends Specification { hello(client, cognitoAccessToken) then: - totalInvocations() >= (oldInvocations + 1) + totalInvocations(googleEmbeddedServer, appleEmbeddedServer, cognitoEmbeddedServer) >= (oldInvocations + 1) when: 'generate a new JWKS with new kid, JWKS attempt to refresh' - oldInvocations = totalInvocations() + oldInvocations = totalInvocations(googleEmbeddedServer, appleEmbeddedServer, cognitoEmbeddedServer) CognitoSignatureConfiguration cognitoSignatureConfiguration = cognitoEmbeddedServer.applicationContext.getBean(CognitoSignatureConfiguration) - invocations = cognitoInvocations() + invocations = cognitoInvocations(cognitoEmbeddedServer) refresh(cognitoClient) cognitoEmbeddedServer.applicationContext.getBean(CognitoKeysController).invocations = invocations cognitoSignatureConfiguration.rotateKid() @@ -192,12 +188,12 @@ class JwksCacheSpec extends Specification { hello(client, cognitoAccessToken) then: - totalInvocations() >= (oldInvocations + 1) + totalInvocations(googleEmbeddedServer, appleEmbeddedServer, cognitoEmbeddedServer) >= (oldInvocations + 1) when: 'generate a new JWT without kid, JWKS attempt to refresh' - oldInvocations = totalInvocations() + oldInvocations = totalInvocations(googleEmbeddedServer, appleEmbeddedServer, cognitoEmbeddedServer) GoogleSignatureConfiguration googleSignatureConfiguration = googleEmbeddedServer.applicationContext.getBean(GoogleSignatureConfiguration) - invocations = googleInvocations() + invocations = googleInvocations(googleEmbeddedServer) refresh(googleClient) googleEmbeddedServer.applicationContext.getBean(GoogleKeysController).invocations = invocations googleSignatureConfiguration.clearKid() @@ -206,27 +202,33 @@ class JwksCacheSpec extends Specification { hello(client, googleAccessToken) then: - totalInvocations() >= (oldInvocations + 1) + totalInvocations(googleEmbeddedServer, appleEmbeddedServer, cognitoEmbeddedServer) >= (oldInvocations + 1) when: - oldInvocations = totalInvocations() + oldInvocations = totalInvocations(googleEmbeddedServer, appleEmbeddedServer, cognitoEmbeddedServer) String randomSignedJwt = randomSignedJwt() hello(client, randomSignedJwt, false) then: - totalInvocations() >= oldInvocations + totalInvocations(googleEmbeddedServer, appleEmbeddedServer, cognitoEmbeddedServer) >= oldInvocations when: - oldInvocations = totalInvocations() + oldInvocations = totalInvocations(googleEmbeddedServer, appleEmbeddedServer, cognitoEmbeddedServer) sleep(6_000) // cache expires the token is still invalid but JWKS attempts to refresh hello(client, randomSignedJwt, false) then: - totalInvocations() == (oldInvocations + 3) + totalInvocations(googleEmbeddedServer, appleEmbeddedServer, cognitoEmbeddedServer) == (oldInvocations + 3) + + cleanup: + embeddedServer.close() + appleEmbeddedServer.close() + cognitoEmbeddedServer.close() + googleEmbeddedServer.close() } - private int totalInvocations() { - googleInvocations() + appleInvocations() + cognitoInvocations() + private int totalInvocations(EmbeddedServer googlServer, EmbeddedServer appleServer, EmbeddedServer cognitoServer) { + googleInvocations(googlServer) + appleInvocations(appleServer) + cognitoInvocations(cognitoServer) } private void assertKeyId(JWT jwt, String keyId) { @@ -513,17 +515,16 @@ class JwksCacheSpec extends Specification { } } - - private int googleInvocations() { - googleEmbeddedServer.applicationContext.getBean(GoogleKeysController).invocations + private int googleInvocations(EmbeddedServer server) { + server.applicationContext.getBean(GoogleKeysController).invocations } - private int appleInvocations() { - appleEmbeddedServer.applicationContext.getBean(AppleKeysController).invocations + private int appleInvocations(EmbeddedServer server) { + server.applicationContext.getBean(AppleKeysController).invocations } - private int cognitoInvocations() { - cognitoEmbeddedServer.applicationContext.getBean(CognitoKeysController).invocations + private int cognitoInvocations(EmbeddedServer server) { + server.applicationContext.getBean(CognitoKeysController).invocations } private static BearerAccessRefreshToken login(BlockingHttpClient client) {