Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change map helper functions' arguments #453

Merged
merged 2 commits into from
May 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions rustler/src/types/elixir_struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,15 @@ use super::map::map_new;
use crate::{Env, NifResult, Term};

pub fn get_ex_struct_name(map: Term) -> NifResult<Atom> {
let env = map.get_env();
// In an Elixir struct the value in the __struct__ field is always an atom.
map.map_get(atom::__struct__().to_term(env))
.and_then(Atom::from_term)
map.map_get(atom::__struct__()).and_then(Atom::from_term)
}

pub fn make_ex_struct<'a>(env: Env<'a>, struct_module: &str) -> NifResult<Term<'a>> {
let map = map_new(env);

let struct_atom = atom::__struct__().to_term(env);
evnu marked this conversation as resolved.
Show resolved Hide resolved
let module_atom = Atom::from_str(env, struct_module)?.to_term(env);
let struct_atom = atom::__struct__();
let module_atom = Atom::from_str(env, struct_module)?;

map.map_put(struct_atom, module_atom)
}
89 changes: 36 additions & 53 deletions rustler/src/types/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use super::atom;
use crate::wrapper::map;
use crate::{Decoder, Env, Error, NifResult, Term};
use crate::{Decoder, Encoder, Env, Error, NifResult, Term};
use std::ops::RangeInclusive;

pub fn map_new(env: Env) -> Term {
Expand Down Expand Up @@ -31,12 +31,12 @@ impl<'a> Term<'a> {
/// ```
pub fn map_from_arrays(
env: Env<'a>,
keys: &[Term<'a>],
values: &[Term<'a>],
keys: &[impl Encoder],
values: &[impl Encoder],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be impl Encoder + 'a as @hansihe noted?

Copy link
Contributor Author

@SeokminHong SeokminHong May 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Encoder::encode encodes a Rust value into a term bounds to env's lifetime. So, the arguments actually don't need the lifetime bound!

) -> NifResult<Term<'a>> {
if keys.len() == values.len() {
let keys: Vec<_> = keys.iter().map(|k| k.as_c_arg()).collect();
let values: Vec<_> = values.iter().map(|v| v.as_c_arg()).collect();
let keys: Vec<_> = keys.iter().map(|k| k.encode(env).as_c_arg()).collect();
let values: Vec<_> = values.iter().map(|v| v.encode(env).as_c_arg()).collect();

unsafe {
map::make_map_from_arrays(env.as_c_arg(), &keys, &values)
Expand All @@ -57,10 +57,13 @@ impl<'a> Term<'a> {
/// ```elixir
/// Map.new([{"foo", 1}, {"bar", 2}])
/// ```
pub fn map_from_pairs(env: Env<'a>, pairs: &[(Term<'a>, Term<'a>)]) -> NifResult<Term<'a>> {
pub fn map_from_pairs(
env: Env<'a>,
pairs: &[(impl Encoder, impl Encoder)],
) -> NifResult<Term<'a>> {
let (keys, values): (Vec<_>, Vec<_>) = pairs
.iter()
.map(|(k, v)| (k.as_c_arg(), v.as_c_arg()))
.map(|(k, v)| (k.encode(env).as_c_arg(), v.encode(env).as_c_arg()))
.unzip();

unsafe {
Expand All @@ -78,9 +81,11 @@ impl<'a> Term<'a> {
/// ```elixir
/// Map.get(self_term, key)
/// ```
pub fn map_get(self, key: Term) -> NifResult<Term<'a>> {
pub fn map_get(self, key: impl Encoder) -> NifResult<Term<'a>> {
let env = self.get_env();
match unsafe { map::get_map_value(env.as_c_arg(), self.as_c_arg(), key.as_c_arg()) } {
match unsafe {
map::get_map_value(env.as_c_arg(), self.as_c_arg(), key.encode(env).as_c_arg())
} {
Some(value) => Ok(unsafe { Term::new(env, value) }),
None => Err(Error::BadArg),
}
Expand Down Expand Up @@ -108,27 +113,18 @@ impl<'a> Term<'a> {
/// ```elixir
/// Map.put(self_term, key, value)
/// ```
pub fn map_put(self, key: Term<'a>, value: Term<'a>) -> NifResult<Term<'a>> {
let map_env = self.get_env();

assert!(
map_env == key.get_env(),
"key is from different environment as map"
);
assert!(
map_env == value.get_env(),
"value is from different environment as map"
);
pub fn map_put(self, key: impl Encoder, value: impl Encoder) -> NifResult<Term<'a>> {
let env = self.get_env();

match unsafe {
map::map_put(
map_env.as_c_arg(),
env.as_c_arg(),
self.as_c_arg(),
key.as_c_arg(),
value.as_c_arg(),
key.encode(env).as_c_arg(),
value.encode(env).as_c_arg(),
)
} {
Some(inner) => Ok(unsafe { Term::new(map_env, inner) }),
Some(inner) => Ok(unsafe { Term::new(env, inner) }),
None => Err(Error::BadArg),
}
}
Expand All @@ -142,16 +138,13 @@ impl<'a> Term<'a> {
/// ```elixir
/// Map.delete(self_term, key)
/// ```
pub fn map_remove(self, key: Term<'a>) -> NifResult<Term<'a>> {
let map_env = self.get_env();

assert!(
map_env == key.get_env(),
"key is from different environment as map"
);
pub fn map_remove(self, key: impl Encoder) -> NifResult<Term<'a>> {
let env = self.get_env();

match unsafe { map::map_remove(map_env.as_c_arg(), self.as_c_arg(), key.as_c_arg()) } {
Some(inner) => Ok(unsafe { Term::new(map_env, inner) }),
match unsafe {
map::map_remove(env.as_c_arg(), self.as_c_arg(), key.encode(env).as_c_arg())
} {
Some(inner) => Ok(unsafe { Term::new(env, inner) }),
None => Err(Error::BadArg),
}
}
Expand All @@ -160,27 +153,18 @@ impl<'a> Term<'a> {
///
/// Returns Err(Error::BadArg) if the term is not a map of if key
/// doesn't exist.
pub fn map_update(self, key: Term<'a>, new_value: Term<'a>) -> NifResult<Term<'a>> {
let map_env = self.get_env();

assert!(
map_env == key.get_env(),
"key is from different environment as map"
);
assert!(
map_env == new_value.get_env(),
"value is from different environment as map"
);
pub fn map_update(self, key: impl Encoder, new_value: impl Encoder) -> NifResult<Term<'a>> {
evnu marked this conversation as resolved.
Show resolved Hide resolved
let env = self.get_env();

match unsafe {
map::map_update(
map_env.as_c_arg(),
env.as_c_arg(),
self.as_c_arg(),
key.as_c_arg(),
new_value.as_c_arg(),
key.encode(env).as_c_arg(),
new_value.encode(env).as_c_arg(),
)
} {
Some(inner) => Ok(unsafe { Term::new(map_env, inner) }),
Some(inner) => Ok(unsafe { Term::new(env, inner) }),
None => Err(Error::BadArg),
}
}
Expand Down Expand Up @@ -234,17 +218,16 @@ where
T: Decoder<'a>,
{
fn decode(term: Term<'a>) -> NifResult<Self> {
let env = term.get_env();
let name = term.map_get(atom::__struct__().to_term(env))?;
let name = term.map_get(atom::__struct__())?;

match name.atom_to_string()?.as_ref() {
"Elixir.Range" => (),
_ => return Err(Error::BadArg),
}

let first = term.map_get(atom::first().to_term(env))?.decode::<T>()?;
let last = term.map_get(atom::last().to_term(env))?.decode::<T>()?;
if let Ok(step) = term.map_get(atom::step().to_term(env)) {
let first = term.map_get(atom::first())?.decode::<T>()?;
let last = term.map_get(atom::last())?.decode::<T>()?;
if let Ok(step) = term.map_get(atom::step()) {
match step.decode::<i64>()? {
1 => (),
_ => return Err(Error::BadArg),
Expand Down
5 changes: 1 addition & 4 deletions rustler/src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,7 @@ where
V: Encoder,
{
fn encode<'c>(&self, env: Env<'c>) -> Term<'c> {
let (keys, values): (Vec<_>, Vec<_>) = self
.iter()
.map(|(k, v)| (k.encode(env), v.encode(env)))
.unzip();
let (keys, values): (Vec<_>, Vec<_>) = self.iter().unzip();
Term::map_from_arrays(env, &keys, &values).unwrap()
}
}
15 changes: 6 additions & 9 deletions rustler_codegen/src/ex_struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T
let variable = Context::escape_ident_with_index(&ident.to_string(), index, "struct");

let assignment = quote_spanned! { field.span() =>
let #variable = try_decode_field(env, term, #atom_fun())?;
let #variable = try_decode_field(term, #atom_fun())?;
};

let field_def = quote! {
Expand All @@ -100,18 +100,15 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T
use #atoms_module_name::*;
use ::rustler::Encoder;

let env = term.get_env();

fn try_decode_field<'a, T>(
env: rustler::Env<'a>,
term: rustler::Term<'a>,
field: rustler::Atom,
) -> ::rustler::NifResult<T>
where
T: rustler::Decoder<'a>,
{
use rustler::Encoder;
match ::rustler::Decoder::decode(term.map_get(field.encode(env))?) {
match ::rustler::Decoder::decode(term.map_get(field)?) {
Err(_) => Err(::rustler::Error::RaiseTerm(Box::new(format!(
"Could not decode field :{:?} on %{}{{}}",
field, #struct_name_str
Expand All @@ -120,7 +117,7 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T
}
};

let module: ::rustler::types::atom::Atom = term.map_get(atom_struct().to_term(env))?.decode()?;
let module: ::rustler::types::atom::Atom = term.map_get(atom_struct())?.decode()?;
if module != atom_module() {
return Err(::rustler::Error::RaiseAtom("invalid_struct"));
}
Expand Down Expand Up @@ -149,14 +146,14 @@ fn gen_encoder(
let field_ident = field.ident.as_ref().unwrap();
let atom_fun = Context::field_to_atom_fun(field);
quote_spanned! { field.span() =>
map = map.map_put(#atom_fun().encode(env), self.#field_ident.encode(env)).unwrap();
map = map.map_put(#atom_fun(), &self.#field_ident).unwrap();
}
})
.collect();

let exception_field = if add_exception {
quote! {
map = map.map_put(atom_exception().encode(env), true.encode(env)).unwrap();
map = map.map_put(atom_exception(), true).unwrap();
}
} else {
quote! {}
Expand All @@ -167,7 +164,7 @@ fn gen_encoder(
fn encode<'a>(&self, env: ::rustler::Env<'a>) -> ::rustler::Term<'a> {
use #atoms_module_name::*;
let mut map = ::rustler::types::map::map_new(env);
map = map.map_put(atom_struct().encode(env), atom_module().encode(env)).unwrap();
map = map.map_put(atom_struct(), atom_module()).unwrap();
#exception_field
#(#field_defs)*
map
Expand Down
9 changes: 3 additions & 6 deletions rustler_codegen/src/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T
let variable = Context::escape_ident_with_index(&ident.to_string(), index, "map");

let assignment = quote_spanned! { field.span() =>
let #variable = try_decode_field(env, term, #atom_fun())?;
let #variable = try_decode_field(term, #atom_fun())?;
};

let field_def = quote! {
Expand All @@ -81,18 +81,15 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T
fn decode(term: ::rustler::Term<'a>) -> ::rustler::NifResult<Self> {
use #atoms_module_name::*;

let env = term.get_env();

fn try_decode_field<'a, T>(
env: rustler::Env<'a>,
term: rustler::Term<'a>,
field: rustler::Atom,
) -> ::rustler::NifResult<T>
where
T: rustler::Decoder<'a>,
{
use rustler::Encoder;
match ::rustler::Decoder::decode(term.map_get(field.encode(env))?) {
match ::rustler::Decoder::decode(term.map_get(field)?) {
Err(_) => Err(::rustler::Error::RaiseTerm(Box::new(format!(
"Could not decode field :{:?} on %{{}}",
field
Expand Down Expand Up @@ -121,7 +118,7 @@ fn gen_encoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T
let atom_fun = Context::field_to_atom_fun(field);

quote_spanned! { field.span() =>
map = map.map_put(#atom_fun().encode(env), self.#field_ident.encode(env)).unwrap();
map = map.map_put(#atom_fun(), self.#field_ident).unwrap();
}
})
.collect();
Expand Down
26 changes: 13 additions & 13 deletions rustler_codegen/src/tagged_enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,19 +107,17 @@ fn gen_decoder(ctx: &Context, variants: &[&Variant], atoms_module_name: &Ident)
fn decode(term: ::rustler::Term<'a>) -> ::rustler::NifResult<Self> {
use #atoms_module_name::*;

let env = term.get_env();
let value = ::rustler::types::atom::Atom::from_term(term);

fn try_decode_field<'a, T>(
env: ::rustler::Env<'a>,
term: ::rustler::Term<'a>,
field: ::rustler::Atom,
) -> ::rustler::NifResult<T>
where
T: ::rustler::Decoder<'a>,
{
use ::rustler::Encoder;
match ::rustler::Decoder::decode(term.map_get(field.encode(env))?) {
match ::rustler::Decoder::decode(term.map_get(field)?) {
Err(_) => Err(::rustler::Error::RaiseTerm(Box::new(format!(
"Could not decode field :{:?} on %{{}}",
field
Expand Down Expand Up @@ -244,7 +242,7 @@ fn gen_named_decoder(
let enum_name_string = enum_name.to_string();

let assignment = quote_spanned! { field.span() =>
let #variable = try_decode_field(env, tuple[1], #atom_fun()).map_err(|_|{
let #variable = try_decode_field(tuple[1], #atom_fun()).map_err(|_|{
::rustler::Error::RaiseTerm(Box::new(format!(
"Could not decode field '{}' on Enum '{}'",
#ident_string, #enum_name_string
Expand Down Expand Up @@ -315,22 +313,24 @@ fn gen_named_encoder(
}
})
.collect::<Vec<_>>();
let field_defs = fields.named.iter()
let (keys, values): (Vec<_>, Vec<_>) = fields
.named
.iter()
.map(|field| {
let field_ident = field.ident.as_ref().expect("Named fields must have an ident.");
let field_ident = field
.ident
.as_ref()
.expect("Named fields must have an ident.");
let atom_fun = Context::field_to_atom_fun(field);

quote_spanned! { field.span() =>
map = map.map_put(#atom_fun().encode(env), #field_ident.encode(env)).expect("Failed to putting map");
}
(atom_fun, field_ident)
})
.collect::<Vec<_>>();
.unzip();
quote! {
#enum_name :: #variant_ident{
#(#field_decls)*
} => {
let mut map = ::rustler::types::map::map_new(env);
#(#field_defs)*
let map = ::rustler::Term::map_from_arrays(env, &[#(#keys()),*], &[#(#values),*])
.expect("Failed to create map");
::rustler::types::tuple::make_tuple(env, &[#atom_fn().encode(env), map])
}
}
Expand Down