Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generic types #425

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 32 additions & 4 deletions rustler_codegen/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@ use super::RustlerAttr;
pub(crate) struct Context<'a> {
pub attrs: Vec<RustlerAttr>,
pub ident: &'a proc_macro2::Ident,
pub ident_with_lifetime: proc_macro2::TokenStream,
pub ident_with_generics: proc_macro2::TokenStream,
pub variants: Option<Vec<&'a Variant>>,
pub struct_fields: Option<Vec<&'a Field>>,
pub is_tuple_struct: bool,
pub type_param_idents: Vec<&'a Ident>,
}

impl<'a> Context<'a> {
Expand All @@ -44,11 +45,37 @@ impl<'a> Context<'a> {
_ => panic!("Struct can only have one lifetime argument"),
};

let type_param_idents = ast
.generics
.type_params()
.map(|type_param| {
if !type_param.attrs.is_empty() {
panic!("Attributes for type parameters are currently not supported.");
}
if !type_param.bounds.is_empty() {
panic!("Type parameter bounds are currently not supported.");
}
if type_param.eq_token.is_some() {
panic!("Type parameter constraints are currently not supported.");
}
if type_param.default.is_some() {
panic!("Type parameter defaults are currently not supported.");
}
&type_param.ident
})
.collect::<Vec<_>>();

let ident = &ast.ident;
let ident_with_lifetime = if has_lifetime {
quote! { #ident <'a> }

let lifetimes = if has_lifetime {
vec![quote! {'a}]
evnu marked this conversation as resolved.
Show resolved Hide resolved
} else {
Vec::new()
};
let ident_with_generics = if type_param_idents.is_empty() && lifetimes.is_empty() {
quote! { #ident }
} else {
quote! { #ident < #(#lifetimes,)* #(#type_param_idents),* > }
};

let variants = match ast.data {
Expand All @@ -72,10 +99,11 @@ impl<'a> Context<'a> {
Self {
attrs,
ident,
ident_with_lifetime,
ident_with_generics,
variants,
struct_fields,
is_tuple_struct,
type_param_idents,
}
}

Expand Down
10 changes: 6 additions & 4 deletions rustler_codegen/src/ex_struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,10 @@ pub fn transcoder_decorator(ast: &syn::DeriveInput, add_exception: bool) -> Toke
}

fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> TokenStream {
let struct_type = &ctx.ident_with_lifetime;
let struct_type = &ctx.ident_with_generics;
let struct_name = ctx.ident;
let struct_name_str = struct_name.to_string();
let type_param_idents = &ctx.type_param_idents;

let idents: Vec<_> = fields
.iter()
Expand Down Expand Up @@ -95,7 +96,7 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T
.unzip();

let gen = quote! {
impl<'a> ::rustler::Decoder<'a> for #struct_type {
impl<'a #(, #type_param_idents : 'a + ::rustler::Decoder<'a>)*> ::rustler::Decoder<'a> for #struct_type {
fn decode(term: ::rustler::Term<'a>) -> ::rustler::NifResult<Self> {
use #atoms_module_name::*;
use ::rustler::Encoder;
Expand Down Expand Up @@ -138,7 +139,8 @@ fn gen_encoder(
atoms_module_name: &Ident,
add_exception: bool,
) -> TokenStream {
let struct_type = &ctx.ident_with_lifetime;
let struct_type = &ctx.ident_with_generics;
let type_param_idents = &ctx.type_param_idents;

let field_defs: Vec<TokenStream> = fields
.iter()
Expand All @@ -160,7 +162,7 @@ fn gen_encoder(
};

let gen = quote! {
impl<'b> ::rustler::Encoder for #struct_type {
impl<'b #(, #type_param_idents: ::rustler::Encoder)*> ::rustler::Encoder for #struct_type {
fn encode<'a>(&self, env: ::rustler::Env<'a>) -> ::rustler::Term<'a> {
use #atoms_module_name::*;
let mut map = ::rustler::types::map::map_new(env);
Expand Down
10 changes: 6 additions & 4 deletions rustler_codegen/src/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ pub fn transcoder_decorator(ast: &syn::DeriveInput) -> TokenStream {
}

fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> TokenStream {
let struct_type = &ctx.ident_with_lifetime;
let struct_type = &ctx.ident_with_generics;
let struct_name = ctx.ident;
let type_param_idents = &ctx.type_param_idents;

let idents: Vec<_> = fields
.iter()
Expand All @@ -77,7 +78,7 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T
.unzip();

let gen = quote! {
impl<'a> ::rustler::Decoder<'a> for #struct_type {
impl<'a #(, #type_param_idents : 'a + ::rustler::Decoder<'a>)*> ::rustler::Decoder<'a> for #struct_type {
fn decode(term: ::rustler::Term<'a>) -> ::rustler::NifResult<Self> {
use #atoms_module_name::*;

Expand Down Expand Up @@ -109,7 +110,8 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T
}

fn gen_encoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> TokenStream {
let struct_type = &ctx.ident_with_lifetime;
let struct_type = &ctx.ident_with_generics;
let type_param_idents = &ctx.type_param_idents;

let field_defs: Vec<TokenStream> = fields
.iter()
Expand All @@ -124,7 +126,7 @@ fn gen_encoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T
.collect();

let gen = quote! {
impl<'b> ::rustler::Encoder for #struct_type {
impl<'b #(, #type_param_idents: ::rustler::Encoder)*> ::rustler::Encoder for #struct_type {
fn encode<'a>(&self, env: ::rustler::Env<'a>) -> ::rustler::Term<'a> {
use #atoms_module_name::*;

Expand Down
9 changes: 9 additions & 0 deletions rustler_codegen/src/nif.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@ pub fn transcoder_decorator(args: syn::AttributeArgs, fun: syn::ItemFn) -> Token
.map(|ref n| syn::Ident::new(n, Span::call_site()))
.unwrap_or_else(|| name.clone());

if fun.sig.generics.type_params().next().is_some() {
evnu marked this conversation as resolved.
Show resolved Hide resolved
panic!(
"Cannot apply the nif macro to polymorphic functions. \
Since Erlang is untyped, rustler cannot know the type of expected inputs and outputs, \
and therefore doesn't know which decoder and encoder to apply. \
You need to monomorphize your function by giving explicit, non-generic types."
);
}

quote! {
#[allow(non_camel_case_types)]
pub struct #name;
Expand Down
10 changes: 6 additions & 4 deletions rustler_codegen/src/record.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ pub fn transcoder_decorator(ast: &syn::DeriveInput) -> TokenStream {
}

fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> TokenStream {
let struct_type = &ctx.ident_with_lifetime;
let struct_type = &ctx.ident_with_generics;
let struct_name = ctx.ident;
let type_param_idents = &ctx.type_param_idents;

// Make a decoder for each of the fields in the struct.
let (assignments, field_defs): (Vec<TokenStream>, Vec<TokenStream>) = fields
Expand Down Expand Up @@ -98,7 +99,7 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T
}
};
let gen = quote! {
impl<'a> ::rustler::Decoder<'a> for #struct_type {
impl<'a #(, #type_param_idents : 'a + ::rustler::Decoder<'a>)*> ::rustler::Decoder<'a> for #struct_type {
fn decode(term: ::rustler::Term<'a>) -> ::rustler::NifResult<Self> {
use #atoms_module_name::*;

Expand Down Expand Up @@ -138,7 +139,8 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T
}

fn gen_encoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> TokenStream {
let struct_type = &ctx.ident_with_lifetime;
let struct_type = &ctx.ident_with_generics;
let type_param_idents = &ctx.type_param_idents;

// Make a field encoder expression for each of the items in the struct.
let field_encoders: Vec<TokenStream> = fields
Expand All @@ -165,7 +167,7 @@ fn gen_encoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T

// The implementation itself
let gen = quote! {
impl<'b> ::rustler::Encoder for #struct_type {
impl<'b #(, #type_param_idents: ::rustler::Encoder)*> ::rustler::Encoder for #struct_type {
fn encode<'a>(&self, env: ::rustler::Env<'a>) -> ::rustler::Term<'a> {
use #atoms_module_name::*;

Expand Down
11 changes: 7 additions & 4 deletions rustler_codegen/src/tagged_enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,10 @@ pub fn transcoder_decorator(ast: &syn::DeriveInput) -> TokenStream {
}

fn gen_decoder(ctx: &Context, variants: &[&Variant], atoms_module_name: &Ident) -> TokenStream {
let enum_type = &ctx.ident_with_lifetime;
let enum_type = &ctx.ident_with_generics;
let enum_name = ctx.ident;
let type_param_idents = &ctx.type_param_idents;

let unit_decoders: Vec<TokenStream> = variants
.iter()
.filter_map(|variant| {
Expand Down Expand Up @@ -118,7 +120,7 @@ fn gen_decoder(ctx: &Context, variants: &[&Variant], atoms_module_name: &Ident)
.collect();

let gen = quote! {
impl<'a> ::rustler::Decoder<'a> for #enum_type {
impl<'a #(, #type_param_idents : 'a + ::rustler::Decoder<'a>)*> ::rustler::Decoder<'a> for #enum_type {
fn decode(term: ::rustler::Term<'a>) -> ::rustler::NifResult<Self> {
use #atoms_module_name::*;

Expand Down Expand Up @@ -158,8 +160,9 @@ fn gen_decoder(ctx: &Context, variants: &[&Variant], atoms_module_name: &Ident)
}

fn gen_encoder(ctx: &Context, variants: &[&Variant], atoms_module_name: &Ident) -> TokenStream {
let enum_type = &ctx.ident_with_lifetime;
let enum_type = &ctx.ident_with_generics;
let enum_name = ctx.ident;
let type_param_idents = &ctx.type_param_idents;

let variant_defs: Vec<TokenStream> = variants
.iter()
Expand All @@ -180,7 +183,7 @@ fn gen_encoder(ctx: &Context, variants: &[&Variant], atoms_module_name: &Ident)
.collect();

let gen = quote! {
impl<'b> ::rustler::Encoder for #enum_type {
impl<'b #(, #type_param_idents: ::rustler::Encoder)*> ::rustler::Encoder for #enum_type {
fn encode<'a>(&self, env: ::rustler::Env<'a>) -> ::rustler::Term<'a> {
use #atoms_module_name::*;

Expand Down
12 changes: 8 additions & 4 deletions rustler_codegen/src/tuple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,12 @@ pub fn transcoder_decorator(ast: &syn::DeriveInput) -> TokenStream {
}

fn gen_decoder(ctx: &Context, fields: &[&Field]) -> TokenStream {
let struct_type = &ctx.ident_with_lifetime;
let struct_type = &ctx.ident_with_generics;
let struct_name = ctx.ident;
let struct_name_str = struct_name.to_string();

let type_param_idents = &ctx.type_param_idents;

// Make a decoder for each of the fields in the struct.
let (assignments, field_defs): (Vec<TokenStream>, Vec<TokenStream>) = fields
.iter()
Expand Down Expand Up @@ -82,7 +84,7 @@ fn gen_decoder(ctx: &Context, fields: &[&Field]) -> TokenStream {
}
};
let gen = quote! {
impl<'a> ::rustler::Decoder<'a> for #struct_type {
impl<'a #(, #type_param_idents : 'a + ::rustler::Decoder<'a>)*> ::rustler::Decoder<'a> for #struct_type {
fn decode(term: ::rustler::Term<'a>) -> ::rustler::NifResult<Self> {
let terms = ::rustler::types::tuple::get_tuple(term)?;
if terms.len() != #field_num {
Expand All @@ -108,7 +110,9 @@ fn gen_decoder(ctx: &Context, fields: &[&Field]) -> TokenStream {
}

fn gen_encoder(ctx: &Context, fields: &[&Field]) -> TokenStream {
let struct_type = &ctx.ident_with_lifetime;
let struct_type = &ctx.ident_with_generics;

let type_param_idents = &ctx.type_param_idents;

// Make a field encoder expression for each of the items in the struct.
let field_encoders: Vec<TokenStream> = fields
Expand All @@ -132,7 +136,7 @@ fn gen_encoder(ctx: &Context, fields: &[&Field]) -> TokenStream {

// The implementation itself
let gen = quote! {
impl<'b> ::rustler::Encoder for #struct_type {
impl<'b #(, #type_param_idents: ::rustler::Encoder)*> ::rustler::Encoder for #struct_type {
fn encode<'a>(&self, env: ::rustler::Env<'a>) -> ::rustler::Term<'a> {
use ::rustler::Encoder;
let arr = #field_list_ast;
Expand Down
4 changes: 2 additions & 2 deletions rustler_codegen/src/unit_enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ pub fn transcoder_decorator(ast: &syn::DeriveInput) -> TokenStream {
}

fn gen_decoder(ctx: &Context, variants: &[&Variant], atoms_module_name: &Ident) -> TokenStream {
let enum_type = &ctx.ident_with_lifetime;
let enum_type = &ctx.ident_with_generics;
let enum_name = ctx.ident;

let variant_defs: Vec<TokenStream> = variants
Expand Down Expand Up @@ -103,7 +103,7 @@ fn gen_decoder(ctx: &Context, variants: &[&Variant], atoms_module_name: &Ident)
}

fn gen_encoder(ctx: &Context, variants: &[&Variant], atoms_module_name: &Ident) -> TokenStream {
let enum_type = &ctx.ident_with_lifetime;
let enum_type = &ctx.ident_with_generics;
let enum_name = ctx.ident;

let variant_defs: Vec<TokenStream> = variants
Expand Down
10 changes: 6 additions & 4 deletions rustler_codegen/src/untagged_enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ pub fn transcoder_decorator(ast: &syn::DeriveInput) -> TokenStream {
}

fn gen_decoder(ctx: &Context, variants: &[&Variant]) -> TokenStream {
let enum_type = &ctx.ident_with_lifetime;
let enum_type = &ctx.ident_with_generics;
let enum_name = ctx.ident;
let type_param_idents = &ctx.type_param_idents;

let variant_defs: Vec<_> = variants
.iter()
Expand All @@ -66,7 +67,7 @@ fn gen_decoder(ctx: &Context, variants: &[&Variant]) -> TokenStream {
.collect();

let gen = quote! {
impl<'a> ::rustler::Decoder<'a> for #enum_type {
impl<'a #(, #type_param_idents : 'a + ::rustler::Decoder<'a>)*> ::rustler::Decoder<'a> for #enum_type {
fn decode(term: ::rustler::Term<'a>) -> ::rustler::NifResult<Self> {
#(#variant_defs)*

Expand All @@ -79,8 +80,9 @@ fn gen_decoder(ctx: &Context, variants: &[&Variant]) -> TokenStream {
}

fn gen_encoder(ctx: &Context, variants: &[&Variant]) -> TokenStream {
let enum_type = &ctx.ident_with_lifetime;
let enum_type = &ctx.ident_with_generics;
let enum_name = ctx.ident;
let type_param_idents = &ctx.type_param_idents;

let variant_defs: Vec<_> = variants
.iter()
Expand All @@ -94,7 +96,7 @@ fn gen_encoder(ctx: &Context, variants: &[&Variant]) -> TokenStream {
.collect();

let gen = quote! {
impl<'b> ::rustler::Encoder for #enum_type {
impl<'b #(, #type_param_idents: ::rustler::Encoder)*> ::rustler::Encoder for #enum_type {
fn encode<'a>(&self, env: ::rustler::Env<'a>) -> ::rustler::Term<'a> {
match *self {
#(#variant_defs)*
Expand Down
10 changes: 9 additions & 1 deletion rustler_tests/lib/rustler_test.ex
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,23 @@ defmodule RustlerTest do
def sublists(_), do: err()

def tuple_echo(_), do: err()
def generic_tuple_echo_usize(_), do: err()
def generic_tuple_echo_str(_), do: err()
def generic_tuple2_echo(_), do: err()
def record_echo(_), do: err()
def generic_record_echo(_), do: err()
def map_echo(_), do: err()
def exception_echo(_), do: err()
def generic_map_echo(_), do: err()
def struct_echo(_), do: err()
def generic_struct_echo(_), do: err()
def exception_echo(_), do: err()
def generic_exception_echo(_), do: err()
def unit_enum_echo(_), do: err()
def tagged_enum_1_echo(_), do: err()
def tagged_enum_2_echo(_), do: err()
def tagged_enum_3_echo(_), do: err()
def untagged_enum_echo(_), do: err()
def generic_untagged_enum_echo(_), do: err()
def untagged_enum_with_truthy(_), do: err()
def untagged_enum_for_issue_370(_), do: err()
def newtype_echo(_), do: err()
Expand Down
10 changes: 9 additions & 1 deletion rustler_tests/native/rustler_test/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,23 @@ rustler::init!(
test_env::whereis_pid,
test_env::sublists,
test_codegen::tuple_echo,
test_codegen::generic_tuple_echo_usize,
test_codegen::generic_tuple_echo_str,
test_codegen::generic_tuple2_echo,
test_codegen::record_echo,
test_codegen::generic_record_echo,
test_codegen::map_echo,
test_codegen::exception_echo,
test_codegen::generic_map_echo,
test_codegen::struct_echo,
test_codegen::generic_struct_echo,
test_codegen::exception_echo,
test_codegen::generic_exception_echo,
test_codegen::unit_enum_echo,
test_codegen::tagged_enum_1_echo,
test_codegen::tagged_enum_2_echo,
test_codegen::tagged_enum_3_echo,
test_codegen::untagged_enum_echo,
test_codegen::generic_untagged_enum_echo,
test_codegen::untagged_enum_with_truthy,
test_codegen::untagged_enum_for_issue_370,
test_codegen::newtype_echo,
Expand Down
Loading