From cb5e59c4198fab36ba071421085b8f75d57f0fa8 Mon Sep 17 00:00:00 2001 From: Mikhail Zabaluev Date: Mon, 11 Nov 2024 11:50:03 +0200 Subject: [PATCH] feat: typed enum fields Add typed_enum_fields method to prost-build configuration, which allows type-checked representation of enumerations in fields of message structs and variants of oneof enums. The argument and the invocation order works like with the boxed method. Depending on the syntax (and preparing for the future support of editions), the type-checked representation can be closed (for proto2) or open (for proto3). The former is represented by the generated enum type itself, while the latter is represented by OpenEnum wrapping the enum type. A new enum_type annotation is supported in the prost attribute inside derives, which allows to specify the type-checked representation of enum types in message fields and oneof variants. The accepted values are "open" or "closed". --- prost-build/src/code_generator.rs | 89 +- prost-build/src/config.rs | 26 + prost-derive/src/field/map.rs | 117 +- prost-derive/src/field/mod.rs | 59 +- prost-derive/src/field/scalar.rs | 193 +- prost-types/src/protobuf.rs | 30 +- prost/src/encoding.rs | 136 +- prost/src/error.rs | 4 +- tests/expanded.rs | 24877 ++++++++++++++++++++++++++++ tests/src/custom_debug.rs | 141 +- tests/src/debug.rs | 311 +- tests/src/message_encoding.rs | 682 +- 12 files changed, 26210 insertions(+), 455 deletions(-) create mode 100644 tests/expanded.rs diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index 4014dcb99..b8e26dc2a 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -402,7 +402,7 @@ impl CodeGenerator<'_> { fn append_field(&mut self, fq_message_name: &str, field: &Field) { let type_ = field.descriptor.r#type(); - let repeated = field.descriptor.label.and_then(|v| v.known()) == Some(Label::Repeated); + let repeated = field.descriptor.label == Some(Label::Repeated as i32); let deprecated = self.deprecated(&field.descriptor); let optional = self.optional(&field.descriptor); let boxed = self.boxed(&field.descriptor, fq_message_name, None); @@ -427,15 +427,19 @@ impl CodeGenerator<'_> { let type_tag = self.field_type_tag(&field.descriptor); self.buf.push_str(&type_tag); - if type_ == Type::Bytes { - let bytes_type = self - .config - .bytes_type - .get_first_field(fq_message_name, field.descriptor.name()) - .copied() - .unwrap_or_default(); - self.buf - .push_str(&format!("={:?}", bytes_type.annotation())); + match type_ { + Type::Bytes => { + let bytes_type = self + .config + .bytes_type + .get_first_field(fq_message_name, field.descriptor.name()) + .copied() + .unwrap_or_default(); + self.buf + .push_str(&format!("={:?}", bytes_type.annotation())); + } + Type::Enum => self.push_enum_type_annotation(fq_message_name, field.descriptor.name()), + _ => {} } match field.descriptor.label() { @@ -555,12 +559,16 @@ impl CodeGenerator<'_> { let value_tag = self.map_value_type_tag(value); self.buf.push_str(&format!( - "#[prost({}=\"{}, {}\", tag=\"{}\")]\n", + "#[prost({}=\"{}, {}\"", map_type.annotation(), key_tag, value_tag, - field.descriptor.number() )); + if value.r#type() == Type::Enum { + self.push_enum_type_annotation(fq_message_name, field.descriptor.name()); + } + self.buf + .push_str(&format!(", tag=\"{}\")]\n", field.descriptor.number())); self.append_field_attributes(fq_message_name, field.descriptor.name()); self.push_indent(); self.buf.push_str(&format!( @@ -639,11 +647,12 @@ impl CodeGenerator<'_> { self.push_indent(); let ty_tag = self.field_type_tag(&field.descriptor); - self.buf.push_str(&format!( - "#[prost({}, tag=\"{}\")]\n", - ty_tag, - field.descriptor.number() - )); + self.buf.push_str(&format!("#[prost({}", ty_tag,)); + if field.descriptor.r#type() == Type::Enum { + self.push_enum_type_annotation(&oneof_name, field.descriptor.name()); + } + self.buf + .push_str(&format!(", tag=\"{}\")]\n", field.descriptor.number())); self.append_field_attributes(&oneof_name, field.descriptor.name()); self.push_indent(); @@ -947,6 +956,14 @@ impl CodeGenerator<'_> { self.buf.push_str("}\n"); } + fn push_enum_type_annotation(&mut self, fq_message_name: &str, field_name: &str) { + match self.enum_field_repr(fq_message_name, field_name) { + EnumRepr::Int => {} + EnumRepr::Open => self.buf.push_str(", enum_type=\"open\""), + EnumRepr::Closed => self.buf.push_str(", enum_type=\"closed\""), + } + } + fn resolve_type(&self, field: &FieldDescriptorProto, fq_message_name: &str) -> String { match field.r#type() { Type::Float => String::from("f32"), @@ -966,11 +983,15 @@ impl CodeGenerator<'_> { .rust_type() .to_owned(), Type::Group | Type::Message => self.resolve_ident(field.type_name()), - Type::Enum => format!( - "{}::OpenEnum<{}>", - prost_path(self.config), - self.resolve_ident(field.type_name()) - ), + Type::Enum => match self.enum_field_repr(fq_message_name, field.name()) { + EnumRepr::Int => String::from("i32"), + EnumRepr::Open => format!( + "{}::OpenEnum<{}>", + prost_path(self.config), + self.resolve_ident(field.type_name()) + ), + EnumRepr::Closed => self.resolve_ident(field.type_name()), + }, } } @@ -1012,6 +1033,22 @@ impl CodeGenerator<'_> { .join("::") } + fn enum_field_repr(&self, fq_message_name: &str, field_name: &str) -> EnumRepr { + if self + .config + .typed_enum_fields + .get_first_field(fq_message_name, field_name) + .is_some() + { + match self.syntax { + Syntax::Proto2 => EnumRepr::Closed, + Syntax::Proto3 => EnumRepr::Open, + } + } else { + EnumRepr::Int + } + } + fn field_type_tag(&self, field: &FieldDescriptorProto) -> Cow<'static, str> { match field.r#type() { Type::Float => Cow::Borrowed("float"), @@ -1074,7 +1111,7 @@ impl CodeGenerator<'_> { fq_message_name: &str, oneof: Option<&str>, ) -> bool { - let repeated = field.label.and_then(|v| v.known()) == Some(Label::Repeated); + let repeated = field.label == Some(Label::Repeated as i32); let fd_type = field.r#type(); if !repeated && (fd_type == Type::Message || fd_type == Type::Group) @@ -1148,6 +1185,12 @@ fn can_pack(field: &FieldDescriptorProto) -> bool { ) } +enum EnumRepr { + Int, + Closed, + Open, +} + struct EnumVariantMapping<'a> { path_idx: usize, proto_name: &'a str, diff --git a/prost-build/src/config.rs b/prost-build/src/config.rs index 896726b16..5d555d12e 100644 --- a/prost-build/src/config.rs +++ b/prost-build/src/config.rs @@ -36,6 +36,7 @@ pub struct Config { pub(crate) enum_attributes: PathMap, pub(crate) field_attributes: PathMap, pub(crate) boxed: PathMap<()>, + pub(crate) typed_enum_fields: PathMap<()>, pub(crate) prost_types: bool, pub(crate) strip_enum_prefix: bool, pub(crate) out_dir: Option, @@ -372,6 +373,30 @@ impl Config { self } + /// Represent Protobuf enum types encountered in matched fields with types + /// bound to their corresponding Rust enum types, rather than the default `i32`. + /// + /// Depending on the proto file syntax, the representation type can be: + /// * For closed enums (in proto2), the corresponding Rust enum type. + /// * For open enums (in proto3), the Rust enum type wrapped in [`OpenEnum`]. + /// + /// # Arguments + /// + /// **`path`** - a path matching any number of fields. These fields will get the type-checked + /// enum representation. + /// For details about matching fields see [`btree_map`](#method.btree_map). + /// + /// # Examples + /// + /// ```rust + /// # let mut config = prost_build::Config::new(); + /// config.typed_enum_fields(".my_messages"); + /// ``` + pub fn typed_enum_fields(&mut self, path: impl AsRef) -> &mut Self { + self.typed_enum_fields.insert(path.as_ref().to_owned(), ()); + self + } + /// Configures the code generator to use the provided service generator. pub fn service_generator(&mut self, service_generator: Box) -> &mut Self { self.service_generator = Some(service_generator); @@ -1158,6 +1183,7 @@ impl default::Default for Config { enum_attributes: PathMap::default(), field_attributes: PathMap::default(), boxed: PathMap::default(), + typed_enum_fields: PathMap::default(), prost_types: true, strip_enum_prefix: true, out_dir: None, diff --git a/prost-derive/src/field/map.rs b/prost-derive/src/field/map.rs index a20a586ad..64842930e 100644 --- a/prost-derive/src/field/map.rs +++ b/prost-derive/src/field/map.rs @@ -4,7 +4,7 @@ use quote::quote; use syn::punctuated::Punctuated; use syn::{Expr, ExprLit, Ident, Lit, Meta, MetaNameValue, Token}; -use crate::field::{scalar, set_option, tag_attr}; +use crate::field::{scalar, set_option, tag_attr, EnumType}; #[derive(Clone, Debug)] pub enum MapTy { @@ -36,11 +36,12 @@ impl MapTy { } } -fn fake_scalar(ty: scalar::Ty) -> scalar::Field { +fn fake_scalar(ty: scalar::Ty, enum_type: Option) -> scalar::Field { let kind = scalar::Kind::Plain(scalar::DefaultValue::new(&ty)); scalar::Field { ty, kind, + enum_type, tag: 0, // Not used here } } @@ -50,6 +51,7 @@ pub struct Field { pub map_ty: MapTy, pub key_ty: scalar::Ty, pub value_ty: ValueTy, + pub enum_type: Option, pub tag: u32, } @@ -57,10 +59,13 @@ impl Field { pub fn new(attrs: &[Meta], inferred_tag: Option) -> Result, Error> { let mut types = None; let mut tag = None; + let mut enum_type = None; for attr in attrs { if let Some(t) = tag_attr(attr)? { set_option(&mut tag, t, "duplicate tag attributes")?; + } else if let Some(et) = EnumType::from_attr(attr)? { + set_option(&mut enum_type, et, "duplicate enum_type attributes")?; } else if let Some(map_ty) = attr .path() .get_ident() @@ -113,6 +118,7 @@ impl Field { map_ty, key_ty, value_ty, + enum_type, tag, }), _ => None, @@ -126,19 +132,19 @@ impl Field { /// Returns a statement which encodes the map field. pub fn encode(&self, ident: TokenStream) -> TokenStream { let tag = self.tag; - let key_mod = self.key_ty.module(); + let key_mod = self.key_ty.module(None); let ke = quote!(::prost::encoding::#key_mod::encode); let kl = quote!(::prost::encoding::#key_mod::encoded_len); let module = self.map_ty.module(); match &self.value_ty { - ValueTy::Scalar(scalar::Ty::Enumeration(ty)) => { - let default = quote!(::prost::OpenEnum::from(#ty::default())); + ValueTy::Scalar(scalar::Ty::Enumeration(ty)) if self.enum_type.is_none() => { + let default = quote!(#ty::default() as i32); quote! { ::prost::encoding::#module::encode_with_default( #ke, #kl, - ::prost::encoding::enumeration::encode, - ::prost::encoding::enumeration::encoded_len, + ::prost::encoding::int32::encode, + ::prost::encoding::int32::encoded_len, &(#default), #tag, &#ident, @@ -147,7 +153,7 @@ impl Field { } } ValueTy::Scalar(value_ty) => { - let val_mod = value_ty.module(); + let val_mod = value_ty.module(self.enum_type); let ve = quote!(::prost::encoding::#val_mod::encode); let vl = quote!(::prost::encoding::#val_mod::encoded_len); quote! { @@ -179,16 +185,16 @@ impl Field { /// Returns an expression which evaluates to the result of merging a decoded key value pair /// into the map. pub fn merge(&self, ident: TokenStream) -> TokenStream { - let key_mod = self.key_ty.module(); + let key_mod = self.key_ty.module(None); let km = quote!(::prost::encoding::#key_mod::merge); let module = self.map_ty.module(); match &self.value_ty { - ValueTy::Scalar(scalar::Ty::Enumeration(ty)) => { - let default = quote!(::prost::OpenEnum::from(#ty::default())); + ValueTy::Scalar(scalar::Ty::Enumeration(ty)) if self.enum_type.is_none() => { + let default = quote!(#ty::default() as i32); quote! { ::prost::encoding::#module::merge_with_default( #km, - ::prost::encoding::enumeration::merge, + ::prost::encoding::int32::merge, #default, &mut #ident, buf, @@ -197,7 +203,7 @@ impl Field { } } ValueTy::Scalar(value_ty) => { - let val_mod = value_ty.module(); + let val_mod = value_ty.module(self.enum_type); let vm = quote!(::prost::encoding::#val_mod::merge); quote!(::prost::encoding::#module::merge(#km, #vm, &mut #ident, buf, ctx)) } @@ -216,16 +222,16 @@ impl Field { /// Returns an expression which evaluates to the encoded length of the map. pub fn encoded_len(&self, ident: TokenStream) -> TokenStream { let tag = self.tag; - let key_mod = self.key_ty.module(); + let key_mod = self.key_ty.module(None); let kl = quote!(::prost::encoding::#key_mod::encoded_len); let module = self.map_ty.module(); match &self.value_ty { - ValueTy::Scalar(scalar::Ty::Enumeration(ty)) => { - let default = quote!(::prost::OpenEnum::from(#ty::default())); + ValueTy::Scalar(scalar::Ty::Enumeration(ty)) if self.enum_type.is_none() => { + let default = quote!(#ty::default() as i32); quote! { ::prost::encoding::#module::encoded_len_with_default( #kl, - ::prost::encoding::enumeration::encoded_len, + ::prost::encoding::int32::encoded_len, &(#default), #tag, &#ident, @@ -233,7 +239,7 @@ impl Field { } } ValueTy::Scalar(value_ty) => { - let val_mod = value_ty.module(); + let val_mod = value_ty.module(self.enum_type); let vl = quote!(::prost::encoding::#val_mod::encoded_len); quote!(::prost::encoding::#module::encoded_len(#kl, #vl, #tag, &#ident)) } @@ -254,36 +260,43 @@ impl Field { /// Returns methods to embed in the message. pub fn methods(&self, ident: &TokenStream) -> Option { - if let ValueTy::Scalar(scalar::Ty::Enumeration(ty)) = &self.value_ty { - let key_ty = self.key_ty.rust_type(); - let key_ref_ty = self.key_ty.rust_ref_type(); + match &self.value_ty { + ValueTy::Scalar(scalar::Ty::Enumeration(ty)) if self.enum_type.is_none() => { + let key_ty = self.key_ty.owned_type(None); + let key_ref_ty = self.key_ty.ref_type(); - let get = Ident::new(&format!("get_{}", ident), Span::call_site()); - let insert = Ident::new(&format!("insert_{}", ident), Span::call_site()); - let take_ref = if self.key_ty.is_numeric() { - quote!(&) - } else { - quote!() - }; + let get = Ident::new(&format!("get_{}", ident), Span::call_site()); + let insert = Ident::new(&format!("insert_{}", ident), Span::call_site()); + let take_ref = if self.key_ty.is_numeric() { + quote!(&) + } else { + quote!() + }; - let get_doc = format!( - "Returns the enum value for the corresponding key in `{}`, \ + let get_doc = format!( + "Returns the enum value for the corresponding key in `{}`, \ or `None` if the entry does not exist or it is not a valid enum value.", - ident, - ); - let insert_doc = format!("Inserts a key value pair into `{}`.", ident); - Some(quote! { - #[doc=#get_doc] - pub fn #get(&self, key: #key_ref_ty) -> ::core::option::Option<#ty> { - self.#ident.get(#take_ref key).cloned().and_then(|x| { x.known() }) - } - #[doc=#insert_doc] - pub fn #insert(&mut self, key: #key_ty, value: #ty) -> ::core::option::Option<#ty> { - self.#ident.insert(key, value.into()).and_then(|x| { x.known() }) - } - }) - } else { - None + ident, + ); + let insert_doc = format!("Inserts a key value pair into `{}`.", ident); + Some(quote! { + #[doc=#get_doc] + pub fn #get(&self, key: #key_ref_ty) -> ::core::option::Option<#ty> { + self.#ident.get(#take_ref key).cloned().and_then(|x| { + let result: ::core::result::Result<#ty, _> = ::core::convert::TryFrom::try_from(x); + result.ok() + }) + } + #[doc=#insert_doc] + pub fn #insert(&mut self, key: #key_ty, value: #ty) -> ::core::option::Option<#ty> { + self.#ident.insert(key, value as i32).and_then(|x| { + let result: ::core::result::Result<#ty, _> = ::core::convert::TryFrom::try_from(x); + result.ok() + }) + } + }) + } + _ => None, } } @@ -298,9 +311,9 @@ impl Field { }; // A fake field for generating the debug wrapper - let key_wrapper = fake_scalar(self.key_ty.clone()).debug(quote!(KeyWrapper)); - let key = self.key_ty.rust_type(); - let value_wrapper = self.value_ty.debug(); + let key_wrapper = fake_scalar(self.key_ty.clone(), None).debug(quote!(KeyWrapper)); + let key = self.key_ty.owned_type(None); + let value_wrapper = self.value_ty.debug(self.enum_type); let libname = self.map_ty.lib(); let fmt = quote! { fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { @@ -326,7 +339,7 @@ impl Field { }; } - let value = ty.rust_type(); + let value = ty.owned_type(self.enum_type); quote! { struct #wrapper_name<'a>(&'a ::#libname::collections::#type_name<#key, #value>); impl<'a> ::core::fmt::Debug for #wrapper_name<'a> { @@ -386,12 +399,14 @@ impl ValueTy { /// Returns a newtype wrapper around the ValueTy for nicer debug. /// + /// The generated implementation depends on the `enum_type` feature selection. /// If the contained value is enumeration, it tries to convert it to the variant. If not, it /// just forwards the implementation. - fn debug(&self) -> TokenStream { + fn debug(&self, enum_type: Option) -> TokenStream { match self { - ValueTy::Scalar(ty) => fake_scalar(ty.clone()).debug(quote!(ValueWrapper)), + ValueTy::Scalar(ty) => fake_scalar(ty.clone(), enum_type).debug(quote!(ValueWrapper)), ValueTy::Message => quote!( + #[allow(non_snake_case)] fn ValueWrapper(v: T) -> T { v } diff --git a/prost-derive/src/field/mod.rs b/prost-derive/src/field/mod.rs index 366075e45..fbf8c8738 100644 --- a/prost-derive/src/field/mod.rs +++ b/prost-derive/src/field/mod.rs @@ -11,7 +11,7 @@ use anyhow::{bail, Error}; use proc_macro2::TokenStream; use quote::quote; use syn::punctuated::Punctuated; -use syn::{Attribute, Expr, ExprLit, Lit, LitBool, LitInt, Meta, MetaNameValue, Token}; +use syn::{Attribute, Expr, ExprLit, Ident, Lit, LitBool, LitInt, Meta, MetaNameValue, Token}; #[derive(Clone)] pub enum Field { @@ -224,6 +224,63 @@ impl fmt::Display for Label { } } +#[derive(Clone, Copy, PartialEq, Eq)] +pub enum EnumType { + Closed, + Open, +} + +impl EnumType { + fn as_str(self) -> &'static str { + match self { + EnumType::Closed => "closed", + EnumType::Open => "open", + } + } + + /// Parses a meta attribute into the enum_type feature value. + /// If the attribute name does not match "enum_type", `None` is returned. + fn from_attr(attr: &Meta) -> Result, Error> { + if !attr.path().is_ident("enum_type") { + return Ok(None); + } + match attr { + Meta::NameValue(MetaNameValue { + value: Expr::Lit(ExprLit { lit, .. }), + .. + }) => { + if let Lit::Str(lit) = lit { + match lit.value().as_str() { + "open" => return Ok(Some(EnumType::Open)), + "closed" => return Ok(Some(EnumType::Closed)), + _ => {} + } + } + bail!("invalid value of enum_type attribute: {:?}", lit); + } + Meta::List(meta_list) => { + let ident = meta_list.parse_args::()?; + if ident == "open" { + return Ok(Some(EnumType::Open)); + } else if ident == "closed" { + return Ok(Some(EnumType::Closed)); + } else { + bail!("invalid content of enum_type attribute: {:?}", ident); + } + } + _ => { + bail!("invalid enum_type attribute: {:?}", attr); + } + } + } +} + +impl fmt::Debug for EnumType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} + /// Get the items belonging to the 'prost' list attribute, e.g. `#[prost(foo, bar="baz")]`. fn prost_attrs(attrs: Vec) -> Result, Error> { let mut result = Vec::new(); diff --git a/prost-derive/src/field/scalar.rs b/prost-derive/src/field/scalar.rs index 62a9d4fd7..966400da7 100644 --- a/prost-derive/src/field/scalar.rs +++ b/prost-derive/src/field/scalar.rs @@ -5,13 +5,14 @@ use proc_macro2::{Span, TokenStream}; use quote::{quote, ToTokens, TokenStreamExt}; use syn::{parse_str, Expr, ExprLit, Ident, Index, Lit, LitByteStr, Meta, MetaNameValue, Path}; -use crate::field::{bool_attr, set_option, tag_attr, Label}; +use crate::field::{bool_attr, set_option, tag_attr, EnumType, Label}; /// A scalar protobuf field. #[derive(Clone)] pub struct Field { pub ty: Ty, pub kind: Kind, + pub enum_type: Option, pub tag: u32, } @@ -20,6 +21,7 @@ impl Field { let mut ty = None; let mut label = None; let mut packed = None; + let mut enum_type = None; let mut default = None; let mut tag = None; @@ -34,6 +36,8 @@ impl Field { set_option(&mut tag, t, "duplicate tag attributes")?; } else if let Some(l) = Label::from_attr(attr) { set_option(&mut label, l, "duplicate label attributes")?; + } else if let Some(t) = EnumType::from_attr(attr)? { + set_option(&mut enum_type, t, "duplicate enum_type attributes")?; } else if let Some(d) = DefaultValue::from_attr(attr)? { set_option(&mut default, d, "duplicate default attributes")?; } else { @@ -86,7 +90,12 @@ impl Field { (Some(Label::Repeated), _, false) => Kind::Repeated, }; - Ok(Some(Field { ty, kind, tag })) + Ok(Some(Field { + ty, + kind, + enum_type, + tag, + })) } pub fn new_oneof(attrs: &[Meta]) -> Result, Error> { @@ -106,7 +115,7 @@ impl Field { } pub fn encode(&self, ident: TokenStream) -> TokenStream { - let module = self.ty.module(); + let module = self.ty.module(self.enum_type); let encode_fn = match self.kind { Kind::Plain(..) | Kind::Optional(..) | Kind::Required(..) => quote!(encode), Kind::Repeated => quote!(encode_repeated), @@ -117,7 +126,7 @@ impl Field { match self.kind { Kind::Plain(ref default) => { - let default = default.typed(); + let default = default.typed(self.enum_type); quote! { if #ident != #default { #encode_fn(#tag, &#ident, buf); @@ -138,7 +147,7 @@ impl Field { /// Returns an expression which evaluates to the result of merging a decoded /// scalar value into the field. pub fn merge(&self, ident: TokenStream) -> TokenStream { - let module = self.ty.module(); + let module = self.ty.module(self.enum_type); let merge_fn = match self.kind { Kind::Plain(..) | Kind::Optional(..) | Kind::Required(..) => quote!(merge), Kind::Repeated | Kind::Packed => quote!(merge_repeated), @@ -160,7 +169,7 @@ impl Field { /// Returns an expression which evaluates to the encoded length of the field. pub fn encoded_len(&self, ident: TokenStream) -> TokenStream { - let module = self.ty.module(); + let module = self.ty.module(self.enum_type); let encoded_len_fn = match self.kind { Kind::Plain(..) | Kind::Optional(..) | Kind::Required(..) => quote!(encoded_len), Kind::Repeated => quote!(encoded_len_repeated), @@ -171,7 +180,7 @@ impl Field { match self.kind { Kind::Plain(ref default) => { - let default = default.typed(); + let default = default.typed(self.enum_type); quote! { if #ident != #default { #encoded_len_fn(#tag, &#ident) @@ -192,7 +201,7 @@ impl Field { pub fn clear(&self, ident: TokenStream) -> TokenStream { match self.kind { Kind::Plain(ref default) | Kind::Required(ref default) => { - let default = default.typed(); + let default = default.typed(self.enum_type); match self.ty { Ty::String | Ty::Bytes(..) => quote!(#ident.clear()), _ => quote!(#ident = #default), @@ -206,7 +215,7 @@ impl Field { /// Returns an expression which evaluates to the default value of the field. pub fn default(&self) -> TokenStream { match self.kind { - Kind::Plain(ref value) | Kind::Required(ref value) => value.owned(), + Kind::Plain(ref value) | Kind::Required(ref value) => value.owned(self.enum_type), Kind::Optional(_) => quote!(::core::option::Option::None), Kind::Repeated | Kind::Packed => quote!(::prost::alloc::vec::Vec::new()), } @@ -214,30 +223,41 @@ impl Field { /// An inner debug wrapper, around the base type. fn debug_inner(&self, wrap_name: TokenStream) -> TokenStream { - if let Ty::Enumeration(ref ty) = self.ty { - quote! { + match (&self.ty, self.enum_type) { + (Ty::Enumeration(ty), None) => quote! { + struct #wrap_name<'a>(&'a i32); + impl<'a> ::core::fmt::Debug for #wrap_name<'a> { + fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { + let res: ::core::result::Result<#ty, _> = ::core::convert::TryFrom::try_from(*self.0); + match res { + Err(_) => ::core::fmt::Debug::fmt(&self.0, f), + Ok(en) => ::core::fmt::Debug::fmt(&en, f), + } + } + } + }, + (Ty::Enumeration(ty), Some(EnumType::Open)) => quote! { struct #wrap_name<'a>(&'a ::prost::OpenEnum<#ty>); impl<'a> ::core::fmt::Debug for #wrap_name<'a> { fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { - match self.0.known() { - Some(en) => ::core::fmt::Debug::fmt(&en, f), - None => ::core::fmt::Debug::fmt(&self.0, f), + match &self.0 { + ::prost::OpenEnum::Known(en) => ::core::fmt::Debug::fmt(en, f), + ::prost::OpenEnum::Unknown(_) => ::core::fmt::Debug::fmt(&self.0, f), } } } - } - } else { - quote! { + }, + _ => quote! { #[allow(non_snake_case)] fn #wrap_name(v: T) -> T { v } - } + }, } } /// Returns a fragment for formatting the field `ident` in `Debug`. pub fn debug(&self, wrapper_name: TokenStream) -> TokenStream { let wrapper = self.debug_inner(quote!(Inner)); - let inner_ty = self.ty.rust_type(); + let inner_ty = self.ty.owned_type(self.enum_type); match self.kind { Kind::Plain(_) | Kind::Required(_) => self.debug_inner(wrapper_name), Kind::Optional(_) => quote! { @@ -286,68 +306,112 @@ impl Field { if let Ty::Enumeration(ref ty) = self.ty { let set = Ident::new(&format!("set_{}", ident_str), Span::call_site()); let set_doc = format!("Sets `{}` to the provided enum value.", ident_str); - Some(match self.kind { - Kind::Plain(ref default) | Kind::Required(ref default) => { + match &self.kind { + Kind::Plain(default) | Kind::Required(default) if self.enum_type.is_none() => { let get_doc = format!( "Returns the enum value of `{}`, \ or the default if the field is set to an invalid enum value.", ident_str, ); - quote! { + Some(quote! { #[doc=#get_doc] pub fn #get(&self) -> #ty { - self.#ident.unwrap_or(#default) + ::core::convert::TryFrom::try_from(self.#ident).unwrap_or(#default) } #[doc=#set_doc] pub fn #set(&mut self, value: #ty) { - self.#ident = value.into(); + self.#ident = value as i32; } - } + }) } - Kind::Optional(ref default) => { + Kind::Optional(default) => { let get_doc = format!( "Returns the enum value of `{}`, \ - or the default if the field is unset or set to an invalid enum value.", + or the default if the field is unset{}.", ident_str, + if self.enum_type.is_some() { + "" + } else { + " or set to an invalid enum value" + }, ); - quote! { + let get_ty = match self.enum_type { + None | Some(EnumType::Closed) => quote!(#ty), + Some(EnumType::Open) => quote!(::prost::OpenEnum<#ty>), + }; + let (get_body, convert_set) = match self.enum_type { + None => ( + quote! { + self.#ident + .and_then(|value| { + let result: ::core::result::Result<#ty, _> = + ::core::convert::TryFrom::try_from(value); + result.ok() + }) + .unwrap_or(#default) + }, + quote! { + value as i32 + }, + ), + Some(EnumType::Open) => ( + quote! { + self.#ident.unwrap_or(::prost::OpenEnum::Known(#default)) + }, + quote! { + ::prost::OpenEnum::Known(value) + }, + ), + Some(EnumType::Closed) => ( + quote! { + self.#ident.unwrap_or(#default) + }, + quote! { + value + }, + ), + }; + Some(quote! { #[doc=#get_doc] - pub fn #get(&self) -> #ty { - self.#ident.and_then(|x| { x.known() }).unwrap_or(#default) + pub fn #get(&self) -> #get_ty { + #get_body } #[doc=#set_doc] pub fn #set(&mut self, value: #ty) { - self.#ident = ::core::option::Option::Some(value.into()); + self.#ident = ::core::option::Option::Some(#convert_set); } - } + }) } - Kind::Repeated | Kind::Packed => { + Kind::Repeated | Kind::Packed if self.enum_type.is_none() => { let iter_doc = format!( "Returns an iterator which yields the valid enum values contained in `{}`.", ident_str, ); let push = Ident::new(&format!("push_{}", ident_str), Span::call_site()); let push_doc = format!("Appends the provided enum value to `{}`.", ident_str); - let wrapped_ty = quote!(::prost::OpenEnum<#ty>); - quote! { + Some(quote! { #[doc=#iter_doc] pub fn #get(&self) -> ::core::iter::FilterMap< - ::core::iter::Cloned<::core::slice::Iter<#wrapped_ty>>, - fn(#wrapped_ty) -> ::core::option::Option<#ty>, + ::core::iter::Cloned<::core::slice::Iter>, + fn(i32) -> ::core::option::Option<#ty>, > { - self.#ident.iter().cloned().filter_map(|x| { x.known() }) + self.#ident.iter().cloned().filter_map(|x| { + let result: ::core::result::Result<#ty, _> = ::core::convert::TryFrom::try_from(x); + result.ok() + }) } #[doc=#push_doc] pub fn #push(&mut self, value: #ty) { - self.#ident.push(value.into()); + self.#ident.push(value as i32); } - } + }) } - }) + _ => None, + } } else if let Kind::Optional(ref default) = self.kind { - let ty = self.ty.rust_ref_type(); + let ty = self.ty.ref_type(); let match_some = if self.ty.is_numeric() { quote!(::core::option::Option::Some(val) => val,) @@ -522,18 +586,20 @@ impl Ty { } } - // TODO: rename to 'owned_type'. - pub fn rust_type(&self) -> TokenStream { + pub fn owned_type(&self, enum_type: Option) -> TokenStream { match self { Ty::String => quote!(::prost::alloc::string::String), Ty::Bytes(ty) => ty.rust_type(), - Ty::Enumeration(path) => quote!(::prost::OpenEnum<#path>), - _ => self.rust_ref_type(), + Ty::Enumeration(path) => match enum_type { + None => quote!(i32), + Some(EnumType::Open) => quote!(::prost::OpenEnum<#path>), + Some(EnumType::Closed) => quote!(#path), + }, + _ => self.ref_type(), } } - // TODO: rename to 'ref_type' - pub fn rust_ref_type(&self) -> TokenStream { + pub fn ref_type(&self) -> TokenStream { match self { Ty::Double => quote!(f64), Ty::Float => quote!(f32), @@ -550,15 +616,18 @@ impl Ty { Ty::Bool => quote!(bool), Ty::String => quote!(&str), Ty::Bytes(..) => quote!(&[u8]), - Ty::Enumeration(..) => unreachable!("an enum should never be queried for its ref type"), + Ty::Enumeration(..) => panic!("references to enum values are not used in derived code"), } } - pub fn module(&self) -> Ident { - match *self { - Ty::Enumeration(..) => Ident::new("enumeration", Span::call_site()), - _ => Ident::new(self.as_str(), Span::call_site()), - } + pub fn module(&self, enum_type: Option) -> Ident { + let name = match (self, enum_type) { + (Ty::Enumeration(..), None) => "int32", + (Ty::Enumeration(..), Some(EnumType::Open)) => "open_enum", + (Ty::Enumeration(..), Some(EnumType::Closed)) => "closed_enum", + _ => self.as_str(), + }; + Ident::new(name, Span::call_site()) } /// Returns false if the scalar type is length delimited (i.e., `string` or `bytes`). @@ -774,7 +843,7 @@ impl DefaultValue { } } - pub fn owned(&self) -> TokenStream { + pub fn owned(&self, enum_type: Option) -> TokenStream { match *self { DefaultValue::String(ref value) if value.is_empty() => { quote!(::prost::alloc::string::String::new()) @@ -788,15 +857,17 @@ impl DefaultValue { quote!(#lit.as_ref().into()) } - ref other => other.typed(), + ref other => other.typed(enum_type), } } - pub fn typed(&self) -> TokenStream { - if let DefaultValue::Enumeration(_) = *self { - quote!(::prost::OpenEnum::from(#self)) - } else { - quote!(#self) + pub fn typed(&self, enum_type: Option) -> TokenStream { + match (self, enum_type) { + (DefaultValue::Enumeration(..), None) => quote!(#self as i32), + (DefaultValue::Enumeration(..), Some(EnumType::Open)) => { + quote!(::prost::OpenEnum::Known(#self)) + } + _ => quote!(#self), } } } diff --git a/prost-types/src/protobuf.rs b/prost-types/src/protobuf.rs index 6c2b618a0..6f75dfc2b 100644 --- a/prost-types/src/protobuf.rs +++ b/prost-types/src/protobuf.rs @@ -113,11 +113,11 @@ pub struct FieldDescriptorProto { #[prost(int32, optional, tag = "3")] pub number: ::core::option::Option, #[prost(enumeration = "field_descriptor_proto::Label", optional, tag = "4")] - pub label: ::core::option::Option<::prost::OpenEnum>, + pub label: ::core::option::Option, /// If type_name is set, this need not be set. If both this and type_name /// are set, this must be one of TYPE_ENUM, TYPE_MESSAGE or TYPE_GROUP. #[prost(enumeration = "field_descriptor_proto::Type", optional, tag = "5")] - pub r#type: ::core::option::Option<::prost::OpenEnum>, + pub r#type: ::core::option::Option, /// For message and enum types, this is the name of the type. If the name /// starts with a '.', it is fully-qualified. Otherwise, C++-like scoping /// rules are used to find the type (i.e. first the nested types within this @@ -470,9 +470,7 @@ pub struct FileOptions { tag = "9", default = "Speed" )] - pub optimize_for: ::core::option::Option< - ::prost::OpenEnum, - >, + pub optimize_for: ::core::option::Option, /// Sets the Go package where structs generated from this .proto will be /// placed. If omitted, the Go package will be derived from the following: /// @@ -666,7 +664,7 @@ pub struct FieldOptions { tag = "1", default = "String" )] - pub ctype: ::core::option::Option<::prost::OpenEnum>, + pub ctype: ::core::option::Option, /// The packed option can be enabled for repeated primitive fields to enable /// a more efficient representation on the wire. Rather than repeatedly /// writing the tag and type for each element, the entire array is encoded as @@ -691,7 +689,7 @@ pub struct FieldOptions { tag = "6", default = "JsNormal" )] - pub jstype: ::core::option::Option<::prost::OpenEnum>, + pub jstype: ::core::option::Option, /// Should this field be parsed lazily? Lazy applies only to message-type /// fields. It means that when the outer message is initially parsed, the /// inner message's contents will not be parsed but instead stored in encoded @@ -879,9 +877,7 @@ pub struct MethodOptions { tag = "34", default = "IdempotencyUnknown" )] - pub idempotency_level: ::core::option::Option< - ::prost::OpenEnum, - >, + pub idempotency_level: ::core::option::Option, /// The parser stores options it doesn't recognize here. See above. #[prost(message, repeated, tag = "999")] pub uninterpreted_option: ::prost::alloc::vec::Vec, @@ -1306,17 +1302,17 @@ pub struct Type { pub source_context: ::core::option::Option, /// The source syntax. #[prost(enumeration = "Syntax", tag = "6")] - pub syntax: ::prost::OpenEnum, + pub syntax: i32, } /// A single field of a message type. #[derive(Clone, PartialEq, ::prost::Message)] pub struct Field { /// The field type. #[prost(enumeration = "field::Kind", tag = "1")] - pub kind: ::prost::OpenEnum, + pub kind: i32, /// The field cardinality. #[prost(enumeration = "field::Cardinality", tag = "2")] - pub cardinality: ::prost::OpenEnum, + pub cardinality: i32, /// The field number. #[prost(int32, tag = "3")] pub number: i32, @@ -1518,7 +1514,7 @@ pub struct Enum { pub source_context: ::core::option::Option, /// The source syntax. #[prost(enumeration = "Syntax", tag = "5")] - pub syntax: ::prost::OpenEnum, + pub syntax: i32, } /// Enum value definition. #[derive(Clone, PartialEq, ::prost::Message)] @@ -1630,7 +1626,7 @@ pub struct Api { pub mixins: ::prost::alloc::vec::Vec, /// The source syntax of the service. #[prost(enumeration = "Syntax", tag = "7")] - pub syntax: ::prost::OpenEnum, + pub syntax: i32, } /// Method represents a method of an API interface. #[derive(Clone, PartialEq, ::prost::Message)] @@ -1655,7 +1651,7 @@ pub struct Method { pub options: ::prost::alloc::vec::Vec