From 778ab3e06f18e969bd64fe73525466d916e1e80d Mon Sep 17 00:00:00 2001 From: Simo Sorce Date: Fri, 20 May 2022 14:24:00 -0400 Subject: [PATCH] Introduce a new JWKeyNotFound exception This new Exception is returned only for the newly introduced support for using JWKset. This patch also includes a bugfix for jwe to be able to successfully decrypt using a JWKSet, which was non-functional, and a direct test for both JWE and JWS to insure no regressions in JWKSet support. Also restores use of JWTMissingKey for backwards compatibility. Signed-off-by: Simo Sorce --- docs/source/common.rst | 3 +++ jwcrypto/common.py | 16 ++++++++++++++ jwcrypto/jwe.py | 17 +++++++++++---- jwcrypto/jws.py | 17 +++++++++++---- jwcrypto/jwt.py | 8 ++++--- jwcrypto/tests.py | 48 ++++++++++++++++++++++++++++++++++++++---- 6 files changed, 94 insertions(+), 15 deletions(-) diff --git a/docs/source/common.rst b/docs/source/common.rst index d3395de..f9040c1 100644 --- a/docs/source/common.rst +++ b/docs/source/common.rst @@ -36,3 +36,6 @@ Exceptions .. autoclass:: jwcrypto.common.InvalidJWSERegOperation :show-inheritance: + +.. autoclass:: jwcrypto.common.JWKeyNotFound + :show-inheritance: diff --git a/jwcrypto/common.py b/jwcrypto/common.py index 6ada7ef..9db2d95 100644 --- a/jwcrypto/common.py +++ b/jwcrypto/common.py @@ -126,6 +126,22 @@ def __init__(self, message=None, exception=None): super(InvalidJWSERegOperation, self).__init__(msg) +class JWKeyNotFound(JWException): + """The key needed to complete the operation was not found. + + This exception is raised when a JWKSet is used to perform + some operation and the key required to successfully complete + the operation is not found. + """ + + def __init__(self, message=None): + if message: + msg = message + else: + msg = 'Key Not Found' + super(JWKeyNotFound, self).__init__(msg) + + # JWSE Header Registry definitions # RFC 7515 - 9.1: JSON Web Signature and Encryption Header Parameters Registry diff --git a/jwcrypto/jwe.py b/jwcrypto/jwe.py index 382639a..0567810 100644 --- a/jwcrypto/jwe.py +++ b/jwcrypto/jwe.py @@ -3,7 +3,7 @@ import zlib from jwcrypto import common -from jwcrypto.common import JWException +from jwcrypto.common import JWException, JWKeyNotFound from jwcrypto.common import JWSEHeaderParameter, JWSEHeaderRegistry from jwcrypto.common import base64url_decode, base64url_encode from jwcrypto.common import json_decode, json_encode @@ -393,8 +393,8 @@ def _decrypt(self, key, ppe): if 'kid' in self.jose_header: kid_keys = key.get_keys(self.jose_header['kid']) if not kid_keys: - raise ValueError('Key ID {} not in key set'.format( - self.jose_header['kid'])) + raise JWKeyNotFound('Key ID {} not in key set'.format( + self.jose_header['kid'])) keys = kid_keys for k in keys: @@ -404,6 +404,7 @@ def _decrypt(self, key, ppe): jh, aad, self.objects['iv'], self.objects['ciphertext'], self.objects['tag']) + self.decryptlog.append("Success") break except Exception as e: # pylint: disable=broad-except keyid = k.get('kid', k.thumbprint()) @@ -411,7 +412,7 @@ def _decrypt(self, key, ppe): keyid, repr(e))) if "Success" not in self.decryptlog: - raise ValueError('No working key found in key set') + raise JWKeyNotFound('No working key found in key set') else: data = self._unwrap_decrypt(alg, enc, key, ppe.get('encrypted_key', b''), @@ -438,25 +439,33 @@ def decrypt(self, key): :raises InvalidJWEOperation: if the key is not a JWK object. :raises InvalidJWEData: if the ciphertext can't be decrypted or the object is otherwise malformed. + :raises JWKeyNotFound: if key is a JWKSet and the key is not found. """ if 'ciphertext' not in self.objects: raise InvalidJWEOperation("No available ciphertext") self.decryptlog = [] + missingkey = False if 'recipients' in self.objects: for rec in self.objects['recipients']: try: self._decrypt(key, rec) except Exception as e: # pylint: disable=broad-except + if isinstance(e, JWKeyNotFound): + missingkey = True self.decryptlog.append('Failed: [%s]' % repr(e)) else: try: self._decrypt(key, self.objects) except Exception as e: # pylint: disable=broad-except + if isinstance(e, JWKeyNotFound): + missingkey = True self.decryptlog.append('Failed: [%s]' % repr(e)) if not self.plaintext: + if missingkey: + raise JWKeyNotFound("Key Not found in JWKSet") raise InvalidJWEData('No recipient matched the provided ' 'key' + repr(self.decryptlog)) diff --git a/jwcrypto/jws.py b/jwcrypto/jws.py index fb04518..d0e0ccb 100644 --- a/jwcrypto/jws.py +++ b/jwcrypto/jws.py @@ -1,6 +1,6 @@ # Copyright (C) 2015 JWCrypto Project Contributors - see LICENSE file -from jwcrypto.common import JWException +from jwcrypto.common import JWException, JWKeyNotFound from jwcrypto.common import JWSEHeaderParameter, JWSEHeaderRegistry from jwcrypto.common import base64url_decode, base64url_encode from jwcrypto.common import json_decode, json_encode @@ -297,8 +297,8 @@ def _verify(self, alg, key, payload, signature, protected, header=None): if 'kid' in self.jose_header: kid_keys = key.get_keys(self.jose_header['kid']) if not kid_keys: - raise ValueError('Key ID {} not in key set'.format( - self.jose_header['kid'])) + raise JWKeyNotFound('Key ID {} not in key set'.format( + self.jose_header['kid'])) keys = kid_keys for k in keys: @@ -312,7 +312,7 @@ def _verify(self, alg, key, payload, signature, protected, header=None): self.verifylog.append('Key [{}] failed: [{}]'.format( keyid, repr(e))) if "Success" not in self.verifylog: - raise ValueError('No working key found in key set') + raise JWKeyNotFound('No working key found in key set') else: raise ValueError("Unrecognized key type") @@ -341,11 +341,13 @@ def verify(self, key, alg=None, detached_payload=None): :raises InvalidJWSSignature: if the verification fails. :raises InvalidJWSOperation: if a detached_payload is provided but an object payload exists + :raises JWKeyNotFound: if key is a JWKSet and the key is not found. """ self.verifylog = [] self.objects['valid'] = False obj = self.objects + missingkey = False if 'signature' in obj: payload = self._get_obj_payload(obj, detached_payload) try: @@ -356,6 +358,8 @@ def verify(self, key, alg=None, detached_payload=None): obj.get('header', None)) obj['valid'] = True except Exception as e: # pylint: disable=broad-except + if isinstance(e, JWKeyNotFound): + missingkey = True self.verifylog.append('Failed: [%s]' % repr(e)) elif 'signatures' in obj: @@ -370,11 +374,15 @@ def verify(self, key, alg=None, detached_payload=None): # Ok if at least one verifies obj['valid'] = True except Exception as e: # pylint: disable=broad-except + if isinstance(e, JWKeyNotFound): + missingkey = True self.verifylog.append('Failed: [%s]' % repr(e)) else: raise InvalidJWSSignature('No signatures available') if not self.is_valid: + if missingkey: + raise JWKeyNotFound('No working key found in key set') raise InvalidJWSSignature('Verification failed for all ' 'signatures' + repr(self.verifylog)) @@ -423,6 +431,7 @@ def deserialize(self, raw_jws, key=None, alg=None): :raises InvalidJWSObject: if the raw object is an invalid JWS token. :raises InvalidJWSSignature: if the verification fails. + :raises JWKeyNotFound: if key is a JWKSet and the key is not found. """ self.objects = {} o = {} diff --git a/jwcrypto/jwt.py b/jwcrypto/jwt.py index 7f3339c..5e8cae7 100644 --- a/jwcrypto/jwt.py +++ b/jwcrypto/jwt.py @@ -6,7 +6,8 @@ from deprecated import deprecated -from jwcrypto.common import JWException, json_decode, json_encode +from jwcrypto.common import JWException, JWKeyNotFound +from jwcrypto.common import json_decode, json_encode from jwcrypto.jwe import JWE from jwcrypto.jws import JWS @@ -127,8 +128,7 @@ def __init__(self, message=None, exception=None): super(JWTMissingKeyID, self).__init__(msg) -@deprecated -class JWTMissingKey(JWException): +class JWTMissingKey(JWKeyNotFound): """JSON Web Token is using a key not in the key set. This exception is raised if the key that was used is not available @@ -515,6 +515,8 @@ def validate(self, key): self.deserializelog = self.token.decryptlog self.deserializelog.append( 'Validation failed: [{}]'.format(repr(e))) + if isinstance(e, JWKeyNotFound): + raise JWTMissingKey() from e raise self.header = self.token.jose_header diff --git a/jwcrypto/tests.py b/jwcrypto/tests.py index e94f327..eaa00f8 100644 --- a/jwcrypto/tests.py +++ b/jwcrypto/tests.py @@ -15,6 +15,7 @@ from jwcrypto import jws from jwcrypto import jwt from jwcrypto.common import InvalidJWSERegOperation +from jwcrypto.common import JWKeyNotFound from jwcrypto.common import JWSEHeaderParameter from jwcrypto.common import base64url_decode, base64url_encode from jwcrypto.common import json_decode, json_encode @@ -1007,6 +1008,25 @@ def test_jws_issue_281(self): self.assertEqual(header, header_copy) + def test_decrypt_keyset(self): + ks = jwk.JWKSet() + key1 = jwk.JWK.generate(kty='oct', alg='HS256', kid='key1') + key2 = jwk.JWK.generate(kty='oct', alg='HS384', kid='key2') + key3 = jwk.JWK.generate(kty='oct', alg='HS512', kid='key3') + ks.add(key1) + ks.add(key2) + s1 = jws.JWS(payload=b'secret') + s1.add_signature(key1, protected='{"alg":"HS256"}') + s2 = jws.JWS() + s2.deserialize(s1.serialize(), ks) + self.assertEqual(s2.payload, b'secret') + + s3 = jws.JWS(payload=b'secret') + s3.add_signature(key3, protected='{"alg":"HS256"}') + s4 = jws.JWS() + with self.assertRaises(JWKeyNotFound): + s4.deserialize(s3.serialize(), ks) + E_A1_plaintext = \ [84, 104, 101, 32, 116, 114, 117, 101, 32, 115, 105, 103, 110, 32, @@ -1330,6 +1350,25 @@ def test_X25519_ECDH(self): e2.deserialize(enc, x25519key) self.assertEqual(e2.payload, plaintext) + def test_decrypt_keyset(self): + ks = jwk.JWKSet() + key1 = jwk.JWK.generate(kty='oct', alg='A128KW', kid='key1') + key2 = jwk.JWK.generate(kty='oct', alg='A192KW', kid='key2') + key3 = jwk.JWK.generate(kty='oct', alg='A256KW', kid='key3') + ks.add(key1) + ks.add(key2) + e1 = jwe.JWE(plaintext=b'secret') + e1.add_recipient(key1, '{"alg":"A128KW","enc":"A128GCM"}') + e2 = jwe.JWE() + e2.deserialize(e1.serialize(), ks) + self.assertEqual(e2.payload, b'secret') + + e3 = jwe.JWE(plaintext=b'secret') + e3.add_recipient(key3, '{"alg":"A256KW","enc":"A256GCM"}') + e4 = jwe.JWE() + with self.assertRaises(JWKeyNotFound): + e4.deserialize(e3.serialize(), ks) + MMA_vector_key = jwk.JWK(**E_A2_key) MMA_vector_ok_cek = \ @@ -1500,7 +1539,7 @@ def test_decrypt_keyset(self): t.make_encrypted_token(key) token = t.serialize() # try to decrypt without a matching key - self.assertRaises(Exception, jwt.JWT, jwt=token, key=keyset, + self.assertRaises(jwt.JWTMissingKey, jwt.JWT, jwt=token, key=keyset, algs=jwe_algs_and_rsa1_5, check_claims={'exp': 1300819380}) # now decrypt with key @@ -1514,7 +1553,8 @@ def test_decrypt_keyset(self): t = jwt.JWT(header, A1_claims, algs=jwe_algs_and_rsa1_5) t.make_encrypted_token(key) token = t.serialize() - self.assertRaises(Exception, jwt.JWT, jwt=token, key=keyset) + self.assertRaises(jwt.JWTMissingKey, jwt.JWT, jwt=token, key=keyset, + algs=jwe_algs_and_rsa1_5) keyset = jwk.JWKSet.from_json(json_encode(PrivateKeys)) # encrypt a new JWT with no kid @@ -1523,7 +1563,7 @@ def test_decrypt_keyset(self): t.make_encrypted_token(key) token = t.serialize() # try to decrypt without a matching key - self.assertRaises(Exception, jwt.JWT, jwt=token, key=keyset, + self.assertRaises(jwt.JWTMissingKey, jwt.JWT, jwt=token, key=keyset, algs=jwe_algs_and_rsa1_5, check_claims={'exp': 1300819380}) # now decrypt with key @@ -1546,7 +1586,7 @@ def test_decrypt_keyset_dup_kid(self): token = t.serialize() # try to decrypt without a matching key - with self.assertRaises(Exception): + with self.assertRaises(jwt.JWTMissingKey): jwt.JWT(jwt=token, key=keyset, algs=jwe_algs_and_rsa1_5, check_claims={'exp': 1300819380})