diff --git a/jwcrypto/jwa.py b/jwcrypto/jwa.py index cf1879f..e0ce824 100644 --- a/jwcrypto/jwa.py +++ b/jwcrypto/jwa.py @@ -608,13 +608,25 @@ def _get_key(self, alg, key, p2s, p2c): return JWK(kty="oct", use="enc", k=base64url_encode(rk)) def wrap(self, key, bitsize, cek, headers): - p2s = _randombits(128) - p2c = 8192 + ret_header = {} + if 'p2s' in headers: + p2s = base64url_decode(headers['p2s']) + if len(p2s) < 8: + raise ValueError('Invalid Salt, must be 8 or more octects') + else: + p2s = _randombits(128) + ret_header['p2s'] = base64url_encode(p2s) + if 'p2c' in headers: + p2c = headers['p2c'] + else: + p2c = 8192 + ret_header['p2c'] = p2c kek = self._get_key(headers['alg'], key, p2s, p2c) aeskw = self.aeskwmap[self.keysize]() ret = aeskw.wrap(kek, bitsize, cek, headers) - ret['header'] = {'p2s': base64url_encode(p2s), 'p2c': p2c} + if len(ret_header) > 0: + ret['header'] = ret_header return ret def unwrap(self, key, bitsize, ek, headers): diff --git a/jwcrypto/tests.py b/jwcrypto/tests.py index 10cfe49..fa62af1 100644 --- a/jwcrypto/tests.py +++ b/jwcrypto/tests.py @@ -1788,6 +1788,28 @@ def test_pbes2_hs256_aeskw(self): check.decrypt(key) self.assertEqual(check.payload, b'plain') + def test_pbes2_hs256_aeskw_custom_params(self): + enc = jwe.JWE(plaintext='plain', + protected={"alg": "PBES2-HS256+A128KW", + "enc": "A256CBC-HS512", + "p2c": 4096, + "p2s": base64url_encode("A" * 16)}) + key = jwk.JWK.from_password('password') + enc.add_recipient(key) + o = enc.serialize() + check = jwe.JWE() + check.deserialize(o) + check.decrypt(key) + self.assertEqual(check.payload, b'plain') + + enc = jwe.JWE(plaintext='plain', + protected={"alg": "PBES2-HS256+A128KW", + "enc": "A256CBC-HS512", + "p2c": 4096, + "p2s": base64url_encode("A" * 7)}) + key = jwk.JWK.from_password('password') + self.assertRaises(ValueError, enc.add_recipient, key) + class JWATests(unittest.TestCase): def test_jwa_create(self):