Skip to content

Commit

Permalink
Cache pub/pri keys on retrieval
Browse files Browse the repository at this point in the history
Pyca rightfully performs consistency checks when importing keys and
these operations are rather expensive. So cache keys once generated so
that repeated uses of the same JWK do not incur undue cost of reloading
the keys from scratch for each subsequent operation.

with a simple test by hand:
$ python
>>> from jwcrypto import jwk
>>> def test():
...     key = jwk.JWK.generate(kty='RSA', size=2048)
...     for i in range(1000):
...             k = key._get_private_key()
...
>>> import timeit

Before the patch:
>>> print(timeit.timeit("test()", setup="from __main__ import test", number=10))
35.80328264506534

After the patch:
>>> print(timeit.timeit("test()", setup="from __main__ import test", number=10))
0.9109518649056554

Resolves #243

Signed-off-by: Simo Sorce <simo@redhat.com>
  • Loading branch information
simo5 committed Nov 30, 2021
1 parent 3ba7408 commit ddff0b0
Showing 1 changed file with 73 additions and 24 deletions.
97 changes: 73 additions & 24 deletions jwcrypto/jwk.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,8 @@ def __init__(self, **kwargs):
are provided.
"""
super(JWK, self).__init__()
self._cache_pub_k = None
self._cache_pri_k = None

if 'generate' in kwargs:
self.generate_key(**kwargs)
Expand Down Expand Up @@ -485,6 +487,8 @@ def _import_pyca_pub_okp(self, key, **params):
def import_key(self, **kwargs):
newkey = {}
key_vals = 0
self._cache_pub_k = None
self._cache_pri_k = None

names = list(kwargs.keys())

Expand Down Expand Up @@ -730,57 +734,93 @@ def _check_constraints(self, usage, operation):
def _decode_int(self, n):
return int(hexlify(base64url_decode(n)), 16)

def _rsa_pub(self):
def _rsa_pub_n(self):
e = self._decode_int(self.get('e'))
n = self._decode_int(self.get('n'))
return rsa.RSAPublicNumbers(e, n)

def _rsa_pri(self):
def _rsa_pri_n(self):
p = self._decode_int(self.get('p'))
q = self._decode_int(self.get('q'))
d = self._decode_int(self.get('d'))
dp = self._decode_int(self.get('dp'))
dq = self._decode_int(self.get('dq'))
qi = self._decode_int(self.get('qi'))
return rsa.RSAPrivateNumbers(p, q, d, dp, dq, qi, self._rsa_pub())
return rsa.RSAPrivateNumbers(p, q, d, dp, dq, qi, self._rsa_pub_n())

def _ec_pub(self, curve):
def _rsa_pub(self):
k = self._cache_pub_k
if k is None:
k = self._rsa_pub_n().public_key(default_backend())
self._cache_pub_k = k
return k

def _rsa_pri(self):
k = self._cache_pri_k
if k is None:
k = self._rsa_pri_n().private_key(default_backend())
self._cache_pri_k = k
return k

def _ec_pub_n(self, curve):
x = self._decode_int(self.get('x'))
y = self._decode_int(self.get('y'))
return ec.EllipticCurvePublicNumbers(x, y, self.get_curve(curve))

def _ec_pri(self, curve):
def _ec_pri_n(self, curve):
d = self._decode_int(self.get('d'))
return ec.EllipticCurvePrivateNumbers(d, self._ec_pub(curve))
return ec.EllipticCurvePrivateNumbers(d, self._ec_pub_n(curve))

def _ec_pub(self, curve):
k = self._cache_pub_k
if k is None:
k = self._ec_pub_n(curve).public_key(default_backend())
self._cache_pub_k = k
return k

def _ec_pri(self, curve):
k = self._cache_pri_k
if k is None:
k = self._ec_pri_n(curve).private_key(default_backend())
self._cache_pri_k = k
return k

def _okp_pub(self):
crv = self.get('crv')
try:
pubkey = _OKP_CURVES_TABLE[crv].pubkey
except KeyError as e:
raise InvalidJWKValue('Unknown curve "%s"' % crv) from e
k = self._cache_pub_k
if k is None:
crv = self.get('crv')
try:
pubkey = _OKP_CURVES_TABLE[crv].pubkey
except KeyError as e:
raise InvalidJWKValue('Unknown curve "%s"' % crv) from e

x = base64url_decode(self.get('x'))
return pubkey.from_public_bytes(x)
x = base64url_decode(self.get('x'))
k = pubkey.from_public_bytes(x)
self._cache_pub_k = k
return k

def _okp_pri(self):
crv = self.get('crv')
try:
privkey = _OKP_CURVES_TABLE[crv].privkey
except KeyError as e:
raise InvalidJWKValue('Unknown curve "%s"' % crv) from e
k = self._cache_pri_k
if k is None:
crv = self.get('crv')
try:
privkey = _OKP_CURVES_TABLE[crv].privkey
except KeyError as e:
raise InvalidJWKValue('Unknown curve "%s"' % crv) from e

d = base64url_decode(self.get('d'))
return privkey.from_private_bytes(d)
d = base64url_decode(self.get('d'))
k = privkey.from_private_bytes(d)
self._cache_pri_k = k
return k

def _get_public_key(self, arg=None):
ktype = self.get('kty')
if ktype == 'oct':
return self.get('k')
elif ktype == 'RSA':
return self._rsa_pub().public_key(default_backend())
return self._rsa_pub()
elif ktype == 'EC':
return self._ec_pub(arg).public_key(default_backend())
return self._ec_pub(arg)
elif ktype == 'OKP':
return self._okp_pub()
else:
Expand All @@ -791,9 +831,9 @@ def _get_private_key(self, arg=None):
if ktype == 'oct':
return self.get('k')
elif ktype == 'RSA':
return self._rsa_pri().private_key(default_backend())
return self._rsa_pri()
elif ktype == 'EC':
return self._ec_pri(arg).private_key(default_backend())
return self._ec_pri(arg)
elif ktype == 'OKP':
return self._okp_pri()
else:
Expand Down Expand Up @@ -969,6 +1009,9 @@ def __setitem__(self, item, value):

# Check if item is a key value and verify its format
if item in list(JWKValuesRegistry[kty].keys()):
# Invalidate cached keys if any
self._cache_pub_k = None
self._cache_pri_k = None
if JWKValuesRegistry[kty][item].type == ParmType.b64:
try:
v = base64url_decode(value)
Expand Down Expand Up @@ -1028,6 +1071,12 @@ def __delitem__(self, item):
if self.get(name) is not None:
raise KeyError("Cannot remove 'kty', values present")

kty = self.get('kty')
if kty is not None and item in list(JWKValuesRegistry[kty].keys()):
# Invalidate cached keys if any
self._cache_pub_k = None
self._cache_pri_k = None

super(JWK, self).__delitem__(item)

def __eq__(self, other):
Expand Down

0 comments on commit ddff0b0

Please sign in to comment.