Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add float_hook to json decoder #511

Merged
merged 1 commit into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 57 additions & 5 deletions msgspec/_core.c
Original file line number Diff line number Diff line change
Expand Up @@ -11092,6 +11092,22 @@ parse_number_nonfinite(
return ms_post_decode_float(val, type, path, strict, true);
}

static MS_NOINLINE PyObject *
json_float_hook(
const char *buf, Py_ssize_t size, PathNode *path, PyObject *float_hook
) {
PyObject *str = PyUnicode_New(size, 127);
if (str == NULL) return NULL;
memcpy(ascii_get_buffer(str), buf, size);
PyObject *out = CALL_ONE_ARG(float_hook, str);
Py_DECREF(str);
if (out == NULL) {
ms_maybe_wrap_validation_error(path);
return NULL;
}
return out;
}

static MS_INLINE PyObject *
parse_number_inline(
const unsigned char *p,
Expand All @@ -11101,6 +11117,7 @@ parse_number_inline(
TypeNode *type,
PathNode *path,
bool strict,
PyObject *float_hook,
bool from_str
) {
uint64_t mantissa = 0;
Expand Down Expand Up @@ -11286,6 +11303,9 @@ parse_number_inline(
(char *)start, p - start, true, path, NULL
);
}
else if (MS_UNLIKELY(float_hook != NULL && type->types & MS_TYPE_ANY)) {
return json_float_hook((char *)start, p - start, path, float_hook);
}
else {
if (MS_UNLIKELY(exponent > 288 || exponent < -307)) {
/* Exponent is out of bounds */
Expand Down Expand Up @@ -11363,6 +11383,7 @@ maybe_parse_number(
type,
path,
strict,
NULL,
true
);
return (*out != NULL || errmsg == NULL);
Expand Down Expand Up @@ -15403,6 +15424,7 @@ typedef struct JSONDecoderState {
/* Configuration */
TypeNode *type;
PyObject *dec_hook;
PyObject *float_hook;
bool strict;

/* Temporary scratch space */
Expand All @@ -15425,10 +15447,11 @@ typedef struct JSONDecoder {
TypeNode *type;
char strict;
PyObject *dec_hook;
PyObject *float_hook;
} JSONDecoder;

PyDoc_STRVAR(JSONDecoder__doc__,
"Decoder(type='Any', *, strict=True, dec_hook=None)\n"
"Decoder(type='Any', *, strict=True, dec_hook=None, float_hook=None)\n"
"--\n"
"\n"
"A JSON decoder.\n"
Expand All @@ -15449,19 +15472,28 @@ PyDoc_STRVAR(JSONDecoder__doc__,
" signature ``dec_hook(type: Type, obj: Any) -> Any``, where ``type`` is the\n"
" expected message type, and ``obj`` is the decoded representation composed\n"
" of only basic JSON types. This hook should transform ``obj`` into type\n"
" ``type``, or raise a ``NotImplementedError`` if unsupported."
" ``type``, or raise a ``NotImplementedError`` if unsupported.\n"
"float_hook : callable, optional\n"
" An optional callback for handling decoding untyped float literals. Should\n"
" have the signature ``float_hook(val: str) -> Any``, where ``val`` is the\n"
" raw string value of the JSON float. This hook is called to decode any\n"
" \"untyped\" float value (e.g. ``typing.Any`` typed). The default is\n"
" equivalent to ``float_hook=float``, where all untyped JSON floats are\n"
" decoded as python floats. Specifying ``float_hook=decimal.Decimal``\n"
" will decode all untyped JSON floats as decimals instead."
);
static int
JSONDecoder_init(JSONDecoder *self, PyObject *args, PyObject *kwds)
{
char *kwlist[] = {"type", "strict", "dec_hook", NULL};
char *kwlist[] = {"type", "strict", "dec_hook", "float_hook", NULL};
MsgspecState *st = msgspec_get_global_state();
PyObject *type = st->typing_any;
PyObject *dec_hook = NULL;
PyObject *float_hook = NULL;
int strict = 1;

if (!PyArg_ParseTupleAndKeywords(
args, kwds, "|O$pO", kwlist, &type, &strict, &dec_hook)
args, kwds, "|O$pOO", kwlist, &type, &strict, &dec_hook, &float_hook)
) {
return -1;
}
Expand All @@ -15479,6 +15511,19 @@ JSONDecoder_init(JSONDecoder *self, PyObject *args, PyObject *kwds)
}
self->dec_hook = dec_hook;

/* Handle float_hook */
if (float_hook == Py_None) {
float_hook = NULL;
}
if (float_hook != NULL) {
if (!PyCallable_Check(float_hook)) {
PyErr_SetString(PyExc_TypeError, "float_hook must be callable");
return -1;
}
Py_INCREF(float_hook);
}
self->float_hook = float_hook;

/* Handle strict */
self->strict = strict;

Expand All @@ -15498,6 +15543,7 @@ JSONDecoder_traverse(JSONDecoder *self, visitproc visit, void *arg)
if (out != 0) return out;
Py_VISIT(self->orig_type);
Py_VISIT(self->dec_hook);
Py_VISIT(self->float_hook);
return 0;
}

Expand All @@ -15508,6 +15554,7 @@ JSONDecoder_dealloc(JSONDecoder *self)
TypeNode_Free(self->type);
Py_XDECREF(self->orig_type);
Py_XDECREF(self->dec_hook);
Py_XDECREF(self->float_hook);
Py_TYPE(self)->tp_free((PyObject *)self);
}

Expand Down Expand Up @@ -17551,7 +17598,7 @@ json_maybe_decode_number(JSONDecoderState *self, TypeNode *type, PathNode *path)
PyObject *out = parse_number_inline(
self->input_pos, self->input_end,
&pout, &errmsg,
type, path, self->strict, false
type, path, self->strict, self->float_hook, false
);
self->input_pos = (unsigned char *)pout;

Expand Down Expand Up @@ -18014,6 +18061,7 @@ msgspec_json_format(PyObject *self, PyObject *args, PyObject *kwargs)

/* Init decoder */
dec.dec_hook = NULL;
dec.float_hook = NULL;
dec.type = NULL;
dec.scratch = NULL;
dec.scratch_capacity = 0;
Expand Down Expand Up @@ -18095,6 +18143,7 @@ JSONDecoder_decode(JSONDecoder *self, PyObject *const *args, Py_ssize_t nargs)
.type = self->type,
.strict = self->strict,
.dec_hook = self->dec_hook,
.float_hook = self->float_hook,
.scratch = NULL,
.scratch_capacity = 0,
.scratch_len = 0
Expand Down Expand Up @@ -18161,6 +18210,7 @@ JSONDecoder_decode_lines(JSONDecoder *self, PyObject *const *args, Py_ssize_t na
.type = self->type,
.strict = self->strict,
.dec_hook = self->dec_hook,
.float_hook = self->float_hook,
.scratch = NULL,
.scratch_capacity = 0,
.scratch_len = 0
Expand Down Expand Up @@ -18237,6 +18287,7 @@ static PyMemberDef JSONDecoder_members[] = {
{"type", T_OBJECT_EX, offsetof(JSONDecoder, orig_type), READONLY, "The Decoder type"},
{"strict", T_BOOL, offsetof(JSONDecoder, strict), READONLY, "The Decoder strict setting"},
{"dec_hook", T_OBJECT, offsetof(JSONDecoder, dec_hook), READONLY, "The Decoder dec_hook"},
{"float_hook", T_OBJECT, offsetof(JSONDecoder, float_hook), READONLY, "The Decoder float_hook"},
{NULL},
};

Expand Down Expand Up @@ -18334,6 +18385,7 @@ msgspec_json_decode(PyObject *self, PyObject *const *args, Py_ssize_t nargs, PyO
JSONDecoderState state = {
.strict = strict,
.dec_hook = dec_hook,
.float_hook = NULL,
.scratch = NULL,
.scratch_capacity = 0,
.scratch_len = 0
Expand Down
5 changes: 5 additions & 0 deletions msgspec/json.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ T = TypeVar("T")

enc_hook_sig = Optional[Callable[[Any], Any]]
dec_hook_sig = Optional[Callable[[type, Any], Any]]
float_hook_sig = Optional[Callable[[str], Any]]

class Encoder:
enc_hook: enc_hook_sig
Expand All @@ -41,13 +42,15 @@ class Decoder(Generic[T]):
type: Type[T]
strict: bool
dec_hook: dec_hook_sig
float_hook: float_hook_sig

@overload
def __init__(
self: Decoder[Any],
*,
strict: bool = True,
dec_hook: dec_hook_sig = None,
float_hook: float_hook_sig = None,
) -> None: ...
@overload
def __init__(
Expand All @@ -56,6 +59,7 @@ class Decoder(Generic[T]):
*,
strict: bool = True,
dec_hook: dec_hook_sig = None,
float_hook: float_hook_sig = None,
) -> None: ...
@overload
def __init__(
Expand All @@ -64,6 +68,7 @@ class Decoder(Generic[T]):
*,
strict: bool = True,
dec_hook: dec_hook_sig = None,
float_hook: float_hook_sig = None,
) -> None: ...
def decode(self, data: Union[bytes, str]) -> T: ...
def decode_lines(self, data: Union[bytes, str]) -> list[T]: ...
Expand Down
9 changes: 9 additions & 0 deletions tests/basic_typing_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

import datetime
import decimal
import pickle
from typing import Any, Dict, Final, List, Type, Union

Expand Down Expand Up @@ -826,6 +827,14 @@ def dec_hook(typ: Type, obj: Any) -> Any:
msgspec.json.Decoder(dec_hook=dec_hook)


def check_json_Decoder_float_hook() -> None:
msgspec.json.Decoder(float_hook=None)
msgspec.json.Decoder(float_hook=float)
dec = msgspec.json.Decoder(float_hook=decimal.Decimal)
if dec.float_hook is not None:
dec.float_hook("1.5")


def check_json_Decoder_strict() -> None:
dec = msgspec.json.Decoder(List[int], strict=False)
reveal_type(dec.strict) # assert "bool" in typ
Expand Down
50 changes: 50 additions & 0 deletions tests/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import base64
import datetime
import decimal
import enum
import gc
import itertools
Expand Down Expand Up @@ -524,6 +525,19 @@ def test_decode_lines_bad_call(self):
with pytest.raises(TypeError):
dec.decode(1)

def test_decoder_init_float_hook(self):
dec = msgspec.json.Decoder()
assert dec.float_hook is None

dec = msgspec.json.Decoder(float_hook=None)
assert dec.float_hook is None

dec = msgspec.json.Decoder(float_hook=decimal.Decimal)
assert dec.float_hook is decimal.Decimal

with pytest.raises(TypeError):
dec = msgspec.json.Decoder(float_hook=1)


class TestBoolAndNone:
def test_encode_none(self):
Expand Down Expand Up @@ -1567,6 +1581,42 @@ def test_decode_float_err_expected_int(self, s):
):
msgspec.json.decode(s, type=int)

def test_float_hook_untyped(self):
dec = msgspec.json.Decoder(float_hook=decimal.Decimal)
res = dec.decode(b"1.33")
assert res == decimal.Decimal("1.33")
assert type(res) is decimal.Decimal

def test_float_hook_typed(self):
class Ex(msgspec.Struct):
a: float
b: decimal.Decimal
c: Any
d: Any

class MyFloat(NamedTuple):
x: str

dec = msgspec.json.Decoder(Ex, float_hook=MyFloat)
res = dec.decode(b'{"a": 1.5, "b": 1.3, "c": 1.3, "d": 123}')
sol = Ex(1.5, decimal.Decimal("1.3"), MyFloat("1.3"), 123)
assert res == sol

def test_float_hook_error(self):
def float_hook(val):
raise ValueError("Oh no!")

class Ex(msgspec.Struct):
a: float
b: Any

dec = msgspec.json.Decoder(Ex, float_hook=float_hook)
assert dec.decode(b'{"a": 1.5, "b": 2}') == Ex(a=1.5, b=2)
with pytest.raises(msgspec.ValidationError) as rec:
dec.decode(b'{"a": 1.5, "b": 2.5}')
assert "Oh no!" in str(rec.value)
assert "at `$.b`" in str(rec.value)


class TestDecimal:
"""Most decimal tests are in test_common.py, the ones here are for json
Expand Down