Skip to content

Commit

Permalink
update #[derive(FromPyObject)] to use extract_bound
Browse files Browse the repository at this point in the history
  • Loading branch information
Icxolu committed Feb 12, 2024
1 parent 1279467 commit 586a7dc
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 16 deletions.
11 changes: 6 additions & 5 deletions pyo3-macros-backend/src/frompyobject.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,12 +298,12 @@ impl<'a> Container<'a> {
let fields = struct_fields.iter().zip(&field_idents).enumerate().map(|(index, (field, ident))| {
match &field.from_py_with {
None => quote!(
_pyo3::impl_::frompyobject::extract_tuple_struct_field(#ident, #struct_name, #index)?
_pyo3::impl_::frompyobject::extract_tuple_struct_field(&#ident, #struct_name, #index)?
),
Some(FromPyWithAttribute {
value: expr_path, ..
}) => quote! (
_pyo3::impl_::frompyobject::extract_tuple_struct_field_with(#expr_path, #ident, #struct_name, #index)?
_pyo3::impl_::frompyobject::extract_tuple_struct_field_with(#expr_path, &#ident, #struct_name, #index)?
),
}
});
Expand Down Expand Up @@ -339,12 +339,12 @@ impl<'a> Container<'a> {
};
let extractor = match &field.from_py_with {
None => {
quote!(_pyo3::impl_::frompyobject::extract_struct_field(obj.#getter?, #struct_name, #field_name)?)
quote!(_pyo3::impl_::frompyobject::extract_struct_field(&obj.#getter?, #struct_name, #field_name)?)
}
Some(FromPyWithAttribute {
value: expr_path, ..
}) => {
quote! (_pyo3::impl_::frompyobject::extract_struct_field_with(#expr_path, obj.#getter?, #struct_name, #field_name)?)
quote! (_pyo3::impl_::frompyobject::extract_struct_field_with(#expr_path, &obj.#getter?, #struct_name, #field_name)?)
}
};

Expand Down Expand Up @@ -606,10 +606,11 @@ pub fn build_derive_from_pyobject(tokens: &DeriveInput) -> Result<TokenStream> {
Ok(quote!(
const _: () = {
use #krate as _pyo3;
use _pyo3::prelude::PyAnyMethods;

#[automatically_derived]
impl #trait_generics _pyo3::FromPyObject<#lt_param> for #ident #generics #where_clause {
fn extract(obj: &#lt_param _pyo3::PyAny) -> _pyo3::PyResult<Self> {
fn extract_bound(obj: &_pyo3::Bound<#lt_param, _pyo3::PyAny>) -> _pyo3::PyResult<Self> {
#derives
}
}
Expand Down
14 changes: 8 additions & 6 deletions src/impl_/frompyobject.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use crate::types::any::PyAnyMethods;
use crate::Bound;
use crate::{exceptions::PyTypeError, FromPyObject, PyAny, PyErr, PyResult, Python};

#[cold]
Expand Down Expand Up @@ -41,7 +43,7 @@ fn extract_traceback(py: Python<'_>, mut error: PyErr) -> String {
}

pub fn extract_struct_field<'py, T>(
obj: &'py PyAny,
obj: &Bound<'py, PyAny>,
struct_name: &str,
field_name: &str,
) -> PyResult<T>
Expand All @@ -60,8 +62,8 @@ where
}

pub fn extract_struct_field_with<'py, T>(
extractor: impl FnOnce(&'py PyAny) -> PyResult<T>,
obj: &'py PyAny,
extractor: impl FnOnce(&Bound<'py, PyAny>) -> PyResult<T>,
obj: &Bound<'py, PyAny>,
struct_name: &str,
field_name: &str,
) -> PyResult<T> {
Expand Down Expand Up @@ -92,7 +94,7 @@ fn failed_to_extract_struct_field(
}

pub fn extract_tuple_struct_field<'py, T>(
obj: &'py PyAny,
obj: &Bound<'py, PyAny>,
struct_name: &str,
index: usize,
) -> PyResult<T>
Expand All @@ -111,8 +113,8 @@ where
}

pub fn extract_tuple_struct_field_with<'py, T>(
extractor: impl FnOnce(&'py PyAny) -> PyResult<T>,
obj: &'py PyAny,
extractor: impl FnOnce(&Bound<'py, PyAny>) -> PyResult<T>,
obj: &Bound<'py, PyAny>,
struct_name: &str,
index: usize,
) -> PyResult<T> {
Expand Down
10 changes: 5 additions & 5 deletions tests/test_frompyobject.rs
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ pub struct Zap {
#[pyo3(item)]
name: String,

#[pyo3(from_py_with = "PyAny::len", item("my_object"))]
#[pyo3(from_py_with = "PyAnyMethods::len", item("my_object"))]
some_object_length: usize,
}

Expand All @@ -525,7 +525,7 @@ fn test_from_py_with() {
}

#[derive(Debug, FromPyObject)]
pub struct ZapTuple(String, #[pyo3(from_py_with = "PyAny::len")] usize);
pub struct ZapTuple(String, #[pyo3(from_py_with = "PyAnyMethods::len")] usize);

#[test]
fn test_from_py_with_tuple_struct() {
Expand Down Expand Up @@ -560,8 +560,8 @@ fn test_from_py_with_tuple_struct_error() {

#[derive(Debug, FromPyObject, PartialEq, Eq)]
pub enum ZapEnum {
Zip(#[pyo3(from_py_with = "PyAny::len")] usize),
Zap(String, #[pyo3(from_py_with = "PyAny::len")] usize),
Zip(#[pyo3(from_py_with = "PyAnyMethods::len")] usize),
Zap(String, #[pyo3(from_py_with = "PyAnyMethods::len")] usize),
}

#[test]
Expand All @@ -581,7 +581,7 @@ fn test_from_py_with_enum() {
#[derive(Debug, FromPyObject, PartialEq, Eq)]
#[pyo3(transparent)]
pub struct TransparentFromPyWith {
#[pyo3(from_py_with = "PyAny::len")]
#[pyo3(from_py_with = "PyAnyMethods::len")]
len: usize,
}

Expand Down

0 comments on commit 586a7dc

Please sign in to comment.