diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index 4014dcb99..b8e26dc2a 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -402,7 +402,7 @@ impl CodeGenerator<'_> { fn append_field(&mut self, fq_message_name: &str, field: &Field) { let type_ = field.descriptor.r#type(); - let repeated = field.descriptor.label.and_then(|v| v.known()) == Some(Label::Repeated); + let repeated = field.descriptor.label == Some(Label::Repeated as i32); let deprecated = self.deprecated(&field.descriptor); let optional = self.optional(&field.descriptor); let boxed = self.boxed(&field.descriptor, fq_message_name, None); @@ -427,15 +427,19 @@ impl CodeGenerator<'_> { let type_tag = self.field_type_tag(&field.descriptor); self.buf.push_str(&type_tag); - if type_ == Type::Bytes { - let bytes_type = self - .config - .bytes_type - .get_first_field(fq_message_name, field.descriptor.name()) - .copied() - .unwrap_or_default(); - self.buf - .push_str(&format!("={:?}", bytes_type.annotation())); + match type_ { + Type::Bytes => { + let bytes_type = self + .config + .bytes_type + .get_first_field(fq_message_name, field.descriptor.name()) + .copied() + .unwrap_or_default(); + self.buf + .push_str(&format!("={:?}", bytes_type.annotation())); + } + Type::Enum => self.push_enum_type_annotation(fq_message_name, field.descriptor.name()), + _ => {} } match field.descriptor.label() { @@ -555,12 +559,16 @@ impl CodeGenerator<'_> { let value_tag = self.map_value_type_tag(value); self.buf.push_str(&format!( - "#[prost({}=\"{}, {}\", tag=\"{}\")]\n", + "#[prost({}=\"{}, {}\"", map_type.annotation(), key_tag, value_tag, - field.descriptor.number() )); + if value.r#type() == Type::Enum { + self.push_enum_type_annotation(fq_message_name, field.descriptor.name()); + } + self.buf + .push_str(&format!(", tag=\"{}\")]\n", field.descriptor.number())); self.append_field_attributes(fq_message_name, field.descriptor.name()); self.push_indent(); self.buf.push_str(&format!( @@ -639,11 +647,12 @@ impl CodeGenerator<'_> { self.push_indent(); let ty_tag = self.field_type_tag(&field.descriptor); - self.buf.push_str(&format!( - "#[prost({}, tag=\"{}\")]\n", - ty_tag, - field.descriptor.number() - )); + self.buf.push_str(&format!("#[prost({}", ty_tag,)); + if field.descriptor.r#type() == Type::Enum { + self.push_enum_type_annotation(&oneof_name, field.descriptor.name()); + } + self.buf + .push_str(&format!(", tag=\"{}\")]\n", field.descriptor.number())); self.append_field_attributes(&oneof_name, field.descriptor.name()); self.push_indent(); @@ -947,6 +956,14 @@ impl CodeGenerator<'_> { self.buf.push_str("}\n"); } + fn push_enum_type_annotation(&mut self, fq_message_name: &str, field_name: &str) { + match self.enum_field_repr(fq_message_name, field_name) { + EnumRepr::Int => {} + EnumRepr::Open => self.buf.push_str(", enum_type=\"open\""), + EnumRepr::Closed => self.buf.push_str(", enum_type=\"closed\""), + } + } + fn resolve_type(&self, field: &FieldDescriptorProto, fq_message_name: &str) -> String { match field.r#type() { Type::Float => String::from("f32"), @@ -966,11 +983,15 @@ impl CodeGenerator<'_> { .rust_type() .to_owned(), Type::Group | Type::Message => self.resolve_ident(field.type_name()), - Type::Enum => format!( - "{}::OpenEnum<{}>", - prost_path(self.config), - self.resolve_ident(field.type_name()) - ), + Type::Enum => match self.enum_field_repr(fq_message_name, field.name()) { + EnumRepr::Int => String::from("i32"), + EnumRepr::Open => format!( + "{}::OpenEnum<{}>", + prost_path(self.config), + self.resolve_ident(field.type_name()) + ), + EnumRepr::Closed => self.resolve_ident(field.type_name()), + }, } } @@ -1012,6 +1033,22 @@ impl CodeGenerator<'_> { .join("::") } + fn enum_field_repr(&self, fq_message_name: &str, field_name: &str) -> EnumRepr { + if self + .config + .typed_enum_fields + .get_first_field(fq_message_name, field_name) + .is_some() + { + match self.syntax { + Syntax::Proto2 => EnumRepr::Closed, + Syntax::Proto3 => EnumRepr::Open, + } + } else { + EnumRepr::Int + } + } + fn field_type_tag(&self, field: &FieldDescriptorProto) -> Cow<'static, str> { match field.r#type() { Type::Float => Cow::Borrowed("float"), @@ -1074,7 +1111,7 @@ impl CodeGenerator<'_> { fq_message_name: &str, oneof: Option<&str>, ) -> bool { - let repeated = field.label.and_then(|v| v.known()) == Some(Label::Repeated); + let repeated = field.label == Some(Label::Repeated as i32); let fd_type = field.r#type(); if !repeated && (fd_type == Type::Message || fd_type == Type::Group) @@ -1148,6 +1185,12 @@ fn can_pack(field: &FieldDescriptorProto) -> bool { ) } +enum EnumRepr { + Int, + Closed, + Open, +} + struct EnumVariantMapping<'a> { path_idx: usize, proto_name: &'a str, diff --git a/prost-build/src/config.rs b/prost-build/src/config.rs index 896726b16..5d555d12e 100644 --- a/prost-build/src/config.rs +++ b/prost-build/src/config.rs @@ -36,6 +36,7 @@ pub struct Config { pub(crate) enum_attributes: PathMap, pub(crate) field_attributes: PathMap, pub(crate) boxed: PathMap<()>, + pub(crate) typed_enum_fields: PathMap<()>, pub(crate) prost_types: bool, pub(crate) strip_enum_prefix: bool, pub(crate) out_dir: Option, @@ -372,6 +373,30 @@ impl Config { self } + /// Represent Protobuf enum types encountered in matched fields with types + /// bound to their corresponding Rust enum types, rather than the default `i32`. + /// + /// Depending on the proto file syntax, the representation type can be: + /// * For closed enums (in proto2), the corresponding Rust enum type. + /// * For open enums (in proto3), the Rust enum type wrapped in [`OpenEnum`]. + /// + /// # Arguments + /// + /// **`path`** - a path matching any number of fields. These fields will get the type-checked + /// enum representation. + /// For details about matching fields see [`btree_map`](#method.btree_map). + /// + /// # Examples + /// + /// ```rust + /// # let mut config = prost_build::Config::new(); + /// config.typed_enum_fields(".my_messages"); + /// ``` + pub fn typed_enum_fields(&mut self, path: impl AsRef) -> &mut Self { + self.typed_enum_fields.insert(path.as_ref().to_owned(), ()); + self + } + /// Configures the code generator to use the provided service generator. pub fn service_generator(&mut self, service_generator: Box) -> &mut Self { self.service_generator = Some(service_generator); @@ -1158,6 +1183,7 @@ impl default::Default for Config { enum_attributes: PathMap::default(), field_attributes: PathMap::default(), boxed: PathMap::default(), + typed_enum_fields: PathMap::default(), prost_types: true, strip_enum_prefix: true, out_dir: None, diff --git a/prost-derive/src/field/map.rs b/prost-derive/src/field/map.rs index a20a586ad..64842930e 100644 --- a/prost-derive/src/field/map.rs +++ b/prost-derive/src/field/map.rs @@ -4,7 +4,7 @@ use quote::quote; use syn::punctuated::Punctuated; use syn::{Expr, ExprLit, Ident, Lit, Meta, MetaNameValue, Token}; -use crate::field::{scalar, set_option, tag_attr}; +use crate::field::{scalar, set_option, tag_attr, EnumType}; #[derive(Clone, Debug)] pub enum MapTy { @@ -36,11 +36,12 @@ impl MapTy { } } -fn fake_scalar(ty: scalar::Ty) -> scalar::Field { +fn fake_scalar(ty: scalar::Ty, enum_type: Option) -> scalar::Field { let kind = scalar::Kind::Plain(scalar::DefaultValue::new(&ty)); scalar::Field { ty, kind, + enum_type, tag: 0, // Not used here } } @@ -50,6 +51,7 @@ pub struct Field { pub map_ty: MapTy, pub key_ty: scalar::Ty, pub value_ty: ValueTy, + pub enum_type: Option, pub tag: u32, } @@ -57,10 +59,13 @@ impl Field { pub fn new(attrs: &[Meta], inferred_tag: Option) -> Result, Error> { let mut types = None; let mut tag = None; + let mut enum_type = None; for attr in attrs { if let Some(t) = tag_attr(attr)? { set_option(&mut tag, t, "duplicate tag attributes")?; + } else if let Some(et) = EnumType::from_attr(attr)? { + set_option(&mut enum_type, et, "duplicate enum_type attributes")?; } else if let Some(map_ty) = attr .path() .get_ident() @@ -113,6 +118,7 @@ impl Field { map_ty, key_ty, value_ty, + enum_type, tag, }), _ => None, @@ -126,19 +132,19 @@ impl Field { /// Returns a statement which encodes the map field. pub fn encode(&self, ident: TokenStream) -> TokenStream { let tag = self.tag; - let key_mod = self.key_ty.module(); + let key_mod = self.key_ty.module(None); let ke = quote!(::prost::encoding::#key_mod::encode); let kl = quote!(::prost::encoding::#key_mod::encoded_len); let module = self.map_ty.module(); match &self.value_ty { - ValueTy::Scalar(scalar::Ty::Enumeration(ty)) => { - let default = quote!(::prost::OpenEnum::from(#ty::default())); + ValueTy::Scalar(scalar::Ty::Enumeration(ty)) if self.enum_type.is_none() => { + let default = quote!(#ty::default() as i32); quote! { ::prost::encoding::#module::encode_with_default( #ke, #kl, - ::prost::encoding::enumeration::encode, - ::prost::encoding::enumeration::encoded_len, + ::prost::encoding::int32::encode, + ::prost::encoding::int32::encoded_len, &(#default), #tag, &#ident, @@ -147,7 +153,7 @@ impl Field { } } ValueTy::Scalar(value_ty) => { - let val_mod = value_ty.module(); + let val_mod = value_ty.module(self.enum_type); let ve = quote!(::prost::encoding::#val_mod::encode); let vl = quote!(::prost::encoding::#val_mod::encoded_len); quote! { @@ -179,16 +185,16 @@ impl Field { /// Returns an expression which evaluates to the result of merging a decoded key value pair /// into the map. pub fn merge(&self, ident: TokenStream) -> TokenStream { - let key_mod = self.key_ty.module(); + let key_mod = self.key_ty.module(None); let km = quote!(::prost::encoding::#key_mod::merge); let module = self.map_ty.module(); match &self.value_ty { - ValueTy::Scalar(scalar::Ty::Enumeration(ty)) => { - let default = quote!(::prost::OpenEnum::from(#ty::default())); + ValueTy::Scalar(scalar::Ty::Enumeration(ty)) if self.enum_type.is_none() => { + let default = quote!(#ty::default() as i32); quote! { ::prost::encoding::#module::merge_with_default( #km, - ::prost::encoding::enumeration::merge, + ::prost::encoding::int32::merge, #default, &mut #ident, buf, @@ -197,7 +203,7 @@ impl Field { } } ValueTy::Scalar(value_ty) => { - let val_mod = value_ty.module(); + let val_mod = value_ty.module(self.enum_type); let vm = quote!(::prost::encoding::#val_mod::merge); quote!(::prost::encoding::#module::merge(#km, #vm, &mut #ident, buf, ctx)) } @@ -216,16 +222,16 @@ impl Field { /// Returns an expression which evaluates to the encoded length of the map. pub fn encoded_len(&self, ident: TokenStream) -> TokenStream { let tag = self.tag; - let key_mod = self.key_ty.module(); + let key_mod = self.key_ty.module(None); let kl = quote!(::prost::encoding::#key_mod::encoded_len); let module = self.map_ty.module(); match &self.value_ty { - ValueTy::Scalar(scalar::Ty::Enumeration(ty)) => { - let default = quote!(::prost::OpenEnum::from(#ty::default())); + ValueTy::Scalar(scalar::Ty::Enumeration(ty)) if self.enum_type.is_none() => { + let default = quote!(#ty::default() as i32); quote! { ::prost::encoding::#module::encoded_len_with_default( #kl, - ::prost::encoding::enumeration::encoded_len, + ::prost::encoding::int32::encoded_len, &(#default), #tag, &#ident, @@ -233,7 +239,7 @@ impl Field { } } ValueTy::Scalar(value_ty) => { - let val_mod = value_ty.module(); + let val_mod = value_ty.module(self.enum_type); let vl = quote!(::prost::encoding::#val_mod::encoded_len); quote!(::prost::encoding::#module::encoded_len(#kl, #vl, #tag, &#ident)) } @@ -254,36 +260,43 @@ impl Field { /// Returns methods to embed in the message. pub fn methods(&self, ident: &TokenStream) -> Option { - if let ValueTy::Scalar(scalar::Ty::Enumeration(ty)) = &self.value_ty { - let key_ty = self.key_ty.rust_type(); - let key_ref_ty = self.key_ty.rust_ref_type(); + match &self.value_ty { + ValueTy::Scalar(scalar::Ty::Enumeration(ty)) if self.enum_type.is_none() => { + let key_ty = self.key_ty.owned_type(None); + let key_ref_ty = self.key_ty.ref_type(); - let get = Ident::new(&format!("get_{}", ident), Span::call_site()); - let insert = Ident::new(&format!("insert_{}", ident), Span::call_site()); - let take_ref = if self.key_ty.is_numeric() { - quote!(&) - } else { - quote!() - }; + let get = Ident::new(&format!("get_{}", ident), Span::call_site()); + let insert = Ident::new(&format!("insert_{}", ident), Span::call_site()); + let take_ref = if self.key_ty.is_numeric() { + quote!(&) + } else { + quote!() + }; - let get_doc = format!( - "Returns the enum value for the corresponding key in `{}`, \ + let get_doc = format!( + "Returns the enum value for the corresponding key in `{}`, \ or `None` if the entry does not exist or it is not a valid enum value.", - ident, - ); - let insert_doc = format!("Inserts a key value pair into `{}`.", ident); - Some(quote! { - #[doc=#get_doc] - pub fn #get(&self, key: #key_ref_ty) -> ::core::option::Option<#ty> { - self.#ident.get(#take_ref key).cloned().and_then(|x| { x.known() }) - } - #[doc=#insert_doc] - pub fn #insert(&mut self, key: #key_ty, value: #ty) -> ::core::option::Option<#ty> { - self.#ident.insert(key, value.into()).and_then(|x| { x.known() }) - } - }) - } else { - None + ident, + ); + let insert_doc = format!("Inserts a key value pair into `{}`.", ident); + Some(quote! { + #[doc=#get_doc] + pub fn #get(&self, key: #key_ref_ty) -> ::core::option::Option<#ty> { + self.#ident.get(#take_ref key).cloned().and_then(|x| { + let result: ::core::result::Result<#ty, _> = ::core::convert::TryFrom::try_from(x); + result.ok() + }) + } + #[doc=#insert_doc] + pub fn #insert(&mut self, key: #key_ty, value: #ty) -> ::core::option::Option<#ty> { + self.#ident.insert(key, value as i32).and_then(|x| { + let result: ::core::result::Result<#ty, _> = ::core::convert::TryFrom::try_from(x); + result.ok() + }) + } + }) + } + _ => None, } } @@ -298,9 +311,9 @@ impl Field { }; // A fake field for generating the debug wrapper - let key_wrapper = fake_scalar(self.key_ty.clone()).debug(quote!(KeyWrapper)); - let key = self.key_ty.rust_type(); - let value_wrapper = self.value_ty.debug(); + let key_wrapper = fake_scalar(self.key_ty.clone(), None).debug(quote!(KeyWrapper)); + let key = self.key_ty.owned_type(None); + let value_wrapper = self.value_ty.debug(self.enum_type); let libname = self.map_ty.lib(); let fmt = quote! { fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { @@ -326,7 +339,7 @@ impl Field { }; } - let value = ty.rust_type(); + let value = ty.owned_type(self.enum_type); quote! { struct #wrapper_name<'a>(&'a ::#libname::collections::#type_name<#key, #value>); impl<'a> ::core::fmt::Debug for #wrapper_name<'a> { @@ -386,12 +399,14 @@ impl ValueTy { /// Returns a newtype wrapper around the ValueTy for nicer debug. /// + /// The generated implementation depends on the `enum_type` feature selection. /// If the contained value is enumeration, it tries to convert it to the variant. If not, it /// just forwards the implementation. - fn debug(&self) -> TokenStream { + fn debug(&self, enum_type: Option) -> TokenStream { match self { - ValueTy::Scalar(ty) => fake_scalar(ty.clone()).debug(quote!(ValueWrapper)), + ValueTy::Scalar(ty) => fake_scalar(ty.clone(), enum_type).debug(quote!(ValueWrapper)), ValueTy::Message => quote!( + #[allow(non_snake_case)] fn ValueWrapper(v: T) -> T { v } diff --git a/prost-derive/src/field/mod.rs b/prost-derive/src/field/mod.rs index 366075e45..fbf8c8738 100644 --- a/prost-derive/src/field/mod.rs +++ b/prost-derive/src/field/mod.rs @@ -11,7 +11,7 @@ use anyhow::{bail, Error}; use proc_macro2::TokenStream; use quote::quote; use syn::punctuated::Punctuated; -use syn::{Attribute, Expr, ExprLit, Lit, LitBool, LitInt, Meta, MetaNameValue, Token}; +use syn::{Attribute, Expr, ExprLit, Ident, Lit, LitBool, LitInt, Meta, MetaNameValue, Token}; #[derive(Clone)] pub enum Field { @@ -224,6 +224,63 @@ impl fmt::Display for Label { } } +#[derive(Clone, Copy, PartialEq, Eq)] +pub enum EnumType { + Closed, + Open, +} + +impl EnumType { + fn as_str(self) -> &'static str { + match self { + EnumType::Closed => "closed", + EnumType::Open => "open", + } + } + + /// Parses a meta attribute into the enum_type feature value. + /// If the attribute name does not match "enum_type", `None` is returned. + fn from_attr(attr: &Meta) -> Result, Error> { + if !attr.path().is_ident("enum_type") { + return Ok(None); + } + match attr { + Meta::NameValue(MetaNameValue { + value: Expr::Lit(ExprLit { lit, .. }), + .. + }) => { + if let Lit::Str(lit) = lit { + match lit.value().as_str() { + "open" => return Ok(Some(EnumType::Open)), + "closed" => return Ok(Some(EnumType::Closed)), + _ => {} + } + } + bail!("invalid value of enum_type attribute: {:?}", lit); + } + Meta::List(meta_list) => { + let ident = meta_list.parse_args::()?; + if ident == "open" { + return Ok(Some(EnumType::Open)); + } else if ident == "closed" { + return Ok(Some(EnumType::Closed)); + } else { + bail!("invalid content of enum_type attribute: {:?}", ident); + } + } + _ => { + bail!("invalid enum_type attribute: {:?}", attr); + } + } + } +} + +impl fmt::Debug for EnumType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} + /// Get the items belonging to the 'prost' list attribute, e.g. `#[prost(foo, bar="baz")]`. fn prost_attrs(attrs: Vec) -> Result, Error> { let mut result = Vec::new(); diff --git a/prost-derive/src/field/scalar.rs b/prost-derive/src/field/scalar.rs index 62a9d4fd7..966400da7 100644 --- a/prost-derive/src/field/scalar.rs +++ b/prost-derive/src/field/scalar.rs @@ -5,13 +5,14 @@ use proc_macro2::{Span, TokenStream}; use quote::{quote, ToTokens, TokenStreamExt}; use syn::{parse_str, Expr, ExprLit, Ident, Index, Lit, LitByteStr, Meta, MetaNameValue, Path}; -use crate::field::{bool_attr, set_option, tag_attr, Label}; +use crate::field::{bool_attr, set_option, tag_attr, EnumType, Label}; /// A scalar protobuf field. #[derive(Clone)] pub struct Field { pub ty: Ty, pub kind: Kind, + pub enum_type: Option, pub tag: u32, } @@ -20,6 +21,7 @@ impl Field { let mut ty = None; let mut label = None; let mut packed = None; + let mut enum_type = None; let mut default = None; let mut tag = None; @@ -34,6 +36,8 @@ impl Field { set_option(&mut tag, t, "duplicate tag attributes")?; } else if let Some(l) = Label::from_attr(attr) { set_option(&mut label, l, "duplicate label attributes")?; + } else if let Some(t) = EnumType::from_attr(attr)? { + set_option(&mut enum_type, t, "duplicate enum_type attributes")?; } else if let Some(d) = DefaultValue::from_attr(attr)? { set_option(&mut default, d, "duplicate default attributes")?; } else { @@ -86,7 +90,12 @@ impl Field { (Some(Label::Repeated), _, false) => Kind::Repeated, }; - Ok(Some(Field { ty, kind, tag })) + Ok(Some(Field { + ty, + kind, + enum_type, + tag, + })) } pub fn new_oneof(attrs: &[Meta]) -> Result, Error> { @@ -106,7 +115,7 @@ impl Field { } pub fn encode(&self, ident: TokenStream) -> TokenStream { - let module = self.ty.module(); + let module = self.ty.module(self.enum_type); let encode_fn = match self.kind { Kind::Plain(..) | Kind::Optional(..) | Kind::Required(..) => quote!(encode), Kind::Repeated => quote!(encode_repeated), @@ -117,7 +126,7 @@ impl Field { match self.kind { Kind::Plain(ref default) => { - let default = default.typed(); + let default = default.typed(self.enum_type); quote! { if #ident != #default { #encode_fn(#tag, &#ident, buf); @@ -138,7 +147,7 @@ impl Field { /// Returns an expression which evaluates to the result of merging a decoded /// scalar value into the field. pub fn merge(&self, ident: TokenStream) -> TokenStream { - let module = self.ty.module(); + let module = self.ty.module(self.enum_type); let merge_fn = match self.kind { Kind::Plain(..) | Kind::Optional(..) | Kind::Required(..) => quote!(merge), Kind::Repeated | Kind::Packed => quote!(merge_repeated), @@ -160,7 +169,7 @@ impl Field { /// Returns an expression which evaluates to the encoded length of the field. pub fn encoded_len(&self, ident: TokenStream) -> TokenStream { - let module = self.ty.module(); + let module = self.ty.module(self.enum_type); let encoded_len_fn = match self.kind { Kind::Plain(..) | Kind::Optional(..) | Kind::Required(..) => quote!(encoded_len), Kind::Repeated => quote!(encoded_len_repeated), @@ -171,7 +180,7 @@ impl Field { match self.kind { Kind::Plain(ref default) => { - let default = default.typed(); + let default = default.typed(self.enum_type); quote! { if #ident != #default { #encoded_len_fn(#tag, &#ident) @@ -192,7 +201,7 @@ impl Field { pub fn clear(&self, ident: TokenStream) -> TokenStream { match self.kind { Kind::Plain(ref default) | Kind::Required(ref default) => { - let default = default.typed(); + let default = default.typed(self.enum_type); match self.ty { Ty::String | Ty::Bytes(..) => quote!(#ident.clear()), _ => quote!(#ident = #default), @@ -206,7 +215,7 @@ impl Field { /// Returns an expression which evaluates to the default value of the field. pub fn default(&self) -> TokenStream { match self.kind { - Kind::Plain(ref value) | Kind::Required(ref value) => value.owned(), + Kind::Plain(ref value) | Kind::Required(ref value) => value.owned(self.enum_type), Kind::Optional(_) => quote!(::core::option::Option::None), Kind::Repeated | Kind::Packed => quote!(::prost::alloc::vec::Vec::new()), } @@ -214,30 +223,41 @@ impl Field { /// An inner debug wrapper, around the base type. fn debug_inner(&self, wrap_name: TokenStream) -> TokenStream { - if let Ty::Enumeration(ref ty) = self.ty { - quote! { + match (&self.ty, self.enum_type) { + (Ty::Enumeration(ty), None) => quote! { + struct #wrap_name<'a>(&'a i32); + impl<'a> ::core::fmt::Debug for #wrap_name<'a> { + fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { + let res: ::core::result::Result<#ty, _> = ::core::convert::TryFrom::try_from(*self.0); + match res { + Err(_) => ::core::fmt::Debug::fmt(&self.0, f), + Ok(en) => ::core::fmt::Debug::fmt(&en, f), + } + } + } + }, + (Ty::Enumeration(ty), Some(EnumType::Open)) => quote! { struct #wrap_name<'a>(&'a ::prost::OpenEnum<#ty>); impl<'a> ::core::fmt::Debug for #wrap_name<'a> { fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { - match self.0.known() { - Some(en) => ::core::fmt::Debug::fmt(&en, f), - None => ::core::fmt::Debug::fmt(&self.0, f), + match &self.0 { + ::prost::OpenEnum::Known(en) => ::core::fmt::Debug::fmt(en, f), + ::prost::OpenEnum::Unknown(_) => ::core::fmt::Debug::fmt(&self.0, f), } } } - } - } else { - quote! { + }, + _ => quote! { #[allow(non_snake_case)] fn #wrap_name(v: T) -> T { v } - } + }, } } /// Returns a fragment for formatting the field `ident` in `Debug`. pub fn debug(&self, wrapper_name: TokenStream) -> TokenStream { let wrapper = self.debug_inner(quote!(Inner)); - let inner_ty = self.ty.rust_type(); + let inner_ty = self.ty.owned_type(self.enum_type); match self.kind { Kind::Plain(_) | Kind::Required(_) => self.debug_inner(wrapper_name), Kind::Optional(_) => quote! { @@ -286,68 +306,112 @@ impl Field { if let Ty::Enumeration(ref ty) = self.ty { let set = Ident::new(&format!("set_{}", ident_str), Span::call_site()); let set_doc = format!("Sets `{}` to the provided enum value.", ident_str); - Some(match self.kind { - Kind::Plain(ref default) | Kind::Required(ref default) => { + match &self.kind { + Kind::Plain(default) | Kind::Required(default) if self.enum_type.is_none() => { let get_doc = format!( "Returns the enum value of `{}`, \ or the default if the field is set to an invalid enum value.", ident_str, ); - quote! { + Some(quote! { #[doc=#get_doc] pub fn #get(&self) -> #ty { - self.#ident.unwrap_or(#default) + ::core::convert::TryFrom::try_from(self.#ident).unwrap_or(#default) } #[doc=#set_doc] pub fn #set(&mut self, value: #ty) { - self.#ident = value.into(); + self.#ident = value as i32; } - } + }) } - Kind::Optional(ref default) => { + Kind::Optional(default) => { let get_doc = format!( "Returns the enum value of `{}`, \ - or the default if the field is unset or set to an invalid enum value.", + or the default if the field is unset{}.", ident_str, + if self.enum_type.is_some() { + "" + } else { + " or set to an invalid enum value" + }, ); - quote! { + let get_ty = match self.enum_type { + None | Some(EnumType::Closed) => quote!(#ty), + Some(EnumType::Open) => quote!(::prost::OpenEnum<#ty>), + }; + let (get_body, convert_set) = match self.enum_type { + None => ( + quote! { + self.#ident + .and_then(|value| { + let result: ::core::result::Result<#ty, _> = + ::core::convert::TryFrom::try_from(value); + result.ok() + }) + .unwrap_or(#default) + }, + quote! { + value as i32 + }, + ), + Some(EnumType::Open) => ( + quote! { + self.#ident.unwrap_or(::prost::OpenEnum::Known(#default)) + }, + quote! { + ::prost::OpenEnum::Known(value) + }, + ), + Some(EnumType::Closed) => ( + quote! { + self.#ident.unwrap_or(#default) + }, + quote! { + value + }, + ), + }; + Some(quote! { #[doc=#get_doc] - pub fn #get(&self) -> #ty { - self.#ident.and_then(|x| { x.known() }).unwrap_or(#default) + pub fn #get(&self) -> #get_ty { + #get_body } #[doc=#set_doc] pub fn #set(&mut self, value: #ty) { - self.#ident = ::core::option::Option::Some(value.into()); + self.#ident = ::core::option::Option::Some(#convert_set); } - } + }) } - Kind::Repeated | Kind::Packed => { + Kind::Repeated | Kind::Packed if self.enum_type.is_none() => { let iter_doc = format!( "Returns an iterator which yields the valid enum values contained in `{}`.", ident_str, ); let push = Ident::new(&format!("push_{}", ident_str), Span::call_site()); let push_doc = format!("Appends the provided enum value to `{}`.", ident_str); - let wrapped_ty = quote!(::prost::OpenEnum<#ty>); - quote! { + Some(quote! { #[doc=#iter_doc] pub fn #get(&self) -> ::core::iter::FilterMap< - ::core::iter::Cloned<::core::slice::Iter<#wrapped_ty>>, - fn(#wrapped_ty) -> ::core::option::Option<#ty>, + ::core::iter::Cloned<::core::slice::Iter>, + fn(i32) -> ::core::option::Option<#ty>, > { - self.#ident.iter().cloned().filter_map(|x| { x.known() }) + self.#ident.iter().cloned().filter_map(|x| { + let result: ::core::result::Result<#ty, _> = ::core::convert::TryFrom::try_from(x); + result.ok() + }) } #[doc=#push_doc] pub fn #push(&mut self, value: #ty) { - self.#ident.push(value.into()); + self.#ident.push(value as i32); } - } + }) } - }) + _ => None, + } } else if let Kind::Optional(ref default) = self.kind { - let ty = self.ty.rust_ref_type(); + let ty = self.ty.ref_type(); let match_some = if self.ty.is_numeric() { quote!(::core::option::Option::Some(val) => val,) @@ -522,18 +586,20 @@ impl Ty { } } - // TODO: rename to 'owned_type'. - pub fn rust_type(&self) -> TokenStream { + pub fn owned_type(&self, enum_type: Option) -> TokenStream { match self { Ty::String => quote!(::prost::alloc::string::String), Ty::Bytes(ty) => ty.rust_type(), - Ty::Enumeration(path) => quote!(::prost::OpenEnum<#path>), - _ => self.rust_ref_type(), + Ty::Enumeration(path) => match enum_type { + None => quote!(i32), + Some(EnumType::Open) => quote!(::prost::OpenEnum<#path>), + Some(EnumType::Closed) => quote!(#path), + }, + _ => self.ref_type(), } } - // TODO: rename to 'ref_type' - pub fn rust_ref_type(&self) -> TokenStream { + pub fn ref_type(&self) -> TokenStream { match self { Ty::Double => quote!(f64), Ty::Float => quote!(f32), @@ -550,15 +616,18 @@ impl Ty { Ty::Bool => quote!(bool), Ty::String => quote!(&str), Ty::Bytes(..) => quote!(&[u8]), - Ty::Enumeration(..) => unreachable!("an enum should never be queried for its ref type"), + Ty::Enumeration(..) => panic!("references to enum values are not used in derived code"), } } - pub fn module(&self) -> Ident { - match *self { - Ty::Enumeration(..) => Ident::new("enumeration", Span::call_site()), - _ => Ident::new(self.as_str(), Span::call_site()), - } + pub fn module(&self, enum_type: Option) -> Ident { + let name = match (self, enum_type) { + (Ty::Enumeration(..), None) => "int32", + (Ty::Enumeration(..), Some(EnumType::Open)) => "open_enum", + (Ty::Enumeration(..), Some(EnumType::Closed)) => "closed_enum", + _ => self.as_str(), + }; + Ident::new(name, Span::call_site()) } /// Returns false if the scalar type is length delimited (i.e., `string` or `bytes`). @@ -774,7 +843,7 @@ impl DefaultValue { } } - pub fn owned(&self) -> TokenStream { + pub fn owned(&self, enum_type: Option) -> TokenStream { match *self { DefaultValue::String(ref value) if value.is_empty() => { quote!(::prost::alloc::string::String::new()) @@ -788,15 +857,17 @@ impl DefaultValue { quote!(#lit.as_ref().into()) } - ref other => other.typed(), + ref other => other.typed(enum_type), } } - pub fn typed(&self) -> TokenStream { - if let DefaultValue::Enumeration(_) = *self { - quote!(::prost::OpenEnum::from(#self)) - } else { - quote!(#self) + pub fn typed(&self, enum_type: Option) -> TokenStream { + match (self, enum_type) { + (DefaultValue::Enumeration(..), None) => quote!(#self as i32), + (DefaultValue::Enumeration(..), Some(EnumType::Open)) => { + quote!(::prost::OpenEnum::Known(#self)) + } + _ => quote!(#self), } } } diff --git a/prost-types/src/protobuf.rs b/prost-types/src/protobuf.rs index 6c2b618a0..6f75dfc2b 100644 --- a/prost-types/src/protobuf.rs +++ b/prost-types/src/protobuf.rs @@ -113,11 +113,11 @@ pub struct FieldDescriptorProto { #[prost(int32, optional, tag = "3")] pub number: ::core::option::Option, #[prost(enumeration = "field_descriptor_proto::Label", optional, tag = "4")] - pub label: ::core::option::Option<::prost::OpenEnum>, + pub label: ::core::option::Option, /// If type_name is set, this need not be set. If both this and type_name /// are set, this must be one of TYPE_ENUM, TYPE_MESSAGE or TYPE_GROUP. #[prost(enumeration = "field_descriptor_proto::Type", optional, tag = "5")] - pub r#type: ::core::option::Option<::prost::OpenEnum>, + pub r#type: ::core::option::Option, /// For message and enum types, this is the name of the type. If the name /// starts with a '.', it is fully-qualified. Otherwise, C++-like scoping /// rules are used to find the type (i.e. first the nested types within this @@ -470,9 +470,7 @@ pub struct FileOptions { tag = "9", default = "Speed" )] - pub optimize_for: ::core::option::Option< - ::prost::OpenEnum, - >, + pub optimize_for: ::core::option::Option, /// Sets the Go package where structs generated from this .proto will be /// placed. If omitted, the Go package will be derived from the following: /// @@ -666,7 +664,7 @@ pub struct FieldOptions { tag = "1", default = "String" )] - pub ctype: ::core::option::Option<::prost::OpenEnum>, + pub ctype: ::core::option::Option, /// The packed option can be enabled for repeated primitive fields to enable /// a more efficient representation on the wire. Rather than repeatedly /// writing the tag and type for each element, the entire array is encoded as @@ -691,7 +689,7 @@ pub struct FieldOptions { tag = "6", default = "JsNormal" )] - pub jstype: ::core::option::Option<::prost::OpenEnum>, + pub jstype: ::core::option::Option, /// Should this field be parsed lazily? Lazy applies only to message-type /// fields. It means that when the outer message is initially parsed, the /// inner message's contents will not be parsed but instead stored in encoded @@ -879,9 +877,7 @@ pub struct MethodOptions { tag = "34", default = "IdempotencyUnknown" )] - pub idempotency_level: ::core::option::Option< - ::prost::OpenEnum, - >, + pub idempotency_level: ::core::option::Option, /// The parser stores options it doesn't recognize here. See above. #[prost(message, repeated, tag = "999")] pub uninterpreted_option: ::prost::alloc::vec::Vec, @@ -1306,17 +1302,17 @@ pub struct Type { pub source_context: ::core::option::Option, /// The source syntax. #[prost(enumeration = "Syntax", tag = "6")] - pub syntax: ::prost::OpenEnum, + pub syntax: i32, } /// A single field of a message type. #[derive(Clone, PartialEq, ::prost::Message)] pub struct Field { /// The field type. #[prost(enumeration = "field::Kind", tag = "1")] - pub kind: ::prost::OpenEnum, + pub kind: i32, /// The field cardinality. #[prost(enumeration = "field::Cardinality", tag = "2")] - pub cardinality: ::prost::OpenEnum, + pub cardinality: i32, /// The field number. #[prost(int32, tag = "3")] pub number: i32, @@ -1518,7 +1514,7 @@ pub struct Enum { pub source_context: ::core::option::Option, /// The source syntax. #[prost(enumeration = "Syntax", tag = "5")] - pub syntax: ::prost::OpenEnum, + pub syntax: i32, } /// Enum value definition. #[derive(Clone, PartialEq, ::prost::Message)] @@ -1630,7 +1626,7 @@ pub struct Api { pub mixins: ::prost::alloc::vec::Vec, /// The source syntax of the service. #[prost(enumeration = "Syntax", tag = "7")] - pub syntax: ::prost::OpenEnum, + pub syntax: i32, } /// Method represents a method of an API interface. #[derive(Clone, PartialEq, ::prost::Message)] @@ -1655,7 +1651,7 @@ pub struct Method { pub options: ::prost::alloc::vec::Vec