Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor JWT decoding logic #443

Merged
merged 5 commits into from
Jan 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion auth0/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ dependencies {
implementation 'com.squareup.okhttp3:okhttp:4.9.0'
implementation 'com.squareup.okhttp3:logging-interceptor:4.9.0'
implementation 'com.google.code.gson:gson:2.8.6'
implementation 'com.auth0.android:jwtdecode:1.3.0'

testImplementation 'junit:junit:4.13.1'
testImplementation 'org.hamcrest:java-hamcrest:2.0.0.0'
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.auth0.android.authentication.storage;

import com.auth0.android.jwt.JWT;
import com.auth0.android.request.internal.Jwt;

/**
* Bridge class for decoding JWTs.
Expand All @@ -11,7 +11,7 @@ class JWTDecoder {
JWTDecoder() {
}

JWT decode(String jwt) {
return new JWT(jwt);
Jwt decode(String jwt) {
return new Jwt(jwt);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import androidx.annotation.NonNull;

import com.auth0.android.jwt.JWT;
import com.auth0.android.request.internal.Jwt;

import java.util.Calendar;
import java.util.Date;
Expand All @@ -25,7 +25,7 @@ class IdTokenVerifier {
* @param verifyOptions the verification options, like audience, issuer, algorithm.
* @throws TokenValidationException If the ID Token is null, its signing algorithm not supported, its signature invalid or one of its claim invalid.
*/
void verify(@NonNull JWT token, @NonNull IdTokenVerificationOptions verifyOptions) throws TokenValidationException {
void verify(@NonNull Jwt token, @NonNull IdTokenVerificationOptions verifyOptions) throws TokenValidationException {
verifyOptions.getSignatureVerifier().verify(token);

if (isEmpty(token.getIssuer())) {
Expand Down Expand Up @@ -69,7 +69,7 @@ void verify(@NonNull JWT token, @NonNull IdTokenVerificationOptions verifyOption
}

if (verifyOptions.getNonce() != null) {
String nonceClaim = token.getClaim(NONCE_CLAIM).asString();
String nonceClaim = token.getNonce();
if (isEmpty(nonceClaim)) {
throw new TokenValidationException("Nonce (nonce) claim must be a string present in the ID token");
}
Expand All @@ -79,7 +79,7 @@ void verify(@NonNull JWT token, @NonNull IdTokenVerificationOptions verifyOption
}

if (audience.size() > 1) {
String azpClaim = token.getClaim(AZP_CLAIM).asString();
String azpClaim = token.getAuthorizedParty();
if (isEmpty(azpClaim)) {
throw new TokenValidationException("Authorized Party (azp) claim must be a string present in the ID token when Audience (aud) claim has multiple values");
}
Expand All @@ -89,7 +89,7 @@ void verify(@NonNull JWT token, @NonNull IdTokenVerificationOptions verifyOption
}

if (verifyOptions.getMaxAge() != null) {
Date authTime = token.getClaim(AUTH_TIME_CLAIM).asDate();
Date authTime = token.getAuthenticationTime();
if (authTime == null) {
throw new TokenValidationException("Authentication Time (auth_time) claim must be a number present in the ID token when Max Age (max_age) is specified");
}
Expand Down
13 changes: 6 additions & 7 deletions auth0/src/main/java/com/auth0/android/provider/OAuthManager.kt
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ import com.auth0.android.Auth0Exception
import com.auth0.android.authentication.AuthenticationAPIClient
import com.auth0.android.authentication.AuthenticationException
import com.auth0.android.callback.Callback
import com.auth0.android.jwt.DecodeException
import com.auth0.android.jwt.JWT
import com.auth0.android.request.internal.Jwt
import com.auth0.android.result.Credentials
import java.security.SecureRandom
import java.util.*
Expand Down Expand Up @@ -134,10 +133,10 @@ internal class OAuthManager(
validationCallback.onFailure(TokenValidationException("ID token is required but missing"))
return
}
val decodedIdToken: JWT = try {
JWT(idToken!!)
} catch (ignored: DecodeException) {
validationCallback.onFailure(TokenValidationException("ID token could not be decoded"))
val decodedIdToken: Jwt = try {
Jwt(idToken!!)
} catch (error: Exception) {
validationCallback.onFailure(TokenValidationException("ID token could not be decoded", error))
return
}
val signatureVerifierCallback: Callback<SignatureVerifier, TokenValidationException> =
Expand Down Expand Up @@ -167,7 +166,7 @@ internal class OAuthManager(
}
}
}
val tokenKeyId = decodedIdToken.header["kid"]
val tokenKeyId = decodedIdToken.keyId
SignatureVerifier.forAsymmetricAlgorithm(tokenKeyId, apiClient, signatureVerifierCallback)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import com.auth0.android.authentication.AuthenticationException;
import com.auth0.android.callback.AuthenticationCallback;
import com.auth0.android.callback.Callback;
import com.auth0.android.jwt.JWT;
import com.auth0.android.request.internal.Jwt;

import java.security.InvalidKeyException;
import java.security.PublicKey;
Expand All @@ -31,12 +31,9 @@ abstract class SignatureVerifier {
* @param token the ID token to have its signature validated
* @throws TokenValidationException if the signature is not valid
*/
void verify(@NonNull JWT token) throws TokenValidationException {
String tokenAlg = token.getHeader().get("alg");
String[] tokenParts = token.toString().split("\\.");

checkAlgorithm(tokenAlg);
checkSignature(tokenParts);
void verify(@NonNull Jwt token) throws TokenValidationException {
checkAlgorithm(token.getAlgorithm());
checkSignature(token.getParts());
}

private void checkAlgorithm(String tokenAlgorithm) throws TokenValidationException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,8 @@ import com.auth0.android.Auth0Exception
/**
* Exception thrown when the validation of the ID token failed.
*/
internal class TokenValidationException(message: String) : Auth0Exception(message)
internal class TokenValidationException @JvmOverloads constructor(
message: String,
cause: Throwable? = null
) :
Auth0Exception(message, cause)
82 changes: 82 additions & 0 deletions auth0/src/main/java/com/auth0/android/request/internal/Jwt.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package com.auth0.android.request.internal

import android.util.Base64
import com.google.gson.reflect.TypeToken
import java.util.*


/**
* Internal class meant to decode the given token of type JWT and provide access to its claims.
*/
internal class Jwt(rawToken: String) {

private val decodedHeader: Map<String, Any>
private val decodedPayload: Map<String, Any>
val parts: Array<String>

// header
val algorithm: String
val keyId: String?

// payload
val subject: String?
val issuer: String?
val nonce: String?
val issuedAt: Date?
val expiresAt: Date?
val authorizedParty: String?
val authenticationTime: Date?
val audience: List<String>

init {
parts = splitToken(rawToken)
val jsonHeader = parts[0].decodeBase64()
val jsonPayload = parts[1].decodeBase64()
val mapAdapter = GsonProvider.gson.getAdapter(object : TypeToken<Map<String, Any>>() {})
decodedHeader = mapAdapter.fromJson(jsonHeader)
decodedPayload = mapAdapter.fromJson(jsonPayload)

// header claims
algorithm = decodedHeader["alg"] as String
keyId = decodedHeader["kid"] as String?

// payload claims
subject = decodedPayload["sub"] as String?
issuer = decodedPayload["iss"] as String?
nonce = decodedPayload["nonce"] as String?
issuedAt = (decodedPayload["iat"] as? Double)?.let { Date(it.toLong() * 1000) }
expiresAt = (decodedPayload["exp"] as? Double)?.let { Date(it.toLong() * 1000) }
authorizedParty = decodedPayload["azp"] as String?
authenticationTime =
(decodedPayload["auth_time"] as? Double)?.let { Date(it.toLong() * 1000) }
audience = when (val aud = decodedPayload["aud"]) {
is String -> listOf(aud)
is List<*> -> aud as List<String>
else -> emptyList()
}
}

private fun splitToken(token: String): Array<String> {
var parts = token.split(".").toTypedArray()
if (parts.size == 2 && token.endsWith(".")) {
// Tokens with alg='none' have empty String as Signature.
parts = arrayOf(parts[0], parts[1], "")
}
if (parts.size != 3) {
throw IllegalArgumentException(
String.format(
"The token was expected to have 3 parts, but got %s.",
parts.size
)
)
}
return parts
}

private fun String.decodeBase64(): String {
val bytes: ByteArray =
Base64.decode(this, Base64.URL_SAFE or Base64.NO_WRAP or Base64.NO_PADDING)
return String(bytes, Charsets.UTF_8)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ package com.auth0.android.authentication.storage
import com.auth0.android.authentication.AuthenticationAPIClient
import com.auth0.android.authentication.AuthenticationException
import com.auth0.android.callback.Callback
import com.auth0.android.jwt.JWT
import com.auth0.android.request.Request
import com.auth0.android.request.internal.Jwt
import com.auth0.android.result.Credentials
import com.auth0.android.result.CredentialsMock
import com.auth0.android.util.Clock
Expand Down Expand Up @@ -355,7 +355,7 @@ public class CredentialsManagerTest {
client.renewAuth("refreshToken")
).thenReturn(request)
val newDate = Date(CredentialsMock.ONE_HOUR_AHEAD_MS + ONE_HOUR_SECONDS * 1000)
val jwtMock = mock<JWT>()
val jwtMock = mock<Jwt>()
Mockito.`when`(jwtMock.expiresAt).thenReturn(newDate)
Mockito.`when`(jwtDecoder.decode("newId")).thenReturn(jwtMock)
manager.getCredentials("some scope", 0, callback)
Expand Down Expand Up @@ -416,7 +416,7 @@ public class CredentialsManagerTest {
client.renewAuth("refreshToken")
).thenReturn(request)
val newDate = Date(CredentialsMock.ONE_HOUR_AHEAD_MS + ONE_HOUR_SECONDS * 1000)
val jwtMock = mock<JWT>()
val jwtMock = mock<Jwt>()
Mockito.`when`(jwtMock.expiresAt).thenReturn(newDate)
Mockito.`when`(jwtDecoder.decode("newId")).thenReturn(jwtMock)
manager.getCredentials("some scope", 0, callback)
Expand Down Expand Up @@ -478,7 +478,7 @@ public class CredentialsManagerTest {
).thenReturn(request)
val newDate =
Date(CredentialsMock.CURRENT_TIME_MS + 61 * 1000) // New token expires in minTTL + 1 second
val jwtMock = mock<JWT>()
val jwtMock = mock<Jwt>()
Mockito.`when`(jwtMock.expiresAt).thenReturn(newDate)
Mockito.`when`(jwtDecoder.decode("newId")).thenReturn(jwtMock)
manager.getCredentials(null, 60, callback) // 60 seconds of minTTL
Expand Down Expand Up @@ -539,7 +539,7 @@ public class CredentialsManagerTest {
client.renewAuth("refreshToken")
).thenReturn(request)
val newDate = Date(CredentialsMock.ONE_HOUR_AHEAD_MS)
val jwtMock = mock<JWT>()
val jwtMock = mock<Jwt>()
Mockito.`when`(jwtMock.expiresAt).thenReturn(newDate)
Mockito.`when`(jwtDecoder.decode("newId")).thenReturn(jwtMock)
manager.getCredentials(callback)
Expand Down Expand Up @@ -601,7 +601,7 @@ public class CredentialsManagerTest {
).thenReturn(request)
val newDate =
Date(CredentialsMock.CURRENT_TIME_MS + 59 * 1000) // New token expires in minTTL - 1 second
val jwtMock = mock<JWT>()
val jwtMock = mock<Jwt>()
Mockito.`when`(jwtMock.expiresAt).thenReturn(newDate)
Mockito.`when`(jwtDecoder.decode("newId")).thenReturn(jwtMock)
manager.getCredentials(null, 60, callback) // 60 seconds of minTTL
Expand Down Expand Up @@ -654,7 +654,7 @@ public class CredentialsManagerTest {
client.renewAuth("refreshToken")
).thenReturn(request)
val newDate = Date(CredentialsMock.ONE_HOUR_AHEAD_MS)
val jwtMock = mock<JWT>()
val jwtMock = mock<Jwt>()
Mockito.`when`(jwtMock.expiresAt).thenReturn(newDate)
Mockito.`when`(jwtDecoder.decode("newId")).thenReturn(jwtMock)
manager.getCredentials(callback)
Expand Down Expand Up @@ -863,7 +863,7 @@ public class CredentialsManagerTest {
}

private fun prepareJwtDecoderMock(expiresAt: Date?) {
val jwtMock = mock<JWT>()
val jwtMock = mock<Jwt>()
Mockito.`when`(jwtMock.expiresAt).thenReturn(expiresAt)
Mockito.`when`(jwtDecoder.decode("idToken")).thenReturn(jwtMock)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.auth0.android.authentication.storage;

import com.auth0.android.jwt.JWT;
import com.auth0.android.request.internal.Jwt;

import org.junit.Test;
import org.junit.runner.RunWith;
Expand All @@ -14,25 +14,24 @@ public class JWTDecoderTest {

@Test
public void shouldDecodeAToken() {
String token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c";
JWT jwt1 = new JWT(token);
String token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6ImFsaWNlIn0.eyJzdWIiOiIxMjM0NTY3ODkwIiwibm9uY2UiOiJyZWFsbHkgcmFuZG9tIHRleHQiLCJpYXQiOjE1MTYyMzkwMjJ9.rYG-HEs1EKKDhwQIoEg32_p-NQzNi5rB7akqGnH_q4k";
Jwt jwt1 = new Jwt(token);

JWT jwt2 = new JWTDecoder().decode(token);
Jwt jwt2 = new JWTDecoder().decode(token);

//Header claims
assertThat(jwt1.getHeader().get("alg"), is("HS256"));
assertThat(jwt1.getHeader().get("typ"), is("JWT"));

assertThat(jwt2.getHeader().get("typ"), is("JWT"));
assertThat(jwt2.getHeader().get("alg"), is("HS256"));
assertThat(jwt1.getAlgorithm(), is("HS256"));
assertThat(jwt1.getKeyId(), is("alice"));
assertThat(jwt2.getAlgorithm(), is("HS256"));
assertThat(jwt2.getKeyId(), is("alice"));

//Payload claims
assertThat(jwt1.getSubject(), is("1234567890"));
assertThat(jwt1.getIssuedAt().getTime(), is(1516239022000L));
assertThat(jwt1.getClaim("name").asString(), is("John Doe"));
assertThat(jwt1.getNonce(), is("really random text"));

assertThat(jwt2.getSubject(), is("1234567890"));
assertThat(jwt2.getIssuedAt().getTime(), is(1516239022000L));
assertThat(jwt2.getClaim("name").asString(), is("John Doe"));
assertThat(jwt2.getNonce(), is("really random text"));
}
}
Loading