diff --git a/include/nanobind/nb_class.h b/include/nanobind/nb_class.h index 730b2811..cb957867 100644 --- a/include/nanobind/nb_class.h +++ b/include/nanobind/nb_class.h @@ -730,6 +730,17 @@ template class enum_ : public object { return *this; } + template + NB_INLINE enum_ &def_static(const char *name_, Func &&f, + const Extra &... extra) { + static_assert( + !std::is_member_function_pointer_v, + "def_static(...) called with a non-static member function pointer"); + cpp_function_def((detail::forward_t) f, scope(*this), name(name_), + extra...); + return *this; + } + template NB_INLINE enum_ &def_prop_rw(const char *name_, Getter &&getter, Setter &&setter, const Extra &...extra) { diff --git a/tests/test_enum.cpp b/tests/test_enum.cpp index 7d90e5dc..6fe5d390 100644 --- a/tests/test_enum.cpp +++ b/tests/test_enum.cpp @@ -25,7 +25,9 @@ NB_MODULE(test_enum_ext, m) { .export_values(); ce.def("get_value", [](ClassicEnum &x) { return (int) x; }) - .def_prop_ro("my_value", [](ClassicEnum &x) { return (int) x; }); + .def_prop_ro("my_value", [](ClassicEnum &x) { return (int) x; }) + .def("foo", [](ClassicEnum x) { return x; }) + .def_static("bar", [](ClassicEnum x) { return x; }); m.def("from_enum", [](Enum value) { return (uint32_t) value; }, nb::arg().noconvert()); m.def("to_enum", [](uint32_t value) { return (Enum) value; }); diff --git a/tests/test_enum.py b/tests/test_enum.py index 7456f9fa..7453164a 100644 --- a/tests/test_enum.py +++ b/tests/test_enum.py @@ -139,3 +139,5 @@ def test08_enum_comparisons(): def test09_enum_methods(): assert t.Item1.my_value == 0 and t.Item2.my_value == 1 assert t.Item1.get_value() == 0 and t.Item2.get_value() == 1 + assert t.Item1.foo() == t.Item1 + assert t.ClassicEnum.bar(t.Item1) == t.Item1 diff --git a/tests/test_enum_ext.pyi.ref b/tests/test_enum_ext.pyi.ref index e6553546..7f0a7937 100644 --- a/tests/test_enum_ext.pyi.ref +++ b/tests/test_enum_ext.pyi.ref @@ -12,6 +12,11 @@ class ClassicEnum(enum.Enum): @property def my_value(self) -> int: ... + def foo(self) -> ClassicEnum: ... + + @staticmethod + def bar(arg: ClassicEnum, /) -> ClassicEnum: ... + class Enum(enum.Enum): """enum-level docstring"""