From 15893b6e7c791b5c2b92a4251afaba6ab8ff1192 Mon Sep 17 00:00:00 2001 From: Mikhail Zabaluev Date: Fri, 29 Mar 2024 16:13:58 +0200 Subject: [PATCH 1/3] prost-build: consolidate message field data When massaging field data in CodeGenerator::append_message, move it into lists of Field and OneofField structs so that later generation passes can operate on the data with less code duplication. Subsidiary append_* methods are changed to take references to these structs rather than moved data, as generation of lexical tokens does not actually consume any owned data, and we will need more passes over the same field lists for the upcoming builder code. --- prost-build/src/code_generator.rs | 160 ++++++++++++++++++------------ 1 file changed, 99 insertions(+), 61 deletions(-) diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index 12e35c036..21697b272 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -52,6 +52,44 @@ fn prost_path(config: &Config) -> &str { config.prost_path.as_deref().unwrap_or("::prost") } +struct Field { + rust_name: String, + descriptor: FieldDescriptorProto, + path_index: i32, +} + +impl Field { + fn new(descriptor: FieldDescriptorProto, path_index: i32) -> Self { + Self { + rust_name: to_snake(descriptor.name()), + descriptor, + path_index, + } + } +} + +struct OneofField { + rust_name: String, + descriptor: OneofDescriptorProto, + fields: Vec<(FieldDescriptorProto, i32)>, + path_index: i32, +} + +impl OneofField { + fn new( + descriptor: OneofDescriptorProto, + fields: Vec<(FieldDescriptorProto, i32)>, + path_index: i32, + ) -> Self { + Self { + rust_name: to_snake(descriptor.name()), + descriptor, + fields, + path_index, + } + } +} + impl<'a> CodeGenerator<'a> { pub fn generate( config: &mut Config, @@ -167,21 +205,33 @@ impl<'a> CodeGenerator<'a> { // Split the fields into a vector of the normal fields, and oneof fields. // Path indexes are preserved so that comments can be retrieved. - type Fields = Vec<(FieldDescriptorProto, usize)>; - type OneofFields = MultiMap; - let (fields, mut oneof_fields): (Fields, OneofFields) = message + type OneofFieldsByIndex = MultiMap; + let (fields, mut oneof_map): (Vec, OneofFieldsByIndex) = message .field .into_iter() .enumerate() - .partition_map(|(idx, field)| { - if field.proto3_optional.unwrap_or(false) { - Either::Left((field, idx)) - } else if let Some(oneof_index) = field.oneof_index { - Either::Right((oneof_index, (field, idx))) + .partition_map(|(idx, proto)| { + let idx = idx as i32; + if proto.proto3_optional.unwrap_or(false) { + Either::Left(Field::new(proto, idx)) + } else if let Some(oneof_index) = proto.oneof_index { + Either::Right((oneof_index, (proto, idx))) } else { - Either::Left((field, idx)) + Either::Left(Field::new(proto, idx)) } }); + // Optional fields create a synthetic oneof that we want to skip + let oneof_fields: Vec = message + .oneof_decl + .into_iter() + .enumerate() + .filter_map(move |(idx, proto)| { + let idx = idx as i32; + oneof_map + .remove(&idx) + .map(|fields| OneofField::new(proto, fields, idx)) + }) + .collect(); self.append_doc(&fq_message_name, None); self.append_type_attributes(&fq_message_name); @@ -201,9 +251,10 @@ impl<'a> CodeGenerator<'a> { self.depth += 1; self.path.push(2); - for (field, idx) in fields { - self.path.push(idx as i32); + for field in &fields { + self.path.push(field.path_index); match field + .descriptor .type_name .as_ref() .and_then(|type_name| map_types.get(type_name)) @@ -216,16 +267,9 @@ impl<'a> CodeGenerator<'a> { self.path.pop(); self.path.push(8); - for (idx, oneof) in message.oneof_decl.iter().enumerate() { - let idx = idx as i32; - - let fields = match oneof_fields.get_vec(&idx) { - Some(fields) => fields, - None => continue, - }; - - self.path.push(idx); - self.append_oneof_field(&message_name, &fq_message_name, oneof, fields); + for oneof in &oneof_fields { + self.path.push(oneof.path_index); + self.append_oneof_field(&message_name, &fq_message_name, oneof); self.path.pop(); } self.path.pop(); @@ -252,14 +296,8 @@ impl<'a> CodeGenerator<'a> { } self.path.pop(); - for (idx, oneof) in message.oneof_decl.into_iter().enumerate() { - let idx = idx as i32; - // optional fields create a synthetic oneof that we want to skip - let fields = match oneof_fields.remove(&idx) { - Some(fields) => fields, - None => continue, - }; - self.append_oneof(&fq_message_name, oneof, idx, fields); + for oneof in &oneof_fields { + self.append_oneof(&fq_message_name, oneof); } self.pop_mod(); @@ -368,12 +406,14 @@ impl<'a> CodeGenerator<'a> { } } - fn append_field(&mut self, fq_message_name: &str, field: FieldDescriptorProto) { + fn append_field(&mut self, fq_message_name: &str, field: &Field) { + let rust_name = &field.rust_name; + let field = &field.descriptor; let type_ = field.r#type(); let repeated = field.label == Some(Label::Repeated as i32); - let deprecated = self.deprecated(&field); - let optional = self.optional(&field); - let ty = self.resolve_type(&field, fq_message_name); + let deprecated = self.deprecated(field); + let optional = self.optional(field); + let ty = self.resolve_type(field, fq_message_name); let boxed = !repeated && ((type_ == Type::Message || type_ == Type::Group) @@ -402,7 +442,7 @@ impl<'a> CodeGenerator<'a> { self.push_indent(); self.buf.push_str("#[prost("); - let type_tag = self.field_type_tag(&field); + let type_tag = self.field_type_tag(field); self.buf.push_str(&type_tag); if type_ == Type::Bytes { @@ -425,7 +465,7 @@ impl<'a> CodeGenerator<'a> { Label::Required => self.buf.push_str(", required"), Label::Repeated => { self.buf.push_str(", repeated"); - if can_pack(&field) + if can_pack(field) && !field .options .as_ref() @@ -476,7 +516,7 @@ impl<'a> CodeGenerator<'a> { self.append_field_attributes(fq_message_name, field.name()); self.push_indent(); self.buf.push_str("pub "); - self.buf.push_str(&to_snake(field.name())); + self.buf.push_str(rust_name); self.buf.push_str(": "); let prost_path = prost_path(self.config); @@ -504,10 +544,12 @@ impl<'a> CodeGenerator<'a> { fn append_map_field( &mut self, fq_message_name: &str, - field: FieldDescriptorProto, + field: &Field, key: &FieldDescriptorProto, value: &FieldDescriptorProto, ) { + let rust_name = &field.rust_name; + let field = &field.descriptor; let key_ty = self.resolve_type(key, fq_message_name); let value_ty = self.resolve_type(value, fq_message_name); @@ -541,7 +583,7 @@ impl<'a> CodeGenerator<'a> { self.push_indent(); self.buf.push_str(&format!( "pub {}: {}<{}, {}>,\n", - to_snake(field.name()), + rust_name, map_type.rust_type(), key_ty, value_ty @@ -552,44 +594,40 @@ impl<'a> CodeGenerator<'a> { &mut self, message_name: &str, fq_message_name: &str, - oneof: &OneofDescriptorProto, - fields: &[(FieldDescriptorProto, usize)], + oneof: &OneofField, ) { - let name = format!( + let type_name = format!( "{}::{}", to_snake(message_name), - to_upper_camel(oneof.name()) + to_upper_camel(oneof.descriptor.name()) ); + let field_tags = oneof + .fields + .iter() + .map(|(field, _)| field.number()) + .join(", "); self.append_doc(fq_message_name, None); self.push_indent(); self.buf.push_str(&format!( "#[prost(oneof=\"{}\", tags=\"{}\")]\n", - name, - fields.iter().map(|(field, _)| field.number()).join(", ") + type_name, field_tags, )); - self.append_field_attributes(fq_message_name, oneof.name()); + self.append_field_attributes(fq_message_name, oneof.descriptor.name()); self.push_indent(); self.buf.push_str(&format!( "pub {}: ::core::option::Option<{}>,\n", - to_snake(oneof.name()), - name + oneof.rust_name, type_name )); } - fn append_oneof( - &mut self, - fq_message_name: &str, - oneof: OneofDescriptorProto, - idx: i32, - fields: Vec<(FieldDescriptorProto, usize)>, - ) { + fn append_oneof(&mut self, fq_message_name: &str, oneof: &OneofField) { self.path.push(8); - self.path.push(idx); + self.path.push(oneof.path_index); self.append_doc(fq_message_name, None); self.path.pop(); self.path.pop(); - let oneof_name = format!("{}.{}", fq_message_name, oneof.name()); + let oneof_name = format!("{}.{}", fq_message_name, oneof.descriptor.name()); self.append_type_attributes(&oneof_name); self.append_enum_attributes(&oneof_name); self.push_indent(); @@ -602,20 +640,20 @@ impl<'a> CodeGenerator<'a> { self.append_skip_debug(fq_message_name); self.push_indent(); self.buf.push_str("pub enum "); - self.buf.push_str(&to_upper_camel(oneof.name())); + self.buf.push_str(&to_upper_camel(oneof.descriptor.name())); self.buf.push_str(" {\n"); self.path.push(2); self.depth += 1; - for (field, idx) in fields { + for (field, idx) in &oneof.fields { let type_ = field.r#type(); - self.path.push(idx as i32); + self.path.push(*idx); self.append_doc(fq_message_name, Some(field.name())); self.path.pop(); self.push_indent(); - let ty_tag = self.field_type_tag(&field); + let ty_tag = self.field_type_tag(field); self.buf.push_str(&format!( "#[prost({}, tag=\"{}\")]\n", ty_tag, @@ -624,7 +662,7 @@ impl<'a> CodeGenerator<'a> { self.append_field_attributes(&oneof_name, field.name()); self.push_indent(); - let ty = self.resolve_type(&field, fq_message_name); + let ty = self.resolve_type(field, fq_message_name); let boxed = ((type_ == Type::Message || type_ == Type::Group) && self From 5333e02a874e19c08abfa596d04e056672fed2ab Mon Sep 17 00:00:00 2001 From: Mikhail Zabaluev Date: Sun, 28 Apr 2024 14:54:34 +0300 Subject: [PATCH 2/3] prost-build: compute field tags in place --- prost-build/src/code_generator.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index 21697b272..4cca824e3 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -601,16 +601,16 @@ impl<'a> CodeGenerator<'a> { to_snake(message_name), to_upper_camel(oneof.descriptor.name()) ); - let field_tags = oneof - .fields - .iter() - .map(|(field, _)| field.number()) - .join(", "); self.append_doc(fq_message_name, None); self.push_indent(); self.buf.push_str(&format!( "#[prost(oneof=\"{}\", tags=\"{}\")]\n", - type_name, field_tags, + type_name, + oneof + .fields + .iter() + .map(|(field, _)| field.number()) + .join(", "), )); self.append_field_attributes(fq_message_name, oneof.descriptor.name()); self.push_indent(); From 1991a11da1c2bdaf10f820234a2d9dab238bcd42 Mon Sep 17 00:00:00 2001 From: Mikhail Zabaluev Date: Sun, 28 Apr 2024 15:21:45 +0300 Subject: [PATCH 3/3] prost-build: address comments on reuse of Field Make rust_field into a method computing the name on the fly. In OneofField, make the vector of fields to have Field members. Don't play reference renaming tricks with field.descriptor. --- prost-build/src/code_generator.rs | 112 +++++++++++++++--------------- 1 file changed, 57 insertions(+), 55 deletions(-) diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index 4cca824e3..a230cad1e 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -53,7 +53,6 @@ fn prost_path(config: &Config) -> &str { } struct Field { - rust_name: String, descriptor: FieldDescriptorProto, path_index: i32, } @@ -61,33 +60,34 @@ struct Field { impl Field { fn new(descriptor: FieldDescriptorProto, path_index: i32) -> Self { Self { - rust_name: to_snake(descriptor.name()), descriptor, path_index, } } + + fn rust_name(&self) -> String { + to_snake(self.descriptor.name()) + } } struct OneofField { - rust_name: String, descriptor: OneofDescriptorProto, - fields: Vec<(FieldDescriptorProto, i32)>, + fields: Vec, path_index: i32, } impl OneofField { - fn new( - descriptor: OneofDescriptorProto, - fields: Vec<(FieldDescriptorProto, i32)>, - path_index: i32, - ) -> Self { + fn new(descriptor: OneofDescriptorProto, fields: Vec, path_index: i32) -> Self { Self { - rust_name: to_snake(descriptor.name()), descriptor, fields, path_index, } } + + fn rust_name(&self) -> String { + to_snake(self.descriptor.name()) + } } impl<'a> CodeGenerator<'a> { @@ -205,7 +205,7 @@ impl<'a> CodeGenerator<'a> { // Split the fields into a vector of the normal fields, and oneof fields. // Path indexes are preserved so that comments can be retrieved. - type OneofFieldsByIndex = MultiMap; + type OneofFieldsByIndex = MultiMap; let (fields, mut oneof_map): (Vec, OneofFieldsByIndex) = message .field .into_iter() @@ -215,7 +215,7 @@ impl<'a> CodeGenerator<'a> { if proto.proto3_optional.unwrap_or(false) { Either::Left(Field::new(proto, idx)) } else if let Some(oneof_index) = proto.oneof_index { - Either::Right((oneof_index, (proto, idx))) + Either::Right((oneof_index, Field::new(proto, idx))) } else { Either::Left(Field::new(proto, idx)) } @@ -407,33 +407,31 @@ impl<'a> CodeGenerator<'a> { } fn append_field(&mut self, fq_message_name: &str, field: &Field) { - let rust_name = &field.rust_name; - let field = &field.descriptor; - let type_ = field.r#type(); - let repeated = field.label == Some(Label::Repeated as i32); - let deprecated = self.deprecated(field); - let optional = self.optional(field); - let ty = self.resolve_type(field, fq_message_name); + let type_ = field.descriptor.r#type(); + let repeated = field.descriptor.label == Some(Label::Repeated as i32); + let deprecated = self.deprecated(&field.descriptor); + let optional = self.optional(&field.descriptor); + let ty = self.resolve_type(&field.descriptor, fq_message_name); let boxed = !repeated && ((type_ == Type::Message || type_ == Type::Group) && self .message_graph - .is_nested(field.type_name(), fq_message_name)) + .is_nested(field.descriptor.type_name(), fq_message_name)) || (self .config .boxed - .get_first_field(fq_message_name, field.name()) + .get_first_field(fq_message_name, field.descriptor.name()) .is_some()); debug!( " field: {:?}, type: {:?}, boxed: {}", - field.name(), + field.descriptor.name(), ty, boxed ); - self.append_doc(fq_message_name, Some(field.name())); + self.append_doc(fq_message_name, Some(field.descriptor.name())); if deprecated { self.push_indent(); @@ -442,21 +440,21 @@ impl<'a> CodeGenerator<'a> { self.push_indent(); self.buf.push_str("#[prost("); - let type_tag = self.field_type_tag(field); + let type_tag = self.field_type_tag(&field.descriptor); self.buf.push_str(&type_tag); if type_ == Type::Bytes { let bytes_type = self .config .bytes_type - .get_first_field(fq_message_name, field.name()) + .get_first_field(fq_message_name, field.descriptor.name()) .copied() .unwrap_or_default(); self.buf .push_str(&format!("={:?}", bytes_type.annotation())); } - match field.label() { + match field.descriptor.label() { Label::Optional => { if optional { self.buf.push_str(", optional"); @@ -465,8 +463,9 @@ impl<'a> CodeGenerator<'a> { Label::Required => self.buf.push_str(", required"), Label::Repeated => { self.buf.push_str(", repeated"); - if can_pack(field) + if can_pack(&field.descriptor) && !field + .descriptor .options .as_ref() .map_or(self.syntax == Syntax::Proto3, |options| options.packed()) @@ -480,9 +479,9 @@ impl<'a> CodeGenerator<'a> { self.buf.push_str(", boxed"); } self.buf.push_str(", tag=\""); - self.buf.push_str(&field.number().to_string()); + self.buf.push_str(&field.descriptor.number().to_string()); - if let Some(ref default) = field.default_value { + if let Some(ref default) = field.descriptor.default_value { self.buf.push_str("\", default=\""); if type_ == Type::Bytes { self.buf.push_str("b\\\""); @@ -499,6 +498,7 @@ impl<'a> CodeGenerator<'a> { // the last segment and strip it from the left // side of the default value. let enum_type = field + .descriptor .type_name .as_ref() .and_then(|ty| ty.split('.').last()) @@ -513,10 +513,10 @@ impl<'a> CodeGenerator<'a> { } self.buf.push_str("\")]\n"); - self.append_field_attributes(fq_message_name, field.name()); + self.append_field_attributes(fq_message_name, field.descriptor.name()); self.push_indent(); self.buf.push_str("pub "); - self.buf.push_str(rust_name); + self.buf.push_str(&field.rust_name()); self.buf.push_str(": "); let prost_path = prost_path(self.config); @@ -548,25 +548,23 @@ impl<'a> CodeGenerator<'a> { key: &FieldDescriptorProto, value: &FieldDescriptorProto, ) { - let rust_name = &field.rust_name; - let field = &field.descriptor; let key_ty = self.resolve_type(key, fq_message_name); let value_ty = self.resolve_type(value, fq_message_name); debug!( " map field: {:?}, key type: {:?}, value type: {:?}", - field.name(), + field.descriptor.name(), key_ty, value_ty ); - self.append_doc(fq_message_name, Some(field.name())); + self.append_doc(fq_message_name, Some(field.descriptor.name())); self.push_indent(); let map_type = self .config .map_type - .get_first_field(fq_message_name, field.name()) + .get_first_field(fq_message_name, field.descriptor.name()) .copied() .unwrap_or_default(); let key_tag = self.field_type_tag(key); @@ -577,13 +575,13 @@ impl<'a> CodeGenerator<'a> { map_type.annotation(), key_tag, value_tag, - field.number() + field.descriptor.number() )); - self.append_field_attributes(fq_message_name, field.name()); + self.append_field_attributes(fq_message_name, field.descriptor.name()); self.push_indent(); self.buf.push_str(&format!( "pub {}: {}<{}, {}>,\n", - rust_name, + field.rust_name(), map_type.rust_type(), key_ty, value_ty @@ -609,14 +607,15 @@ impl<'a> CodeGenerator<'a> { oneof .fields .iter() - .map(|(field, _)| field.number()) + .map(|field| field.descriptor.number()) .join(", "), )); self.append_field_attributes(fq_message_name, oneof.descriptor.name()); self.push_indent(); self.buf.push_str(&format!( "pub {}: ::core::option::Option<{}>,\n", - oneof.rust_name, type_name + oneof.rust_name(), + type_name )); } @@ -645,38 +644,38 @@ impl<'a> CodeGenerator<'a> { self.path.push(2); self.depth += 1; - for (field, idx) in &oneof.fields { - let type_ = field.r#type(); + for field in &oneof.fields { + let type_ = field.descriptor.r#type(); - self.path.push(*idx); - self.append_doc(fq_message_name, Some(field.name())); + self.path.push(field.path_index); + self.append_doc(fq_message_name, Some(field.descriptor.name())); self.path.pop(); self.push_indent(); - let ty_tag = self.field_type_tag(field); + let ty_tag = self.field_type_tag(&field.descriptor); self.buf.push_str(&format!( "#[prost({}, tag=\"{}\")]\n", ty_tag, - field.number() + field.descriptor.number() )); - self.append_field_attributes(&oneof_name, field.name()); + self.append_field_attributes(&oneof_name, field.descriptor.name()); self.push_indent(); - let ty = self.resolve_type(field, fq_message_name); + let ty = self.resolve_type(&field.descriptor, fq_message_name); let boxed = ((type_ == Type::Message || type_ == Type::Group) && self .message_graph - .is_nested(field.type_name(), fq_message_name)) + .is_nested(field.descriptor.type_name(), fq_message_name)) || (self .config .boxed - .get_first_field(&oneof_name, field.name()) + .get_first_field(&oneof_name, field.descriptor.name()) .is_some()); debug!( " oneof: {:?}, type: {:?}, boxed: {}", - field.name(), + field.descriptor.name(), ty, boxed ); @@ -684,12 +683,15 @@ impl<'a> CodeGenerator<'a> { if boxed { self.buf.push_str(&format!( "{}(::prost::alloc::boxed::Box<{}>),\n", - to_upper_camel(field.name()), + to_upper_camel(field.descriptor.name()), ty )); } else { - self.buf - .push_str(&format!("{}({}),\n", to_upper_camel(field.name()), ty)); + self.buf.push_str(&format!( + "{}({}),\n", + to_upper_camel(field.descriptor.name()), + ty + )); } } self.depth -= 1;