Skip to content

Commit

Permalink
Merge pull request #2091 from davidhewitt/use-pyfunction
Browse files Browse the repository at this point in the history
pyfunction: allow wrap_pyfunction to work on imports (even cross-crate)
  • Loading branch information
davidhewitt authored Jan 9, 2022
2 parents 2cee7fe + de81746 commit 43077da
Show file tree
Hide file tree
Showing 9 changed files with 89 additions and 37 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add buffer magic methods `__getbuffer__` and `__releasebuffer__` to `#[pymethods]`. [#2067](https://github.com/PyO3/pyo3/pull/2067)
- Accept paths in `wrap_pyfunction` and `wrap_pymodule`. [#2081](https://github.com/PyO3/pyo3/pull/2081)
- Add check for correct number of arguments on magic methods. [#2083](https://github.com/PyO3/pyo3/pull/2083)
- `wrap_pyfunction!` can now wrap a `#[pyfunction]` which is implemented in a different Rust module or crate. [#2091](https://github.com/PyO3/pyo3/pull/2091)

### Changed

Expand Down
5 changes: 3 additions & 2 deletions pyo3-macros-backend/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,11 @@ pub fn process_functions_in_module(func: &mut syn::ItemFn) -> syn::Result<()> {
if let syn::Stmt::Item(syn::Item::Fn(func)) = &mut stmt {
if let Some(pyfn_args) = get_pyfn_attr(&mut func.attrs)? {
let module_name = pyfn_args.modname;
let (ident, wrapped_function) = impl_wrap_pyfunction(func, pyfn_args.options)?;
let wrapped_function = impl_wrap_pyfunction(func, pyfn_args.options)?;
let name = &func.sig.ident;
let statements: Vec<syn::Stmt> = syn::parse_quote! {
#wrapped_function
#module_name.add_function(#ident(#module_name)?)?;
#module_name.add_function(#name::wrap(#name::DEF, #module_name)?)?;
};
stmts.extend(statements);
}
Expand Down
39 changes: 28 additions & 11 deletions pyo3-macros-backend/src/pyfunction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@ use crate::{
method::{self, CallingConvention, FnArg},
pymethod::check_generic,
utils::{self, ensure_not_async_fn, get_pyo3_crate},
wrap::function_wrapper_ident,
};
use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote};
use syn::punctuated::Punctuated;
use syn::{ext::IdentExt, spanned::Spanned, Ident, NestedMeta, Path, Result};
use syn::{ext::IdentExt, spanned::Spanned, NestedMeta, Path, Result};
use syn::{
parse::{Parse, ParseBuffer, ParseStream},
token::Comma,
Expand Down Expand Up @@ -364,15 +363,15 @@ pub fn build_py_function(
mut options: PyFunctionOptions,
) -> syn::Result<TokenStream> {
options.add_attributes(take_pyo3_options(&mut ast.attrs)?)?;
Ok(impl_wrap_pyfunction(ast, options)?.1)
impl_wrap_pyfunction(ast, options)
}

/// Generates python wrapper over a function that allows adding it to a python module as a python
/// function
pub fn impl_wrap_pyfunction(
func: &mut syn::ItemFn,
options: PyFunctionOptions,
) -> syn::Result<(Ident, TokenStream)> {
) -> syn::Result<TokenStream> {
check_generic(&func.sig)?;
ensure_not_async_fn(&func.sig)?;

Expand Down Expand Up @@ -412,7 +411,6 @@ pub fn impl_wrap_pyfunction(
.map(|attr| (&python_name, attr)),
);

let function_wrapper_ident = function_wrapper_ident(&func.sig.ident);
let krate = get_pyo3_crate(&options.krate);

let spec = method::FnSpec {
Expand All @@ -434,21 +432,40 @@ pub fn impl_wrap_pyfunction(
unsafety: func.sig.unsafety,
};

let wrapper_ident = format_ident!("__pyo3_raw_{}", spec.name);
let vis = &func.vis;
let name = &func.sig.ident;

let wrapper_ident = format_ident!("__pyfunction_{}", spec.name);
let wrapper = spec.get_wrapper_function(&wrapper_ident, None)?;
let methoddef = spec.get_methoddef(wrapper_ident);

let wrapped_pyfunction = quote! {
#wrapper

pub(crate) fn #function_wrapper_ident<'a>(
args: impl ::std::convert::Into<#krate::derive_utils::PyFunctionArguments<'a>>
) -> #krate::PyResult<&'a #krate::types::PyCFunction> {
// Create a module with the same name as the `#[pyfunction]` - this way `use <the function>`
// will actually bring both the module and the function into scope.
#[doc(hidden)]
#vis mod #name {
use #krate as _pyo3;
_pyo3::types::PyCFunction::internal_new(#methoddef, args.into())
pub(crate) struct PyO3Def;

// Exported for `wrap_pyfunction!`
pub use _pyo3::impl_::pyfunction::wrap_pyfunction as wrap;
pub const DEF: _pyo3::PyMethodDef = <PyO3Def as _pyo3::impl_::pyfunction::PyFunctionDef>::DEF;
}

// Generate the definition inside an anonymous function in the same scope as the original function -
// this avoids complications around the fact that the generated module has a different scope
// (and `super` doesn't always refer to the outer scope, e.g. if the `#[pyfunction] is
// inside a function body)
const _: () = {
use #krate as _pyo3;
impl _pyo3::impl_::pyfunction::PyFunctionDef for #name::PyO3Def {
const DEF: _pyo3::PyMethodDef = #methoddef;
}
};
};
Ok((function_wrapper_ident, wrapped_pyfunction))
Ok(wrapped_pyfunction)
}

fn type_is_pymodule(ty: &syn::Type) -> bool {
Expand Down
25 changes: 6 additions & 19 deletions pyo3-macros-backend/src/wrap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,16 @@ impl Parse for WrapPyFunctionArgs {
}
}

pub fn wrap_pyfunction_impl(args: WrapPyFunctionArgs) -> syn::Result<TokenStream> {
pub fn wrap_pyfunction_impl(args: WrapPyFunctionArgs) -> TokenStream {
let WrapPyFunctionArgs {
mut function,
function,
comma_and_arg,
} = args;
let span = function.span();
let last_segment = function
.segments
.last_mut()
.ok_or_else(|| err_spanned!(span => "expected non-empty path"))?;

last_segment.ident = function_wrapper_ident(&last_segment.ident);

let output = if let Some((_, arg)) = comma_and_arg {
quote! { #function(#arg) }
if let Some((_, arg)) = comma_and_arg {
quote! { #function::wrap(#function::DEF, #arg) }
} else {
quote! { &|arg| #function(arg) }
};
Ok(output)
quote! { &|arg| #function::wrap(#function::DEF, arg) }
}
}

pub fn wrap_pymodule_impl(mut module_path: syn::Path) -> syn::Result<TokenStream> {
Expand All @@ -58,10 +49,6 @@ pub fn wrap_pymodule_impl(mut module_path: syn::Path) -> syn::Result<TokenStream
})
}

pub(crate) fn function_wrapper_ident(name: &Ident) -> Ident {
format_ident!("__pyo3_get_function_{}", name)
}

pub(crate) fn module_def_ident(name: &Ident) -> Ident {
format_ident!("__PYO3_PYMODULE_DEF_{}", name.to_string().to_uppercase())
}
10 changes: 8 additions & 2 deletions pyo3-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ pub fn pymethods(_: TokenStream, input: TokenStream) -> TokenStream {

/// A proc macro used to expose Rust functions to Python.
///
/// Functions annotated with `#[pyfunction]` can also be annotated with the following `#[pyo3]` options:
/// Functions annotated with `#[pyfunction]` can also be annotated with the following `#[pyo3]`
/// options:
///
/// | Annotation | Description |
/// | :- | :- |
Expand All @@ -176,6 +177,11 @@ pub fn pymethods(_: TokenStream, input: TokenStream) -> TokenStream {
///
/// For more on exposing functions see the [function section of the guide][1].
///
/// Due to technical limitations on how `#[pyfunction]` is implemented, a function marked
/// `#[pyfunction]` cannot have a module with the same name in the same scope. (The
/// `#[pyfunction]` implementation generates a hidden module with the same name containing
/// metadata about the function, which is used by `wrap_pyfunction!`).
///
/// [1]: https://pyo3.rs/latest/function.html
#[proc_macro_attribute]
pub fn pyfunction(attr: TokenStream, input: TokenStream) -> TokenStream {
Expand Down Expand Up @@ -208,7 +214,7 @@ pub fn derive_from_py_object(item: TokenStream) -> TokenStream {
#[proc_macro]
pub fn wrap_pyfunction(input: TokenStream) -> TokenStream {
let args = parse_macro_input!(input as WrapPyFunctionArgs);
wrap_pyfunction_impl(args).unwrap_or_compile_error().into()
wrap_pyfunction_impl(args).into()
}

/// Returns a function that takes a `Python` instance and returns a Python module.
Expand Down
4 changes: 1 addition & 3 deletions src/derive_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@

//! Functionality for the code generated by the derive backend
use crate::err::PyErr;
use crate::types::PyModule;
use crate::{PyCell, PyClass, Python};
use crate::{types::PyModule, PyCell, PyClass, PyErr, Python};

/// Utility trait to enable &PyClass as a pymethod/function argument
#[doc(hidden)]
Expand Down
2 changes: 2 additions & 0 deletions src/impl_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,7 @@ pub(crate) mod not_send;
#[doc(hidden)]
pub mod pyclass;
#[doc(hidden)]
pub mod pyfunction;
#[doc(hidden)]
pub mod pymethods;
pub mod pymodule;
14 changes: 14 additions & 0 deletions src/impl_/pyfunction.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
use crate::{
class::methods::PyMethodDef, derive_utils::PyFunctionArguments, types::PyCFunction, PyResult,
};

pub trait PyFunctionDef {
const DEF: crate::PyMethodDef;
}

pub fn wrap_pyfunction<'a>(
method_def: PyMethodDef,
args: impl Into<PyFunctionArguments<'a>>,
) -> PyResult<&'a PyCFunction> {
PyCFunction::internal_new(method_def, args.into())
}
26 changes: 26 additions & 0 deletions tests/test_pyfunction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,3 +269,29 @@ fn test_closure_counter() {
py_assert!(py, counter_py, "counter_py() == 2");
py_assert!(py, counter_py, "counter_py() == 3");
}

#[test]
fn use_pyfunction() {
mod function_in_module {
use pyo3::prelude::*;

#[pyfunction]
pub fn foo(x: i32) -> i32 {
x
}
}

Python::with_gil(|py| {
use function_in_module::foo;

// check imported name can be wrapped
let f = wrap_pyfunction!(foo, py).unwrap();
assert_eq!(f.call1((5,)).unwrap().extract::<i32>().unwrap(), 5);
assert_eq!(f.call1((42,)).unwrap().extract::<i32>().unwrap(), 42);

// check path import can be wrapped
let f2 = wrap_pyfunction!(function_in_module::foo, py).unwrap();
assert_eq!(f2.call1((5,)).unwrap().extract::<i32>().unwrap(), 5);
assert_eq!(f2.call1((42,)).unwrap().extract::<i32>().unwrap(), 42);
})
}

0 comments on commit 43077da

Please sign in to comment.