Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Flag enums #599

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/api_core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1926,6 +1926,14 @@ The following annotations can be specified using the variable-length
mixed enum types (such as ``Shape.Circle + Color.Red``) are
permissible.

.. cpp:struct:: is_flag_enum
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we give it a shorter name? is_flag? (for symmetry with the arithmetic flag, which is called is_arithmetic)


Indicate that the enumeration may be used with bitwise
operations. This enables the bitwise operators ``| & ^ ~``
with two enumeration as operands.
The result will an enumeration of the same type.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
with two enumeration as operands.
The result will an enumeration of the same type.
with two enumerators as operands.
The result will have the same enumeration type as the operands.

The enumeration is the type; the named values are called enumerators. (It is a little bit ambiguous whether 3 is properly an "enumerator" if the defined values are 1 and 2, so I suggest wording so as to sidestep that question.)

So ``Shape(2) | Shape(1) -> Shepe(3)`.
mwittgen marked this conversation as resolved.
Show resolved Hide resolved

Function binding
----------------

Expand Down
9 changes: 8 additions & 1 deletion docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ case, both modules must use the same nanobind ABI version, or they will be
isolated from each other. Releases that don't explicitly mention an ABI version
below inherit that of the preceding release.

Version 2.1.0 (TBA)
-------------------
* The :cpp:class:`nb::enum_\<T\>() <enum_>` can now create an``enum.Flag``-derived type
with flag :cpp:class:`nb::is_flag_enum() <is_flag_enum>`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* The :cpp:class:`nb::enum_\<T\>() <enum_>` can now create an``enum.Flag``-derived type
with flag :cpp:class:`nb::is_flag_enum() <is_flag_enum>`.
* nanobind now allows a :cpp:class:`nb::enum_\<T\>() <enum_>` to specify the
new class binding annotation :cpp:class:`nb::is_flag_enum() <is_flag_enum>`,
which produces an enumeration type derived from `enum.Flag`. Instances of such
an enumeration type represent a bitwise combination of the defined enumerators,
and they may be combined using bitwise operators ``& | ^ ~``.


Version 2.0.1 (TBA)
---------------------------

Expand Down Expand Up @@ -138,9 +143,11 @@ according to `SemVer <http://semver.org>`__. The following changes are
noteworthy:

* The :cpp:class:`nb::enum_\<T\>() <enum_>` binding declaration is now a
wrapper that creates either a ``enum.Enum`` or ``enum.IntEnum``-derived type.
wrapper that creates either a ``enum.Enum``, ``enum.IntEnum`` or ``enum.Flag``-derived type.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would make sense to support a matrix of 4 base classes: nb::is_arithmetic switches between enum.* and enum.Int* while nb::is_flag switches between Enum and Flag.

Previously, nanobind relied on a custom enumeration base class that was a
frequent source of friction for users.
A new flag :cpp:class:`nb::is_flag_enum() <is_flag_enum>`
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will need to go to a separate changelog entry for 2.1.0

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like you added the new changelog entry but didn't remove the old one.

creates a ``enum.Flag``-derived type.

This change may break code that casts entries to integers, which now only
works for arithmetic (``enum.IntEnum``-derived) enumerations. Replace
Expand Down
7 changes: 5 additions & 2 deletions docs/classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,11 @@ C++11-style strongly typed enumerations.

When the annotation :cpp:class:`nb::is_arithmetic() <is_arithmetic>` is
passed to :cpp:class:`nb::enum_\<T\> <enum_>`, the resulting Python type
will support arithmetic and bit-level operations like comparisons, and, or,
xor, negation, etc.
will support arithmetic and bit-level operations (and, or,
xor, negation).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
xor, negation).
xor, negation). The operands of these operations may be either enumerators
(including of two different `is_arithmetic` enumeration types) or integers, and the
result will be an integer.

Passing the annotation :cpp:class:`nb::is_flag_enum() <is_flag_enum>` to
to :cpp:class:`nb::enum_\<T\> <enum_>`, will result in a Python type
`enum.Flags`, that supports bit operations without losing their `Flag` membership.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Passing the annotation :cpp:class:`nb::is_flag_enum() <is_flag_enum>` to
to :cpp:class:`nb::enum_\<T\> <enum_>`, will result in a Python type
`enum.Flags`, that supports bit operations without losing their `Flag` membership.
When the annotation :cpp:class:`nb::is_flag_enum() <is_flag_enum>` is passed to
to :cpp:class:`nb::enum_\<T\> <enum_>`, the resulting Python type will be an
`enum.Flag`, meaning its enumerators can be combined using bitwise operators
in a type-safe way: the result will have the same enumeration type as the operands,
and only enumerators of the same type can be combined.


.. code-block:: cpp

Expand Down
1 change: 1 addition & 0 deletions include/nanobind/nb_attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ struct is_method {};
struct is_implicit {};
struct is_operator {};
struct is_arithmetic {};
struct is_flag_enum {};
struct is_final {};
struct is_generic {};
struct kw_only {};
Expand Down
11 changes: 9 additions & 2 deletions include/nanobind/nb_class.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,12 @@ enum class type_flags : uint32_t {
is_arithmetic = (1 << 16),

/// Is the number type underlying the enumeration signed?
is_signed = (1 << 17)
is_signed = (1 << 17),

// One more flag bits available (18) without needing
/// Is the underlying enumeration type Flag?
is_flag_enum = (1 << 18)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please line this up with all the others.


// No more flag bits available (18). Needs
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🥲

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we need this (as we undoubtedly will eventually), suggest splitting the type_init_flags to be in a separate field of type_init_data that is not preserved in type_data. That will free up six bits. I think some of the new type_flags might be able to move to type_init_flags also (is_generic?), which would free up further space.

// a larger reorganization
};

Expand Down Expand Up @@ -201,6 +204,10 @@ NB_INLINE void enum_extra_apply(enum_init_data &e, is_arithmetic) {
e.flags |= (uint32_t) type_flags::is_arithmetic;
}

NB_INLINE void enum_extra_apply(enum_init_data &e, is_flag_enum) {
e.flags |= (uint32_t) type_flags::is_flag_enum;
}

NB_INLINE void enum_extra_apply(enum_init_data &e, const char *doc) {
e.docstr = doc;
}
Expand Down
1 change: 1 addition & 0 deletions include/nanobind/nb_python.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include <Python.h>
#include <frameobject.h>
#include <object.h>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this include should come with Python.h automatically?

#include <pythread.h>
#include <structmember.h>

Expand Down
43 changes: 35 additions & 8 deletions src/nb_enum.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include "nb_internals.h"

#include <string>
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we avoid depending on std::string? I've tried hard not to have anything depend on it in the core library.

NAMESPACE_BEGIN(NB_NAMESPACE)
NAMESPACE_BEGIN(detail)

Expand Down Expand Up @@ -28,7 +28,7 @@ PyObject *enum_create(enum_init_data *ed) noexcept {
handle scope(ed->scope);

bool is_arithmetic = ed->flags & (uint32_t) type_flags::is_arithmetic;

bool is_flag_enum = ed->flags & (uint32_t) type_flags::is_flag_enum;
str name(ed->name), qualname = name;
object modname;

Expand All @@ -43,7 +43,7 @@ PyObject *enum_create(enum_init_data *ed) noexcept {
PyUnicode_FromFormat("%U.%U", scope_qualname.ptr(), name.ptr()));
}

const char *factory_name = is_arithmetic ? "IntEnum" : "Enum";
const char *factory_name = (is_arithmetic || is_flag_enum) ? (is_flag_enum ? "Flag" : "IntEnum") : "Enum";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be more straightforwardly written as is_flag_enum ? "Flag" : is_arithmetic ? "IntEnum" : "Enum". But I think you might want the combination of flag+arithmetic to give you IntFlag as well.


object enum_mod = module_::import_("enum"),
factory = enum_mod.attr(factory_name),
Expand All @@ -56,7 +56,11 @@ PyObject *enum_create(enum_init_data *ed) noexcept {

if (is_arithmetic)
result.attr("__str__") = enum_mod.attr("Enum").attr("__str__");

if (is_flag_enum)
result.attr("__str__") = enum_mod.attr("Flag").attr("__str__");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be IntFlag.__str__, since we use that factory above (L46)?

#if PY_VERSION_HEX >= 0x030B0000
result.attr("_boundary_") = enum_mod.attr("FlagBoundary").attr("KEEP");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This choice seems not in keeping with nanobind's general defaults of strictness around enums. I think the default STRICT policy is more appropriate. People can use is_arithmetic (producing an IntFlag) if they want more flexibility. For IntFlag, KEEP is the default.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See also my point above. We likely want the full matrix of combinations.

#endif
result.attr("__repr__") = result.attr("__str__");

type_init_data *t = new type_init_data();
Expand Down Expand Up @@ -150,6 +154,28 @@ bool enum_from_python(const std::type_info *tp, PyObject *o, int64_t *out, uint8
if (!t)
return false;

if ((t->flags & (uint32_t) type_flags::is_flag_enum) !=0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please follow the code style of the rest of this file. Spaces on both sides of binary operators, space after if (several instances of that further below), etc.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

auto base = PyObject_GetAttrString((PyObject *)o->ob_type, "__base__");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These Python API calls return new object references which you must drop (Py_DECREF) when you're done using them. Also, this is way too much effort to spend on every individual enum-value conversion from Python to C++. You should do this only if the map lookup in enum_tbl.rev fails. If this slower-path conversion succeeds, then you can add to the enum_tbl.rev map to save time when converting the same flags value in the future.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm actually wondering if it makes sense to cache these. With flags, it would seem one can generate 2^n cache entries for n bits, which has the potential of being undesirably large. Maybe flags enums are OK to go on a slow path? Or we somehow restrict the size of the cache?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAICT, they are already being cached in members of the enum type. This "singleton-izes" so that one can test for flag combinations with is.

>>> import enum
>>> class Test(enum.Flag):
...   A, B, C, D, E, F, G, H = 1, 2, 4, 8, 16, 32, 64, 128
...
>>> for i in range(256): Test(i)
...
<Test: 0>
<Test.A: 1>
<Test.B: 2>
[etc ...]
>>> len(Test._value2member_map_)
256

The existing cache takes, for each referenced combination, one dict entry (like 32 bytes I think?) plus the footprint of the enum value object itself (sys.getsizeof says 48 bytes, but that doesn't include the 4-entry instance dictionary, name string, etc -- it's probably at least 128 bytes total). We're adding one robin_map entry (24 bytes?) in each direction (C++->Py and Py->C++) for each combination that is passed across the language boundary. So the additional cost of caching is something like 30% above what we're already paying just for using enum.Flag. That seems reasonable to me.

auto basename = PyObject_GetAttrString(base, "__name__");
Py_ssize_t size;
const char* data = PyUnicode_AsUTF8AndSize(basename, &size);
std::string s(data, size);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a sufficient reason to use std::string; nanobind generally tries to keep a low footprint of standard-library includes, to avoid code bloat. Check the size, then use memcmp().

Also, I don't think this logic is correct. You should instead check whether Py_TYPE(o) -- the Python type of the incoming might-be-an-enumeration-value -- matches t->type_py -- the Python type of the nanobind enumeration you're trying to convert to. What you have here is both much slower than that, and is foolable:

class Flag:
    pass

class Dummy(Flag):
    value = 42

some_nanobind_function(Dummy())  # result is treated as flags of 42!

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

if(s == "Flag") {
auto pValue = PyObject_GetAttrString(o, "value");
if (pValue == nullptr) {
PyErr_Clear();
return false;
}
long long value = PyLong_AsLongLong(pValue);
if (value == -1 && PyErr_Occurred()) {
PyErr_Clear();
return false;
}
*out = (int64_t) value;
return true;
}
}

enum_map *rev = (enum_map *) t->enum_tbl.rev;
enum_map::iterator it = rev->find((int64_t) (uintptr_t) o);

Expand All @@ -175,7 +201,6 @@ bool enum_from_python(const std::type_info *tp, PyObject *o, int64_t *out, uint8
} else {
unsigned long long value = PyLong_AsUnsignedLongLong(o);
if (value == (unsigned long long) -1 && PyErr_Occurred()) {
PyErr_Clear();
return false;
}
enum_map::iterator it2 = fwd->find((int64_t) value);
Expand All @@ -186,17 +211,19 @@ bool enum_from_python(const std::type_info *tp, PyObject *o, int64_t *out, uint8
}

}

return false;
}

PyObject *enum_from_cpp(const std::type_info *tp, int64_t key) noexcept {
type_data *t = nb_type_c2p(internals, tp);
if (!t)
return nullptr;

enum_map *fwd = (enum_map *) t->enum_tbl.fwd;

if(t->flags & (uint32_t) type_flags::is_flag_enum) {
PyObject *value = PyLong_FromLongLong(key);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is definitely wrong and the fact that it wasn't caught by a unit test means you need better test coverage. This will turn any flag-enum value into a plain Python int, i.e., it loses its flag-enum membership. That kind of defeats the point of a flag enum, and is extra confusing because you're not allowed to pass that thing back into a C++ function that expects an instance of the flag-enum type.

Py_INCREF(value);
return value;
}
enum_map::iterator it = fwd->find(key);
if (it != fwd->end()) {
PyObject *value = (PyObject *) it->second;
Expand Down
15 changes: 14 additions & 1 deletion tests/test_enum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
namespace nb = nanobind;

enum class Enum : uint32_t { A, B, C = (uint32_t) -1 };
enum class Flag : uint32_t { A = 1, B = 2, C = 4};
enum class SEnum : int32_t { A, B, C = (int32_t) -1 };
enum ClassicEnum { Item1, Item2 };

Expand All @@ -14,6 +15,12 @@ NB_MODULE(test_enum_ext, m) {
.value("B", Enum::B, "Value B")
.value("C", Enum::C, "Value C");

nb::enum_<Flag>(m, "Flag", "enum-level docstring", nb::is_flag_enum())
.value("A", Flag::A, "Value A")
.value("B", Flag::B, "Value B")
.value("C", Flag::C, "Value C")
.export_values();

nb::enum_<SEnum>(m, "SEnum", nb::is_arithmetic())
.value("A", SEnum::A)
.value("B", SEnum::B)
Expand All @@ -31,11 +38,17 @@ NB_MODULE(test_enum_ext, m) {

m.def("from_enum", [](Enum value) { return (uint32_t) value; }, nb::arg().noconvert());
m.def("to_enum", [](uint32_t value) { return (Enum) value; });
m.def("from_enum", [](Flag value) { return (uint32_t) value; }, nb::arg().noconvert());
m.def("to_enum", [](uint32_t value) { return (Flag) value; });
m.def("from_enum", [](SEnum value) { return (int32_t) value; }, nb::arg().noconvert());

m.def("from_enum_implicit", [](Enum value) { return (uint32_t) value; });

m.def("from_enum_default_0", [](Enum value) { return (uint32_t) value; }, nb::arg("value") = Enum::A);

m.def("from_enum_implicit", [](Flag value) { return (uint32_t) value; });

m.def("from_enum_default_0", [](Flag value) { return (uint32_t) value; }, nb::arg("value") = Enum::A);

m.def("from_enum_default_1", [](SEnum value) { return (uint32_t) value; }, nb::arg("value") = SEnum::A);

// test for issue #39
Expand Down
9 changes: 9 additions & 0 deletions tests/test_enum.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import test_enum_ext as t
import enum
import pytest

def test01_unsigned_enum():
Expand Down Expand Up @@ -136,6 +137,14 @@ def test08_enum_comparisons():
assert not (t.Enum.B == t.SEnum.B) and not (t.SEnum.B == t.Enum.B)
assert t.Enum.B != t.SEnum.C and t.SEnum.C != t.Enum.B

def test06_enum_flag():
assert (t.Flag(1) | t.Flag(2)).value == 3
assert (t.Flag(3) & t.Flag(1)).value == 1
assert (t.Flag(3) ^ t.Flag(1)).value == 2
assert (t.Flag(3) == (t.Flag.A | t.Flag.B))
assert (t.from_enum(t.Flag.A | t.Flag.C) == 5)
assert (t.from_enum_implicit(t.Flag(1) | t.Flag(4)) == 5)

def test09_enum_methods():
assert t.Item1.my_value == 0 and t.Item2.my_value == 1
assert t.Item1.get_value() == 0 and t.Item2.get_value() == 1
Expand Down
33 changes: 33 additions & 0 deletions tests/test_enum_ext.pyi.ref
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@ import enum
from typing import overload


A: Flag = Flag.A

B: Flag = Flag.B

C: Flag = Flag.C

class ClassicEnum(enum.Enum):
Item1 = 0

Expand Down Expand Up @@ -35,6 +41,18 @@ class EnumProperty:
@property
def read_enum(self) -> Enum: ...

class Flag(enum.Flag):
"""enum-level docstring"""

A = 1
"""Value A"""

B = 2
"""Value B"""

C = 4
"""Value C"""

Item1: ClassicEnum = ClassicEnum.Item1

Item2: ClassicEnum = ClassicEnum.Item2
Expand All @@ -49,13 +67,28 @@ class SEnum(enum.IntEnum):
@overload
def from_enum(arg: Enum) -> int: ...

@overload
def from_enum(arg: Flag) -> int: ...

@overload
def from_enum(arg: SEnum) -> int: ...

@overload
def from_enum_default_0(value: Enum = Enum.A) -> int: ...

@overload
def from_enum_default_0(value: Flag = Enum.A) -> int: ...

def from_enum_default_1(value: SEnum = SEnum.A) -> int: ...

@overload
def from_enum_implicit(arg: Enum, /) -> int: ...

@overload
def from_enum_implicit(arg: Flag, /) -> int: ...

@overload
def to_enum(arg: int, /) -> Enum: ...

@overload
def to_enum(arg: int, /) -> Flag: ...