Skip to content

Commit

Permalink
Add support for Flag enums
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthias Wittgen authored and Matthias Wittgen committed May 28, 2024
1 parent 4ed5fdf commit 972971f
Show file tree
Hide file tree
Showing 9 changed files with 88 additions and 8 deletions.
8 changes: 8 additions & 0 deletions docs/api_core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1926,6 +1926,14 @@ The following annotations can be specified using the variable-length
mixed enum types (such as ``Shape.Circle + Color.Red``) are
permissible.

.. cpp:struct:: is_flag_enum

Indicate that the enumeration may be used with bitwise
operations. This enables the bitwise operators ``| & ^ ~``
with two enumeration as operands.
The result will an enumeration of the same type.
So ``Shape(2) | Shape(1) -> Shepe(3)`.
Function binding
----------------
Expand Down
4 changes: 3 additions & 1 deletion docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,11 @@ according to `SemVer <http://semver.org>`__. The following changes are
noteworthy:

* The :cpp:class:`nb::enum_\<T\>() <enum_>` binding declaration is now a
wrapper that creates either a ``enum.Enum`` or ``enum.IntEnum``-derived type.
wrapper that creates either a ``enum.Enum``, ``enum.IntEnum`` or ``enum.Flag``-derived type.
Previously, nanobind relied on a custom enumeration base class that was a
frequent source of friction for users.
A new flag :cpp:class:`nb::is_flag_enum() <is_flag_enum>`
creates a ``enum.Flag``-derived type.

This change may break code that casts entries to integers, which now only
works for arithmetic (``enum.IntEnum``-derived) enumerations. Replace
Expand Down
8 changes: 6 additions & 2 deletions docs/classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,12 @@ C++11-style strongly typed enumerations.

When the annotation :cpp:class:`nb::is_arithmetic() <is_arithmetic>` is
passed to :cpp:class:`nb::enum_\<T\> <enum_>`, the resulting Python type
will support arithmetic and bit-level operations like comparisons, and, or,
xor, negation, etc.
will support arithmetic and bit-level operations and, or,
xor, negation.
Passing the annotation :cpp:class:`nb::is_flag_enum() <is_flag_enum>` to
to :cpp:class:`nb::enum_\<T\> <enum_>`, will result in a Python type
`enum.Flags`, that supports bit operations without losing their `Flag` membership.


.. code-block:: cpp
Expand Down
1 change: 1 addition & 0 deletions include/nanobind/nb_attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ struct is_method {};
struct is_implicit {};
struct is_operator {};
struct is_arithmetic {};
struct is_flag_enum {};
struct is_final {};
struct is_generic {};
struct kw_only {};
Expand Down
11 changes: 9 additions & 2 deletions include/nanobind/nb_class.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,12 @@ enum class type_flags : uint32_t {
is_arithmetic = (1 << 16),

/// Is the number type underlying the enumeration signed?
is_signed = (1 << 17)
is_signed = (1 << 17),

// One more flag bits available (18) without needing
/// Is the underlying enumeration type Flag?
is_flag_enum = (1 << 18)

// No more flag bits available (18). Needs
// a larger reorganization
};

Expand Down Expand Up @@ -201,6 +204,10 @@ NB_INLINE void enum_extra_apply(enum_init_data &e, is_arithmetic) {
e.flags |= (uint32_t) type_flags::is_arithmetic;
}

NB_INLINE void enum_extra_apply(enum_init_data &e, is_flag_enum) {
e.flags |= (uint32_t) type_flags::is_flag_enum;
}

NB_INLINE void enum_extra_apply(enum_init_data &e, const char *doc) {
e.docstr = doc;
}
Expand Down
10 changes: 7 additions & 3 deletions src/nb_enum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ PyObject *enum_create(enum_init_data *ed) noexcept {
handle scope(ed->scope);

bool is_arithmetic = ed->flags & (uint32_t) type_flags::is_arithmetic;

bool is_flag_enum = ed->flags & (uint32_t) type_flags::is_flag_enum;
str name(ed->name), qualname = name;
object modname;

Expand All @@ -43,7 +43,7 @@ PyObject *enum_create(enum_init_data *ed) noexcept {
PyUnicode_FromFormat("%U.%U", scope_qualname.ptr(), name.ptr()));
}

const char *factory_name = is_arithmetic ? "IntEnum" : "Enum";
const char *factory_name = (is_arithmetic || is_flag_enum) ? (is_flag_enum ? "Flag" : "IntEnum") : "Enum";

object enum_mod = module_::import_("enum"),
factory = enum_mod.attr(factory_name),
Expand All @@ -56,7 +56,11 @@ PyObject *enum_create(enum_init_data *ed) noexcept {

if (is_arithmetic)
result.attr("__str__") = enum_mod.attr("Enum").attr("__str__");

if (is_flag_enum)
result.attr("__str__") = enum_mod.attr("Flag").attr("__str__");
#if PY_VERSION_HEX >= 0x030B0000
result.attr("_boundary_") = enum_mod.attr("FlagBoundary").attr("KEEP");
#endif
result.attr("__repr__") = result.attr("__str__");

type_init_data *t = new type_init_data();
Expand Down
14 changes: 14 additions & 0 deletions tests/test_enum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
namespace nb = nanobind;

enum class Enum : uint32_t { A, B, C = (uint32_t) -1 };
enum class Flag : uint32_t { A = 1, B = 2, C = 4};
enum class SEnum : int32_t { A, B, C = (int32_t) -1 };
enum ClassicEnum { Item1, Item2 };

Expand All @@ -14,6 +15,12 @@ NB_MODULE(test_enum_ext, m) {
.value("B", Enum::B, "Value B")
.value("C", Enum::C, "Value C");

nb::enum_<Flag>(m, "Flag", "enum-level docstring", nb::is_flag_enum())
.value("A", Flag::A, "Value A")
.value("B", Flag::B, "Value B")
.value("C", Flag::C, "Value C")
.export_values();

nb::enum_<SEnum>(m, "SEnum", nb::is_arithmetic())
.value("A", SEnum::A)
.value("B", SEnum::B)
Expand All @@ -31,11 +38,18 @@ NB_MODULE(test_enum_ext, m) {

m.def("from_enum", [](Enum value) { return (uint32_t) value; }, nb::arg().noconvert());
m.def("to_enum", [](uint32_t value) { return (Enum) value; });
m.def("from_enum", [](Flag value) { return (uint32_t) value; }, nb::arg().noconvert());
m.def("to_enum", [](uint32_t value) { return (Flag) value; });
m.def("from_enum", [](SEnum value) { return (int32_t) value; }, nb::arg().noconvert());

m.def("from_enum_implicit", [](Enum value) { return (uint32_t) value; });

m.def("from_enum_default_0", [](Enum value) { return (uint32_t) value; }, nb::arg("value") = Enum::A);

m.def("from_enum_implicit", [](Flag value) { return (uint32_t) value; });

m.def("from_enum_default_0", [](Flag value) { return (uint32_t) value; }, nb::arg("value") = Enum::A);

m.def("from_enum_default_1", [](SEnum value) { return (uint32_t) value; }, nb::arg("value") = SEnum::A);

// test for issue #39
Expand Down
7 changes: 7 additions & 0 deletions tests/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,13 @@ def test08_enum_comparisons():
assert not (t.Enum.B == t.SEnum.B) and not (t.SEnum.B == t.Enum.B)
assert t.Enum.B != t.SEnum.C and t.SEnum.C != t.Enum.B

def test06_enum_flag():
assert (t.Flag(1) | t.Flag(2)).value == 3
assert (t.Flag(3) & t.Flag(1)).value == 1
assert (t.Flag(3) ^ t.Flag(1)).value == 2
assert (t.Flag(3) == (t.Flag.A | t.Flag.B))


def test09_enum_methods():
assert t.Item1.my_value == 0 and t.Item2.my_value == 1
assert t.Item1.get_value() == 0 and t.Item2.get_value() == 1
Expand Down
33 changes: 33 additions & 0 deletions tests/test_enum_ext.pyi.ref
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@ import enum
from typing import overload


A: Flag = Flag.A

B: Flag = Flag.B

C: Flag = Flag.C

class ClassicEnum(enum.Enum):
Item1 = 0

Expand Down Expand Up @@ -35,6 +41,18 @@ class EnumProperty:
@property
def read_enum(self) -> Enum: ...

class Flag(enum.Flag):
"""enum-level docstring"""

A = 1
"""Value A"""

B = 2
"""Value B"""

C = 4
"""Value C"""

Item1: ClassicEnum = ClassicEnum.Item1

Item2: ClassicEnum = ClassicEnum.Item2
Expand All @@ -49,13 +67,28 @@ class SEnum(enum.IntEnum):
@overload
def from_enum(arg: Enum) -> int: ...

@overload
def from_enum(arg: Flag) -> int: ...

@overload
def from_enum(arg: SEnum) -> int: ...

@overload
def from_enum_default_0(value: Enum = Enum.A) -> int: ...

@overload
def from_enum_default_0(value: Flag = Enum.A) -> int: ...

def from_enum_default_1(value: SEnum = SEnum.A) -> int: ...

@overload
def from_enum_implicit(arg: Enum, /) -> int: ...

@overload
def from_enum_implicit(arg: Flag, /) -> int: ...

@overload
def to_enum(arg: int, /) -> Enum: ...

@overload
def to_enum(arg: int, /) -> Flag: ...

0 comments on commit 972971f

Please sign in to comment.