From 5f25ae0eb9691fbe03a20bcb9f604277ccc1884b Mon Sep 17 00:00:00 2001 From: Wenzel Jakob Date: Mon, 9 Oct 2023 20:59:36 +0200 Subject: [PATCH] fixes issue reported in #318 also in other type casters --- include/nanobind/nb_cast.h | 5 ++++- include/nanobind/stl/detail/nb_array.h | 7 +++++-- include/nanobind/stl/detail/nb_dict.h | 29 ++++++++++++++++---------- include/nanobind/stl/detail/nb_list.h | 9 +++++--- include/nanobind/stl/detail/nb_set.h | 15 +++++++------ include/nanobind/stl/optional.h | 4 +++- include/nanobind/stl/variant.h | 6 ++---- src/nb_type.cpp | 2 +- tests/test_stl.py | 6 ++++++ 9 files changed, 54 insertions(+), 29 deletions(-) diff --git a/include/nanobind/nb_cast.h b/include/nanobind/nb_cast.h index 71da158a..7067a9cc 100644 --- a/include/nanobind/nb_cast.h +++ b/include/nanobind/nb_cast.h @@ -34,7 +34,10 @@ enum cast_flags : uint8_t { convert = (1 << 0), // Passed to the 'self' argument in a constructor call (__init__) - construct = (1 << 1) + construct = (1 << 1), + + // Don't accept 'None' Python objects in the base class caster + none_disallowed = (1 << 2), }; /** diff --git a/include/nanobind/stl/detail/nb_array.h b/include/nanobind/stl/detail/nb_array.h index 17202c5b..da743b9d 100644 --- a/include/nanobind/stl/detail/nb_array.h +++ b/include/nanobind/stl/detail/nb_array.h @@ -5,8 +5,8 @@ NAMESPACE_BEGIN(NB_NAMESPACE) NAMESPACE_BEGIN(detail) -template struct array_caster { - NB_TYPE_CASTER(Value_, const_name(NB_TYPING_LIST "[") + +template struct array_caster { + NB_TYPE_CASTER(Array, const_name(NB_TYPING_LIST "[") + make_caster::Name + const_name("]")); using Caster = make_caster; @@ -20,6 +20,9 @@ template struct array_caster { Caster caster; bool success = o != nullptr; + if (is_base_caster_v && !std::is_pointer_v) + flags |= (uint8_t) cast_flags::none_disallowed; + if (success) { for (size_t i = 0; i < Size; ++i) { if (!caster.from_python(o[i], flags, cleanup)) { diff --git a/include/nanobind/stl/detail/nb_dict.h b/include/nanobind/stl/detail/nb_dict.h index e1e9971e..0fc31d90 100644 --- a/include/nanobind/stl/detail/nb_dict.h +++ b/include/nanobind/stl/detail/nb_dict.h @@ -14,13 +14,13 @@ NAMESPACE_BEGIN(NB_NAMESPACE) NAMESPACE_BEGIN(detail) -template struct dict_caster { - NB_TYPE_CASTER(Value_, const_name(NB_TYPING_DICT "[") + make_caster::Name + - const_name(", ") + make_caster::Name + +template struct dict_caster { + NB_TYPE_CASTER(Dict, const_name(NB_TYPING_DICT "[") + make_caster::Name + + const_name(", ") + make_caster::Name + const_name("]")); using KeyCaster = make_caster; - using ElementCaster = make_caster; + using ValCaster = make_caster; bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { value.clear(); @@ -32,27 +32,34 @@ template struct dict_caster { } Py_ssize_t size = NB_LIST_GET_SIZE(items); - bool success = (size >= 0); + bool success = size >= 0; + + uint8_t flags_key = flags, flags_val = flags; + + if (is_base_caster_v && !std::is_pointer_v) + flags_key |= (uint8_t) cast_flags::none_disallowed; + if (is_base_caster_v && !std::is_pointer_v) + flags_val |= (uint8_t) cast_flags::none_disallowed; KeyCaster key_caster; - ElementCaster element_caster; + ValCaster val_caster; for (Py_ssize_t i = 0; i < size; ++i) { PyObject *item = NB_LIST_GET_ITEM(items, i); PyObject *key = NB_TUPLE_GET_ITEM(item, 0); - PyObject *element = NB_TUPLE_GET_ITEM(item, 1); + PyObject *val = NB_TUPLE_GET_ITEM(item, 1); - if (!key_caster.from_python(key, flags, cleanup)) { + if (!key_caster.from_python(key, flags_key, cleanup)) { success = false; break; } - if (!element_caster.from_python(element, flags, cleanup)) { + if (!val_caster.from_python(val, flags_val, cleanup)) { success = false; break; } value.emplace(key_caster.operator cast_t(), - element_caster.operator cast_t()); + val_caster.operator cast_t()); } Py_DECREF(items); @@ -68,7 +75,7 @@ template struct dict_caster { for (auto &item : src) { object k = steal(KeyCaster::from_cpp( forward_like(item.first), policy, cleanup)); - object e = steal(ElementCaster::from_cpp( + object e = steal(ValCaster::from_cpp( forward_like(item.second), policy, cleanup)); if (!k.is_valid() || !e.is_valid() || diff --git a/include/nanobind/stl/detail/nb_list.h b/include/nanobind/stl/detail/nb_list.h index a1c82b10..03fe1e7b 100644 --- a/include/nanobind/stl/detail/nb_list.h +++ b/include/nanobind/stl/detail/nb_list.h @@ -14,8 +14,8 @@ NAMESPACE_BEGIN(NB_NAMESPACE) NAMESPACE_BEGIN(detail) -template struct list_caster { - NB_TYPE_CASTER(Value_, const_name(NB_TYPING_LIST "[") + +template struct list_caster { + NB_TYPE_CASTER(List, const_name(NB_TYPING_LIST "[") + make_caster::Name + const_name("]")); using Caster = make_caster; @@ -32,12 +32,15 @@ template struct list_caster { value.clear(); - if constexpr (is_detected_v) + if constexpr (is_detected_v) value.reserve(size); Caster caster; bool success = o != nullptr; + if (is_base_caster_v && !std::is_pointer_v) + flags |= (uint8_t) cast_flags::none_disallowed; + for (size_t i = 0; i < size; ++i) { if (!caster.from_python(o[i], flags, cleanup)) { success = false; diff --git a/include/nanobind/stl/detail/nb_set.h b/include/nanobind/stl/detail/nb_set.h index 21fe58ce..a221c2e8 100644 --- a/include/nanobind/stl/detail/nb_set.h +++ b/include/nanobind/stl/detail/nb_set.h @@ -14,24 +14,27 @@ NAMESPACE_BEGIN(NB_NAMESPACE) NAMESPACE_BEGIN(detail) -template struct set_caster { - NB_TYPE_CASTER(Value_, const_name(NB_TYPING_SET "[") + make_caster::Name + const_name("]")); +template struct set_caster { + NB_TYPE_CASTER(Set, const_name(NB_TYPING_SET "[") + make_caster::Name + const_name("]")); - using KeyCaster = make_caster; + using Caster = make_caster; bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { value.clear(); PyObject* iter = obj_iter(src.ptr()); - if (iter == nullptr) { + if (!iter) { PyErr_Clear(); return false; } bool success = true; - KeyCaster key_caster; + Caster key_caster; PyObject *key; + if (is_base_caster_v && !std::is_pointer_v) + flags |= (uint8_t) cast_flags::none_disallowed; + while ((key = PyIter_Next(iter)) != nullptr) { success &= key_caster.from_python(key, flags, cleanup); Py_DECREF(key); @@ -59,7 +62,7 @@ template struct set_caster { if (ret.is_valid()) { for (auto& key : src) { object k = steal( - KeyCaster::from_cpp(forward_like(key), policy, cleanup)); + Caster::from_cpp(forward_like(key), policy, cleanup)); if (!k.is_valid() || PySet_Add(ret.ptr(), k.ptr()) != 0) { ret.reset(); diff --git a/include/nanobind/stl/optional.h b/include/nanobind/stl/optional.h index 8a10296f..af35c100 100644 --- a/include/nanobind/stl/optional.h +++ b/include/nanobind/stl/optional.h @@ -29,8 +29,10 @@ struct type_caster> { type_caster() : value(std::nullopt) { } bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) noexcept { - if (src.is_none()) + if (src.is_none()) { + value = std::nullopt; return true; + } Caster caster; if (!caster.from_python(src, flags, cleanup)) diff --git a/include/nanobind/stl/variant.h b/include/nanobind/stl/variant.h index 9dc7cb77..91f5ad6a 100644 --- a/include/nanobind/stl/variant.h +++ b/include/nanobind/stl/variant.h @@ -52,10 +52,8 @@ template struct type_caster> { "type caster was registered to intercept this particular " "type, which is not allowed."); - if constexpr (!std::is_pointer_v && is_base_caster_v) { - if (src.is_none()) - return false; - } + if (is_base_caster_v && !std::is_pointer_v) + flags |= (uint8_t) cast_flags::none_disallowed; CasterT caster; diff --git a/src/nb_type.cpp b/src/nb_type.cpp index de3f4e80..c99fb245 100644 --- a/src/nb_type.cpp +++ b/src/nb_type.cpp @@ -1013,7 +1013,7 @@ bool nb_type_get(const std::type_info *cpp_type, PyObject *src, uint8_t flags, // Convert None -> nullptr if (src == Py_None) { *out = nullptr; - return true; + return (flags & (uint8_t) cast_flags::none_disallowed) == 0; } PyTypeObject *src_type = Py_TYPE(src); diff --git a/tests/test_stl.py b/tests/test_stl.py index 933d526f..c4547760 100644 --- a/tests/test_stl.py +++ b/tests/test_stl.py @@ -824,3 +824,9 @@ def test69_complex_array(): def test70_vec_char(): assert isinstance(t.vector_str("123"), str) assert isinstance(t.vector_str(["123", "345"]), list) + +def test71_null_input(): + with pytest.raises(TypeError): + t.vec_movable_in_value([None]) + with pytest.raises(TypeError): + t.map_copyable_in_value({'a': None})