diff --git a/rustler/src/types/primitive.rs b/rustler/src/types/primitive.rs index cb47b348..89072c30 100644 --- a/rustler/src/types/primitive.rs +++ b/rustler/src/types/primitive.rs @@ -1,27 +1,43 @@ use crate::types::atom; use crate::{Decoder, Encoder, Env, Error, NifResult, Term}; -macro_rules! impl_number_transcoder { - ($dec_type:ty, $nif_type:ty, $encode_fun:ident, $decode_fun:ident) => { +macro_rules! erl_make { + ($self:expr, $env:ident, $encode_fun:ident, $type:ty) => { + #[allow(clippy::cast_lossless)] + unsafe { + Term::new( + $env, + rustler_sys::$encode_fun($env.as_c_arg(), $self as $type), + ) + } + }; +} + +macro_rules! erl_get { + ($decode_fun:ident, $term:ident, $dest:ident) => { + unsafe { + rustler_sys::$decode_fun($term.get_env().as_c_arg(), $term.as_c_arg(), &mut $dest) + } + }; +} + +macro_rules! impl_number_encoder { + ($dec_type:ty, $nif_type:ty, $encode_fun:ident) => { impl Encoder for $dec_type { fn encode<'a>(&self, env: Env<'a>) -> Term<'a> { - #[allow(clippy::cast_lossless)] - unsafe { - Term::new( - env, - rustler_sys::$encode_fun(env.as_c_arg(), *self as $nif_type), - ) - } + erl_make!(*self, env, $encode_fun, $nif_type) } } + }; +} + +macro_rules! impl_number_decoder { + ($dec_type:ty, $nif_type:ty, $decode_fun:ident) => { impl<'a> Decoder<'a> for $dec_type { fn decode(term: Term) -> NifResult<$dec_type> { #![allow(unused_unsafe)] let mut res: $nif_type = Default::default(); - if unsafe { - rustler_sys::$decode_fun(term.get_env().as_c_arg(), term.as_c_arg(), &mut res) - } == 0 - { + if erl_get!($decode_fun, term, res) == 0 { return Err(Error::BadArg); } Ok(res as $dec_type) @@ -30,12 +46,19 @@ macro_rules! impl_number_transcoder { }; } +macro_rules! impl_number_transcoder { + ($dec_type:ty, $nif_type:ty, $encode_fun:ident, $decode_fun:ident) => { + impl_number_encoder!($dec_type, $nif_type, $encode_fun); + impl_number_decoder!($dec_type, $nif_type, $decode_fun); + }; +} + // Base number types impl_number_transcoder!(i32, i32, enif_make_int, enif_get_int); impl_number_transcoder!(u32, u32, enif_make_uint, enif_get_uint); impl_number_transcoder!(i64, i64, enif_make_int64, enif_get_int64); impl_number_transcoder!(u64, u64, enif_make_uint64, enif_get_uint64); -impl_number_transcoder!(f64, f64, enif_make_double, enif_get_double); +impl_number_encoder!(f64, f64, enif_make_double); // Casted number types impl_number_transcoder!(i8, i32, enif_make_int, enif_get_int); @@ -44,25 +67,18 @@ impl_number_transcoder!(i16, i32, enif_make_int, enif_get_int); impl_number_transcoder!(u16, u32, enif_make_uint, enif_get_uint); impl_number_transcoder!(usize, u64, enif_make_uint64, enif_get_uint64); impl_number_transcoder!(isize, i64, enif_make_int64, enif_get_int64); +impl_number_encoder!(f32, f64, enif_make_double); -impl Encoder for bool { - fn encode<'a>(&self, env: Env<'a>) -> Term<'a> { - if *self { - atom::true_().to_term(env) - } else { - atom::false_().to_term(env) +// Manual Decoder impls for floats so they can fall back to decoding from integer terms +impl<'a> Decoder<'a> for f64 { + fn decode(term: Term) -> NifResult { + #![allow(unused_unsafe)] + let mut res: f64 = Default::default(); + if erl_get!(enif_get_double, term, res) == 0 { + let res_fallback: i64 = term.decode()?; + return Ok(res_fallback as f64); } - } -} -impl<'a> Decoder<'a> for bool { - fn decode(term: Term<'a>) -> NifResult { - atom::decode_bool(term) - } -} - -impl Encoder for f32 { - fn encode<'a>(&self, env: Env<'a>) -> Term<'a> { - f64::from(*self).encode(env) + Ok(res) } } @@ -78,3 +94,18 @@ impl<'a> Decoder<'a> for f32 { } } } + +impl Encoder for bool { + fn encode<'a>(&self, env: Env<'a>) -> Term<'a> { + if *self { + atom::true_().to_term(env) + } else { + atom::false_().to_term(env) + } + } +} +impl<'a> Decoder<'a> for bool { + fn decode(term: Term<'a>) -> NifResult { + atom::decode_bool(term) + } +} diff --git a/rustler_tests/lib/rustler_test.ex b/rustler_tests/lib/rustler_test.ex index 77ca2790..f6f100fd 100644 --- a/rustler_tests/lib/rustler_test.ex +++ b/rustler_tests/lib/rustler_test.ex @@ -23,6 +23,7 @@ defmodule RustlerTest do def add_u32(_, _), do: err() def add_i32(_, _), do: err() + def add_floats(_, _), do: err() def echo_u8(_), do: err() def echo_u128(_), do: err() def echo_i128(_), do: err() diff --git a/rustler_tests/native/rustler_test/src/test_primitives.rs b/rustler_tests/native/rustler_test/src/test_primitives.rs index f6faeb63..6574950f 100644 --- a/rustler_tests/native/rustler_test/src/test_primitives.rs +++ b/rustler_tests/native/rustler_test/src/test_primitives.rs @@ -10,6 +10,11 @@ pub fn add_i32(a: i32, b: i32) -> i32 { a + b } +#[rustler::nif] +pub fn add_floats(a: f32, b: f64) -> f64 { + (a as f64) + b +} + #[rustler::nif] pub fn echo_u8(n: u8) -> u8 { n diff --git a/rustler_tests/test/primitives_test.exs b/rustler_tests/test/primitives_test.exs index e63904d9..0b1a97c7 100644 --- a/rustler_tests/test/primitives_test.exs +++ b/rustler_tests/test/primitives_test.exs @@ -6,6 +6,10 @@ defmodule RustlerTest.PrimitivesTest do assert 3 == RustlerTest.add_i32(6, -3) assert -3 == RustlerTest.add_i32(3, -6) assert 3 == RustlerTest.echo_u8(3) + assert 2.0 == RustlerTest.add_floats(3.0, -1.0) + assert 2.0 == RustlerTest.add_floats(3, -1) + assert 2.0 == RustlerTest.add_floats(3.0, -1) + assert 2.0 == RustlerTest.add_floats(3, -1.0) end test "number decoding should fail on invalid terms" do