diff --git a/include/nanobind/nb_cast.h b/include/nanobind/nb_cast.h index e646098c..b8875fe5 100644 --- a/include/nanobind/nb_cast.h +++ b/include/nanobind/nb_cast.h @@ -338,8 +338,13 @@ T cast(const detail::api &value, bool convert = true) { if constexpr (std::is_same_v) { return; } else { - using Ti = detail::intrinsic_t; - using Caster = detail::make_caster; + using Caster = detail::make_caster; + using Output = typename Caster::template Cast; + + static_assert( + !(std::is_reference_v || std::is_pointer_v) || Caster::IsClass || + std::is_same_v, + "nanobind::cast(): cannot return a reference to a temporary."); Caster caster; if (!caster.from_python(value.derived().ptr(), @@ -347,18 +352,7 @@ T cast(const detail::api &value, bool convert = true) { : (uint8_t) 0, nullptr)) detail::raise_cast_error(); - if constexpr (std::is_same_v) { - return caster.operator const char *(); - } else { - static_assert( - !(std::is_reference_v || std::is_pointer_v) || Caster::IsClass, - "nanobind::cast(): cannot return a reference to a temporary."); - - if constexpr (detail::is_pointer_v) - return caster.operator Ti*(); - else - return caster.operator Ti&(); - } + return caster.operator Output(); } } diff --git a/tests/test_functions.cpp b/tests/test_functions.cpp index 7028dae2..0534ecb9 100644 --- a/tests/test_functions.cpp +++ b/tests/test_functions.cpp @@ -212,4 +212,12 @@ NB_MODULE(test_functions_ext, m) { return nb::cpp_function(callback); }); + + m.def("test_cast_char", [](nb::handle h) { + return nb::cast(h); + }); + + m.def("test_cast_str", [](nb::handle h) { + return nb::cast(h); + }); } diff --git a/tests/test_functions.py b/tests/test_functions.py index 3ee08dfc..2f9d195b 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -307,3 +307,16 @@ def test34_module_docstring(): def test35_return_capture(): x = t.test_35() assert x() == 'Test Foo' + +def test36_test_char(): + assert t.test_cast_char('c') == 'c' + with pytest.raises(TypeError): + assert t.test_cast_char('abc') + with pytest.raises(RuntimeError): + assert t.test_cast_char(123) + +def test37_test_str(): + assert t.test_cast_str('c') == 'c' + assert t.test_cast_str('abc') == 'abc' + with pytest.raises(RuntimeError): + assert t.test_cast_str(123)