Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some Improvments for getter #137

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion examples/pure/pure.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ from enum import Enum, auto

MY_CONSTANT: int
class A:
x: int
x: int = 3
r"""
Class variable `x`, the default value is 3
"""
def __new__(cls,x:int): ...
def show_x(self) -> None:
...
Expand Down
10 changes: 10 additions & 0 deletions examples/pure/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,16 @@ fn create_dict(n: usize) -> HashMap<usize, Vec<usize>> {
#[pyclass]
#[derive(Debug)]
struct A {
/// Class variable `x`, the default value is 3
#[pyo3(get, set)]
#[gen_stub(default = A::default().x)]
x: usize,
}
impl Default for A {
fn default() -> Self {
Self { x: 1 + 2 }
}
}

#[gen_stub_pymethods]
#[pymethods]
Expand All @@ -56,6 +63,9 @@ impl A {
fn ref_test<'a>(&self, x: Bound<'a, PyDict>) -> Bound<'a, PyDict> {
x
}

#[gen_stub(skip)]
fn need_skip(&self) {}
}

#[gen_stub_pyfunction]
Expand Down
24 changes: 18 additions & 6 deletions pyo3-stub-gen-derive/src/gen_stub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,20 @@
//! MemberInfo {
//! name: "name",
//! r#type: <String as ::pyo3_stub_gen::PyStubType>::type_output,
//! default: None,
//! doc: "",
//! },
//! MemberInfo {
//! name: "ndim",
//! r#type: <usize as ::pyo3_stub_gen::PyStubType>::type_output,
//! default: None,
//! doc: "",
//! },
//! MemberInfo {
//! name: "description",
//! r#type: <Option<String> as ::pyo3_stub_gen::PyStubType>::type_output,
//! default: None,
//! doc: "",
//! },
//! ],
//! doc: "",
Expand Down Expand Up @@ -90,10 +96,12 @@ use quote::quote;
use syn::{parse2, ItemEnum, ItemFn, ItemImpl, ItemStruct, Result};

pub fn pyclass(item: TokenStream2) -> Result<TokenStream2> {
let inner = PyClassInfo::try_from(parse2::<ItemStruct>(item.clone())?)?;
let mut item_struct = parse2::<ItemStruct>(item)?;
let inner = PyClassInfo::try_from(item_struct.clone())?;
let derive_stub_type = StubType::from(&inner);
pyclass::prune_attrs(&mut item_struct);
Ok(quote! {
#item
#item_struct
#derive_stub_type
pyo3_stub_gen::inventory::submit! {
#inner
Expand All @@ -114,9 +122,11 @@ pub fn pyclass_enum(item: TokenStream2) -> Result<TokenStream2> {
}

pub fn pymethods(item: TokenStream2) -> Result<TokenStream2> {
let inner = PyMethodsInfo::try_from(parse2::<ItemImpl>(item.clone())?)?;
let mut item_impl = parse2::<ItemImpl>(item)?;
let inner = PyMethodsInfo::try_from(item_impl.clone())?;
pymethods::prune_attrs(&mut item_impl);
Ok(quote! {
#item
#item_impl
#[automatically_derived]
pyo3_stub_gen::inventory::submit! {
#inner
Expand All @@ -125,10 +135,12 @@ pub fn pymethods(item: TokenStream2) -> Result<TokenStream2> {
}

pub fn pyfunction(attr: TokenStream2, item: TokenStream2) -> Result<TokenStream2> {
let mut inner = PyFunctionInfo::try_from(parse2::<ItemFn>(item.clone())?)?;
let mut item_fn = parse2::<ItemFn>(item)?;
let mut inner = PyFunctionInfo::try_from(item_fn.clone())?;
inner.parse_attr(attr)?;
pyfunction::prune_attrs(&mut item_fn);
Ok(quote! {
#item
#item_fn
#[automatically_derived]
pyo3_stub_gen::inventory::submit! {
#inner
Expand Down
125 changes: 124 additions & 1 deletion pyo3-stub-gen-derive/src/gen_stub/attr.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use super::Signature;
use proc_macro2::TokenTree;
use quote::ToTokens;
use syn::{Attribute, Expr, ExprLit, Ident, Lit, Meta, MetaList, Result};
use syn::{
parse::ParseStream, Attribute, Expr, ExprLit, Ident, Lit, Meta, MetaList, Result, Token,
};

pub fn extract_documents(attrs: &[Attribute]) -> Vec<String> {
let mut docs = Vec::new();
Expand Down Expand Up @@ -141,6 +143,91 @@ pub fn parse_pyo3_attr(attr: &Attribute) -> Result<Vec<Attr>> {
Ok(pyo3_attrs)
}

#[derive(Debug, Clone, PartialEq)]
pub enum StubGenAttr {
/// Default value for getter
Default(Expr),
/// Skip a function in #[pymethods]
Skip,
}

pub fn prune_attrs(attrs: &mut Vec<Attribute>) {
*attrs = std::mem::take(attrs)
.into_iter()
.filter_map(|attr| {
if attr.path().is_ident("gen_stub") {
None
} else {
Some(attr)
}
})
.collect();
}

pub fn parse_gen_stub_default(attrs: &[Attribute]) -> Result<Option<Expr>> {
for attr in parse_gen_stub_attrs(attrs)? {
if let StubGenAttr::Default(default) = attr {
return Ok(Some(default));
}
}
Ok(None)
}
pub fn parse_gen_stub_skip(attrs: &[Attribute]) -> Result<bool> {
let skip = parse_gen_stub_attrs(attrs)?
.iter()
.any(|attr| matches!(attr, StubGenAttr::Skip));
Ok(skip)
}
fn parse_gen_stub_attrs(attrs: &[Attribute]) -> Result<Vec<StubGenAttr>> {
let mut out = Vec::new();
for attr in attrs {
let mut new = parse_gen_stub_attr(attr)?;
out.append(&mut new);
}
Ok(out)
}

fn parse_gen_stub_attr(attr: &Attribute) -> Result<Vec<StubGenAttr>> {
let mut gen_stub_attrs = Vec::new();
let path = attr.path();
if path.is_ident("gen_stub") {
attr.parse_args_with(|input: ParseStream| {
while !input.is_empty() {
let ident: Ident = input.parse()?;
#[allow(clippy::collapsible_else_if)]
if input.peek(Token![=]) {
input.parse::<Token![=]>()?;
if ident == "default" {
gen_stub_attrs.push(StubGenAttr::Default(input.parse()?));
} else {
return Err(syn::Error::new(
ident.span(),
format!("Unsupport keyword `{ident}`, valid is `default=xxx`"),
));
}
} else {
if ident == "skip" {
gen_stub_attrs.push(StubGenAttr::Skip);
} else {
return Err(syn::Error::new(
ident.span(),
format!("Unsupport keyword `{ident}`, valid is `skip`"),
));
}
}
if input.peek(Token![,]) {
input.parse::<Token![,]>()?;
} else {
break;
}
}
Ok(())
})?;
}

Ok(gen_stub_attrs)
}

#[cfg(test)]
mod test {
use super::*;
Expand Down Expand Up @@ -176,4 +263,40 @@ mod test {
}
Ok(())
}

#[test]
fn test_parse_gen_stub_attr() -> Result<()> {
let item: ItemStruct = parse_str(
r#"
pub struct PyPlaceholder {
#[gen_stub(default = String::from("foo"), skip)]
pub field0: String,
#[gen_stub(skip)]
pub field1: String,
#[gen_stub(default = 1+2)]
pub field2: usize,
}
"#,
)?;
let fields: Vec<_> = item.fields.into_iter().collect();
let field0_attrs = parse_gen_stub_attrs(&fields[0].attrs)?;
if let StubGenAttr::Default(expr) = &field0_attrs[0] {
assert_eq!(
expr.to_token_stream().to_string(),
"String :: from (\"foo\")"
);
} else {
panic!("attr shoubd be Default");
};
assert_eq!(&StubGenAttr::Skip, &field0_attrs[1]);
let field1_attrs = parse_gen_stub_attrs(&fields[1].attrs)?;
assert_eq!(vec![StubGenAttr::Skip], field1_attrs);
let field2_attrs = parse_gen_stub_attrs(&fields[2].attrs)?;
if let StubGenAttr::Default(expr) = &field2_attrs[0] {
assert_eq!(expr.to_token_stream().to_string(), "1 + 2");
} else {
panic!("attr shoubd be Default");
};
Ok(())
}
}
47 changes: 43 additions & 4 deletions pyo3-stub-gen-derive/src/gen_stub/member.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
use super::{escape_return_type, parse_pyo3_attrs, Attr};
use super::{
escape_return_type, extract_documents, parse_gen_stub_default, parse_pyo3_attrs, Attr,
};

use proc_macro2::TokenStream as TokenStream2;
use quote::{quote, ToTokens, TokenStreamExt};
use syn::{Error, Field, ImplItemFn, Result, Type};
use syn::{Error, Expr, Field, ImplItemFn, Result, Type};

#[derive(Debug)]
pub struct MemberInfo {
name: String,
r#type: Type,
default: Option<Expr>,
doc: String,
}

impl MemberInfo {
Expand All @@ -29,12 +33,16 @@ impl TryFrom<ImplItemFn> for MemberInfo {
fn try_from(item: ImplItemFn) -> Result<Self> {
assert!(Self::is_candidate_item(&item)?);
let ImplItemFn { attrs, sig, .. } = &item;
let default = parse_gen_stub_default(attrs)?;
let doc = extract_documents(attrs).join("\n");
let attrs = parse_pyo3_attrs(attrs)?;
for attr in attrs {
if let Attr::Getter(name) = attr {
return Ok(MemberInfo {
name: name.unwrap_or(sig.ident.to_string()),
r#type: escape_return_type(&sig.output).expect("Getter must return a type"),
default,
doc,
});
}
}
Expand All @@ -54,21 +62,52 @@ impl TryFrom<Field> for MemberInfo {
field_name = Some(name);
}
}
let default = parse_gen_stub_default(&attrs)?;
let doc = extract_documents(&attrs).join("\n");
Ok(Self {
name: field_name.unwrap_or(ident.unwrap().to_string()),
r#type: ty,
default,
doc,
})
}
}

impl ToTokens for MemberInfo {
fn to_tokens(&self, tokens: &mut TokenStream2) {
let Self { name, r#type: ty } = self;
let Self {
name,
r#type: ty,
default,
doc,
} = self;
let name = name.strip_prefix("get_").unwrap_or(name);
let default_tt = if let Some(default) = default {
quote! {
Some({
static DEFAULT: std::sync::LazyLock<String> = std::sync::LazyLock::new(|| {
::pyo3::prepare_freethreaded_python();
::pyo3::Python::with_gil(|py| -> String {
let v: #ty = #default;
if let Ok(py_obj) = <#ty as ::pyo3::IntoPyObject>::into_pyobject(v, py) {
::pyo3_stub_gen::util::fmt_py_obj(&py_obj)
} else {
"...".to_owned()
}
})
});
&DEFAULT
})
}
} else {
quote! {None}
};
tokens.append_all(quote! {
::pyo3_stub_gen::type_info::MemberInfo {
name: #name,
r#type: <#ty as ::pyo3_stub_gen::PyStubType>::type_output
r#type: <#ty as ::pyo3_stub_gen::PyStubType>::type_output,
default: #default_tt,
doc: #doc,
}
})
}
Expand Down
18 changes: 18 additions & 0 deletions pyo3-stub-gen-derive/src/gen_stub/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,16 @@ impl ToTokens for PyClassInfo {
}
}

// `#[gen_stub(xxx)]` is not a valid proc_macro_attribute
// it's only designed to receive user's setting.
// We need to remove all `#[gen_stub(xxx)]` before print the item_struct back
pub fn prune_attrs(item_struct: &mut ItemStruct) {
super::attr::prune_attrs(&mut item_struct.attrs);
for field in item_struct.fields.iter_mut() {
super::attr::prune_attrs(&mut field.attrs);
}
}

#[cfg(test)]
mod test {
use super::*;
Expand All @@ -105,6 +115,8 @@ mod test {
Debug, Clone, PyNeg, PyAdd, PySub, PyMul, PyDiv, PyMod, PyPow, PyCmp, PyIndex, PyPrint,
)]
pub struct PyPlaceholder {
/// doc line 1
/// doc line 2
#[pyo3(get)]
pub name: String,
#[pyo3(get)]
Expand All @@ -124,14 +136,20 @@ mod test {
::pyo3_stub_gen::type_info::MemberInfo {
name: "name",
r#type: <String as ::pyo3_stub_gen::PyStubType>::type_output,
default: None,
doc: "doc line 1\ndoc line 2",
},
::pyo3_stub_gen::type_info::MemberInfo {
name: "ndim",
r#type: <usize as ::pyo3_stub_gen::PyStubType>::type_output,
default: None,
doc: "",
},
::pyo3_stub_gen::type_info::MemberInfo {
name: "description",
r#type: <Option<String> as ::pyo3_stub_gen::PyStubType>::type_output,
default: None,
doc: "",
},
],
module: Some("my_module"),
Expand Down
7 changes: 7 additions & 0 deletions pyo3-stub-gen-derive/src/gen_stub/pyfunction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,10 @@ impl ToTokens for PyFunctionInfo {
})
}
}

// `#[gen_stub(xxx)]` is not a valid proc_macro_attribute
// it's only designed to receive user's setting.
// We need to remove all `#[gen_stub(xxx)]` before print the item_fn back
pub fn prune_attrs(item_fn: &mut ItemFn) {
super::attr::prune_attrs(&mut item_fn.attrs);
}
Loading
Loading