Skip to content

Commit

Permalink
Enable ECDSA algorithms supported by PyJWT (#520)
Browse files Browse the repository at this point in the history
* Parameterize some tests to reduce duplication and make it easy to add more algorithms

This way new algorithms can be added to the basic test set simply by
adding their backends to TestTokenBackend.backends.

* Enable ECDSA algorithms supported by PyJWT

Enable the algorithms and add basic tests for them.

Also convert the ALLOWED_ALGORITHMS constant to a set for a minor style
cleanup.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
vainu-arto and pre-commit-ci[bot] authored Jan 28, 2022
1 parent 72dd1a5 commit 92124cf
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 122 deletions.
7 changes: 5 additions & 2 deletions rest_framework_simplejwt/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@
from .exceptions import TokenBackendError
from .utils import format_lazy

ALLOWED_ALGORITHMS = (
ALLOWED_ALGORITHMS = {
"HS256",
"HS384",
"HS512",
"RS256",
"RS384",
"RS512",
)
"ES256",
"ES384",
"ES512",
}


class TokenBackend:
Expand Down
15 changes: 15 additions & 0 deletions tests/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,18 @@
E01hmaHk9xlOpo73IjUxhXUCAwEAAQ==
-----END PUBLIC KEY-----
"""

ES256_PRIVATE_KEY = """
-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIMtBPxiLHcJCrAGdz4jHvTtAh6Rw7351AckG3whXq2WOoAoGCCqGSM49
AwEHoUQDQgAEMZHyNxbkr7+zqQ1dQk/zug2pwYdztmjhpC+XqK88q5NfIS1cBYYt
zhHUS4vGpazNqbW8HA3ZIvJRmx4L96O6/w==
-----END EC PRIVATE KEY-----
"""

ES256_PUBLIC_KEY = """
-----BEGIN PUBLIC KEY-----
MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEMZHyNxbkr7+zqQ1dQk/zug2pwYdz
tmjhpC+XqK88q5NfIS1cBYYtzhHUS4vGpazNqbW8HA3ZIvJRmx4L96O6/w==
-----END PUBLIC KEY-----
"""
221 changes: 101 additions & 120 deletions tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,14 @@
from rest_framework_simplejwt.backends import TokenBackend
from rest_framework_simplejwt.exceptions import TokenBackendError
from rest_framework_simplejwt.utils import aware_utcnow, datetime_to_epoch, make_utc
from tests.keys import PRIVATE_KEY, PRIVATE_KEY_2, PUBLIC_KEY, PUBLIC_KEY_2
from tests.keys import (
ES256_PRIVATE_KEY,
ES256_PUBLIC_KEY,
PRIVATE_KEY,
PRIVATE_KEY_2,
PUBLIC_KEY,
PUBLIC_KEY_2,
)

SECRET = "not_secret"

Expand All @@ -31,6 +38,13 @@ def setUp(self):
"RS256", PRIVATE_KEY, PUBLIC_KEY, AUDIENCE, ISSUER
)
self.payload = {"foo": "bar"}
self.backends = (
self.hmac_token_backend,
self.rsa_token_backend,
TokenBackend("ES256", ES256_PRIVATE_KEY, ES256_PUBLIC_KEY),
TokenBackend("ES384", ES256_PRIVATE_KEY, ES256_PUBLIC_KEY),
TokenBackend("ES512", ES256_PRIVATE_KEY, ES256_PUBLIC_KEY),
)

def test_init(self):
# Should reject unknown algorithms
Expand All @@ -41,18 +55,12 @@ def test_init(self):

@patch.object(algorithms, "has_crypto", new=False)
def test_init_fails_for_rs_algorithms_when_crypto_not_installed(self):
with self.assertRaisesRegex(
TokenBackendError, "You must have cryptography installed to use RS256."
):
TokenBackend("RS256", "not_secret")
with self.assertRaisesRegex(
TokenBackendError, "You must have cryptography installed to use RS384."
):
TokenBackend("RS384", "not_secret")
with self.assertRaisesRegex(
TokenBackendError, "You must have cryptography installed to use RS512."
):
TokenBackend("RS512", "not_secret")
for algo in ("RS256", "RS384", "RS512", "ES256"):
with self.assertRaisesRegex(
TokenBackendError,
f"You must have cryptography installed to use {algo}.",
):
TokenBackend(algo, "not_secret")

def test_encode_hmac(self):
# Should return a JSON web token for the given payload
Expand Down Expand Up @@ -113,127 +121,100 @@ def test_encode_aud_iss(self):
),
)

def test_decode_hmac_with_no_expiry(self):
no_exp_token = jwt.encode(self.payload, SECRET, algorithm="HS256")
def test_decode_with_no_expiry(self):
for backend in self.backends:
with self.subTest("Test decode with no expiry for f{backend.algorithm}"):
no_exp_token = jwt.encode(
self.payload, backend.signing_key, algorithm=backend.algorithm
)

self.hmac_token_backend.decode(no_exp_token)
backend.decode(no_exp_token)

def test_decode_hmac_with_no_expiry_no_verify(self):
no_exp_token = jwt.encode(self.payload, SECRET, algorithm="HS256")
def test_decode_with_no_expiry_no_verify(self):
for backend in self.backends:
with self.subTest(
"Test decode with no expiry and no verify for f{backend.algorithm}"
):
no_exp_token = jwt.encode(
self.payload, backend.signing_key, algorithm=backend.algorithm
)

self.assertEqual(
self.hmac_token_backend.decode(no_exp_token, verify=False),
self.payload,
)
self.assertEqual(
backend.decode(no_exp_token, verify=False),
self.payload,
)

def test_decode_hmac_with_expiry(self):
def test_decode_with_expiry(self):
self.payload["exp"] = aware_utcnow() - timedelta(seconds=1)
for backend in self.backends:
with self.subTest("Test decode with expiry for f{backend.algorithm}"):

expired_token = jwt.encode(self.payload, SECRET, algorithm="HS256")

with self.assertRaises(TokenBackendError):
self.hmac_token_backend.decode(expired_token)

def test_decode_hmac_with_invalid_sig(self):
self.payload["exp"] = aware_utcnow() + timedelta(days=1)
token_1 = jwt.encode(self.payload, SECRET, algorithm="HS256")
self.payload["foo"] = "baz"
token_2 = jwt.encode(self.payload, SECRET, algorithm="HS256")

token_2_payload = token_2.rsplit(".", 1)[0]
token_1_sig = token_1.rsplit(".", 1)[-1]
invalid_token = token_2_payload + "." + token_1_sig

with self.assertRaises(TokenBackendError):
self.hmac_token_backend.decode(invalid_token)
expired_token = jwt.encode(
self.payload, backend.signing_key, algorithm=backend.algorithm
)

def test_decode_hmac_with_invalid_sig_no_verify(self):
self.payload["exp"] = aware_utcnow() + timedelta(days=1)
token_1 = jwt.encode(self.payload, SECRET, algorithm="HS256")
self.payload["foo"] = "baz"
token_2 = jwt.encode(self.payload, SECRET, algorithm="HS256")
# Payload copied
self.payload["exp"] = datetime_to_epoch(self.payload["exp"])

token_2_payload = token_2.rsplit(".", 1)[0]
token_1_sig = token_1.rsplit(".", 1)[-1]
invalid_token = token_2_payload + "." + token_1_sig

self.assertEqual(
self.hmac_token_backend.decode(invalid_token, verify=False),
self.payload,
)

def test_decode_hmac_success(self):
self.payload["exp"] = aware_utcnow() + timedelta(days=1)
self.payload["foo"] = "baz"

token = jwt.encode(self.payload, SECRET, algorithm="HS256")
# Payload copied
self.payload["exp"] = datetime_to_epoch(self.payload["exp"])
with self.assertRaises(TokenBackendError):
backend.decode(expired_token)

self.assertEqual(self.hmac_token_backend.decode(token), self.payload)

def test_decode_rsa_with_no_expiry(self):
no_exp_token = jwt.encode(self.payload, PRIVATE_KEY, algorithm="RS256")

self.rsa_token_backend.decode(no_exp_token)

def test_decode_rsa_with_no_expiry_no_verify(self):
no_exp_token = jwt.encode(self.payload, PRIVATE_KEY, algorithm="RS256")

self.assertEqual(
self.hmac_token_backend.decode(no_exp_token, verify=False),
self.payload,
)

def test_decode_rsa_with_expiry(self):
def test_decode_with_invalid_sig(self):
self.payload["exp"] = aware_utcnow() - timedelta(seconds=1)

expired_token = jwt.encode(self.payload, PRIVATE_KEY, algorithm="RS256")

with self.assertRaises(TokenBackendError):
self.rsa_token_backend.decode(expired_token)

def test_decode_rsa_with_invalid_sig(self):
self.payload["exp"] = aware_utcnow() + timedelta(days=1)
token_1 = jwt.encode(self.payload, PRIVATE_KEY, algorithm="RS256")
self.payload["foo"] = "baz"
token_2 = jwt.encode(self.payload, PRIVATE_KEY, algorithm="RS256")

token_2_payload = token_2.rsplit(".", 1)[0]
token_1_sig = token_1.rsplit(".", 1)[-1]
invalid_token = token_2_payload + "." + token_1_sig

with self.assertRaises(TokenBackendError):
self.rsa_token_backend.decode(invalid_token)

def test_decode_rsa_with_invalid_sig_no_verify(self):
for backend in self.backends:
with self.subTest("Test decode with invalid sig for f{backend.algorithm}"):
payload = self.payload.copy()
payload["exp"] = aware_utcnow() + timedelta(days=1)
token_1 = jwt.encode(
payload, backend.signing_key, algorithm=backend.algorithm
)
payload["foo"] = "baz"
token_2 = jwt.encode(
payload, backend.signing_key, algorithm=backend.algorithm
)

token_2_payload = token_2.rsplit(".", 1)[0]
token_1_sig = token_1.rsplit(".", 1)[-1]
invalid_token = token_2_payload + "." + token_1_sig

with self.assertRaises(TokenBackendError):
backend.decode(invalid_token)

def test_decode_with_invalid_sig_no_verify(self):
self.payload["exp"] = aware_utcnow() + timedelta(days=1)
token_1 = jwt.encode(self.payload, PRIVATE_KEY, algorithm="RS256")
self.payload["foo"] = "baz"
token_2 = jwt.encode(self.payload, PRIVATE_KEY, algorithm="RS256")

token_2_payload = token_2.rsplit(".", 1)[0]
token_1_sig = token_1.rsplit(".", 1)[-1]
invalid_token = token_2_payload + "." + token_1_sig
# Payload copied
self.payload["exp"] = datetime_to_epoch(self.payload["exp"])

self.assertEqual(
self.hmac_token_backend.decode(invalid_token, verify=False),
self.payload,
)

def test_decode_rsa_success(self):
for backend in self.backends:
with self.subTest("Test decode with invalid sig for f{backend.algorithm}"):
payload = self.payload.copy()
token_1 = jwt.encode(
payload, backend.signing_key, algorithm=backend.algorithm
)
payload["foo"] = "baz"
token_2 = jwt.encode(
payload, backend.signing_key, algorithm=backend.algorithm
)
# Payload copied
payload["exp"] = datetime_to_epoch(payload["exp"])

token_2_payload = token_2.rsplit(".", 1)[0]
token_1_sig = token_1.rsplit(".", 1)[-1]
invalid_token = token_2_payload + "." + token_1_sig

self.assertEqual(
backend.decode(invalid_token, verify=False),
payload,
)

def test_decode_success(self):
self.payload["exp"] = aware_utcnow() + timedelta(days=1)
self.payload["foo"] = "baz"
for backend in self.backends:
with self.subTest("Test decode success for f{backend.algorithm}"):

token = jwt.encode(self.payload, PRIVATE_KEY, algorithm="RS256")
# Payload copied
self.payload["exp"] = datetime_to_epoch(self.payload["exp"])
token = jwt.encode(
self.payload, backend.signing_key, algorithm=backend.algorithm
)
# Payload copied
payload = self.payload.copy()
payload["exp"] = datetime_to_epoch(self.payload["exp"])

self.assertEqual(self.rsa_token_backend.decode(token), self.payload)
self.assertEqual(backend.decode(token), payload)

def test_decode_aud_iss_success(self):
self.payload["exp"] = aware_utcnow() + timedelta(days=1)
Expand Down

0 comments on commit 92124cf

Please sign in to comment.