Skip to content

Commit

Permalink
Merge pull request #3799 from lreimer/jwks-kid-jwt-support
Browse files Browse the repository at this point in the history
PAYARA-3824 Use KeyID from JWT header to find public key in JSON Web Key Set
  • Loading branch information
Alan authored Jul 5, 2019
2 parents c09d823 + ca49b72 commit 601701f
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import javax.json.JsonArray;
import javax.json.JsonObject;
import javax.json.JsonReader;
import javax.json.JsonValue;
import javax.security.enterprise.identitystore.CredentialValidationResult;
import javax.security.enterprise.identitystore.IdentityStore;
import java.io.ByteArrayInputStream;
Expand All @@ -71,8 +72,8 @@
import static org.eclipse.microprofile.jwt.config.Names.*;

/**
* Identity store capable of asserting that a signed JWT token is valid according to
* the MP-JWT 1.0 spec.
* Identity store capable of asserting that a signed JWT token is valid
* according to the MP-JWT 1.0 spec.
*
* @author Arjan Tijms
*/
Expand All @@ -82,39 +83,43 @@ public class SignedJWTIdentityStore implements IdentityStore {

private static final String RSA_ALGORITHM = "RSA";

private final JwtTokenParser jwtTokenParser;

private final String acceptedIssuer;
private final Optional<Boolean> enabledNamespace;
private final Optional<String> customNamespace;

private final Config config;

public SignedJWTIdentityStore() {
config = ConfigProvider.getConfig();

Optional<Properties> properties = readVendorProperties();
acceptedIssuer = readVendorIssuer(properties)
.orElseGet(() -> config.getOptionalValue(ISSUER, String.class)
.orElseThrow(() -> new IllegalStateException("No issuer found")));

jwtTokenParser = new JwtTokenParser(readEnabledNamespace(properties), readCustomNamespace(properties));

enabledNamespace = readEnabledNamespace(properties);
customNamespace = readCustomNamespace(properties);
}

public CredentialValidationResult validate(SignedJWTCredential signedJWTCredential) {
final JwtTokenParser jwtTokenParser = new JwtTokenParser(enabledNamespace, customNamespace);
try {
jwtTokenParser.parse(signedJWTCredential.getSignedJWT());
String keyID = jwtTokenParser.getKeyID();

Optional<PublicKey> publicKey = readPublicKeyFromLocation("/publicKey.pem");
Optional<PublicKey> publicKey = readDefaultPublicKey();
if (!publicKey.isPresent()) {
publicKey = readMPEmbeddedPublicKey();
publicKey = readMPEmbeddedPublicKey(keyID);
}
if (!publicKey.isPresent()) {
publicKey = readMPPublicKeyFromLocation();
publicKey = readMPPublicKeyFromLocation(keyID);
}
if (!publicKey.isPresent()) {
throw new IllegalStateException("No PublicKey found");
}

jwtTokenParser.parse(signedJWTCredential.getSignedJWT());
JsonWebTokenImpl jsonWebToken = jwtTokenParser.verify(acceptedIssuer, publicKey.get());
JsonWebTokenImpl jsonWebToken
= jwtTokenParser.verify(acceptedIssuer, publicKey.get());

List<String> groups = new ArrayList<>(
jsonWebToken.getClaim("groups"));
Expand All @@ -129,7 +134,7 @@ public CredentialValidationResult validate(SignedJWTCredential signedJWTCredenti

return INVALID_RESULT;
}

private Optional<Properties> readVendorProperties() {
URL mpJwtResource = currentThread().getContextClassLoader().getResource("/payara-mp-jwt.properties");
Properties properties = null;
Expand All @@ -143,28 +148,32 @@ private Optional<Properties> readVendorProperties() {
}
return Optional.ofNullable(properties);
}

private Optional<String> readVendorIssuer(Optional<Properties> properties) {
return properties.isPresent() ? Optional.ofNullable(properties.get().getProperty("accepted.issuer")) : Optional.empty();
}
private Optional<Boolean> readEnabledNamespace(Optional<Properties> properties){

private Optional<Boolean> readEnabledNamespace(Optional<Properties> properties) {
return properties.isPresent() ? Optional.ofNullable(Boolean.valueOf(properties.get().getProperty("enable.namespace", "false"))) : Optional.empty();
}
private Optional<String> readCustomNamespace(Optional<Properties> properties){

private Optional<String> readCustomNamespace(Optional<Properties> properties) {
return properties.isPresent() ? Optional.ofNullable(properties.get().getProperty("custom.namespace", null)) : Optional.empty();
}

private Optional<PublicKey> readMPEmbeddedPublicKey() throws Exception {
private Optional<PublicKey> readDefaultPublicKey() throws Exception {
return readPublicKeyFromLocation("/publicKey.pem", null);
}

private Optional<PublicKey> readMPEmbeddedPublicKey(String keyID) throws Exception {
Optional<String> key = config.getOptionalValue(VERIFIER_PUBLIC_KEY, String.class);
if (!key.isPresent()) {
return Optional.empty();
}
return createPublicKey(key.get());
return createPublicKey(key.get(), keyID);
}

private Optional<PublicKey> readMPPublicKeyFromLocation() throws Exception {
private Optional<PublicKey> readMPPublicKeyFromLocation(String keyID) throws Exception {
Optional<String> locationOpt = config.getOptionalValue(VERIFIER_PUBLIC_KEY_LOCATION, String.class);

if (!locationOpt.isPresent()) {
Expand All @@ -173,10 +182,10 @@ private Optional<PublicKey> readMPPublicKeyFromLocation() throws Exception {

String publicKeyLocation = locationOpt.get();

return readPublicKeyFromLocation(publicKeyLocation);
return readPublicKeyFromLocation(publicKeyLocation, keyID);
}

private Optional<PublicKey> readPublicKeyFromLocation(String publicKeyLocation) throws Exception {
private Optional<PublicKey> readPublicKeyFromLocation(String publicKeyLocation, String keyID) throws Exception {

URL publicKeyURL = currentThread().getContextClassLoader().getResource(publicKeyLocation);

Expand All @@ -193,17 +202,17 @@ private Optional<PublicKey> readPublicKeyFromLocation(String publicKeyLocation)

byte[] byteBuffer = new byte[16384];
try (InputStream inputStream = publicKeyURL.openStream()) {
return createPublicKey(new String(byteBuffer, 0, inputStream.read(byteBuffer)));
String key = new String(byteBuffer, 0, inputStream.read(byteBuffer));
return createPublicKey(key, keyID);
}
}


private Optional<PublicKey> createPublicKey(String key) throws Exception {
private Optional<PublicKey> createPublicKey(String key, String keyID) throws Exception {
try {
return Optional.of(createPublicKeyFromPem(key));
} catch (Exception pemEx) {
try {
return Optional.of(createPublicKeyFromJWKS(key));
return Optional.of(createPublicKeyFromJWKS(key, keyID));
} catch (Exception jwksEx) {
throw new DeploymentException(jwksEx);
}
Expand All @@ -224,10 +233,10 @@ private PublicKey createPublicKeyFromPem(String key) throws Exception {

}

private PublicKey createPublicKeyFromJWKS(String jwksValue) throws Exception {
private PublicKey createPublicKeyFromJWKS(String jwksValue, String keyID) throws Exception {
JsonObject jwks = parseJwks(jwksValue);
JsonArray keys = jwks.getJsonArray("keys");
JsonObject jwk = keys != null ? keys.getJsonObject(0) : jwks;
JsonObject jwk = keys != null ? findJwk(keys, keyID) : jwks;

// the public exponent
byte[] exponentBytes = Base64.getUrlDecoder().decode(jwk.getString("e"));
Expand All @@ -250,13 +259,26 @@ private JsonObject parseJwks(String jwksValue) throws Exception {
// if jwks is encoded
byte[] jwksDecodedValue = Base64.getDecoder().decode(jwksValue);
try (InputStream jwksStream = new ByteArrayInputStream(jwksDecodedValue);
JsonReader reader = Json.createReader(jwksStream)) {
JsonReader reader = Json.createReader(jwksStream)) {
jwks = reader.readObject();
}
}
return jwks;
}

private JsonObject findJwk(JsonArray keys, String keyID) {
if (Objects.isNull(keyID) && keys.size() > 0) {
return keys.getJsonObject(0);
}

for (JsonValue value : keys) {
JsonObject jwk = value.asJsonObject();
if (Objects.equals(keyID, jwk.getString("kid"))) {
return jwk;
}
}

throw new IllegalStateException("No matching JWK for KeyID.");
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import java.util.Map;
import java.util.Optional;

import static com.google.common.base.Preconditions.checkState;
import static com.nimbusds.jose.JWSAlgorithm.RS256;
import static java.util.Arrays.asList;
import static javax.json.Json.createObjectBuilder;
Expand All @@ -71,7 +72,7 @@ public class JwtTokenParser {

private final boolean enableNamespacedClaims;
private final Optional<String> customNamespace;

private String rawToken;
private SignedJWT signedJWT;

Expand All @@ -87,22 +88,16 @@ public JwtTokenParser() {
public void parse(String bearerToken) throws Exception {
rawToken = bearerToken;
signedJWT = SignedJWT.parse(rawToken);


// MP-JWT 1.0 4.1 typ
if (!checkIsJWT(signedJWT.getHeader())) {
throw new IllegalStateException("Not JWT");
}
}

public JsonWebTokenImpl verify(String issuer, PublicKey publicKey) throws Exception {
if (signedJWT == null) {
parse(rawToken);
}

// MP-JWT 1.0 4.1 typ
if (!checkIsJWT(signedJWT.getHeader())) {
throw new IllegalStateException("Not JWT");
}

checkState(signedJWT != null, "No parsed SignedJWT.");

// 1.0 4.1 alg + MP-JWT 1.0 6.1 1
if (!signedJWT.getHeader().getAlgorithm().equals(RS256)) {
throw new IllegalStateException("Not RS256");
Expand Down Expand Up @@ -147,6 +142,11 @@ public JsonWebTokenImpl verify(String issuer, PublicKey publicKey) throws Except
}
}

public String getKeyID() {
checkState(signedJWT != null, "No parsed SignedJWT.");
return signedJWT.getHeader().getKeyID();
}

private Map<String, JsonValue> handleNamespacedClaims(Map<String, JsonValue> currentClaims){
if(this.enableNamespacedClaims){
final String namespace = customNamespace.orElse(DEFAULT_NAMESPACE);
Expand Down

0 comments on commit 601701f

Please sign in to comment.