Skip to content

Commit

Permalink
Refactor alises support on ToSchema derive (#546)
Browse files Browse the repository at this point in the history
Prior to this commit the implementation was not able to resolve
nested generics within aliases. That lead scenarios where types with
extensive use of lifetimes was not possible.

This commit takes another approach on aliases support for `ToSchema`
derive macro that provides generic schema types. Instead of trying to
parse `Generics` manually we parse `syn::Type` instead that contains
generics as is allowing complex generic arguments with lifetimes to be
used.

Fundamental difference is that we create `TypeTree` for alias and the
implementor type. Then we compare generic arguments to the field
arguments and replace matching occurrences.
```rust
 #[derive(ToSchema)]
 #[aliases(Paginated1 = Paginated<'b, String>, Paginated2 = Paginated<'b, Cow<'b, bool>>)]
 struct Paginated<'r, T> {
     pub total: usize,
     pub data: Vec<T>,
     pub next: Option<&'r str>,
     pub prev: Option<&'r str>,
 }
```

Also removed the need to define lifetime on left side of the equals (=) sign.
  • Loading branch information
juhaku authored Mar 28, 2023
1 parent 7b505fb commit 1d26a65
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 89 deletions.
2 changes: 1 addition & 1 deletion utoipa-gen/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ proc-macro = true

[dependencies]
proc-macro2 = "1.0"
syn = { version = "2.0", features = ["full"] }
syn = { version = "2.0", features = ["full", "extra-traits"] }
quote = "1.0"
proc-macro-error = "1.0"
regex = { version = "1.7", optional = true }
Expand Down
30 changes: 3 additions & 27 deletions utoipa-gen/src/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,11 +295,11 @@ impl<'t> TypeTree<'t> {
is
}

fn find_mut_by_ident(&mut self, ident: &'_ Ident) -> Option<&mut Self> {
fn find_mut(&mut self, type_tree: &TypeTree) -> Option<&mut Self> {
let is = self
.path
.as_mut()
.map(|path| path.segments.iter().any(|segment| &segment.ident == ident))
.map(|p| matches!(&type_tree.path, Some(path) if path.as_ref() == p.as_ref()))
.unwrap_or(false);

if is {
Expand All @@ -308,35 +308,11 @@ impl<'t> TypeTree<'t> {
self.children.as_mut().and_then(|children| {
children
.iter_mut()
.find_map(|child| Self::find_mut_by_ident(child, ident))
.find_map(|child| Self::find_mut(child, type_tree))
})
}
}

/// Update current [`TypeTree`] from given `ident`.
///
/// It will update everything else except `children` for the `TypeTree`. This means that the
/// `TypeTree` will not be changed and will be traveled as before update.
fn update(&mut self, ident: Ident) {
let new_path = Path::from(ident);

let segments = &new_path.segments;
let last_segment = segments
.last()
.expect("TypeTree::update path should have at least one segment");

let generic_type = Self::get_generic_type(last_segment);
let value_type = if SchemaType(&new_path).is_primitive() {
ValueType::Primitive
} else {
ValueType::Object
};

self.value_type = value_type;
self.generic_type = generic_type;
self.path = Some(Cow::Owned(new_path));
}

/// `Object` virtual type is used when generic object is required in OpenAPI spec. Typically used
/// with `value_type` attribute to hinder the actual type.
pub fn is_object(&self) -> bool {
Expand Down
127 changes: 70 additions & 57 deletions utoipa-gen/src/component/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ use proc_macro2::{Ident, Span, TokenStream};
use proc_macro_error::abort;
use quote::{format_ident, quote, ToTokens};
use syn::{
parse::Parse, punctuated::Punctuated, token::Comma, Attribute, Data, Field, Fields,
FieldsNamed, FieldsUnnamed, GenericParam, Generics, Lifetime, LifetimeParam, Path,
PathArguments, Token, Variant, Visibility,
parse::Parse, parse_quote, punctuated::Punctuated, spanned::Spanned, token::Comma, Attribute,
Data, Field, Fields, FieldsNamed, FieldsUnnamed, GenericArgument, GenericParam, Generics,
Lifetime, LifetimeParam, Path, PathArguments, Token, Type, Variant, Visibility,
};

use crate::{
Expand Down Expand Up @@ -78,24 +78,36 @@ impl<'a> Schema<'a> {
impl ToTokens for Schema<'_> {
fn to_tokens(&self, tokens: &mut TokenStream) {
let ident = self.ident;
let variant = SchemaVariant::new(self.data, self.attributes, ident, self.generics, None);
let variant = SchemaVariant::new(
self.data,
self.attributes,
ident,
self.generics,
None::<Vec<(TypeTree, &TypeTree)>>,
);

let (_, ty_generics, where_clause) = self.generics.split_for_impl();

let life = &Lifetime::new(Schema::TO_SCHEMA_LIFETIME, Span::call_site());

let schema_ty: Type = parse_quote!(#ident #ty_generics);
let schema_children = &*TypeTree::from_type(&schema_ty).children.unwrap_or_default();

let aliases = self.aliases.as_ref().map(|aliases| {
let alias_schemas = aliases
.iter()
.map(|alias| {
let name = &*alias.name;
let alias_type_tree = TypeTree::from_type(&alias.ty);

let variant = SchemaVariant::new(
self.data,
self.attributes,
ident,
self.generics,
Some(alias),
alias_type_tree
.children
.map(|children| children.into_iter().zip(schema_children)),
);
quote! { (#name, #variant.into()) }
})
Expand All @@ -114,12 +126,17 @@ impl ToTokens for Schema<'_> {
.map(|alias| {
let name = quote::format_ident!("{}", alias.name);
let ty = &alias.ty;
let (_, alias_type_generics, _) = &alias.generics.split_for_impl();
let vis = self.vis;
let name_generics = &alias.get_name_lifetime_generics();
let name_generics = alias.get_lifetimes().fold(
Punctuated::<&GenericArgument, Comma>::new(),
|mut acc, lifetime| {
acc.push(lifetime);
acc
},
);

quote! {
#vis type #name #name_generics = #ty #alias_type_generics;
#vis type #name < #name_generics > = #ty;
}
})
.collect::<TokenStream>()
Expand Down Expand Up @@ -164,12 +181,12 @@ enum SchemaVariant<'a> {
}

impl<'a> SchemaVariant<'a> {
pub fn new(
pub fn new<I: IntoIterator<Item = (TypeTree<'a>, &'a TypeTree<'a>)>>(
data: &'a Data,
attributes: &'a [Attribute],
ident: &'a Ident,
generics: &'a Generics,
alias: Option<&'a AliasSchema>,
aliases: Option<I>,
) -> SchemaVariant<'a> {
match data {
Data::Struct(content) => match &content.fields {
Expand Down Expand Up @@ -203,7 +220,7 @@ impl<'a> SchemaVariant<'a> {
fields: named,
generics: Some(generics),
schema_as,
alias,
aliases: aliases.map(|aliases| aliases.into_iter().collect()),
})
}
Fields::Unit => Self::Unit(UnitStructVariant),
Expand Down Expand Up @@ -260,7 +277,7 @@ pub struct NamedStructSchema<'a> {
pub features: Option<Vec<Feature>>,
pub rename_all: Option<RenameAll>,
pub generics: Option<&'a Generics>,
pub alias: Option<&'a AliasSchema>,
pub aliases: Option<Vec<(TypeTree<'a>, &'a TypeTree<'a>)>>,
pub schema_as: Option<As>,
}

Expand All @@ -279,6 +296,13 @@ impl NamedStructSchema<'_> {
yield_: impl FnOnce(NamedStructFieldOptions<'_>) -> R,
) -> R {
let type_tree = &mut TypeTree::from_type(&field.ty);
if let Some(aliases) = &self.aliases {
for (new_generic, old_generic_matcher) in aliases.iter() {
if let Some(generic_match) = type_tree.find_mut(old_generic_matcher) {
*generic_match = new_generic.clone();
}
}
}

let mut field_features = field
.attrs
Expand Down Expand Up @@ -315,25 +339,6 @@ impl NamedStructSchema<'_> {
_ => None,
});

if let Some((generic_types, alias)) = self.generics.zip(self.alias) {
generic_types
.type_params()
.enumerate()
.for_each(|(index, generic)| {
if let Some(generic_type) = type_tree.find_mut_by_ident(&generic.ident) {
generic_type.update(
alias
.generics
.type_params()
.nth(index)
.unwrap()
.ident
.clone(),
);
};
})
}

let deprecated = super::get_deprecated(&field.attrs);
let value_type = field_features
.as_mut()
Expand Down Expand Up @@ -953,7 +958,7 @@ impl ComplexEnum<'_> {
features: Some(named_struct_features),
fields: &named_fields.named,
generics: None,
alias: None,
aliases: None,
schema_as: None,
},
})
Expand Down Expand Up @@ -1041,7 +1046,7 @@ impl ComplexEnum<'_> {
features: Some(named_struct_features),
fields: &named_fields.named,
generics: None,
alias: None,
aliases: None,
schema_as: None,
}
.to_token_stream()
Expand Down Expand Up @@ -1109,7 +1114,7 @@ impl ComplexEnum<'_> {
features: Some(named_struct_features),
fields: &named_fields.named,
generics: None,
alias: None,
aliases: None,
schema_as: None,
};
let title = title_features.first().map(ToTokens::to_token_stream);
Expand Down Expand Up @@ -1261,7 +1266,7 @@ impl ComplexEnum<'_> {
features: Some(named_struct_features),
fields: &named_fields.named,
generics: None,
alias: None,
aliases: None,
schema_as: None,
};
let title = title_features.first().map(ToTokens::to_token_stream);
Expand Down Expand Up @@ -1494,42 +1499,50 @@ fn is_flatten(rule: &Option<SerdeValue>) -> bool {
#[cfg_attr(feature = "debug", derive(Debug))]
pub struct AliasSchema {
pub name: String,
pub ty: Ident,
pub generics: Generics,
pub ty: Type,
}

impl AliasSchema {
fn get_name_lifetime_generics(&self) -> Option<Generics> {
let lifetimes = self
.generics
.lifetimes()
.filter(|lifetime| lifetime.lifetime.ident != "'static")
.map(|lifetime| GenericParam::Lifetime(lifetime.clone()))
.collect::<Punctuated<GenericParam, Comma>>();

if !lifetimes.is_empty() {
Some(Generics {
params: lifetimes,
..Default::default()
})
} else {
None
fn get_lifetimes(&self) -> impl Iterator<Item = &GenericArgument> {
fn lifetimes_from_type(ty: &Type) -> impl Iterator<Item = &GenericArgument> {
match ty {
Type::Path(type_path) => type_path
.path
.segments
.iter()
.flat_map(|segment| match &segment.arguments {
PathArguments::AngleBracketed(angle_bracketed_args) => {
Some(angle_bracketed_args.args.iter())
}
_ => None,
})
.flatten()
.flat_map(|arg| match arg {
GenericArgument::Type(type_argument) => {
lifetimes_from_type(type_argument).collect::<Vec<_>>()
}
_ => vec![arg],
})
.filter(|generic_arg| matches!(generic_arg, syn::GenericArgument::Lifetime(lifetime) if lifetime.ident != "'static")),
_ => abort!(
&ty.span(),
"AliasSchema `get_lifetimes` only supports syn::TypePath types"
),
}
}

lifetimes_from_type(&self.ty)
}
}

impl Parse for AliasSchema {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let name = input.parse::<Ident>()?;
if input.peek(Token![<]) {
input.parse::<Generics>()?;
}
input.parse::<Token![=]>()?;

Ok(Self {
name: name.to_string(),
ty: input.parse::<Ident>()?,
generics: input.parse()?,
ty: input.parse::<Type>()?,
})
}
}
Expand Down
4 changes: 2 additions & 2 deletions utoipa-gen/src/path/response/derive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ impl NamedStructResponse<'_> {
let inline_schema = NamedStructSchema {
attributes,
fields,
alias: None,
aliases: None,
features: None,
generics: None,
rename_all: None,
Expand Down Expand Up @@ -364,7 +364,7 @@ impl<'p> ToResponseNamedStructResponse<'p> {
let ty = Self::to_type(ident);

let inline_schema = NamedStructSchema {
alias: None,
aliases: None,
fields,
features: None,
generics: None,
Expand Down
4 changes: 2 additions & 2 deletions utoipa-gen/tests/schema_derive_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4023,7 +4023,7 @@ fn derive_schema_with_generics_and_lifetimes() {
struct TResult;

let value = api_doc_aliases! {
#[aliases(Paginated1<'b> = Paginated<'b, String>, Paginated2 = Paginated<'b, Value>)]
#[aliases(Paginated1 = Paginated<'b, String>, Paginated2 = Paginated<'b, Cow<'c, bool>>)]
struct Paginated<'r, TResult> {
pub total: usize,
pub data: Vec<TResult>,
Expand Down Expand Up @@ -4072,7 +4072,7 @@ fn derive_schema_with_generics_and_lifetimes() {
"data": {
"type": "array",
"items": {
"$ref": "#/components/schemas/Value",
"type": "boolean"
}
},
"next": {
Expand Down

0 comments on commit 1d26a65

Please sign in to comment.