From ca73cbe4bbcbdb4f3d30455a090ff50209dc1fc3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9o=20Meira=20Vital?= Date: Fri, 25 Aug 2023 15:15:43 -0300 Subject: [PATCH] feat: add TryFrom implementation to Enumeration (#853) --- README.md | 13 +++++++------ prost-derive/src/field/map.rs | 10 ++++++++-- prost-derive/src/field/scalar.rs | 19 +++++++++++++------ prost-derive/src/lib.rs | 16 ++++++++++++++++ tests/src/lib.rs | 26 ++++++++++++++++++++++++++ 5 files changed, 70 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index bed30d0ad..1633a4255 100644 --- a/README.md +++ b/README.md @@ -163,21 +163,22 @@ The `#[derive(::prost::Enumeration)]` annotation added to the generated ```rust,ignore impl PhoneType { pub fn is_valid(value: i32) -> bool { ... } + #[deprecated] pub fn from_i32(value: i32) -> Option { ... } } ``` -so you can convert an `i32` to its corresponding `PhoneType` value by doing, +It also adds an `impl TryFrom for PhoneType`, so you can convert an `i32` to its corresponding `PhoneType` value by doing, for example: ```rust,ignore let phone_type = 2i32; -match PhoneType::from_i32(phone_type) { - Some(PhoneType::Mobile) => ..., - Some(PhoneType::Home) => ..., - Some(PhoneType::Work) => ..., - None => ..., +match PhoneType::try_from(phone_type) { + Ok(PhoneType::Mobile) => ..., + Ok(PhoneType::Home) => ..., + Ok(PhoneType::Work) => ..., + Err(_) => ..., } ``` diff --git a/prost-derive/src/field/map.rs b/prost-derive/src/field/map.rs index 540074dc7..aabceb1f1 100644 --- a/prost-derive/src/field/map.rs +++ b/prost-derive/src/field/map.rs @@ -275,11 +275,17 @@ impl Field { Some(quote! { #[doc=#get_doc] pub fn #get(&self, key: #key_ref_ty) -> ::core::option::Option<#ty> { - self.#ident.get(#take_ref key).cloned().and_then(#ty::from_i32) + self.#ident.get(#take_ref key).cloned().and_then(|x| { + let result: Result<#ty, _> = ::core::convert::TryFrom::try_from(x); + result.ok() + }) } #[doc=#insert_doc] pub fn #insert(&mut self, key: #key_ty, value: #ty) -> ::core::option::Option<#ty> { - self.#ident.insert(key, value as i32).and_then(#ty::from_i32) + self.#ident.insert(key, value as i32).and_then(|x| { + let result: Result<#ty, _> = ::core::convert::TryFrom::try_from(x); + result.ok() + }) } }) } else { diff --git a/prost-derive/src/field/scalar.rs b/prost-derive/src/field/scalar.rs index 5fd0c7174..fa426b7fd 100644 --- a/prost-derive/src/field/scalar.rs +++ b/prost-derive/src/field/scalar.rs @@ -219,9 +219,10 @@ impl Field { struct #wrap_name<'a>(&'a i32); impl<'a> ::core::fmt::Debug for #wrap_name<'a> { fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { - match #ty::from_i32(*self.0) { - None => ::core::fmt::Debug::fmt(&self.0, f), - Some(en) => ::core::fmt::Debug::fmt(&en, f), + let res: Result<#ty, _> = ::core::convert::TryFrom::try_from(*self.0); + match res { + Err(_) => ::core::fmt::Debug::fmt(&self.0, f), + Ok(en) => ::core::fmt::Debug::fmt(&en, f), } } } @@ -296,7 +297,7 @@ impl Field { quote! { #[doc=#get_doc] pub fn #get(&self) -> #ty { - #ty::from_i32(self.#ident).unwrap_or(#default) + ::core::convert::TryFrom::try_from(self.#ident).unwrap_or(#default) } #[doc=#set_doc] @@ -314,7 +315,10 @@ impl Field { quote! { #[doc=#get_doc] pub fn #get(&self) -> #ty { - self.#ident.and_then(#ty::from_i32).unwrap_or(#default) + self.#ident.and_then(|x| { + let result: Result<#ty, _> = ::core::convert::TryFrom::try_from(x); + result.ok() + }).unwrap_or(#default) } #[doc=#set_doc] @@ -336,7 +340,10 @@ impl Field { ::core::iter::Cloned<::core::slice::Iter>, fn(i32) -> ::core::option::Option<#ty>, > { - self.#ident.iter().cloned().filter_map(#ty::from_i32) + self.#ident.iter().cloned().filter_map(|x| { + let result: Result<#ty, _> = ::core::convert::TryFrom::try_from(x); + result.ok() + }) } #[doc=#push_doc] pub fn #push(&mut self, value: #ty) { diff --git a/prost-derive/src/lib.rs b/prost-derive/src/lib.rs index 488b407a5..974e55598 100644 --- a/prost-derive/src/lib.rs +++ b/prost-derive/src/lib.rs @@ -291,6 +291,10 @@ fn try_enumeration(input: TokenStream) -> Result { |&(ref variant, ref value)| quote!(#value => ::core::option::Option::Some(#ident::#variant)), ); + let try_from = variants.iter().map( + |&(ref variant, ref value)| quote!(#value => ::core::result::Result::Ok(#ident::#variant)), + ); + let is_valid_doc = format!("Returns `true` if `value` is a variant of `{}`.", ident); let from_i32_doc = format!( "Converts an `i32` to a `{}`, or `None` if `value` is not a valid variant.", @@ -307,6 +311,7 @@ fn try_enumeration(input: TokenStream) -> Result { } } + #[deprecated = "Use the TryFrom implementation instead"] #[doc=#from_i32_doc] pub fn from_i32(value: i32) -> ::core::option::Option<#ident> { match value { @@ -327,6 +332,17 @@ fn try_enumeration(input: TokenStream) -> Result { value as i32 } } + + impl #impl_generics ::core::convert::TryFrom:: for #ident #ty_generics #where_clause { + type Error = ::prost::DecodeError; + + fn try_from(value: i32) -> ::core::result::Result<#ident, ::prost::DecodeError> { + match value { + #(#try_from,)* + _ => ::core::result::Result::Err(::prost::DecodeError::new("invalid enumeration value")), + } + } + } }; Ok(expanded.into()) diff --git a/tests/src/lib.rs b/tests/src/lib.rs index d672fbc40..e755736fe 100644 --- a/tests/src/lib.rs +++ b/tests/src/lib.rs @@ -578,6 +578,32 @@ mod tests { ); } + #[test] + fn test_enum_try_from_i32() { + use core::convert::TryFrom; + use default_enum_value::{ERemoteClientBroadcastMsg, PrivacyLevel}; + + assert_eq!(Ok(PrivacyLevel::One), PrivacyLevel::try_from(1)); + assert_eq!(Ok(PrivacyLevel::Two), PrivacyLevel::try_from(2)); + assert_eq!( + Ok(PrivacyLevel::PrivacyLevelThree), + PrivacyLevel::try_from(3) + ); + assert_eq!( + Ok(PrivacyLevel::PrivacyLevelprivacyLevelFour), + PrivacyLevel::try_from(4) + ); + assert_eq!( + Err(prost::DecodeError::new("invalid enumeration value")), + PrivacyLevel::try_from(5) + ); + + assert_eq!( + Ok(ERemoteClientBroadcastMsg::KERemoteClientBroadcastMsgDiscovery), + ERemoteClientBroadcastMsg::try_from(0) + ); + } + #[test] fn test_default_string_escape() { let msg = default_string_escape::Person::default();