From 7b5a07668036a5b3b86bded27fe2eef44af12781 Mon Sep 17 00:00:00 2001 From: scalexm Date: Tue, 5 May 2020 21:54:09 +0200 Subject: [PATCH 1/4] Add `#[classattr]` methods to define Python class attributes --- pyo3-derive-backend/src/method.rs | 14 +++++++++++++- pyo3-derive-backend/src/pymethod.rs | 30 +++++++++++++++++++++++++++++ src/class/methods.rs | 21 +++++++++++++++++++- src/class/mod.rs | 4 +++- src/pyclass.rs | 23 +++++++++++++++++++--- tests/test_class_attributes.rs | 29 ++++++++++++++++++++++++++++ 6 files changed, 115 insertions(+), 6 deletions(-) create mode 100644 tests/test_class_attributes.rs diff --git a/pyo3-derive-backend/src/method.rs b/pyo3-derive-backend/src/method.rs index aa02a90be97..5cb9606efe3 100644 --- a/pyo3-derive-backend/src/method.rs +++ b/pyo3-derive-backend/src/method.rs @@ -29,6 +29,7 @@ pub enum FnType { FnCall, FnClass, FnStatic, + ClassAttribute, /// For methods taht have `self_: &PyCell` instead of self receiver PySelfRef(syn::TypeReference), /// For methods taht have `self_: PyRef` or `PyRefMut` instead of self receiver @@ -139,6 +140,15 @@ impl<'a> FnSpec<'a> { }; } + if let FnType::ClassAttribute = &fn_type { + if self_.is_some() || !arguments.is_empty() { + return Err(syn::Error::new_spanned( + name, + "Class attribute methods cannot take arguments", + )); + } + } + // "Tweak" getter / setter names: strip off set_ and get_ if needed if let FnType::Getter | FnType::Setter = &fn_type { if python_name.is_none() { @@ -178,7 +188,7 @@ impl<'a> FnSpec<'a> { "text_signature not allowed on __new__; if you want to add a signature on \ __new__, put it on the struct definition instead", )?, - FnType::FnCall | FnType::Getter | FnType::Setter => { + FnType::FnCall | FnType::Getter | FnType::Setter | FnType::ClassAttribute => { parse_erroneous_text_signature("text_signature not allowed with this attribute")? } }; @@ -331,6 +341,8 @@ fn parse_method_attributes( res = Some(FnType::FnClass) } else if name.is_ident("staticmethod") { res = Some(FnType::FnStatic) + } else if name.is_ident("classattr") { + res = Some(FnType::ClassAttribute) } else if name.is_ident("setter") || name.is_ident("getter") { if let syn::AttrStyle::Inner(_) = attr.style { return Err(syn::Error::new_spanned( diff --git a/pyo3-derive-backend/src/pymethod.rs b/pyo3-derive-backend/src/pymethod.rs index 64bf7b9a02a..6782a6908ef 100644 --- a/pyo3-derive-backend/src/pymethod.rs +++ b/pyo3-derive-backend/src/pymethod.rs @@ -30,6 +30,9 @@ pub fn gen_py_method( FnType::FnCall => impl_py_method_def_call(&spec, &impl_wrap(cls, &spec, false)), FnType::FnClass => impl_py_method_def_class(&spec, &impl_wrap_class(cls, &spec)), FnType::FnStatic => impl_py_method_def_static(&spec, &impl_wrap_static(cls, &spec)), + FnType::ClassAttribute => { + impl_py_class_attribute(&spec, &impl_wrap_class_attribute(cls, &spec)) + } FnType::Getter => impl_py_getter_def( &spec.python_name, &spec.doc, @@ -246,6 +249,19 @@ pub fn impl_wrap_static(cls: &syn::Type, spec: &FnSpec<'_>) -> TokenStream { } } +/// Generate a wrapper for initialization of a class attribute. +/// To be called in `pyo3::pyclass::initialize_type_object`. +pub fn impl_wrap_class_attribute(cls: &syn::Type, spec: &FnSpec<'_>) -> TokenStream { + let name = &spec.name; + let cb = quote! { #cls::#name() }; + + quote! { + fn __wrap(py: pyo3::Python<'_>) -> pyo3::PyObject { + pyo3::IntoPy::into_py(#cb, py) + } + } +} + fn impl_call_getter(spec: &FnSpec) -> syn::Result { let (py_arg, args) = split_off_python_arg(&spec.args); if !args.is_empty() { @@ -615,6 +631,20 @@ pub fn impl_py_method_def_static(spec: &FnSpec, wrapper: &TokenStream) -> TokenS } } +pub fn impl_py_class_attribute(spec: &FnSpec<'_>, wrapper: &TokenStream) -> TokenStream { + let python_name = &spec.python_name; + quote! { + pyo3::class::PyMethodDefType::ClassAttribute({ + #wrapper + + pyo3::class::PyClassAttributeDef { + name: stringify!(#python_name), + meth: __wrap, + } + }) + } +} + pub fn impl_py_method_def_call(spec: &FnSpec, wrapper: &TokenStream) -> TokenStream { let python_name = &spec.python_name; let doc = &spec.doc; diff --git a/src/class/methods.rs b/src/class/methods.rs index 1240c97296c..125e1eb2947 100644 --- a/src/class/methods.rs +++ b/src/class/methods.rs @@ -1,8 +1,9 @@ // Copyright (c) 2017-present PyO3 Project and Contributors -use crate::ffi; +use crate::{ffi, PyObject, Python}; use libc::c_int; use std::ffi::CString; +use std::fmt; /// `PyMethodDefType` represents different types of Python callable objects. /// It is used by the `#[pymethods]` and `#[pyproto]` annotations. @@ -18,6 +19,8 @@ pub enum PyMethodDefType { Static(PyMethodDef), /// Represents normal method Method(PyMethodDef), + /// Represents class attribute, used by `#[attribute]` + ClassAttribute(PyClassAttributeDef), /// Represents getter descriptor, used by `#[getter]` Getter(PyGetterDef), /// Represents setter descriptor, used by `#[setter]` @@ -40,6 +43,12 @@ pub struct PyMethodDef { pub ml_doc: &'static str, } +#[derive(Copy, Clone)] +pub struct PyClassAttributeDef { + pub name: &'static str, + pub meth: for<'p> fn(Python<'p>) -> PyObject, +} + #[derive(Copy, Clone, Debug)] pub struct PyGetterDef { pub name: &'static str, @@ -85,6 +94,16 @@ impl PyMethodDef { } } +// Manual implementation because `Python<'_>` does not implement `Debug` and +// trait bounds on `fn` compiler-generated derive impls are too restrictive. +impl fmt::Debug for PyClassAttributeDef { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("PyClassAttributeDef") + .field("name", &self.name) + .finish() + } +} + impl PyGetterDef { /// Copy descriptor information to `ffi::PyGetSetDef` pub fn copy_to(&self, dst: &mut ffi::PyGetSetDef) { diff --git a/src/class/mod.rs b/src/class/mod.rs index 46114023866..df828ecb75c 100644 --- a/src/class/mod.rs +++ b/src/class/mod.rs @@ -24,7 +24,9 @@ pub use self::descr::PyDescrProtocol; pub use self::gc::{PyGCProtocol, PyTraverseError, PyVisit}; pub use self::iter::PyIterProtocol; pub use self::mapping::PyMappingProtocol; -pub use self::methods::{PyGetterDef, PyMethodDef, PyMethodDefType, PyMethodType, PySetterDef}; +pub use self::methods::{ + PyClassAttributeDef, PyGetterDef, PyMethodDef, PyMethodDefType, PyMethodType, PySetterDef, +}; pub use self::number::PyNumberProtocol; pub use self::pyasync::PyAsyncProtocol; pub use self::sequence::PySequenceProtocol; diff --git a/src/pyclass.rs b/src/pyclass.rs index 829cf3dbb47..f0da8c35640 100644 --- a/src/pyclass.rs +++ b/src/pyclass.rs @@ -1,7 +1,9 @@ //! `PyClass` trait -use crate::class::methods::{PyMethodDefType, PyMethodsImpl}; +use crate::class::methods::{PyClassAttributeDef, PyMethodDefType, PyMethodsImpl}; +use crate::conversion::{IntoPyPointer, ToPyObject}; use crate::pyclass_slots::{PyClassDict, PyClassWeakRef}; use crate::type_object::{type_flags, PyLayout}; +use crate::types::PyDict; use crate::{class, ffi, PyCell, PyErr, PyNativeType, PyResult, PyTypeInfo, Python}; use std::ffi::CString; use std::os::raw::c_void; @@ -165,13 +167,23 @@ where // buffer protocol type_object.tp_as_buffer = to_ptr(::tp_as_buffer()); + let (new, call, mut methods, attrs) = py_class_method_defs::(); + // normal methods - let (new, call, mut methods) = py_class_method_defs::(); if !methods.is_empty() { methods.push(ffi::PyMethodDef_INIT); type_object.tp_methods = Box::into_raw(methods.into_boxed_slice()) as *mut _; } + // class attributes + if !attrs.is_empty() { + let dict = PyDict::new(py); + for attr in attrs { + dict.set_item(attr.name, (attr.meth)(py))?; + } + type_object.tp_dict = dict.to_object(py).into_ptr(); + } + // __new__ method type_object.tp_new = new; // __call__ method @@ -219,8 +231,10 @@ fn py_class_method_defs() -> ( Option, Option, Vec, + Vec, ) { let mut defs = Vec::new(); + let mut attrs = Vec::new(); let mut call = None; let mut new = None; @@ -243,6 +257,9 @@ fn py_class_method_defs() -> ( | PyMethodDefType::Static(ref def) => { defs.push(def.as_method_def()); } + PyMethodDefType::ClassAttribute(def) => { + attrs.push(def); + } _ => (), } } @@ -265,7 +282,7 @@ fn py_class_method_defs() -> ( py_class_async_methods::(&mut defs); - (new, call, defs) + (new, call, defs, attrs) } fn py_class_async_methods(defs: &mut Vec) { diff --git a/tests/test_class_attributes.rs b/tests/test_class_attributes.rs new file mode 100644 index 00000000000..c5a36a2eccc --- /dev/null +++ b/tests/test_class_attributes.rs @@ -0,0 +1,29 @@ +use pyo3::prelude::*; + +mod common; + +#[pyclass] +struct Foo {} + +#[pymethods] +impl Foo { + #[classattr] + fn a() -> i32 { + 5 + } + + #[classattr] + #[name = "B"] + fn b() -> String { + "bar".to_string() + } +} + +#[test] +fn class_attributes() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let foo_obj = py.get_type::(); + py_assert!(py, foo_obj, "foo_obj.a == 5"); + py_assert!(py, foo_obj, "foo_obj.B == 'bar'"); +} From 8f22d10a145ba8ad8f995c1ac8d869471ce379d3 Mon Sep 17 00:00:00 2001 From: scalexm Date: Wed, 6 May 2020 20:11:29 +0200 Subject: [PATCH 2/4] Add a test showing that class attrs are immutable --- tests/test_class_attributes.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/test_class_attributes.rs b/tests/test_class_attributes.rs index c5a36a2eccc..ab679c024eb 100644 --- a/tests/test_class_attributes.rs +++ b/tests/test_class_attributes.rs @@ -27,3 +27,11 @@ fn class_attributes() { py_assert!(py, foo_obj, "foo_obj.a == 5"); py_assert!(py, foo_obj, "foo_obj.B == 'bar'"); } + +#[test] +fn class_attributes_are_immutable() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let foo_obj = py.get_type::(); + py_expect_exception!(py, foo_obj, "foo_obj.a = 6", TypeError); +} From d3d68eafb4370e338403fba4b4bcf7d036d3008c Mon Sep 17 00:00:00 2001 From: scalexm Date: Thu, 7 May 2020 20:13:10 +0200 Subject: [PATCH 3/4] Add a test with class attrs returning `PyClass` instances --- tests/test_class_attributes.rs | 40 +++++++++++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/tests/test_class_attributes.rs b/tests/test_class_attributes.rs index ab679c024eb..692a6629bc1 100644 --- a/tests/test_class_attributes.rs +++ b/tests/test_class_attributes.rs @@ -3,7 +3,16 @@ use pyo3::prelude::*; mod common; #[pyclass] -struct Foo {} +struct Foo { + #[pyo3(get)] + x: i32, +} + +#[pyclass] +struct Bar { + #[pyo3(get)] + x: i32, +} #[pymethods] impl Foo { @@ -17,6 +26,24 @@ impl Foo { fn b() -> String { "bar".to_string() } + + #[classattr] + fn foo() -> Foo { + Foo { x: 1 } + } + + #[classattr] + fn bar() -> Bar { + Bar { x: 2 } + } +} + +#[pymethods] +impl Bar { + #[classattr] + fn foo() -> Foo { + Foo { x: 3 } + } } #[test] @@ -35,3 +62,14 @@ fn class_attributes_are_immutable() { let foo_obj = py.get_type::(); py_expect_exception!(py, foo_obj, "foo_obj.a = 6", TypeError); } + +#[test] +fn recursive_class_attributes() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let foo_obj = py.get_type::(); + let bar_obj = py.get_type::(); + py_assert!(py, foo_obj, "foo_obj.foo.x == 1"); + py_assert!(py, foo_obj, "foo_obj.bar.x == 2"); + py_assert!(py, bar_obj, "bar_obj.foo.x == 3"); +} From e3d9544ae0f624a398b559bead62e76ef512d0dd Mon Sep 17 00:00:00 2001 From: scalexm Date: Thu, 7 May 2020 20:33:15 +0200 Subject: [PATCH 4/4] Add a paragraph to the guide about `#[classattr]` --- guide/src/class.md | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/guide/src/class.md b/guide/src/class.md index 990f10e11e2..84fccd072f7 100644 --- a/guide/src/class.md +++ b/guide/src/class.md @@ -573,6 +573,37 @@ impl MyClass { } ``` +## Class attributes + +To create a class attribute (also called [class variable][classattr]), a method without +any arguments can be annotated with the `#[classattr]` attribute. The return type must be `T` for +some `T` that implements `IntoPy`. + +```rust +# use pyo3::prelude::*; +# #[pyclass] +# struct MyClass {} +#[pymethods] +impl MyClass { + #[classattr] + fn my_attribute() -> String { + "hello".to_string() + } +} + +let gil = Python::acquire_gil(); +let py = gil.python(); +let my_class = py.get_type::(); +pyo3::py_run!(py, my_class, "assert my_class.my_attribute == 'hello'") +``` + +Note that unlike class variables defined in Python code, class attributes defined in Rust cannot +be mutated at all: +```rust,ignore +// Would raise a `TypeError: can't set attributes of built-in/extension type 'MyClass'` +pyo3::py_run!(py, my_class, "my_class.my_attribute = 'foo'") +``` + ## Callable objects To specify a custom `__call__` method for a custom class, the method needs to be annotated with @@ -914,3 +945,5 @@ To escape this we use [inventory](https://github.com/dtolnay/inventory), which a [`PyClassInitializer`]: https://pyo3.rs/master/doc/pyo3/pyclass_init/struct.PyClassInitializer.html [`RefCell`]: https://doc.rust-lang.org/std/cell/struct.RefCell.html + +[classattr]: https://docs.python.org/3/tutorial/classes.html#class-and-instance-variables