Skip to content

Commit

Permalink
Refactor how EC curves are fetched
Browse files Browse the repository at this point in the history
Deprecates the get_curve() function which shouldn't really be exposed
to users as it is an internal detail.
Change tests and jwa.py to stop using get_curve()

Signed-off-by: Simo Sorce <simo@redhat.com>
  • Loading branch information
simo5 committed Dec 2, 2021
1 parent 4780e07 commit b68702d
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 38 deletions.
2 changes: 1 addition & 1 deletion jwcrypto/jwa.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,9 @@ def curve(self):

def sign(self, key, payload):
skey = key.get_op_key('sign', self._curve)
size = skey.key_size
signature = skey.sign(payload, ec.ECDSA(self.hashfn))
r, s = ec_utils.decode_dss_signature(signature)
size = key.get_curve(self._curve).key_size
return _encode_int(r, size) + _encode_int(s, size)

def verify(self, key, payload, signature):
Expand Down
75 changes: 44 additions & 31 deletions jwcrypto/jwk.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,15 @@ def from_private_bytes(cls, *args):
X448PrivateKey = UnimplementedOKPCurveKey


_OKP_CURVE = namedtuple('Name', 'pubkey privkey')
_Ed25519_CURVE = namedtuple('Ed25519', 'pubkey privkey')
_Ed448_CURVE = namedtuple('Ed448', 'pubkey privkey')
_X25519_CURVE = namedtuple('X25519', 'pubkey privkey')
_X448_CURVE = namedtuple('X448', 'pubkey privkey')
_OKP_CURVES_TABLE = {
'Ed25519': _OKP_CURVE(Ed25519PublicKey, Ed25519PrivateKey),
'Ed448': _OKP_CURVE(Ed448PublicKey, Ed448PrivateKey),
'X25519': _OKP_CURVE(X25519PublicKey, X25519PrivateKey),
'X448': _OKP_CURVE(X448PublicKey, X448PrivateKey)
'Ed25519': _Ed25519_CURVE(Ed25519PublicKey, Ed25519PrivateKey),
'Ed448': _Ed448_CURVE(Ed448PublicKey, Ed448PrivateKey),
'X25519': _X25519_CURVE(X25519PublicKey, X25519PrivateKey),
'X448': _X448_CURVE(X448PublicKey, X448PrivateKey)
}


Expand Down Expand Up @@ -395,15 +398,33 @@ def _import_pyca_pub_rsa(self, key, **params):
)
self.import_key(**params)

def _curve_name(self, name):
# P-256K is an alias for 'secp256k1' to handle compatibility
# with some implementation using this old drafting name
if name == 'P-256K':
return 'secp256k1'
return name
def _get_curve_by_name(self, name, ctype=None):
crv = self.get('crv')

def _get_curve_by_name(self, name):
cname = self._curve_name(name)
if name is None:
cname = crv
elif name == 'P-256K':
# P-256K is an alias for 'secp256k1' to handle compatibility
# with some implementation using this old drafting name
cname = 'secp256k1'
else:
cname = name

# Check we are asking for the correct curve unless this is being
# requested for generation on a blank JWK object
if crv:
ccrv = crv
if ccrv == 'P-256K':
ccrv = 'secp256k1'
if ccrv != cname:
raise InvalidJWKValue('Curve requested is "%s", but '
'key curve is "%s"' % (name, crv))
kty = self.get('kty')
if kty is not None and ctype is not None and kty != ctype:
raise InvalidJWKType('Curve Requested is of type "%s", but '
'key curve is of type "%s"' % (ctype, kty))

# Return a curve object
if cname == 'P-256':
return ec.SECP256R1()
elif cname == 'P-384':
Expand All @@ -413,9 +434,9 @@ def _get_curve_by_name(self, name):
elif cname == 'secp256k1':
return ec.SECP256K1()
elif cname in _OKP_CURVES_TABLE:
return cname
return _OKP_CURVES_TABLE[cname]
else:
raise InvalidJWKValue('Unknown Elliptic Curve Type')
raise InvalidJWKValue('Unknown Curve Name [%s]' % (name))

def _generate_EC(self, params):
curve = 'P-256'
Expand All @@ -425,8 +446,8 @@ def _generate_EC(self, params):
# precedence
if 'crv' in params:
curve = params.pop('crv')
curve_name = self._get_curve_by_name(curve)
key = ec.generate_private_key(curve_name, default_backend())
curve_fn = self._get_curve_by_name(curve, 'EC')
key = ec.generate_private_key(curve_fn, default_backend())
self._import_pyca_pri_ec(key, **params)

def _import_pyca_pri_ec(self, key, **params):
Expand Down Expand Up @@ -455,11 +476,8 @@ def _import_pyca_pub_ec(self, key, **params):
def _generate_OKP(self, params):
if 'crv' not in params:
raise InvalidJWKValue('Must specify "crv" for OKP key generation')
try:
key = _OKP_CURVES_TABLE[params['crv']].privkey.generate()
except KeyError as e:
raise InvalidJWKValue('"%s" is not a supported curve for the '
'OKP key type' % params['crv']) from e
curve_fn = self._get_curve_by_name(params['crv'], 'OKP')
key = curve_fn.privkey.generate()
self._import_pyca_pri_okp(key, **params)

def _okp_curve_from_pyca_key(self, key):
Expand Down Expand Up @@ -710,6 +728,7 @@ def key_curve(self):
raise InvalidJWKType('Not an EC or OKP key')
return self.get('crv')

@deprecated
def get_curve(self, arg):
"""Gets the Elliptic Curve associated with the key.
Expand All @@ -718,14 +737,7 @@ def get_curve(self, arg):
:raises InvalidJWKType: the key is not an EC or OKP key.
:raises InvalidJWKValue: if the curve name is invalid.
"""
crv = self.get('crv')
if self.get('kty') not in ['EC', 'OKP']:
raise InvalidJWKType('Not an EC or OKP key')
if arg and self._curve_name(crv) != self._curve_name(arg):
raise InvalidJWKValue('Curve requested is "%s", but '
'key curve is "%s"' % (arg, crv))

return self._get_curve_by_name(crv)
return self._get_curve_by_name(arg)

def _check_constraints(self, usage, operation):
use = self.get('use')
Expand Down Expand Up @@ -773,7 +785,8 @@ def _rsa_pri(self):
def _ec_pub_n(self, curve):
x = self._decode_int(self.get('x'))
y = self._decode_int(self.get('y'))
return ec.EllipticCurvePublicNumbers(x, y, self.get_curve(curve))
curve_fn = self._get_curve_by_name(curve, ctype='EC')
return ec.EllipticCurvePublicNumbers(x, y, curve_fn)

def _ec_pri_n(self, curve):
d = self._decode_int(self.get('d'))
Expand Down
12 changes: 6 additions & 6 deletions jwcrypto/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,21 +375,21 @@ def test_generate_oct_key(self):
def test_generate_EC_key(self):
# Backwards compat curve
key = jwk.JWK.generate(kty='EC', curve='P-256')
key.get_curve('P-256')
key.get_op_key('verify', 'P-256')
# New param
key = jwk.JWK.generate(kty='EC', crv='P-521')
key.get_curve('P-521')
key.get_op_key('verify', 'P-521')
# New param prevails
key = jwk.JWK.generate(kty='EC', curve='P-256', crv='P-521')
key.get_curve('P-521')
key.get_op_key('verify', 'P-521')
# New secp256k curve
key = jwk.JWK.generate(kty='EC', curve='secp256k1')
key.get_curve('secp256k1')
key.get_op_key('verify', 'secp256k1')

def test_generate_OKP_keys(self):
for crv in jwk.ImplementedOkpCurves:
key = jwk.JWK.generate(kty='OKP', crv=crv)
self.assertEqual(key.get_curve(crv), crv)
self.assertEqual(key['crv'], crv)

def test_import_pyca_keys(self):
rsa1 = rsa.generate_private_key(65537, 1024, default_backend())
Expand Down Expand Up @@ -610,7 +610,7 @@ def test_jwk_from_password(self):

def test_p256k_alias(self):
key = jwk.JWK.generate(kty='EC', curve='P-256K')
key.get_curve('secp256k1')
key.get_op_key('verify', 'secp256k1')

pub_k = jwk.JWK(**PrivateKeys_secp256k1['keys'][0])
pri_k = jwk.JWK(**PrivateKeys_secp256k1['keys'][1])
Expand Down

0 comments on commit b68702d

Please sign in to comment.