Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use scale-type-resolver to be generic over the type resolver used #45

Merged
merged 16 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions CODEOWNERS
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# main codeowner @paritytech/tools-team
* @paritytech/tools-team
# main codeowner @paritytech/subxt-team
* @paritytech/subxt-team

# CI
/.github/ @paritytech/ci @paritytech/tools-team
/.github/ @paritytech/ci @paritytech/subxt-team
189 changes: 135 additions & 54 deletions scale-decode-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ extern crate alloc;

use alloc::string::ToString;
use darling::FromAttributes;
use proc_macro2::TokenStream as TokenStream2;
use proc_macro2::{Span, TokenStream as TokenStream2};
use quote::quote;
use syn::{parse_macro_input, punctuated::Punctuated, DeriveInput};

Expand Down Expand Up @@ -59,10 +59,17 @@ fn generate_enum_impl(
) -> TokenStream2 {
let path_to_scale_decode = &attrs.crate_path;
let path_to_type: syn::Path = input.ident.clone().into();
let (impl_generics, ty_generics, where_clause, phantomdata_type) =
handle_generics(&attrs, &input.generics);
let variant_names = details.variants.iter().map(|v| v.ident.to_string());

let generic_types = handle_generics(&attrs, input.generics.clone());
let ty_generics = generic_types.ty_generics();
let impl_generics = generic_types.impl_generics();
let visitor_where_clause = generic_types.visitor_where_clause();
let visitor_ty_generics = generic_types.visitor_ty_generics();
let visitor_impl_generics = generic_types.visitor_impl_generics();
let visitor_phantomdata_type = generic_types.visitor_phantomdata_type();
let type_resolver_ident = generic_types.type_resolver_ident();

// determine what the body of our visitor functions will be based on the type of enum fields
// that we're trying to generate output for.
let variant_ifs = details.variants.iter().map(|variant| {
Expand Down Expand Up @@ -130,28 +137,29 @@ fn generate_enum_impl(

quote!(
const _: () = {
#visibility struct Visitor #impl_generics (
::core::marker::PhantomData<#phantomdata_type>
#visibility struct Visitor #visitor_impl_generics (
::core::marker::PhantomData<#visitor_phantomdata_type>
);

use #path_to_scale_decode::vec;
use #path_to_scale_decode::ToString;

impl #impl_generics #path_to_scale_decode::IntoVisitor for #path_to_type #ty_generics #where_clause {
type Visitor = Visitor #ty_generics;
fn into_visitor() -> Self::Visitor {
impl #impl_generics #path_to_scale_decode::IntoVisitor for #path_to_type #ty_generics #visitor_where_clause {
type AnyVisitor<#type_resolver_ident: #path_to_scale_decode::TypeResolver> = Visitor #visitor_ty_generics;
fn into_visitor<#type_resolver_ident: #path_to_scale_decode::TypeResolver>() -> Self::AnyVisitor<#type_resolver_ident> {
Visitor(::core::marker::PhantomData)
}
}

impl #impl_generics #path_to_scale_decode::Visitor for Visitor #ty_generics #where_clause {
impl #visitor_impl_generics #path_to_scale_decode::Visitor for Visitor #visitor_ty_generics #visitor_where_clause {
type Error = #path_to_scale_decode::Error;
type Value<'scale, 'info> = #path_to_type #ty_generics;
type TypeResolver = #type_resolver_ident;

fn visit_variant<'scale, 'info>(
self,
value: &mut #path_to_scale_decode::visitor::types::Variant<'scale, 'info>,
type_id: #path_to_scale_decode::visitor::TypeId,
value: &mut #path_to_scale_decode::visitor::types::Variant<'scale, 'info, Self::TypeResolver>,
type_id: &<Self::TypeResolver as #path_to_scale_decode::TypeResolver>::TypeId,
) -> Result<Self::Value<'scale, 'info>, Self::Error> {
#(
#variant_ifs
Expand All @@ -164,8 +172,8 @@ fn generate_enum_impl(
// Allow an enum to be decoded through nested 1-field composites and tuples:
fn visit_composite<'scale, 'info>(
self,
value: &mut #path_to_scale_decode::visitor::types::Composite<'scale, 'info>,
_type_id: #path_to_scale_decode::visitor::TypeId,
value: &mut #path_to_scale_decode::visitor::types::Composite<'scale, 'info, Self::TypeResolver>,
_type_id: &<Self::TypeResolver as #path_to_scale_decode::TypeResolver>::TypeId,
) -> Result<Self::Value<'scale, 'info>, Self::Error> {
if value.remaining() != 1 {
return self.visit_unexpected(#path_to_scale_decode::visitor::Unexpected::Composite);
Expand All @@ -174,8 +182,8 @@ fn generate_enum_impl(
}
fn visit_tuple<'scale, 'info>(
self,
value: &mut #path_to_scale_decode::visitor::types::Tuple<'scale, 'info>,
_type_id: #path_to_scale_decode::visitor::TypeId,
value: &mut #path_to_scale_decode::visitor::types::Tuple<'scale, 'info, Self::TypeResolver>,
_type_id: &<Self::TypeResolver as #path_to_scale_decode::TypeResolver>::TypeId,
) -> Result<Self::Value<'scale, 'info>, Self::Error> {
if value.remaining() != 1 {
return self.visit_unexpected(#path_to_scale_decode::visitor::Unexpected::Tuple);
Expand All @@ -195,8 +203,15 @@ fn generate_struct_impl(
) -> TokenStream2 {
let path_to_scale_decode = &attrs.crate_path;
let path_to_type: syn::Path = input.ident.clone().into();
let (impl_generics, ty_generics, where_clause, phantomdata_type) =
handle_generics(&attrs, &input.generics);

let generic_types = handle_generics(&attrs, input.generics.clone());
let ty_generics = generic_types.ty_generics();
let impl_generics = generic_types.impl_generics();
let visitor_where_clause = generic_types.visitor_where_clause();
let visitor_ty_generics = generic_types.visitor_ty_generics();
let visitor_impl_generics = generic_types.visitor_impl_generics();
let visitor_phantomdata_type = generic_types.visitor_phantomdata_type();
let type_resolver_ident = generic_types.type_resolver_ident();

// determine what the body of our visitor functions will be based on the type of struct
// that we're trying to generate output for.
Expand Down Expand Up @@ -260,48 +275,51 @@ fn generate_struct_impl(

quote!(
const _: () = {
#visibility struct Visitor #impl_generics (
::core::marker::PhantomData<#phantomdata_type>
#visibility struct Visitor #visitor_impl_generics (
::core::marker::PhantomData<#visitor_phantomdata_type>
);

use #path_to_scale_decode::vec;
use #path_to_scale_decode::ToString;

impl #impl_generics #path_to_scale_decode::IntoVisitor for #path_to_type #ty_generics #where_clause {
type Visitor = Visitor #ty_generics;
fn into_visitor() -> Self::Visitor {
impl #impl_generics #path_to_scale_decode::IntoVisitor for #path_to_type #ty_generics #visitor_where_clause {
type AnyVisitor<#type_resolver_ident: #path_to_scale_decode::TypeResolver> = Visitor #visitor_ty_generics;
fn into_visitor<#type_resolver_ident: #path_to_scale_decode::TypeResolver>() -> Self::AnyVisitor<#type_resolver_ident> {
Visitor(::core::marker::PhantomData)
}
}

impl #impl_generics #path_to_scale_decode::Visitor for Visitor #ty_generics #where_clause {
impl #visitor_impl_generics #path_to_scale_decode::Visitor for Visitor #visitor_ty_generics #visitor_where_clause {
type Error = #path_to_scale_decode::Error;
type Value<'scale, 'info> = #path_to_type #ty_generics;
type TypeResolver = #type_resolver_ident;

fn visit_composite<'scale, 'info>(
self,
value: &mut #path_to_scale_decode::visitor::types::Composite<'scale, 'info>,
type_id: #path_to_scale_decode::visitor::TypeId,
value: &mut #path_to_scale_decode::visitor::types::Composite<'scale, 'info, Self::TypeResolver>,
type_id: &<Self::TypeResolver as #path_to_scale_decode::TypeResolver>::TypeId,
) -> Result<Self::Value<'scale, 'info>, Self::Error> {
#visit_composite_body
}
fn visit_tuple<'scale, 'info>(
self,
value: &mut #path_to_scale_decode::visitor::types::Tuple<'scale, 'info>,
type_id: #path_to_scale_decode::visitor::TypeId,
value: &mut #path_to_scale_decode::visitor::types::Tuple<'scale, 'info, Self::TypeResolver>,
type_id: &<Self::TypeResolver as #path_to_scale_decode::TypeResolver>::TypeId,
) -> Result<Self::Value<'scale, 'info>, Self::Error> {
#visit_tuple_body
}
}

impl #impl_generics #path_to_scale_decode::DecodeAsFields for #path_to_type #ty_generics #where_clause {
fn decode_as_fields<'info>(input: &mut &[u8], fields: &mut dyn #path_to_scale_decode::FieldIter<'info>, types: &'info #path_to_scale_decode::PortableRegistry)
-> Result<Self, #path_to_scale_decode::Error>
impl #impl_generics #path_to_scale_decode::DecodeAsFields for #path_to_type #ty_generics #visitor_where_clause {
fn decode_as_fields<'info, R: #path_to_scale_decode::TypeResolver>(
input: &mut &[u8],
fields: &mut dyn #path_to_scale_decode::FieldIter<'info, R::TypeId>,
types: &'info R
) -> Result<Self, #path_to_scale_decode::Error>
{
let path = #path_to_scale_decode::EMPTY_SCALE_INFO_PATH;
let mut composite = #path_to_scale_decode::visitor::types::Composite::new(input, path, fields, types, false);
let mut composite = #path_to_scale_decode::visitor::types::Composite::new(input, fields, types, false);
use #path_to_scale_decode::{ Visitor, IntoVisitor };
let val = <#path_to_type #ty_generics>::into_visitor().visit_composite(&mut composite, #path_to_scale_decode::visitor::TypeId(0));
let val = <#path_to_type #ty_generics>::into_visitor().visit_composite(&mut composite, &Default::default());
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using a type ID of 0 was a bit of a hack before, so here and in one other place we change it to being Default::default() instead. This requires our TypeId to have a Default bound on it, which ideally it would do without.

In the future I expect that we can tweak these cases and remove the bound entirely, but for now I don't suspect it will be a big deal.


// Consume any remaining bytes and update input:
composite.skip_decoding()?;
Expand Down Expand Up @@ -392,29 +410,31 @@ fn unnamed_field_vals<'f>(
(field_count, field_vals)
}

fn handle_generics<'a>(
attrs: &TopLevelAttrs,
generics: &'a syn::Generics,
) -> (syn::ImplGenerics<'a>, syn::TypeGenerics<'a>, syn::WhereClause, syn::Type) {
fn handle_generics(attrs: &TopLevelAttrs, generics: syn::Generics) -> GenericTypes {
let path_to_crate = &attrs.crate_path;
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

let mut where_clause = where_clause.cloned().unwrap_or(syn::parse_quote!(where));

if let Some(where_predicates) = &attrs.trait_bounds {
// if custom trait bounds are given, append those to the where clause.
where_clause.predicates.extend(where_predicates.clone());
} else {
// else, append our default bounds to each parameter to ensure that it all lines up with our generated impls and such:
for param in generics.type_params() {
let ty = &param.ident;
where_clause.predicates.push(syn::parse_quote!(#ty: #path_to_crate::IntoVisitor));
where_clause.predicates.push(syn::parse_quote!(#path_to_crate::Error: From<<<#ty as #path_to_crate::IntoVisitor>::Visitor as #path_to_crate::Visitor>::Error>));

let type_resolver_ident =
syn::Ident::new(GenericTypes::TYPE_RESOLVER_IDENT_STR, Span::call_site());

// Where clause to use on Visitor/IntoVisitor
let visitor_where_clause = {
let (_, _, where_clause) = generics.split_for_impl();
let mut where_clause = where_clause.cloned().unwrap_or(syn::parse_quote!(where));
if let Some(where_predicates) = &attrs.trait_bounds {
// if custom trait bounds are given, append those to the where clause.
where_clause.predicates.extend(where_predicates.clone());
} else {
// else, append our default bounds to each parameter to ensure that it all lines up with our generated impls and such:
for param in generics.type_params() {
let ty = &param.ident;
where_clause.predicates.push(syn::parse_quote!(#ty: #path_to_crate::IntoVisitor));
}
}
}
where_clause
};

// Construct a type to put into PhantomData<$ty>. This takes lifetimes into account too.
let phantomdata_type: syn::Type = {
// (A, B, C, ScaleDecodeTypeResolver) style PhantomData type to use in Visitor struct.
let visitor_phantomdata_type = {
let tys = generics.params.iter().filter_map::<syn::Type, _>(|p| match p {
syn::GenericParam::Type(ty) => {
let ty = &ty.ident;
Expand All @@ -427,10 +447,71 @@ fn handle_generics<'a>(
// We don't need to mention const's in the PhantomData type.
syn::GenericParam::Const(_) => None,
});

// Add a param for the type resolver generic.
let tys = tys.chain(core::iter::once(syn::parse_quote!(#type_resolver_ident)));

syn::parse_quote!( (#( #tys, )*) )
};

(impl_generics, ty_generics, where_clause, phantomdata_type)
// generics for our Visitor/IntoVisitor; we just add the type resolver param to the list.
let visitor_generics = {
let mut type_generics = generics.clone();
let type_resolver_generic_param: syn::GenericParam =
syn::parse_quote!(#type_resolver_ident: #path_to_crate::TypeResolver);

type_generics.params.push(type_resolver_generic_param);
type_generics
};

// generics for the type itself
let type_generics = generics;

GenericTypes {
type_generics,
type_resolver_ident,
visitor_generics,
visitor_phantomdata_type,
visitor_where_clause,
}
}

struct GenericTypes {
type_resolver_ident: syn::Ident,
type_generics: syn::Generics,
visitor_generics: syn::Generics,
visitor_where_clause: syn::WhereClause,
visitor_phantomdata_type: syn::Type,
}

impl GenericTypes {
const TYPE_RESOLVER_IDENT_STR: &'static str = "ScaleDecodeTypeResolver";

pub fn ty_generics(&self) -> syn::TypeGenerics<'_> {
let (_, ty_generics, _) = self.type_generics.split_for_impl();
ty_generics
}
pub fn impl_generics(&self) -> syn::ImplGenerics<'_> {
let (impl_generics, _, _) = self.type_generics.split_for_impl();
impl_generics
}
pub fn visitor_where_clause(&self) -> &syn::WhereClause {
&self.visitor_where_clause
}
pub fn visitor_ty_generics(&self) -> syn::TypeGenerics<'_> {
let (_, ty_generics, _) = self.visitor_generics.split_for_impl();
ty_generics
}
pub fn visitor_impl_generics(&self) -> syn::ImplGenerics<'_> {
let (impl_generics, _, _) = self.visitor_generics.split_for_impl();
impl_generics
}
pub fn visitor_phantomdata_type(&self) -> &syn::Type {
&self.visitor_phantomdata_type
}
pub fn type_resolver_ident(&self) -> &syn::Ident {
&self.type_resolver_ident
}
}

struct TopLevelAttrs {
Expand Down
3 changes: 2 additions & 1 deletion scale-decode/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@ derive = ["dep:scale-decode-derive"]
[dependencies]
scale-info = { version = "2.7.0", default-features = false, features = ["bit-vec"] }
codec = { package = "parity-scale-codec", version = "3.0.0", default-features = false, features = ["derive"] }
scale-bits = { version = "0.4.0", default-features = false, features = ["scale-info"] }
scale-bits = { path = "../../scale-bits", version = "0.4.0", default-features = false, features = ["scale-info"] }
scale-decode-derive = { workspace = true, optional = true }
primitive-types = { version = "0.12.0", optional = true, default-features = false }
smallvec = "1.10.0"
derive_more = { version = "0.99.17", default-features = false, features = ["from", "display"] }
scale-type-resolver = { path = "../../scale-type-resolver" }

[dev-dependencies]
scale-info = { version = "2.7.0", default-features = false, features = ["bit-vec", "derive"] }
Expand Down
Loading
Loading