From f469a1461ec13d88aa4ceb7af9ec8f75645265e2 Mon Sep 17 00:00:00 2001 From: defyrlt Date: Sun, 14 Dec 2014 00:09:13 +0200 Subject: [PATCH] Added functionality and test. --- jwt/__init__.py | 10 +++++++--- tests/test_jwt.py | 24 ++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/jwt/__init__.py b/jwt/__init__.py index f885d2293..0eb6172da 100644 --- a/jwt/__init__.py +++ b/jwt/__init__.py @@ -215,7 +215,7 @@ def header(jwt): raise DecodeError('Invalid header encoding') -def encode(payload, key, algorithm='HS256', headers=None): +def encode(payload, key, algorithm='HS256', headers=None, json_encoder=None): segments = [] if algorithm is None: @@ -231,7 +231,9 @@ def encode(payload, key, algorithm='HS256', headers=None): if headers: header.update(headers) - json_header = json.dumps(header, separators=(',', ':')).encode('utf-8') + json_header = json.dumps(header, + separators=(',', ':'), + cls=json_encoder).encode('utf-8') segments.append(base64url_encode(json_header)) # Payload @@ -240,7 +242,9 @@ def encode(payload, key, algorithm='HS256', headers=None): if isinstance(payload.get(time_claim), datetime): payload[time_claim] = timegm(payload[time_claim].utctimetuple()) - json_payload = json.dumps(payload, separators=(',', ':')).encode('utf-8') + json_payload = json.dumps(payload, + separators=(',', ':'), + cls=json_encoder).encode('utf-8') segments.append(base64url_encode(json_payload)) # Segments diff --git a/tests/test_jwt.py b/tests/test_jwt.py index b65480091..5920a6cc6 100644 --- a/tests/test_jwt.py +++ b/tests/test_jwt.py @@ -1,9 +1,13 @@ from __future__ import unicode_literals + from calendar import timegm from datetime import datetime +from decimal import Decimal + import sys import time import unittest +import json import jwt @@ -811,6 +815,26 @@ def test_raise_exception_token_without_issuer(self): jwt.InvalidIssuer, lambda: jwt.decode(token, 'secret', issuer=issuer)) + def test_custom_json_encoder(self): + + class CustomJSONEncoder(json.JSONEncoder): + + def default(self, o): + if isinstance(o, Decimal): + return 'it worked' + return super(CustomJSONEncoder, self).default(o) + + data = { + 'some_decimal': Decimal('2.2') + } + + with self.assertRaises(TypeError): + jwt.encode(data, 'secret') + + token = jwt.encode(data, 'secret', json_encoder=CustomJSONEncoder) + payload = jwt.decode(token, 'secret') + self.assertDictEqual(payload, {'some_decimal': 'it worked'}) + if __name__ == '__main__': unittest.main()