Skip to content

Commit

Permalink
Added functionality and test.
Browse files Browse the repository at this point in the history
  • Loading branch information
defyrlt committed Dec 13, 2014
1 parent 139a779 commit f469a14
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 3 deletions.
10 changes: 7 additions & 3 deletions jwt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def header(jwt):
raise DecodeError('Invalid header encoding')


def encode(payload, key, algorithm='HS256', headers=None):
def encode(payload, key, algorithm='HS256', headers=None, json_encoder=None):
segments = []

if algorithm is None:
Expand All @@ -231,7 +231,9 @@ def encode(payload, key, algorithm='HS256', headers=None):
if headers:
header.update(headers)

json_header = json.dumps(header, separators=(',', ':')).encode('utf-8')
json_header = json.dumps(header,
separators=(',', ':'),
cls=json_encoder).encode('utf-8')
segments.append(base64url_encode(json_header))

# Payload
Expand All @@ -240,7 +242,9 @@ def encode(payload, key, algorithm='HS256', headers=None):
if isinstance(payload.get(time_claim), datetime):
payload[time_claim] = timegm(payload[time_claim].utctimetuple())

json_payload = json.dumps(payload, separators=(',', ':')).encode('utf-8')
json_payload = json.dumps(payload,
separators=(',', ':'),
cls=json_encoder).encode('utf-8')
segments.append(base64url_encode(json_payload))

# Segments
Expand Down
24 changes: 24 additions & 0 deletions tests/test_jwt.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from __future__ import unicode_literals

from calendar import timegm
from datetime import datetime
from decimal import Decimal

import sys
import time
import unittest
import json

import jwt

Expand Down Expand Up @@ -811,6 +815,26 @@ def test_raise_exception_token_without_issuer(self):
jwt.InvalidIssuer,
lambda: jwt.decode(token, 'secret', issuer=issuer))

def test_custom_json_encoder(self):

class CustomJSONEncoder(json.JSONEncoder):

def default(self, o):
if isinstance(o, Decimal):
return 'it worked'
return super(CustomJSONEncoder, self).default(o)

data = {
'some_decimal': Decimal('2.2')
}

with self.assertRaises(TypeError):
jwt.encode(data, 'secret')

token = jwt.encode(data, 'secret', json_encoder=CustomJSONEncoder)
payload = jwt.decode(token, 'secret')
self.assertDictEqual(payload, {'some_decimal': 'it worked'})


if __name__ == '__main__':
unittest.main()

0 comments on commit f469a14

Please sign in to comment.