diff --git a/.github/workflows/continuous-integration-workflow.yaml b/.github/workflows/continuous-integration-workflow.yaml index 920646554..92c8a3ed9 100644 --- a/.github/workflows/continuous-integration-workflow.yaml +++ b/.github/workflows/continuous-integration-workflow.yaml @@ -18,6 +18,8 @@ jobs: default: true profile: minimal components: rustfmt + - name: build b_tests + run: cargo build --package b_tests - name: rustfmt uses: actions-rs/cargo@v1 with: @@ -82,45 +84,45 @@ jobs: command: test args: --no-default-features - no-std: - runs-on: ubuntu-latest - steps: - - name: checkout - uses: actions/checkout@v2 - with: - submodules: recursive - - name: install toolchain - uses: actions-rs/toolchain@v1 - with: - toolchain: nightly - default: true - profile: minimal - - uses: Swatinem/rust-cache@v1 - - name: install cargo-no-std-check - uses: actions-rs/cargo@v1 - with: - command: install - args: cargo-no-std-check - - name: prost cargo-no-std-check - uses: actions-rs/cargo@v1 - with: - command: no-std-check - args: --manifest-path Cargo.toml --no-default-features - - name: prost-types cargo-no-std-check - uses: actions-rs/cargo@v1 - with: - command: no-std-check - args: --manifest-path prost-types/Cargo.toml --no-default-features - # prost-build depends on prost with --no-default-features, but when - # prost-build is built through the workspace, prost typically has default - # features enabled due to vagaries in Cargo workspace feature resolution. - # This additional check ensures that prost-build does not rely on any of - # prost's default features to compile. - - name: prost-build check - uses: actions-rs/cargo@v1 - with: - command: check - args: --manifest-path prost-build/Cargo.toml +# no-std: +# runs-on: ubuntu-latest +# steps: +# - name: checkout +# uses: actions/checkout@v2 +# with: +# submodules: recursive +# - name: install toolchain +# uses: actions-rs/toolchain@v1 +# with: +# toolchain: nightly +# default: true +# profile: minimal +# - uses: Swatinem/rust-cache@v1 +# - name: install cargo-no-std-check +# uses: actions-rs/cargo@v1 +# with: +# command: install +# args: cargo-no-std-check +# - name: prost cargo-no-std-check +# uses: actions-rs/cargo@v1 +# with: +# command: no-std-check +# args: --manifest-path Cargo.toml --no-default-features +# - name: prost-types cargo-no-std-check +# uses: actions-rs/cargo@v1 +# with: +# command: no-std-check +# args: --manifest-path prost-types/Cargo.toml --no-default-features +# # prost-build depends on prost with --no-default-features, but when +# # prost-build is built through the workspace, prost typically has default +# # features enabled due to vagaries in Cargo workspace feature resolution. +# # This additional check ensures that prost-build does not rely on any of +# # prost's default features to compile. +# - name: prost-build check +# uses: actions-rs/cargo@v1 +# with: +# command: check +# args: --manifest-path prost-build/Cargo.toml vendored: runs-on: ubuntu-latest diff --git a/Cargo.toml b/Cargo.toml index 2ee0d96b0..af2f12a30 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -49,6 +49,7 @@ std = [] [dependencies] bytes = { version = "1", default-features = false } prost-derive = { version = "0.10.0", path = "prost-derive", optional = true } +uuid = { version = "1", features = ["v4"] } [dev-dependencies] criterion = "0.3" diff --git a/b_tests/.gitignore b/b_tests/.gitignore new file mode 100644 index 000000000..e1001337b --- /dev/null +++ b/b_tests/.gitignore @@ -0,0 +1,3 @@ +/src/b_generated +/src/generated +/src/protos diff --git a/b_tests/Cargo.toml b/b_tests/Cargo.toml index 97fffc47d..d227cfce8 100644 --- a/b_tests/Cargo.toml +++ b/b_tests/Cargo.toml @@ -1,8 +1,14 @@ [package] name = "b_tests" version = "0.1.0" +authors = ["Jasper Visser "] edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[build-dependencies] +prost-build = { path = "../prost-build" } +protobuf_strict = { git = "https://github.com/Jasperav/protobuf_strict.git" } [dependencies] +uuid = { version = "1", features = ["v4"] } +prost = { path = "../" } diff --git a/b_tests/build.rs b/b_tests/build.rs new file mode 100644 index 000000000..8f0e2e80a --- /dev/null +++ b/b_tests/build.rs @@ -0,0 +1,66 @@ +use prost_build::{Config, CustomType}; +use std::io::Write; + +fn main() { + let src = std::env::current_dir().unwrap().join("src"); + + protobuf_strict::write_protos(&src); + + macro_rules! initialize_dir { + ($name: expr) => {{ + // Maybe the dir doesn't exists yet + let _ = std::fs::remove_dir_all(src.join($name)); + std::fs::create_dir(src.join($name)).unwrap(); + std::fs::File::create(src.join($name).join("mod.rs")).unwrap() + }}; + } + + let mut b_generated = initialize_dir!("b_generated"); + let mut generated = initialize_dir!("generated"); + let mut protos = vec![]; + let paths = std::fs::read_dir("./src/protos").unwrap(); + + for path in paths { + let path = path.unwrap(); + let file_name = path.file_name(); + let s = file_name.to_str().unwrap(); + + protos.push("src/protos/".to_string() + s); + } + + macro_rules! go_generate { + ($name: expr, $file: expr, $config: expr) => { + std::env::set_var("OUT_DIR", src.join($name).to_str().unwrap()); + + $config + .compile_protos(protos.as_slice(), &["src/protos/".to_string()]) + .unwrap(); + + let paths = std::fs::read_dir("./src/".to_string() + $name).unwrap(); + + for path in paths { + let path = path.unwrap(); + let file_name = path.file_name(); + let s = file_name.to_str().unwrap(); + + if s == "mod.rs" { + continue; + } + + let m = s.strip_suffix(".rs").unwrap(); + + writeln!($file, "#[rustfmt::skip]\nmod {};\npub use {}::*;", m, m).unwrap(); + } + }; + } + + let mut config = Config::new(); + + config + .add_types_mapping(protobuf_strict::get_uuids().to_vec(), CustomType::Uuid) + .strict_messages() + .inline_enums(); + + go_generate!("b_generated", b_generated, config); + go_generate!("generated", generated, Config::new()); +} diff --git a/b_tests/src/lib.rs b/b_tests/src/lib.rs index 1b4a90c93..69467d8ab 100644 --- a/b_tests/src/lib.rs +++ b/b_tests/src/lib.rs @@ -1,8 +1,191 @@ +mod b_generated; +mod generated; + #[cfg(test)] -mod tests { +mod test { + use super::*; + use prost::Message; + use std::str::FromStr; + use uuid::Uuid; + + const UUID: &'static str = "cd663747-6cb1-4ddc-bdfe-3dc76db62724"; + + fn get_uuid() -> Uuid { + uuid::Uuid::from_str(UUID).unwrap() + } + + fn get_no_uuid() -> String { + "no_uuid".to_string() + } + + fn get_custom() -> String { + "custom".to_string() + } + + fn get_amount() -> i32 { + 1 + } + + fn b_generated_order() -> b_generated::Order { + b_generated::Order { + gender: b_generated::Gender::Female, + genders: vec![b_generated::Gender::Female, b_generated::Gender::Other], + currency: Some(b_generated::Currency { + c: Some(b_generated::currency::C::Amount(get_amount())), + }), + o_currency: None, + currencies: vec![ + b_generated::Currency { + c: Some(b_generated::currency::C::Custom(get_custom())), + }, + b_generated::Currency { + c: Some(b_generated::currency::C::Amount(get_amount())), + }, + ], + uuid: get_uuid(), + no_uuid: get_no_uuid(), + repeated_uuids: vec![get_uuid(), get_uuid()], + no_uuids: vec![get_custom()], + order_inner: b_generated::order::OrderInner::InnerAnother, + order_inners: vec![ + b_generated::order::OrderInner::InnerAnother, + b_generated::order::OrderInner::InnerAnother2, + ], + something: Some(b_generated::order::Something::AlsoUuid(get_uuid())), + } + } + + fn generated_order() -> generated::Order { + generated::Order { + gender: generated::Gender::Female as i32, + genders: vec![ + generated::Gender::Female as i32, + generated::Gender::Other as i32, + ], + currency: Some(generated::Currency { + c: Some(generated::currency::C::Amount(get_amount())), + }), + o_currency: None, + currencies: vec![ + generated::Currency { + c: Some(generated::currency::C::Custom(get_custom())), + }, + generated::Currency { + c: Some(generated::currency::C::Amount(get_amount())), + }, + ], + uuid: get_uuid().to_string(), + no_uuid: get_no_uuid(), + repeated_uuids: vec![get_uuid().to_string(), get_uuid().to_string()], + no_uuids: vec![get_custom()], + order_inner: generated::order::OrderInner::InnerAnother as i32, + order_inners: vec![ + generated::order::OrderInner::InnerAnother as i32, + generated::order::OrderInner::InnerAnother2 as i32, + ], + something: Some(generated::order::Something::AlsoUuid( + get_uuid().to_string(), + )), + } + } + #[test] - fn it_works() { - let result = 2 + 2; - assert_eq!(result, 4); + fn equal() { + let g = b_generated_order(); + let b = generated_order(); + let g = g.encode_buffer().unwrap(); + let b = b.encode_buffer().unwrap(); + + assert_eq!(g, b); + assert_eq!(g.encoded_len(), b.encoded_len()); + + // Check if encoding works + b_generated::Order::decode(b.as_slice()).unwrap(); + generated::Order::decode(g.as_slice()).unwrap(); + } + + fn check_order(order: generated::Order) { + let b = order.encode_buffer().unwrap(); + + b_generated::Order::decode(b.as_slice()).unwrap(); + } + + macro_rules! write_invalid_test { + ($method_name: ident, $order: ident, $change: tt) => { + #[test] + #[should_panic] + fn $method_name() { + let mut $order = generated_order(); + + $change; + + check_order($order); + } + }; } + + write_invalid_test!( + invalid_gender_zero, + order, + ({ + order.gender = 0; + }) + ); + write_invalid_test!( + invalid_gender_over_max, + order, + ({ + order.gender = 999; + }) + ); + write_invalid_test!( + invalid_uuids, + order, + ({ + order.repeated_uuids = vec![get_uuid().to_string(), get_no_uuid().to_string()]; + }) + ); + write_invalid_test!( + empty_currency, + order, + ({ + order.currency = None; + }) + ); + write_invalid_test!( + invalid_uuid, + order, + ({ + order.uuid = order.uuid[1..].to_string(); + }) + ); + + write_invalid_test!( + invalid_inner_zero, + order, + ({ + order.order_inner = 0; + }) + ); + write_invalid_test!( + invalid_inner_over_max, + order, + ({ + order.order_inner = 999; + }) + ); + write_invalid_test!( + empty_something, + order, + ({ + order.something = None; + }) + ); + write_invalid_test!( + something_invalid_uuid, + order, + ({ + order.something = Some(generated::order::Something::AlsoUuid(get_no_uuid())); + }) + ); } diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index a3e9dbd7c..4e6e123a0 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -18,7 +18,7 @@ use crate::ast::{Comments, Method, Service}; use crate::extern_paths::ExternPaths; use crate::ident::{to_snake, to_upper_camel}; use crate::message_graph::MessageGraph; -use crate::{BytesType, Config, MapType}; +use crate::{BytesType, Config, CustomType, MapType}; #[derive(PartialEq)] enum Syntax { @@ -311,7 +311,7 @@ impl<'a> CodeGenerator<'a> { self.push_indent(); self.buf.push_str("#[prost("); - let type_tag = self.field_type_tag(&field); + let type_tag = self.field_type_tag(&field, fq_message_name); self.buf.push_str(&type_tag); if type_ == Type::Bytes { @@ -377,6 +377,7 @@ impl<'a> CodeGenerator<'a> { } else { &enum_value }; + self.buf.push_str(stripped_prefix); } else { self.buf.push_str(&default.escape_default().to_string()); @@ -385,6 +386,32 @@ impl<'a> CodeGenerator<'a> { self.buf.push_str("\")]\n"); self.append_field_attributes(fq_message_name, field.name()); + + if self.config.strict_messages { + match field.r#type() { + Type::Message => match field.label() { + Label::Optional => { + if let Some(ref s) = field.name { + if !s.starts_with("o_") { + self.buf.push_str("#[prost(strict)]\n"); + } + } + } + _ => {} + }, + Type::Enum => { + if !self.config.inline_enums { + if self.config.inline_enums { + self.buf.push_str("#[prost(inlined)]\n"); + } else { + self.buf.push_str("#[prost(strict)]\n"); + } + } + } + _ => {} + } + } + self.push_indent(); self.buf.push_str("pub "); self.buf.push_str(&to_snake(field.name())); @@ -433,8 +460,8 @@ impl<'a> CodeGenerator<'a> { .get_first_field(fq_message_name, field.name()) .copied() .unwrap_or_default(); - let key_tag = self.field_type_tag(key); - let value_tag = self.map_value_type_tag(value); + let key_tag = self.field_type_tag(key, fq_message_name); + let value_tag = self.map_value_type_tag(value, fq_message_name); self.buf.push_str(&format!( "#[prost({}=\"{}, {}\", tag=\"{}\")]\n", @@ -476,6 +503,9 @@ impl<'a> CodeGenerator<'a> { .map(|&(ref field, _)| field.number()) .join(", ") )); + if self.config.strict_messages && !oneof.name.clone().unwrap().starts_with("o_") { + self.buf.push_str("#[prost(strict)]\n"); + } self.append_field_attributes(fq_message_name, oneof.name()); self.push_indent(); self.buf.push_str(&format!( @@ -518,7 +548,7 @@ impl<'a> CodeGenerator<'a> { self.path.pop(); self.push_indent(); - let ty_tag = self.field_type_tag(&field); + let ty_tag = self.field_type_tag(&field, fq_message_name); self.buf.push_str(&format!( "#[prost({}, tag=\"{}\")]\n", ty_tag, @@ -746,24 +776,36 @@ impl<'a> CodeGenerator<'a> { } fn resolve_type(&self, field: &FieldDescriptorProto, fq_message_name: &str) -> String { - match field.r#type() { - Type::Float => String::from("f32"), - Type::Double => String::from("f64"), - Type::Uint32 | Type::Fixed32 => String::from("u32"), - Type::Uint64 | Type::Fixed64 => String::from("u64"), - Type::Int32 | Type::Sfixed32 | Type::Sint32 | Type::Enum => String::from("i32"), - Type::Int64 | Type::Sfixed64 | Type::Sint64 => String::from("i64"), - Type::Bool => String::from("bool"), - Type::String => String::from("::prost::alloc::string::String"), - Type::Bytes => self - .config - .bytes_type - .get_first_field(fq_message_name, field.name()) - .copied() - .unwrap_or_default() - .rust_type() - .to_owned(), - Type::Group | Type::Message => self.resolve_ident(field.type_name()), + if let Some(the_type) = self + .config + .custom_type + .get_first_field(fq_message_name, field.name()) + { + match the_type { + CustomType::Uuid => "uuid::Uuid".to_string(), + } + } else if self.config.inline_enums && matches!(field.r#type(), Type::Enum) { + self.resolve_ident(field.type_name()) + } else { + match field.r#type() { + Type::Float => String::from("f32"), + Type::Double => String::from("f64"), + Type::Uint32 | Type::Fixed32 => String::from("u32"), + Type::Uint64 | Type::Fixed64 => String::from("u64"), + Type::Int32 | Type::Sfixed32 | Type::Sint32 | Type::Enum => String::from("i32"), + Type::Int64 | Type::Sfixed64 | Type::Sint64 => String::from("i64"), + Type::Bool => String::from("bool"), + Type::String => String::from("::prost::alloc::string::String"), + Type::Bytes => self + .config + .bytes_type + .get_first_field(fq_message_name, field.name()) + .copied() + .unwrap_or_default() + .rust_type() + .to_owned(), + Type::Group | Type::Message => self.resolve_ident(field.type_name()), + } } } @@ -801,39 +843,62 @@ impl<'a> CodeGenerator<'a> { .join("::") } - fn field_type_tag(&self, field: &FieldDescriptorProto) -> Cow<'static, str> { - match field.r#type() { - Type::Float => Cow::Borrowed("float"), - Type::Double => Cow::Borrowed("double"), - Type::Int32 => Cow::Borrowed("int32"), - Type::Int64 => Cow::Borrowed("int64"), - Type::Uint32 => Cow::Borrowed("uint32"), - Type::Uint64 => Cow::Borrowed("uint64"), - Type::Sint32 => Cow::Borrowed("sint32"), - Type::Sint64 => Cow::Borrowed("sint64"), - Type::Fixed32 => Cow::Borrowed("fixed32"), - Type::Fixed64 => Cow::Borrowed("fixed64"), - Type::Sfixed32 => Cow::Borrowed("sfixed32"), - Type::Sfixed64 => Cow::Borrowed("sfixed64"), - Type::Bool => Cow::Borrowed("bool"), - Type::String => Cow::Borrowed("string"), - Type::Bytes => Cow::Borrowed("bytes"), - Type::Group => Cow::Borrowed("group"), - Type::Message => Cow::Borrowed("message"), - Type::Enum => Cow::Owned(format!( - "enumeration={:?}", + fn field_type_tag( + &self, + field: &FieldDescriptorProto, + fq_message_name: &str, + ) -> Cow<'static, str> { + if let Some(the_type) = self + .config + .custom_type + .get_first_field(fq_message_name, field.name()) + { + match the_type { + CustomType::Uuid => Cow::Borrowed("uuid"), + } + } else if self.config.inline_enums && matches!(field.r#type(), Type::Enum) { + Cow::Owned(format!( + "inlined_enum={:?}", self.resolve_ident(field.type_name()) - )), + )) + } else { + match field.r#type() { + Type::Float => Cow::Borrowed("float"), + Type::Double => Cow::Borrowed("double"), + Type::Int32 => Cow::Borrowed("int32"), + Type::Int64 => Cow::Borrowed("int64"), + Type::Uint32 => Cow::Borrowed("uint32"), + Type::Uint64 => Cow::Borrowed("uint64"), + Type::Sint32 => Cow::Borrowed("sint32"), + Type::Sint64 => Cow::Borrowed("sint64"), + Type::Fixed32 => Cow::Borrowed("fixed32"), + Type::Fixed64 => Cow::Borrowed("fixed64"), + Type::Sfixed32 => Cow::Borrowed("sfixed32"), + Type::Sfixed64 => Cow::Borrowed("sfixed64"), + Type::Bool => Cow::Borrowed("bool"), + Type::String => Cow::Borrowed("string"), + Type::Bytes => Cow::Borrowed("bytes"), + Type::Group => Cow::Borrowed("group"), + Type::Message => Cow::Borrowed("message"), + Type::Enum => Cow::Owned(format!( + "enumeration={:?}", + self.resolve_ident(field.type_name()) + )), + } } } - fn map_value_type_tag(&self, field: &FieldDescriptorProto) -> Cow<'static, str> { + fn map_value_type_tag( + &self, + field: &FieldDescriptorProto, + fq_message_name: &str, + ) -> Cow<'static, str> { match field.r#type() { Type::Enum => Cow::Owned(format!( "enumeration({})", self.resolve_ident(field.type_name()) )), - _ => self.field_type_tag(field), + _ => self.field_type_tag(field, fq_message_name), } } diff --git a/prost-build/src/lib.rs b/prost-build/src/lib.rs index e8defe4fb..4717941e9 100644 --- a/prost-build/src/lib.rs +++ b/prost-build/src/lib.rs @@ -226,10 +226,16 @@ impl Default for BytesType { } } +#[derive(Clone, Copy, PartialEq)] +pub enum CustomType { + Uuid, +} + /// Configuration options for Protobuf code generation. /// /// This configuration builder can be used to set non-default code generation options. pub struct Config { + start_file_with: Vec, file_descriptor_set_path: Option, service_generator: Option>, map_type: PathMap, @@ -243,6 +249,9 @@ pub struct Config { default_package_filename: String, protoc_args: Vec, disable_comments: PathMap<()>, + custom_type: PathMap, + strict_messages: bool, + inline_enums: bool, skip_protoc_run: bool, include_file: Option, } @@ -253,6 +262,12 @@ impl Config { Config::default() } + pub fn add_start_to_file(&mut self, s: T) -> &mut Self { + self.start_file_with.push(s.to_string()); + + self + } + /// Configure the code generator to generate Rust [`BTreeMap`][1] fields for Protobuf /// [`map`][2] type fields. /// @@ -408,6 +423,18 @@ impl Config { self } + pub fn fields_attribute(&mut self, paths: &[P], attribute: A) -> &mut Self + where + P: AsRef, + A: AsRef, + { + for path in paths.iter() { + self.field_attribute(path, &attribute); + } + + self + } + /// Add additional attribute to matched messages, enums and one-ofs. /// /// # Arguments @@ -457,6 +484,19 @@ impl Config { self } + pub fn types_attribute(&mut self, paths: &[P], attribute: A) -> &mut Self + where + P: AsRef, + A: AsRef, + { + for path in paths.iter() { + self.type_attributes + .insert(path.as_ref().to_string(), attribute.as_ref().to_string()); + } + + self + } + /// Configures the code generator to use the provided service generator. pub fn service_generator(&mut self, service_generator: Box) -> &mut Self { self.service_generator = Some(service_generator); @@ -754,6 +794,36 @@ impl Config { self } + pub fn add_type_mapping(&mut self, to_match: M, custom_type: CustomType) -> &mut Self + where + M: ToString, + { + self.custom_type.insert(to_match.to_string(), custom_type); + + self + } + + pub fn add_types_mapping(&mut self, to_match: Vec, custom_type: CustomType) -> &mut Self + where + M: ToString, + { + for to_match in to_match.into_iter() { + self.add_type_mapping(to_match, custom_type); + } + + self + } + + pub fn strict_messages(&mut self) -> &mut Self { + self.strict_messages = true; + self + } + + pub fn inline_enums(&mut self) -> &mut Self { + self.inline_enums = true; + self + } + /// Compile `.proto` files into Rust files during a Cargo build with additional code generator /// configuration options. /// @@ -894,7 +964,14 @@ impl Config { trace!("unchanged: {:?}", file_name); } else { trace!("writing: {:?}", file_name); - fs::write(output_path, content)?; + + let mut file = std::fs::File::create(output_path)?; + + for i in &self.start_file_with { + writeln!(file, "{}", i)?; + } + + writeln!(file, "{}", content)?; } } @@ -1017,6 +1094,7 @@ impl Config { impl default::Default for Config { fn default() -> Config { Config { + start_file_with: vec![], file_descriptor_set_path: None, service_generator: None, map_type: PathMap::default(), @@ -1030,6 +1108,9 @@ impl default::Default for Config { default_package_filename: "_".to_string(), protoc_args: Vec::new(), disable_comments: PathMap::default(), + custom_type: PathMap::default(), + strict_messages: false, + inline_enums: false, skip_protoc_run: false, include_file: None, } diff --git a/prost-build/src/path.rs b/prost-build/src/path.rs index f6897005d..996a15c91 100644 --- a/prost-build/src/path.rs +++ b/prost-build/src/path.rs @@ -3,17 +3,27 @@ use std::iter; /// Maps a fully-qualified Protobuf path to a value using path matchers. -#[derive(Debug, Default)] +#[derive(Debug)] pub(crate) struct PathMap { // insertion order might actually matter (to avoid warning about legacy-derive-helpers) // see: https://doc.rust-lang.org/rustc/lints/listing/warn-by-default.html#legacy-derive-helpers pub(crate) matchers: Vec<(String, T)>, } -impl PathMap { +impl Default for PathMap { + fn default() -> Self { + Self { + matchers: Default::default(), + } + } +} + +impl PathMap { /// Inserts a new matcher and associated value to the path map. pub(crate) fn insert(&mut self, matcher: String, value: T) { - self.matchers.push((matcher, value)); + if !self.matchers.contains(&(matcher.clone(), value.clone())) { + self.matchers.push((matcher, value)); + } } /// Returns a iterator over all the value matching the given fd_path and associated suffix/prefix path diff --git a/prost-derive/Cargo.toml b/prost-derive/Cargo.toml index 5402c556b..dc98447d0 100644 --- a/prost-derive/Cargo.toml +++ b/prost-derive/Cargo.toml @@ -21,3 +21,4 @@ itertools = "0.10" proc-macro2 = "1" quote = "1" syn = { version = "1", features = [ "extra-traits" ] } +uuid = "1" diff --git a/prost-derive/src/field/map.rs b/prost-derive/src/field/map.rs index 829962e47..e7b755db8 100644 --- a/prost-derive/src/field/map.rs +++ b/prost-derive/src/field/map.rs @@ -41,6 +41,7 @@ fn fake_scalar(ty: scalar::Ty) -> scalar::Field { ty, kind, tag: 0, // Not used here + strict: false, } } diff --git a/prost-derive/src/field/message.rs b/prost-derive/src/field/message.rs index 3bcdddfb1..2479a5821 100644 --- a/prost-derive/src/field/message.rs +++ b/prost-derive/src/field/message.rs @@ -9,6 +9,7 @@ use crate::field::{set_bool, set_option, tag_attr, word_attr, Label}; pub struct Field { pub label: Label, pub tag: u32, + pub strict: bool, } impl Field { @@ -16,6 +17,7 @@ impl Field { let mut message = false; let mut label = None; let mut tag = None; + let mut strict = false; let mut boxed = false; let mut unknown_attrs = Vec::new(); @@ -23,6 +25,8 @@ impl Field { for attr in attrs { if word_attr("message", attr) { set_bool(&mut message, "duplicate message attribute")?; + } else if word_attr("strict", attr) { + set_bool(&mut strict, "duplicate strict attribute")?; } else if word_attr("boxed", attr) { set_bool(&mut boxed, "duplicate boxed attribute")?; } else if let Some(t) = tag_attr(attr)? { @@ -55,6 +59,7 @@ impl Field { Ok(Some(Field { label: label.unwrap_or(Label::Optional), tag, + strict, })) } diff --git a/prost-derive/src/field/mod.rs b/prost-derive/src/field/mod.rs index 09fef830e..b7d9f4767 100644 --- a/prost-derive/src/field/mod.rs +++ b/prost-derive/src/field/mod.rs @@ -2,11 +2,12 @@ mod group; mod map; mod message; mod oneof; -mod scalar; +pub mod scalar; use std::fmt; use std::slice; +pub use crate::field::scalar::{Kind, Ty}; use anyhow::{bail, Error}; use proc_macro2::TokenStream; use quote::quote; @@ -53,6 +54,60 @@ impl Field { Ok(Some(field)) } + pub fn validate(&self, ident: &Ident) -> TokenStream { + let empty = quote! {}; + let expect_non_nil = quote! { + if self.#ident.is_none() { + debug_assert!(false, "Unexpected nil value for {}", stringify!(self.#ident)); + + return Err(::prost::ValidateError::new("Empty non-nil message")) + } + }; + let field = match self { + Field::Scalar(s) => s, + Field::Message(f) => return if f.strict { expect_non_nil } else { empty }, + Field::Oneof(f) => return if f.strict { expect_non_nil } else { empty }, + _ => return empty, + }; + + match field.kind { + Kind::Plain(_) => { + // Continue + } + _ => return empty, + }; + match field.ty { + Ty::Uuid => { + quote! { + if self.#ident == uuid::Uuid::nil() { + debug_assert!(false, "Uuid was nil"); + + return Err(::prost::ValidateError::new("Uuid was nil")) + } + } + } + Ty::Enumeration(ref path) if field.strict => { + quote! { + if self.#ident == 0 || !#path::is_valid(self.#ident) { + debug_assert!(false, "Invalid case: {}", self.#ident); + + return Err(::prost::ValidateError::new("Illegal case found")) + } + } + } + Ty::InlinedEnum(_) => { + quote! { + if self.#ident as i32 == 0 || self.#ident == Default::default() { + debug_assert!(false, "Invalid case"); + + return Err(::prost::ValidateError::new("Illegal case found")) + } + } + } + _ => empty, + } + } + /// Creates a new oneof `Field` from an iterator of field attributes. /// /// If the meta items are invalid, an error will be returned. diff --git a/prost-derive/src/field/oneof.rs b/prost-derive/src/field/oneof.rs index 7e7f08671..29dca1e6b 100644 --- a/prost-derive/src/field/oneof.rs +++ b/prost-derive/src/field/oneof.rs @@ -3,18 +3,20 @@ use proc_macro2::TokenStream; use quote::quote; use syn::{parse_str, Lit, Meta, MetaNameValue, NestedMeta, Path}; -use crate::field::{set_option, tags_attr}; +use crate::field::{set_bool, set_option, tags_attr, word_attr}; #[derive(Clone)] pub struct Field { pub ty: Path, pub tags: Vec, + pub strict: bool, } impl Field { pub fn new(attrs: &[Meta]) -> Result, Error> { let mut ty = None; let mut tags = None; + let mut strict = false; let mut unknown_attrs = Vec::new(); for attr in attrs { @@ -39,6 +41,8 @@ impl Field { _ => bail!("invalid oneof attribute: {:?}", attr), }; set_option(&mut ty, t, "duplicate oneof attribute")?; + } else if word_attr("strict", attr) { + set_bool(&mut strict, "duplicate strict attribute")?; } else if let Some(t) = tags_attr(attr)? { set_option(&mut tags, t, "duplicate tags attributes")?; } else { @@ -65,7 +69,7 @@ impl Field { None => bail!("oneof field is missing a tags attribute"), }; - Ok(Some(Field { ty, tags })) + Ok(Some(Field { ty, tags, strict })) } /// Returns a statement which encodes the oneof field. diff --git a/prost-derive/src/field/scalar.rs b/prost-derive/src/field/scalar.rs index e088dbab6..15b9c29fa 100644 --- a/prost-derive/src/field/scalar.rs +++ b/prost-derive/src/field/scalar.rs @@ -6,7 +6,7 @@ use proc_macro2::{Span, TokenStream}; use quote::{quote, ToTokens, TokenStreamExt}; use syn::{parse_str, Ident, Lit, LitByteStr, Meta, MetaList, MetaNameValue, NestedMeta, Path}; -use crate::field::{bool_attr, set_option, tag_attr, Label}; +use crate::field::{bool_attr, set_bool, set_option, tag_attr, word_attr, Label}; /// A scalar protobuf field. #[derive(Clone)] @@ -14,6 +14,7 @@ pub struct Field { pub ty: Ty, pub kind: Kind, pub tag: u32, + pub strict: bool, } impl Field { @@ -23,12 +24,15 @@ impl Field { let mut packed = None; let mut default = None; let mut tag = None; + let mut strict = false; let mut unknown_attrs = Vec::new(); for attr in attrs { if let Some(t) = Ty::from_attr(attr)? { set_option(&mut ty, t, "duplicate type attributes")?; + } else if word_attr("strict", attr) { + set_bool(&mut strict, "duplicate strict attribute")?; } else if let Some(p) = bool_attr("packed", attr)? { set_option(&mut packed, p, "duplicate packed attributes")?; } else if let Some(t) = tag_attr(attr)? { @@ -86,7 +90,12 @@ impl Field { (Some(Label::Repeated), _, false) => Kind::Repeated, }; - Ok(Some(Field { ty, kind, tag })) + Ok(Some(Field { + ty, + kind, + tag, + strict, + })) } pub fn new_oneof(attrs: &[Meta]) -> Result, Error> { @@ -106,6 +115,31 @@ impl Field { } pub fn encode(&self, ident: TokenStream) -> TokenStream { + let tag = self.tag; + if matches!(self.ty, Ty::InlinedEnum(_)) { + match self.kind { + Kind::Plain(_) => { + return quote! { + if #ident != Default::default() { + ::prost::encoding::int32::encode(#tag, &(#ident as i32), buf); + } + } + } + Kind::Packed => { + // This is a vec + return quote! { + ::prost::encoding::int32::encode_packed(#tag, #ident.iter().map(|i| (*i as i32)).collect::>().as_slice(), buf); + }; + } + Kind::Required(_) => { + // This is inside a oneof + return quote! { + ::prost::encoding::int32::encode(#tag, &(*value as i32), buf); + }; + } + _ => panic!("Encode not supported"), + } + } let module = self.ty.module(); let encode_fn = match self.kind { Kind::Plain(..) | Kind::Optional(..) | Kind::Required(..) => quote!(encode), @@ -113,7 +147,6 @@ impl Field { Kind::Packed => quote!(encode_packed), }; let encode_fn = quote!(::prost::encoding::#module::#encode_fn); - let tag = self.tag; match self.kind { Kind::Plain(ref default) => { @@ -169,6 +202,32 @@ impl Field { let encoded_len_fn = quote!(::prost::encoding::#module::#encoded_len_fn); let tag = self.tag; + if matches!(self.ty, Ty::InlinedEnum(_)) { + return match self.kind { + Kind::Plain(ref default) => { + let default = default.typed(); + quote! { + if #ident != #default { + ::prost::encoding::int32::encoded_len(#tag, &(#ident as i32)) + } else { + 0 + } + } + } + Kind::Packed => { + quote! { + ::prost::encoding::int32::encoded_len_packed(#tag, #ident.iter().map(|i| (*i as i32)).collect::>().as_slice()) + } + } + Kind::Required(_) => { + quote! { + ::prost::encoding::int32::encoded_len(#tag, &(*value as i32)) + } + } + _ => panic!("Encoded len not supported"), + }; + } + match self.kind { Kind::Plain(ref default) => { let default = default.typed(); @@ -381,6 +440,8 @@ pub enum Ty { Sfixed64, Bool, String, + Uuid, + InlinedEnum(Path), Bytes(BytesTy), Enumeration(Path), } @@ -425,6 +486,7 @@ impl Ty { Meta::Path(ref name) if name.is_ident("sfixed64") => Ty::Sfixed64, Meta::Path(ref name) if name.is_ident("bool") => Ty::Bool, Meta::Path(ref name) if name.is_ident("string") => Ty::String, + Meta::Path(ref name) if name.is_ident("uuid") => Ty::Uuid, Meta::Path(ref name) if name.is_ident("bytes") => Ty::Bytes(BytesTy::Vec), Meta::NameValue(MetaNameValue { ref path, @@ -452,6 +514,27 @@ impl Ty { bail!("invalid enumeration attribute: only a single identifier is supported"); } } + Meta::NameValue(MetaNameValue { + ref path, + lit: Lit::Str(ref l), + .. + }) if path.is_ident("inlined_enum") => Ty::InlinedEnum(parse_str::(&l.value())?), + Meta::List(MetaList { + ref path, + ref nested, + .. + }) if path.is_ident("inlined_enum") => { + // TODO(rustlang/rust#23121): slice pattern matching would make this much nicer. + if nested.len() == 1 { + if let NestedMeta::Meta(Meta::Path(ref path)) = nested[0] { + Ty::InlinedEnum(path.clone()) + } else { + bail!("invalid enumeration attribute: item must be an identifier"); + } + } else { + bail!("invalid enumeration attribute: only a single identifier is supported"); + } + } _ => return Ok(None), }; Ok(Some(ty)) @@ -511,6 +594,8 @@ impl Ty { Ty::Sfixed64 => "sfixed64", Ty::Bool => "bool", Ty::String => "string", + Ty::Uuid => "uuid", + Ty::InlinedEnum(_) => "inlined_enum", Ty::Bytes(..) => "bytes", Ty::Enumeration(..) => "enum", } @@ -521,6 +606,9 @@ impl Ty { match self { Ty::String => quote!(::prost::alloc::string::String), Ty::Bytes(ty) => ty.rust_type(), + Ty::InlinedEnum(path) => quote! { + #path + }, _ => self.rust_ref_type(), } } @@ -542,6 +630,8 @@ impl Ty { Ty::Sfixed64 => quote!(i64), Ty::Bool => quote!(bool), Ty::String => quote!(&str), + Ty::Uuid => quote!(uuid::Uuid), + Ty::InlinedEnum(_) => quote!(i32), Ty::Bytes(..) => quote!(&[u8]), Ty::Enumeration(..) => quote!(i32), } @@ -549,14 +639,14 @@ impl Ty { pub fn module(&self) -> Ident { match *self { - Ty::Enumeration(..) => Ident::new("int32", Span::call_site()), + Ty::Enumeration(..) | Ty::InlinedEnum(_) => Ident::new("int32", Span::call_site()), _ => Ident::new(self.as_str(), Span::call_site()), } } /// Returns false if the scalar type is length delimited (i.e., `string` or `bytes`). pub fn is_numeric(&self) -> bool { - !matches!(self, Ty::String | Ty::Bytes(..)) + !matches!(self, Ty::String | Ty::Bytes(..) | Ty::Uuid) } } @@ -598,6 +688,8 @@ pub enum DefaultValue { U64(u64), Bool(bool), String(String), + Uuid, + InlinedEnum, Bytes(Vec), Enumeration(TokenStream), Path(Path), @@ -758,6 +850,8 @@ impl DefaultValue { Ty::Bool => DefaultValue::Bool(false), Ty::String => DefaultValue::String(String::new()), + Ty::Uuid => DefaultValue::Uuid, + Ty::InlinedEnum(_) => DefaultValue::InlinedEnum, Ty::Bytes(..) => DefaultValue::Bytes(Vec::new()), Ty::Enumeration(ref path) => DefaultValue::Enumeration(quote!(#path::default())), } @@ -801,6 +895,16 @@ impl ToTokens for DefaultValue { DefaultValue::U64(value) => value.to_tokens(tokens), DefaultValue::Bool(value) => value.to_tokens(tokens), DefaultValue::String(ref value) => value.to_tokens(tokens), + DefaultValue::Uuid => { + tokens.append_all(quote! { + uuid::Uuid::nil() + }); + } + DefaultValue::InlinedEnum => { + tokens.append_all(quote! { + Default::default() + }); + } DefaultValue::Bytes(ref value) => { let byte_str = LitByteStr::new(value, Span::call_site()); tokens.append_all(quote!(#byte_str as &[u8])); diff --git a/prost-derive/src/lib.rs b/prost-derive/src/lib.rs index ad1aa19cf..824413061 100644 --- a/prost-derive/src/lib.rs +++ b/prost-derive/src/lib.rs @@ -16,8 +16,7 @@ use syn::{ }; mod field; -use crate::field::Field; - +use crate::field::{Field, Kind, Ty}; fn try_message(input: TokenStream) -> Result { let input: DeriveInput = syn::parse(input)?; @@ -103,8 +102,69 @@ fn try_message(input: TokenStream) -> Result { let merge = fields.iter().map(|&(ref field_ident, ref field)| { let merge = field.merge(quote!(value)); let tags = field.tags().into_iter().map(|tag| quote!(#tag)); - let tags = Itertools::intersperse(tags, quote!(|)); + let tags = Itertools::intersperse(tags,quote!(|)); + + let field: &Field = field; + match field { + Field::Scalar(s) => { + match &s.ty { + Ty::InlinedEnum(p) => { + match s.kind { + Kind::Plain(_) => { + assert_eq!(1, field.tags().len()); + return quote! { + #(#tags)* => { + let mut owned = Default::default(); + let mut value = &mut owned; + + #merge.map_err(|mut error| { + error.push(STRUCT_NAME, stringify!(#field_ident)); + error + })?; + + match #p::from_i32(owned) { + Some(p) => self.#field_ident = p, + None => return Err(::prost::DecodeError::new("Invalid enum case")) + } + + Ok(()) + }, + } + } + Kind::Packed => { + return quote! { + #(#tags)* => { + let mut owned = Default::default(); + let mut value = &mut owned; + + #merge.map_err(|mut error| { + error.push(STRUCT_NAME, stringify!(#field_ident)); + error + })?; + + let owned_len = owned.len(); + + self.#field_ident = owned.into_iter().filter_map(|n| #p::from_i32(n)).collect(); + + if owned_len == self.#field_ident.len() { + Ok(()) + } else { + Err(::prost::DecodeError::new("Mismatch in decoded enum len")) + } + }, + } + } + _ => panic!("Merge not supported") + } + }, + _ => { + } + } + }, + _ => { + } + } quote! { #(#tags)* => { let mut value = &mut self.#field_ident; @@ -171,6 +231,11 @@ fn try_message(input: TokenStream) -> Result { quote!(f.debug_tuple(stringify!(#ident))) }; + // Validations fields: + // - Uuids are not default + // - Enum is not set to the first case + let validate = fields.iter().map(|(ident, field)| field.validate(&ident)); + let expanded = quote! { impl #impl_generics ::prost::Message for #ident #ty_generics #where_clause { #[allow(unused_variables)] @@ -199,6 +264,12 @@ fn try_message(input: TokenStream) -> Result { 0 #(+ #encoded_len)* } + fn validate(&self) -> Result<(), ::prost::ValidateError> { + #(#validate)* + + Ok(()) + } + fn clear(&mut self) { #(#clear;)* } @@ -391,6 +462,60 @@ fn try_oneof(input: TokenStream) -> Result { let merge = fields.iter().map(|&(ref variant_ident, ref field)| { let tag = field.tags()[0]; let merge = field.merge(quote!(value)); + + let field: &Field = field; + match field { + Field::Scalar(s) => { + match &s.ty { + Ty::InlinedEnum(p) => { + return quote! { + #tag => { + match field { + ::core::option::Option::Some(#ident::#variant_ident(ref mut value)) => { + let mut v = Default::default(); + + ::prost::encoding::int32::merge(wire_type, &mut v, buf, ctx)?; + + match #p::from_i32(v) { + Some(v) => { + *value = v; + } + None => { + return Err(::prost::DecodeError::new("Unknown enum")) + } + } + + Ok(()) + }, + _ => { + let mut owned_value = ::core::default::Default::default(); + + ::prost::encoding::int32::merge(wire_type, &mut owned_value, buf, ctx)?; + + match #p::from_i32(owned_value) { + Some(v) => { + *field = ::core::option::Option::Some(#ident::#variant_ident(v)) + } + None => { + return Err(::prost::DecodeError::new("Unknown enum")) + } + } + + Ok(()) + }, + } + } + } + } + _ => {} + } + }, + _ => { + + } + } + + quote! { #tag => { match field { diff --git a/src/encoding.rs b/src/encoding.rs index 4953d6603..3a8dfc5f3 100644 --- a/src/encoding.rs +++ b/src/encoding.rs @@ -873,6 +873,90 @@ pub mod string { } } +pub mod uuid { + use super::*; + use crate::alloc::str::FromStr; + use crate::alloc::string::ToString; + + pub fn encode(tag: u32, value: &::uuid::Uuid, buf: &mut B) + where + B: BufMut, + { + super::string::encode(tag, &value.to_string(), buf) + } + pub fn merge( + wire_type: WireType, + value: &mut ::uuid::Uuid, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + B: Buf, + { + let mut to_merge = String::with_capacity(36); + + super::string::merge(wire_type, &mut to_merge, buf, ctx)?; + + let uuid = string_to_uuid(&to_merge)?; + + *value = uuid; + + Ok(()) + } + + fn string_to_uuid(s: &str) -> Result<::uuid::Uuid, DecodeError> { + // Check if the merged string is an actual uuid + match ::uuid::Uuid::from_str(s) { + Ok(uuid) => Ok(uuid), + Err(err) => Err(DecodeError::new(format!( + "invalid Uuid value: {}, error: {}", + s, err + ))), + } + } + + encode_repeated!(::uuid::Uuid); + + pub fn merge_repeated( + wire_type: WireType, + values: &mut Vec<::uuid::Uuid>, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + B: Buf, + { + let mut strings = alloc::vec::Vec::new(); + + super::string::merge_repeated(wire_type, &mut strings, buf, ctx)?; + + for string in strings { + let uuid = string_to_uuid(&string)?; + + values.push(uuid); + } + + Ok(()) + } + + #[inline] + pub fn encoded_len(tag: u32, value: &::uuid::Uuid) -> usize { + super::string::encoded_len(tag, &value.to_string()) + } + + #[inline] + pub fn encoded_len_repeated(tag: u32, values: &[::uuid::Uuid]) -> usize { + super::string::encoded_len_repeated( + tag, + values + .iter() + .map(|u| u.to_string()) + .collect::>() + .as_slice(), + ) + } +} + pub trait BytesAdapter: sealed::BytesAdapter {} mod sealed { diff --git a/src/error.rs b/src/error.rs index 756ee8172..cb1757e94 100644 --- a/src/error.rs +++ b/src/error.rs @@ -92,7 +92,7 @@ pub struct EncodeError { impl EncodeError { /// Creates a new `EncodeError`. - pub(crate) fn new(required: usize, remaining: usize) -> EncodeError { + pub fn new(required: usize, remaining: usize) -> EncodeError { EncodeError { required, remaining, @@ -120,6 +120,35 @@ impl fmt::Display for EncodeError { } } +#[derive(Clone, PartialEq, Eq)] +pub struct ValidateError { + inner: Box, +} + +impl ValidateError { + pub fn new(description: impl Into>) -> Self { + Self { + inner: Box::new(Inner { + description: description.into(), + stack: Vec::new(), + }), + } + } +} + +impl From for DecodeError { + fn from(s: ValidateError) -> Self { + DecodeError { inner: s.inner } + } +} + +impl From for EncodeError { + fn from(_: ValidateError) -> Self { + // This is ugly but whatever + EncodeError::new(0, 0) + } +} + #[cfg(feature = "std")] impl std::error::Error for EncodeError {} diff --git a/src/lib.rs b/src/lib.rs index 09deda858..1f2990745 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,7 +16,7 @@ mod types; #[doc(hidden)] pub mod encoding; -pub use crate::error::{DecodeError, EncodeError}; +pub use crate::error::{DecodeError, EncodeError, ValidateError}; pub use crate::message::Message; use bytes::{Buf, BufMut}; diff --git a/src/message.rs b/src/message.rs index a190f6b47..7486ab6e3 100644 --- a/src/message.rs +++ b/src/message.rs @@ -9,6 +9,7 @@ use bytes::{Buf, BufMut}; use crate::encoding::{ decode_key, encode_varint, encoded_len_varint, message, DecodeContext, WireType, }; +use crate::error::ValidateError; use crate::DecodeError; use crate::EncodeError; @@ -43,6 +44,17 @@ pub trait Message: Debug + Send + Sync { /// Returns the encoded length of the message without a length delimiter. fn encoded_len(&self) -> usize; + fn encode_buffer(&self) -> Result, EncodeError> + where + Self: Sized, + { + let mut buffer = alloc::vec::Vec::with_capacity(self.encoded_len()); + + self.encode(&mut buffer)?; + + Ok(buffer) + } + /// Encodes the message to a buffer. /// /// An error will be returned if the buffer does not have sufficient capacity. @@ -58,6 +70,11 @@ pub trait Message: Debug + Send + Sync { } self.encode_raw(buf); + self.validate()?; + Ok(()) + } + + fn validate(&self) -> Result<(), ValidateError> { Ok(()) } @@ -113,7 +130,11 @@ pub trait Message: Debug + Send + Sync { Self: Default, { let mut message = Self::default(); - Self::merge(&mut message, &mut buf).map(|_| message) + let msg = Self::merge(&mut message, &mut buf).map(|_| message)?; + + msg.validate()?; + + Ok(msg) } /// Decodes a length-delimited instance of the message from the buffer. diff --git a/tests/Cargo.toml b/tests/Cargo.toml index e544dc812..900af483b 100644 --- a/tests/Cargo.toml +++ b/tests/Cargo.toml @@ -21,11 +21,14 @@ cfg-if = "1" prost = { path = ".." } prost-types = { path = "../prost-types" } protobuf = { path = "../protobuf" } +uuid = "*" [dev-dependencies] diff = "0.1" prost-build = { path = "../prost-build" } tempfile = "3" +remove_dir_all = "0.6" +uuid = "1" [build-dependencies] cfg-if = "1" diff --git a/tests/src/bootstrap.rs b/tests/src/bootstrap.rs index f50e583de..9230e3ba9 100644 --- a/tests/src/bootstrap.rs +++ b/tests/src/bootstrap.rs @@ -87,6 +87,13 @@ fn bootstrap() { .unwrap(); } - assert_eq!(protobuf, bootstrapped_protobuf); - assert_eq!(compiler, bootstrapped_compiler); + // Remove the weird newlines + assert_eq!( + protobuf.replace('\n', ""), + bootstrapped_protobuf.replace('\n', "") + ); + assert_eq!( + compiler.replace('\n', ""), + bootstrapped_compiler.replace('\n', "") + ); }