Skip to content

Commit

Permalink
bugfix: py contains raises errors when appropiate (#4209)
Browse files Browse the repository at this point in the history
* bugfix: contains now throws an exception if the key is not hashable

* Fix tests and improve robustness

* Remove todo

* Workaround PyPy corner case

* PyPy xfail

* Fix typo

* fix xfail

* Make clang-tidy happy

* Remove redundant exc checking
  • Loading branch information
Skylion007 authored Oct 17, 2022
1 parent 5b5547b commit b926396
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 3 deletions.
12 changes: 10 additions & 2 deletions include/pybind11/pytypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1967,7 +1967,11 @@ class dict : public object {
void clear() /* py-non-const */ { PyDict_Clear(ptr()); }
template <typename T>
bool contains(T &&key) const {
return PyDict_Contains(m_ptr, detail::object_or_cast(std::forward<T>(key)).ptr()) == 1;
auto result = PyDict_Contains(m_ptr, detail::object_or_cast(std::forward<T>(key)).ptr());
if (result == -1) {
throw error_already_set();
}
return result == 1;
}

private:
Expand Down Expand Up @@ -2053,7 +2057,11 @@ class anyset : public object {
bool empty() const { return size() == 0; }
template <typename T>
bool contains(T &&val) const {
return PySet_Contains(m_ptr, detail::object_or_cast(std::forward<T>(val)).ptr()) == 1;
auto result = PySet_Contains(m_ptr, detail::object_or_cast(std::forward<T>(val)).ptr());
if (result == -1) {
throw error_already_set();
}
return result == 1;
}
};

Expand Down
5 changes: 4 additions & 1 deletion tests/test_pytypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ TEST_SUBMODULE(pytypes, m) {
return d2;
});
m.def("dict_contains",
[](const py::dict &dict, py::object val) { return dict.contains(val); });
[](const py::dict &dict, const py::object &val) { return dict.contains(val); });
m.def("dict_contains",
[](const py::dict &dict, const char *val) { return dict.contains(val); });

Expand Down Expand Up @@ -538,6 +538,9 @@ TEST_SUBMODULE(pytypes, m) {

m.def("hash_function", [](py::object obj) { return py::hash(std::move(obj)); });

m.def("obj_contains",
[](py::object &obj, const py::object &key) { return obj.contains(key); });

m.def("test_number_protocol", [](const py::object &a, const py::object &b) {
py::list l;
l.append(a.equal(b));
Expand Down
25 changes: 25 additions & 0 deletions tests/test_pytypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,31 @@ def test_dict(capture, doc):
assert m.dict_keyword_constructor() == {"x": 1, "y": 2, "z": 3}


class CustomContains:
d = {"key": None}

def __contains__(self, m):
return m in self.d


@pytest.mark.parametrize(
"arg,func",
[
(set(), m.anyset_contains),
(dict(), m.dict_contains),
(CustomContains(), m.obj_contains),
],
)
@pytest.mark.xfail("env.PYPY and sys.pypy_version_info < (7, 3, 10)", strict=False)
def test_unhashable_exceptions(arg, func):
class Unhashable:
__hash__ = None

with pytest.raises(TypeError) as exc_info:
func(arg, Unhashable())
assert "unhashable type:" in str(exc_info.value)


def test_tuple():
assert m.tuple_no_args() == ()
assert m.tuple_ssize_t() == ()
Expand Down

0 comments on commit b926396

Please sign in to comment.