Skip to content

Commit

Permalink
type inference for from_py_with using function pointers
Browse files Browse the repository at this point in the history
  • Loading branch information
Icxolu committed Feb 12, 2024
1 parent 586a7dc commit 90e62bf
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 17 deletions.
8 changes: 4 additions & 4 deletions pyo3-macros-backend/src/frompyobject.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ impl<'a> Container<'a> {
value: expr_path, ..
}) => quote! {
Ok(#self_ty {
#ident: _pyo3::impl_::frompyobject::extract_struct_field_with(#expr_path, obj, #struct_name, #field_name)?
#ident: _pyo3::impl_::frompyobject::extract_struct_field_with(#expr_path as fn(_) -> _, obj, #struct_name, #field_name)?
})
},
}
Expand All @@ -283,7 +283,7 @@ impl<'a> Container<'a> {
Some(FromPyWithAttribute {
value: expr_path, ..
}) => quote! (
_pyo3::impl_::frompyobject::extract_tuple_struct_field_with(#expr_path, obj, #struct_name, 0).map(#self_ty)
_pyo3::impl_::frompyobject::extract_tuple_struct_field_with(#expr_path as fn(_) -> _, obj, #struct_name, 0).map(#self_ty)
),
}
}
Expand All @@ -303,7 +303,7 @@ impl<'a> Container<'a> {
Some(FromPyWithAttribute {
value: expr_path, ..
}) => quote! (
_pyo3::impl_::frompyobject::extract_tuple_struct_field_with(#expr_path, &#ident, #struct_name, #index)?
_pyo3::impl_::frompyobject::extract_tuple_struct_field_with(#expr_path as fn(_) -> _, &#ident, #struct_name, #index)?
),
}
});
Expand Down Expand Up @@ -344,7 +344,7 @@ impl<'a> Container<'a> {
Some(FromPyWithAttribute {
value: expr_path, ..
}) => {
quote! (_pyo3::impl_::frompyobject::extract_struct_field_with(#expr_path, &obj.#getter?, #struct_name, #field_name)?)
quote! (_pyo3::impl_::frompyobject::extract_struct_field_with(#expr_path as fn(_) -> _, &obj.#getter?, #struct_name, #field_name)?)
}
};

Expand Down
42 changes: 34 additions & 8 deletions src/impl_/frompyobject.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,32 @@ use crate::types::any::PyAnyMethods;
use crate::Bound;
use crate::{exceptions::PyTypeError, FromPyObject, PyAny, PyErr, PyResult, Python};

pub enum Extractor<'a, 'py, T> {
Bound(fn(&'a Bound<'py, PyAny>) -> PyResult<T>),
GilRef(fn(&'a PyAny) -> PyResult<T>),
}

impl<'a, 'py, T> From<fn(&'a Bound<'py, PyAny>) -> PyResult<T>> for Extractor<'a, 'py, T> {
fn from(value: fn(&'a Bound<'py, PyAny>) -> PyResult<T>) -> Self {
Self::Bound(value)
}
}

impl<'a, T> From<fn(&'a PyAny) -> PyResult<T>> for Extractor<'a, '_, T> {
fn from(value: fn(&'a PyAny) -> PyResult<T>) -> Self {
Self::GilRef(value)
}
}

impl<'a, 'py, T> Extractor<'a, 'py, T> {
fn call(self, obj: &'a Bound<'py, PyAny>) -> PyResult<T> {
match self {
Extractor::Bound(f) => f(obj),
Extractor::GilRef(f) => f(obj.as_gil_ref()),
}
}
}

#[cold]
pub fn failed_to_extract_enum(
py: Python<'_>,
Expand Down Expand Up @@ -61,13 +87,13 @@ where
}
}

pub fn extract_struct_field_with<'py, T>(
extractor: impl FnOnce(&Bound<'py, PyAny>) -> PyResult<T>,
obj: &Bound<'py, PyAny>,
pub fn extract_struct_field_with<'a, 'py, T>(
extractor: impl Into<Extractor<'a, 'py, T>>,
obj: &'a Bound<'py, PyAny>,
struct_name: &str,
field_name: &str,
) -> PyResult<T> {
match extractor(obj) {
match extractor.into().call(obj) {
Ok(value) => Ok(value),
Err(err) => Err(failed_to_extract_struct_field(
obj.py(),
Expand Down Expand Up @@ -112,13 +138,13 @@ where
}
}

pub fn extract_tuple_struct_field_with<'py, T>(
extractor: impl FnOnce(&Bound<'py, PyAny>) -> PyResult<T>,
obj: &Bound<'py, PyAny>,
pub fn extract_tuple_struct_field_with<'a, 'py, T>(
extractor: impl Into<Extractor<'a, 'py, T>>,
obj: &'a Bound<'py, PyAny>,
struct_name: &str,
index: usize,
) -> PyResult<T> {
match extractor(obj) {
match extractor.into().call(obj) {
Ok(value) => Ok(value),
Err(err) => Err(failed_to_extract_tuple_struct_field(
obj.py(),
Expand Down
16 changes: 11 additions & 5 deletions tests/test_frompyobject.rs
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ pub struct Zap {
#[pyo3(item)]
name: String,

#[pyo3(from_py_with = "PyAnyMethods::len", item("my_object"))]
#[pyo3(from_py_with = "Bound::<'_, PyAny>::len", item("my_object"))]
some_object_length: usize,
}

Expand All @@ -525,7 +525,10 @@ fn test_from_py_with() {
}

#[derive(Debug, FromPyObject)]
pub struct ZapTuple(String, #[pyo3(from_py_with = "PyAnyMethods::len")] usize);
pub struct ZapTuple(
String,
#[pyo3(from_py_with = "Bound::<'_, PyAny>::len")] usize,
);

#[test]
fn test_from_py_with_tuple_struct() {
Expand Down Expand Up @@ -560,8 +563,11 @@ fn test_from_py_with_tuple_struct_error() {

#[derive(Debug, FromPyObject, PartialEq, Eq)]
pub enum ZapEnum {
Zip(#[pyo3(from_py_with = "PyAnyMethods::len")] usize),
Zap(String, #[pyo3(from_py_with = "PyAnyMethods::len")] usize),
Zip(#[pyo3(from_py_with = "Bound::<'_, PyAny>::len")] usize),
Zap(
String,
#[pyo3(from_py_with = "Bound::<'_, PyAny>::len")] usize,
),
}

#[test]
Expand All @@ -581,7 +587,7 @@ fn test_from_py_with_enum() {
#[derive(Debug, FromPyObject, PartialEq, Eq)]
#[pyo3(transparent)]
pub struct TransparentFromPyWith {
#[pyo3(from_py_with = "PyAnyMethods::len")]
#[pyo3(from_py_with = "Bound::<'_, PyAny>::len")]
len: usize,
}

Expand Down

0 comments on commit 90e62bf

Please sign in to comment.