Skip to content

Commit

Permalink
Introduce a new JWKeyNotFound exception
Browse files Browse the repository at this point in the history
This new Exception is returned only for the newly introduced support for
using JWKset.
This patch also includes a bugfix for jwe to be able to successfully
decrypt using a JWKSet, which was non-functional, and a direct test for
both JWE and JWS to insure no regressions in JWKSet support.

Also restores use of JWTMissingKey for backwards compatibility.

Signed-off-by: Simo Sorce <simo@redhat.com>
  • Loading branch information
simo5 committed May 21, 2022
1 parent 60fc7ee commit 997b900
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 15 deletions.
3 changes: 3 additions & 0 deletions docs/source/common.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,6 @@ Exceptions

.. autoclass:: jwcrypto.common.InvalidJWSERegOperation
:show-inheritance:

.. autoclass:: jwcrypto.common.JWKeyNotFound
:show-inheritance:
16 changes: 16 additions & 0 deletions jwcrypto/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,22 @@ def __init__(self, message=None, exception=None):
super(InvalidJWSERegOperation, self).__init__(msg)


class JWKeyNotFound(JWException):
"""The key needed to complete the operation was not found.
This exception is raised when a JWKSet is used to perform
some operation and the key required to successfully complete
the operation is not found.
"""

def __init__(self, message=None):
if message:
msg = message
else:
msg = 'Key Not Found'
super(JWKeyNotFound, self).__init__(msg)


# JWSE Header Registry definitions

# RFC 7515 - 9.1: JSON Web Signature and Encryption Header Parameters Registry
Expand Down
17 changes: 13 additions & 4 deletions jwcrypto/jwe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import zlib

from jwcrypto import common
from jwcrypto.common import JWException
from jwcrypto.common import JWException, JWKeyNotFound
from jwcrypto.common import JWSEHeaderParameter, JWSEHeaderRegistry
from jwcrypto.common import base64url_decode, base64url_encode
from jwcrypto.common import json_decode, json_encode
Expand Down Expand Up @@ -393,8 +393,8 @@ def _decrypt(self, key, ppe):
if 'kid' in self.jose_header:
kid_keys = key.get_keys(self.jose_header['kid'])
if not kid_keys:
raise ValueError('Key ID {} not in key set'.format(
self.jose_header['kid']))
raise JWKeyNotFound('Key ID {} not in key set'.format(
self.jose_header['kid']))
keys = kid_keys

for k in keys:
Expand All @@ -404,14 +404,15 @@ def _decrypt(self, key, ppe):
jh, aad, self.objects['iv'],
self.objects['ciphertext'],
self.objects['tag'])
self.decryptlog.append("Success")
break
except Exception as e: # pylint: disable=broad-except
keyid = k.get('kid', k.thumbprint())
self.decryptlog.append('Key [{}] failed: [{}]'.format(
keyid, repr(e)))

if "Success" not in self.decryptlog:
raise ValueError('No working key found in key set')
raise JWKeyNotFound('No working key found in key set')
else:
data = self._unwrap_decrypt(alg, enc, key,
ppe.get('encrypted_key', b''),
Expand All @@ -438,25 +439,33 @@ def decrypt(self, key):
:raises InvalidJWEOperation: if the key is not a JWK object.
:raises InvalidJWEData: if the ciphertext can't be decrypted or
the object is otherwise malformed.
:raises JWKeyNotFound: if key is a JWKSet and the key is not found.
"""

if 'ciphertext' not in self.objects:
raise InvalidJWEOperation("No available ciphertext")
self.decryptlog = []
missingkey = False

if 'recipients' in self.objects:
for rec in self.objects['recipients']:
try:
self._decrypt(key, rec)
except Exception as e: # pylint: disable=broad-except
if isinstance(e, JWKeyNotFound):
missingkey = True
self.decryptlog.append('Failed: [%s]' % repr(e))
else:
try:
self._decrypt(key, self.objects)
except Exception as e: # pylint: disable=broad-except
if isinstance(e, JWKeyNotFound):
missingkey = True
self.decryptlog.append('Failed: [%s]' % repr(e))

if not self.plaintext:
if missingkey:
raise JWKeyNotFound("Key Not found in JWKSet")
raise InvalidJWEData('No recipient matched the provided '
'key' + repr(self.decryptlog))

Expand Down
17 changes: 13 additions & 4 deletions jwcrypto/jws.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (C) 2015 JWCrypto Project Contributors - see LICENSE file

from jwcrypto.common import JWException
from jwcrypto.common import JWException, JWKeyNotFound
from jwcrypto.common import JWSEHeaderParameter, JWSEHeaderRegistry
from jwcrypto.common import base64url_decode, base64url_encode
from jwcrypto.common import json_decode, json_encode
Expand Down Expand Up @@ -297,8 +297,8 @@ def _verify(self, alg, key, payload, signature, protected, header=None):
if 'kid' in self.jose_header:
kid_keys = key.get_keys(self.jose_header['kid'])
if not kid_keys:
raise ValueError('Key ID {} not in key set'.format(
self.jose_header['kid']))
raise JWKeyNotFound('Key ID {} not in key set'.format(
self.jose_header['kid']))
keys = kid_keys

for k in keys:
Expand All @@ -312,7 +312,7 @@ def _verify(self, alg, key, payload, signature, protected, header=None):
self.verifylog.append('Key [{}] failed: [{}]'.format(
keyid, repr(e)))
if "Success" not in self.verifylog:
raise ValueError('No working key found in key set')
raise JWKeyNotFound('No working key found in key set')
else:
raise ValueError("Unrecognized key type")

Expand Down Expand Up @@ -341,11 +341,13 @@ def verify(self, key, alg=None, detached_payload=None):
:raises InvalidJWSSignature: if the verification fails.
:raises InvalidJWSOperation: if a detached_payload is provided but
an object payload exists
:raises JWKeyNotFound: if key is a JWKSet and the key is not found.
"""

self.verifylog = []
self.objects['valid'] = False
obj = self.objects
missingkey = False
if 'signature' in obj:
payload = self._get_obj_payload(obj, detached_payload)
try:
Expand All @@ -356,6 +358,8 @@ def verify(self, key, alg=None, detached_payload=None):
obj.get('header', None))
obj['valid'] = True
except Exception as e: # pylint: disable=broad-except
if isinstance(e, JWKeyNotFound):
missingkey = True
self.verifylog.append('Failed: [%s]' % repr(e))

elif 'signatures' in obj:
Expand All @@ -370,11 +374,15 @@ def verify(self, key, alg=None, detached_payload=None):
# Ok if at least one verifies
obj['valid'] = True
except Exception as e: # pylint: disable=broad-except
if isinstance(e, JWKeyNotFound):
missingkey = True
self.verifylog.append('Failed: [%s]' % repr(e))
else:
raise InvalidJWSSignature('No signatures available')

if not self.is_valid:
if missingkey:
raise JWKeyNotFound('No working key found in key set')
raise InvalidJWSSignature('Verification failed for all '
'signatures' + repr(self.verifylog))

Expand Down Expand Up @@ -423,6 +431,7 @@ def deserialize(self, raw_jws, key=None, alg=None):
:raises InvalidJWSObject: if the raw object is an invalid JWS token.
:raises InvalidJWSSignature: if the verification fails.
:raises JWKeyNotFound: if key is a JWKSet and the key is not found.
"""
self.objects = {}
o = {}
Expand Down
8 changes: 5 additions & 3 deletions jwcrypto/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from deprecated import deprecated

from jwcrypto.common import JWException, json_decode, json_encode
from jwcrypto.common import JWException, JWKeyNotFound
from jwcrypto.common import json_decode, json_encode
from jwcrypto.jwe import JWE
from jwcrypto.jws import JWS

Expand Down Expand Up @@ -127,8 +128,7 @@ def __init__(self, message=None, exception=None):
super(JWTMissingKeyID, self).__init__(msg)


@deprecated
class JWTMissingKey(JWException):
class JWTMissingKey(JWKeyNotFound):
"""JSON Web Token is using a key not in the key set.
This exception is raised if the key that was used is not available
Expand Down Expand Up @@ -515,6 +515,8 @@ def validate(self, key):
self.deserializelog = self.token.decryptlog
self.deserializelog.append(
'Validation failed: [{}]'.format(repr(e)))
if isinstance(e, JWKeyNotFound):
raise JWTMissingKey() from e
raise

self.header = self.token.jose_header
Expand Down
48 changes: 44 additions & 4 deletions jwcrypto/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from jwcrypto import jws
from jwcrypto import jwt
from jwcrypto.common import InvalidJWSERegOperation
from jwcrypto.common import JWKeyNotFound
from jwcrypto.common import JWSEHeaderParameter
from jwcrypto.common import base64url_decode, base64url_encode
from jwcrypto.common import json_decode, json_encode
Expand Down Expand Up @@ -1007,6 +1008,25 @@ def test_jws_issue_281(self):

self.assertEqual(header, header_copy)

def test_decrypt_keyset(self):
ks = jwk.JWKSet()
key1 = jwk.JWK.generate(kty='oct', alg='HS256', kid='key1')
key2 = jwk.JWK.generate(kty='oct', alg='HS384', kid='key2')
key3 = jwk.JWK.generate(kty='oct', alg='HS512', kid='key3')
ks.add(key1)
ks.add(key2)
s1 = jws.JWS(payload=b'secret')
s1.add_signature(key1, protected='{"alg":"HS256"}')
s2 = jws.JWS()
s2.deserialize(s1.serialize(), ks)
self.assertEqual(s2.payload, b'secret')

s3 = jws.JWS(payload=b'secret')
s3.add_signature(key3, protected='{"alg":"HS256"}')
s4 = jws.JWS()
with self.assertRaises(JWKeyNotFound):
s4.deserialize(s3.serialize(), ks)


E_A1_plaintext = \
[84, 104, 101, 32, 116, 114, 117, 101, 32, 115, 105, 103, 110, 32,
Expand Down Expand Up @@ -1330,6 +1350,25 @@ def test_X25519_ECDH(self):
e2.deserialize(enc, x25519key)
self.assertEqual(e2.payload, plaintext)

def test_decrypt_keyset(self):
ks = jwk.JWKSet()
key1 = jwk.JWK.generate(kty='oct', alg='A128KW', kid='key1')
key2 = jwk.JWK.generate(kty='oct', alg='A192KW', kid='key2')
key3 = jwk.JWK.generate(kty='oct', alg='A256KW', kid='key3')
ks.add(key1)
ks.add(key2)
e1 = jwe.JWE(plaintext=b'secret')
e1.add_recipient(key1, '{"alg":"A128KW","enc":"A128GCM"}')
e2 = jwe.JWE()
e2.deserialize(e1.serialize(), ks)
self.assertEqual(e2.payload, b'secret')

e3 = jwe.JWE(plaintext=b'secret')
e3.add_recipient(key3, '{"alg":"A256KW","enc":"A256GCM"}')
e4 = jwe.JWE()
with self.assertRaises(JWKeyNotFound):
e4.deserialize(e3.serialize(), ks)


MMA_vector_key = jwk.JWK(**E_A2_key)
MMA_vector_ok_cek = \
Expand Down Expand Up @@ -1500,7 +1539,7 @@ def test_decrypt_keyset(self):
t.make_encrypted_token(key)
token = t.serialize()
# try to decrypt without a matching key
self.assertRaises(Exception, jwt.JWT, jwt=token, key=keyset,
self.assertRaises(jwt.JWTMissingKey, jwt.JWT, jwt=token, key=keyset,
algs=jwe_algs_and_rsa1_5,
check_claims={'exp': 1300819380})
# now decrypt with key
Expand All @@ -1514,7 +1553,8 @@ def test_decrypt_keyset(self):
t = jwt.JWT(header, A1_claims, algs=jwe_algs_and_rsa1_5)
t.make_encrypted_token(key)
token = t.serialize()
self.assertRaises(Exception, jwt.JWT, jwt=token, key=keyset)
self.assertRaises(jwt.JWTMissingKey, jwt.JWT, jwt=token, key=keyset,
algs=jwe_algs_and_rsa1_5)

keyset = jwk.JWKSet.from_json(json_encode(PrivateKeys))
# encrypt a new JWT with no kid
Expand All @@ -1523,7 +1563,7 @@ def test_decrypt_keyset(self):
t.make_encrypted_token(key)
token = t.serialize()
# try to decrypt without a matching key
self.assertRaises(Exception, jwt.JWT, jwt=token, key=keyset,
self.assertRaises(jwt.JWTMissingKey, jwt.JWT, jwt=token, key=keyset,
algs=jwe_algs_and_rsa1_5,
check_claims={'exp': 1300819380})
# now decrypt with key
Expand All @@ -1546,7 +1586,7 @@ def test_decrypt_keyset_dup_kid(self):
token = t.serialize()

# try to decrypt without a matching key
with self.assertRaises(Exception):
with self.assertRaises(jwt.JWTMissingKey):
jwt.JWT(jwt=token, key=keyset, algs=jwe_algs_and_rsa1_5,
check_claims={'exp': 1300819380})

Expand Down

0 comments on commit 997b900

Please sign in to comment.