From 56dad027132c83379e81e29038bd229da840b30f Mon Sep 17 00:00:00 2001 From: Seokmin Hong Date: Thu, 12 May 2022 13:49:07 +0900 Subject: [PATCH 1/2] `Term` -> `impl Encoder` for map --- rustler/src/types/elixir_struct.rs | 8 +-- rustler/src/types/map.rs | 89 ++++++++++++------------------ rustler/src/types/mod.rs | 5 +- rustler_codegen/src/ex_struct.rs | 15 ++--- rustler_codegen/src/map.rs | 9 +-- rustler_codegen/src/tagged_enum.rs | 26 ++++----- 6 files changed, 62 insertions(+), 90 deletions(-) diff --git a/rustler/src/types/elixir_struct.rs b/rustler/src/types/elixir_struct.rs index 3cbe65e9..d0e60dc3 100644 --- a/rustler/src/types/elixir_struct.rs +++ b/rustler/src/types/elixir_struct.rs @@ -12,17 +12,15 @@ use super::map::map_new; use crate::{Env, NifResult, Term}; pub fn get_ex_struct_name(map: Term) -> NifResult { - let env = map.get_env(); // In an Elixir struct the value in the __struct__ field is always an atom. - map.map_get(atom::__struct__().to_term(env)) - .and_then(Atom::from_term) + map.map_get(atom::__struct__()).and_then(Atom::from_term) } pub fn make_ex_struct<'a>(env: Env<'a>, struct_module: &str) -> NifResult> { let map = map_new(env); - let struct_atom = atom::__struct__().to_term(env); - let module_atom = Atom::from_str(env, struct_module)?.to_term(env); + let struct_atom = atom::__struct__(); + let module_atom = Atom::from_str(env, struct_module)?; map.map_put(struct_atom, module_atom) } diff --git a/rustler/src/types/map.rs b/rustler/src/types/map.rs index 66c87517..e3ea5764 100644 --- a/rustler/src/types/map.rs +++ b/rustler/src/types/map.rs @@ -2,7 +2,7 @@ use super::atom; use crate::wrapper::map; -use crate::{Decoder, Env, Error, NifResult, Term}; +use crate::{Decoder, Encoder, Env, Error, NifResult, Term}; use std::ops::RangeInclusive; pub fn map_new(env: Env) -> Term { @@ -31,12 +31,12 @@ impl<'a> Term<'a> { /// ``` pub fn map_from_arrays( env: Env<'a>, - keys: &[Term<'a>], - values: &[Term<'a>], + keys: &[impl Encoder], + values: &[impl Encoder], ) -> NifResult> { if keys.len() == values.len() { - let keys: Vec<_> = keys.iter().map(|k| k.as_c_arg()).collect(); - let values: Vec<_> = values.iter().map(|v| v.as_c_arg()).collect(); + let keys: Vec<_> = keys.iter().map(|k| k.encode(env).as_c_arg()).collect(); + let values: Vec<_> = values.iter().map(|v| v.encode(env).as_c_arg()).collect(); unsafe { map::make_map_from_arrays(env.as_c_arg(), &keys, &values) @@ -57,10 +57,13 @@ impl<'a> Term<'a> { /// ```elixir /// Map.new([{"foo", 1}, {"bar", 2}]) /// ``` - pub fn map_from_pairs(env: Env<'a>, pairs: &[(Term<'a>, Term<'a>)]) -> NifResult> { + pub fn map_from_pairs( + env: Env<'a>, + pairs: &[(impl Encoder, impl Encoder)], + ) -> NifResult> { let (keys, values): (Vec<_>, Vec<_>) = pairs .iter() - .map(|(k, v)| (k.as_c_arg(), v.as_c_arg())) + .map(|(k, v)| (k.encode(env).as_c_arg(), v.encode(env).as_c_arg())) .unzip(); unsafe { @@ -78,9 +81,11 @@ impl<'a> Term<'a> { /// ```elixir /// Map.get(self_term, key) /// ``` - pub fn map_get(self, key: Term) -> NifResult> { + pub fn map_get(self, key: impl Encoder) -> NifResult> { let env = self.get_env(); - match unsafe { map::get_map_value(env.as_c_arg(), self.as_c_arg(), key.as_c_arg()) } { + match unsafe { + map::get_map_value(env.as_c_arg(), self.as_c_arg(), key.encode(env).as_c_arg()) + } { Some(value) => Ok(unsafe { Term::new(env, value) }), None => Err(Error::BadArg), } @@ -108,27 +113,18 @@ impl<'a> Term<'a> { /// ```elixir /// Map.put(self_term, key, value) /// ``` - pub fn map_put(self, key: Term<'a>, value: Term<'a>) -> NifResult> { - let map_env = self.get_env(); - - assert!( - map_env == key.get_env(), - "key is from different environment as map" - ); - assert!( - map_env == value.get_env(), - "value is from different environment as map" - ); + pub fn map_put(self, key: impl Encoder, value: impl Encoder) -> NifResult> { + let env = self.get_env(); match unsafe { map::map_put( - map_env.as_c_arg(), + env.as_c_arg(), self.as_c_arg(), - key.as_c_arg(), - value.as_c_arg(), + key.encode(env).as_c_arg(), + value.encode(env).as_c_arg(), ) } { - Some(inner) => Ok(unsafe { Term::new(map_env, inner) }), + Some(inner) => Ok(unsafe { Term::new(env, inner) }), None => Err(Error::BadArg), } } @@ -142,16 +138,13 @@ impl<'a> Term<'a> { /// ```elixir /// Map.delete(self_term, key) /// ``` - pub fn map_remove(self, key: Term<'a>) -> NifResult> { - let map_env = self.get_env(); - - assert!( - map_env == key.get_env(), - "key is from different environment as map" - ); + pub fn map_remove(self, key: impl Encoder) -> NifResult> { + let env = self.get_env(); - match unsafe { map::map_remove(map_env.as_c_arg(), self.as_c_arg(), key.as_c_arg()) } { - Some(inner) => Ok(unsafe { Term::new(map_env, inner) }), + match unsafe { + map::map_remove(env.as_c_arg(), self.as_c_arg(), key.encode(env).as_c_arg()) + } { + Some(inner) => Ok(unsafe { Term::new(env, inner) }), None => Err(Error::BadArg), } } @@ -160,27 +153,18 @@ impl<'a> Term<'a> { /// /// Returns Err(Error::BadArg) if the term is not a map of if key /// doesn't exist. - pub fn map_update(self, key: Term<'a>, new_value: Term<'a>) -> NifResult> { - let map_env = self.get_env(); - - assert!( - map_env == key.get_env(), - "key is from different environment as map" - ); - assert!( - map_env == new_value.get_env(), - "value is from different environment as map" - ); + pub fn map_update(self, key: impl Encoder, new_value: impl Encoder) -> NifResult> { + let env = self.get_env(); match unsafe { map::map_update( - map_env.as_c_arg(), + env.as_c_arg(), self.as_c_arg(), - key.as_c_arg(), - new_value.as_c_arg(), + key.encode(env).as_c_arg(), + new_value.encode(env).as_c_arg(), ) } { - Some(inner) => Ok(unsafe { Term::new(map_env, inner) }), + Some(inner) => Ok(unsafe { Term::new(env, inner) }), None => Err(Error::BadArg), } } @@ -234,17 +218,16 @@ where T: Decoder<'a>, { fn decode(term: Term<'a>) -> NifResult { - let env = term.get_env(); - let name = term.map_get(atom::__struct__().to_term(env))?; + let name = term.map_get(atom::__struct__())?; match name.atom_to_string()?.as_ref() { "Elixir.Range" => (), _ => return Err(Error::BadArg), } - let first = term.map_get(atom::first().to_term(env))?.decode::()?; - let last = term.map_get(atom::last().to_term(env))?.decode::()?; - if let Ok(step) = term.map_get(atom::step().to_term(env)) { + let first = term.map_get(atom::first())?.decode::()?; + let last = term.map_get(atom::last())?.decode::()?; + if let Ok(step) = term.map_get(atom::step()) { match step.decode::()? { 1 => (), _ => return Err(Error::BadArg), diff --git a/rustler/src/types/mod.rs b/rustler/src/types/mod.rs index 106e1bff..f3f19356 100644 --- a/rustler/src/types/mod.rs +++ b/rustler/src/types/mod.rs @@ -154,10 +154,7 @@ where V: Encoder, { fn encode<'c>(&self, env: Env<'c>) -> Term<'c> { - let (keys, values): (Vec<_>, Vec<_>) = self - .iter() - .map(|(k, v)| (k.encode(env), v.encode(env))) - .unzip(); + let (keys, values): (Vec<_>, Vec<_>) = self.iter().map(|(k, v)| (k, v)).unzip(); Term::map_from_arrays(env, &keys, &values).unwrap() } } diff --git a/rustler_codegen/src/ex_struct.rs b/rustler_codegen/src/ex_struct.rs index a034e67d..1ed93bd5 100644 --- a/rustler_codegen/src/ex_struct.rs +++ b/rustler_codegen/src/ex_struct.rs @@ -83,7 +83,7 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T let variable = Context::escape_ident_with_index(&ident.to_string(), index, "struct"); let assignment = quote_spanned! { field.span() => - let #variable = try_decode_field(env, term, #atom_fun())?; + let #variable = try_decode_field(term, #atom_fun())?; }; let field_def = quote! { @@ -100,10 +100,7 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T use #atoms_module_name::*; use ::rustler::Encoder; - let env = term.get_env(); - fn try_decode_field<'a, T>( - env: rustler::Env<'a>, term: rustler::Term<'a>, field: rustler::Atom, ) -> ::rustler::NifResult @@ -111,7 +108,7 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T T: rustler::Decoder<'a>, { use rustler::Encoder; - match ::rustler::Decoder::decode(term.map_get(field.encode(env))?) { + match ::rustler::Decoder::decode(term.map_get(field)?) { Err(_) => Err(::rustler::Error::RaiseTerm(Box::new(format!( "Could not decode field :{:?} on %{}{{}}", field, #struct_name_str @@ -120,7 +117,7 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T } }; - let module: ::rustler::types::atom::Atom = term.map_get(atom_struct().to_term(env))?.decode()?; + let module: ::rustler::types::atom::Atom = term.map_get(atom_struct())?.decode()?; if module != atom_module() { return Err(::rustler::Error::RaiseAtom("invalid_struct")); } @@ -149,14 +146,14 @@ fn gen_encoder( let field_ident = field.ident.as_ref().unwrap(); let atom_fun = Context::field_to_atom_fun(field); quote_spanned! { field.span() => - map = map.map_put(#atom_fun().encode(env), self.#field_ident.encode(env)).unwrap(); + map = map.map_put(#atom_fun(), &self.#field_ident).unwrap(); } }) .collect(); let exception_field = if add_exception { quote! { - map = map.map_put(atom_exception().encode(env), true.encode(env)).unwrap(); + map = map.map_put(atom_exception(), true).unwrap(); } } else { quote! {} @@ -167,7 +164,7 @@ fn gen_encoder( fn encode<'a>(&self, env: ::rustler::Env<'a>) -> ::rustler::Term<'a> { use #atoms_module_name::*; let mut map = ::rustler::types::map::map_new(env); - map = map.map_put(atom_struct().encode(env), atom_module().encode(env)).unwrap(); + map = map.map_put(atom_struct(), atom_module()).unwrap(); #exception_field #(#field_defs)* map diff --git a/rustler_codegen/src/map.rs b/rustler_codegen/src/map.rs index 5c68cfc3..ed35cf78 100644 --- a/rustler_codegen/src/map.rs +++ b/rustler_codegen/src/map.rs @@ -66,7 +66,7 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T let variable = Context::escape_ident_with_index(&ident.to_string(), index, "map"); let assignment = quote_spanned! { field.span() => - let #variable = try_decode_field(env, term, #atom_fun())?; + let #variable = try_decode_field(term, #atom_fun())?; }; let field_def = quote! { @@ -81,10 +81,7 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T fn decode(term: ::rustler::Term<'a>) -> ::rustler::NifResult { use #atoms_module_name::*; - let env = term.get_env(); - fn try_decode_field<'a, T>( - env: rustler::Env<'a>, term: rustler::Term<'a>, field: rustler::Atom, ) -> ::rustler::NifResult @@ -92,7 +89,7 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T T: rustler::Decoder<'a>, { use rustler::Encoder; - match ::rustler::Decoder::decode(term.map_get(field.encode(env))?) { + match ::rustler::Decoder::decode(term.map_get(field)?) { Err(_) => Err(::rustler::Error::RaiseTerm(Box::new(format!( "Could not decode field :{:?} on %{{}}", field @@ -121,7 +118,7 @@ fn gen_encoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T let atom_fun = Context::field_to_atom_fun(field); quote_spanned! { field.span() => - map = map.map_put(#atom_fun().encode(env), self.#field_ident.encode(env)).unwrap(); + map = map.map_put(#atom_fun(), self.#field_ident).unwrap(); } }) .collect(); diff --git a/rustler_codegen/src/tagged_enum.rs b/rustler_codegen/src/tagged_enum.rs index fae102fe..f0405273 100644 --- a/rustler_codegen/src/tagged_enum.rs +++ b/rustler_codegen/src/tagged_enum.rs @@ -107,11 +107,9 @@ fn gen_decoder(ctx: &Context, variants: &[&Variant], atoms_module_name: &Ident) fn decode(term: ::rustler::Term<'a>) -> ::rustler::NifResult { use #atoms_module_name::*; - let env = term.get_env(); let value = ::rustler::types::atom::Atom::from_term(term); fn try_decode_field<'a, T>( - env: ::rustler::Env<'a>, term: ::rustler::Term<'a>, field: ::rustler::Atom, ) -> ::rustler::NifResult @@ -119,7 +117,7 @@ fn gen_decoder(ctx: &Context, variants: &[&Variant], atoms_module_name: &Ident) T: ::rustler::Decoder<'a>, { use ::rustler::Encoder; - match ::rustler::Decoder::decode(term.map_get(field.encode(env))?) { + match ::rustler::Decoder::decode(term.map_get(field)?) { Err(_) => Err(::rustler::Error::RaiseTerm(Box::new(format!( "Could not decode field :{:?} on %{{}}", field @@ -244,7 +242,7 @@ fn gen_named_decoder( let enum_name_string = enum_name.to_string(); let assignment = quote_spanned! { field.span() => - let #variable = try_decode_field(env, tuple[1], #atom_fun()).map_err(|_|{ + let #variable = try_decode_field(tuple[1], #atom_fun()).map_err(|_|{ ::rustler::Error::RaiseTerm(Box::new(format!( "Could not decode field '{}' on Enum '{}'", #ident_string, #enum_name_string @@ -315,22 +313,24 @@ fn gen_named_encoder( } }) .collect::>(); - let field_defs = fields.named.iter() + let (keys, values): (Vec<_>, Vec<_>) = fields + .named + .iter() .map(|field| { - let field_ident = field.ident.as_ref().expect("Named fields must have an ident."); + let field_ident = field + .ident + .as_ref() + .expect("Named fields must have an ident."); let atom_fun = Context::field_to_atom_fun(field); - - quote_spanned! { field.span() => - map = map.map_put(#atom_fun().encode(env), #field_ident.encode(env)).expect("Failed to putting map"); - } + (atom_fun, field_ident) }) - .collect::>(); + .unzip(); quote! { #enum_name :: #variant_ident{ #(#field_decls)* } => { - let mut map = ::rustler::types::map::map_new(env); - #(#field_defs)* + let map = ::rustler::Term::map_from_arrays(env, &[#(#keys()),*], &[#(#values),*]) + .expect("Failed to create map"); ::rustler::types::tuple::make_tuple(env, &[#atom_fn().encode(env), map]) } } From 150cae1ae2cd686625ac5e93c0ea99fab9ee9d28 Mon Sep 17 00:00:00 2001 From: Seokmin Hong Date: Thu, 12 May 2022 14:54:31 +0900 Subject: [PATCH 2/2] Remove unnecessary map --- rustler/src/types/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rustler/src/types/mod.rs b/rustler/src/types/mod.rs index f3f19356..bd1d5364 100644 --- a/rustler/src/types/mod.rs +++ b/rustler/src/types/mod.rs @@ -154,7 +154,7 @@ where V: Encoder, { fn encode<'c>(&self, env: Env<'c>) -> Term<'c> { - let (keys, values): (Vec<_>, Vec<_>) = self.iter().map(|(k, v)| (k, v)).unzip(); + let (keys, values): (Vec<_>, Vec<_>) = self.iter().unzip(); Term::map_from_arrays(env, &keys, &values).unwrap() } }