Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tets #1

Merged
merged 4 commits into from
May 19, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 41 additions & 39 deletions .github/workflows/continuous-integration-workflow.yaml
Original file line number Diff line number Diff line change
@@ -18,6 +18,8 @@ jobs:
default: true
profile: minimal
components: rustfmt
- name: build b_tests
run: cargo build --package b_tests
- name: rustfmt
uses: actions-rs/cargo@v1
with:
@@ -82,45 +84,45 @@ jobs:
command: test
args: --no-default-features

no-std:
runs-on: ubuntu-latest
steps:
- name: checkout
uses: actions/checkout@v2
with:
submodules: recursive
- name: install toolchain
uses: actions-rs/toolchain@v1
with:
toolchain: nightly
default: true
profile: minimal
- uses: Swatinem/rust-cache@v1
- name: install cargo-no-std-check
uses: actions-rs/cargo@v1
with:
command: install
args: cargo-no-std-check
- name: prost cargo-no-std-check
uses: actions-rs/cargo@v1
with:
command: no-std-check
args: --manifest-path Cargo.toml --no-default-features
- name: prost-types cargo-no-std-check
uses: actions-rs/cargo@v1
with:
command: no-std-check
args: --manifest-path prost-types/Cargo.toml --no-default-features
# prost-build depends on prost with --no-default-features, but when
# prost-build is built through the workspace, prost typically has default
# features enabled due to vagaries in Cargo workspace feature resolution.
# This additional check ensures that prost-build does not rely on any of
# prost's default features to compile.
- name: prost-build check
uses: actions-rs/cargo@v1
with:
command: check
args: --manifest-path prost-build/Cargo.toml
# no-std:
# runs-on: ubuntu-latest
# steps:
# - name: checkout
# uses: actions/checkout@v2
# with:
# submodules: recursive
# - name: install toolchain
# uses: actions-rs/toolchain@v1
# with:
# toolchain: nightly
# default: true
# profile: minimal
# - uses: Swatinem/rust-cache@v1
# - name: install cargo-no-std-check
# uses: actions-rs/cargo@v1
# with:
# command: install
# args: cargo-no-std-check
# - name: prost cargo-no-std-check
# uses: actions-rs/cargo@v1
# with:
# command: no-std-check
# args: --manifest-path Cargo.toml --no-default-features
# - name: prost-types cargo-no-std-check
# uses: actions-rs/cargo@v1
# with:
# command: no-std-check
# args: --manifest-path prost-types/Cargo.toml --no-default-features
# # prost-build depends on prost with --no-default-features, but when
# # prost-build is built through the workspace, prost typically has default
# # features enabled due to vagaries in Cargo workspace feature resolution.
# # This additional check ensures that prost-build does not rely on any of
# # prost's default features to compile.
# - name: prost-build check
# uses: actions-rs/cargo@v1
# with:
# command: check
# args: --manifest-path prost-build/Cargo.toml

vendored:
runs-on: ubuntu-latest
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -49,6 +49,7 @@ std = []
[dependencies]
bytes = { version = "1", default-features = false }
prost-derive = { version = "0.10.0", path = "prost-derive", optional = true }
uuid = { version = "1", features = ["v4"] }

[dev-dependencies]
criterion = "0.3"
3 changes: 3 additions & 0 deletions b_tests/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
/src/b_generated
/src/generated
/src/protos
6 changes: 6 additions & 0 deletions b_tests/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
[package]
name = "b_tests"
version = "0.1.0"
authors = ["Jasper Visser <jasperav@hotmail.com>"]
edition = "2018"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[build-dependencies]
prost-build = { path = "../prost-build" }
protobuf_strict = { git = "https://github.com/Jasperav/protobuf_strict.git" }

[dependencies]
uuid = { version = "1", features = ["v4"] }
prost = { path = "../" }
66 changes: 66 additions & 0 deletions b_tests/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
use prost_build::{Config, CustomType};
use std::io::Write;

fn main() {
let src = std::env::current_dir().unwrap().join("src");

protobuf_strict::write_protos(&src);

macro_rules! initialize_dir {
($name: expr) => {{
// Maybe the dir doesn't exists yet
let _ = std::fs::remove_dir_all(src.join($name));
std::fs::create_dir(src.join($name)).unwrap();
std::fs::File::create(src.join($name).join("mod.rs")).unwrap()
}};
}

let mut b_generated = initialize_dir!("b_generated");
let mut generated = initialize_dir!("generated");
let mut protos = vec![];
let paths = std::fs::read_dir("./src/protos").unwrap();

for path in paths {
let path = path.unwrap();
let file_name = path.file_name();
let s = file_name.to_str().unwrap();

protos.push("src/protos/".to_string() + s);
}

macro_rules! go_generate {
($name: expr, $file: expr, $config: expr) => {
std::env::set_var("OUT_DIR", src.join($name).to_str().unwrap());

$config
.compile_protos(protos.as_slice(), &["src/protos/".to_string()])
.unwrap();

let paths = std::fs::read_dir("./src/".to_string() + $name).unwrap();

for path in paths {
let path = path.unwrap();
let file_name = path.file_name();
let s = file_name.to_str().unwrap();

if s == "mod.rs" {
continue;
}

let m = s.strip_suffix(".rs").unwrap();

writeln!($file, "#[rustfmt::skip]\nmod {};\npub use {}::*;", m, m).unwrap();
}
};
}

let mut config = Config::new();

config
.add_types_mapping(protobuf_strict::get_uuids().to_vec(), CustomType::Uuid)
.strict_messages()
.inline_enums();

go_generate!("b_generated", b_generated, config);
go_generate!("generated", generated, Config::new());
}
191 changes: 187 additions & 4 deletions b_tests/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,191 @@
mod b_generated;
mod generated;

#[cfg(test)]
mod tests {
mod test {
use super::*;
use prost::Message;
use std::str::FromStr;
use uuid::Uuid;

const UUID: &'static str = "cd663747-6cb1-4ddc-bdfe-3dc76db62724";

fn get_uuid() -> Uuid {
uuid::Uuid::from_str(UUID).unwrap()
}

fn get_no_uuid() -> String {
"no_uuid".to_string()
}

fn get_custom() -> String {
"custom".to_string()
}

fn get_amount() -> i32 {
1
}

fn b_generated_order() -> b_generated::Order {
b_generated::Order {
gender: b_generated::Gender::Female,
genders: vec![b_generated::Gender::Female, b_generated::Gender::Other],
currency: Some(b_generated::Currency {
c: Some(b_generated::currency::C::Amount(get_amount())),
}),
o_currency: None,
currencies: vec![
b_generated::Currency {
c: Some(b_generated::currency::C::Custom(get_custom())),
},
b_generated::Currency {
c: Some(b_generated::currency::C::Amount(get_amount())),
},
],
uuid: get_uuid(),
no_uuid: get_no_uuid(),
repeated_uuids: vec![get_uuid(), get_uuid()],
no_uuids: vec![get_custom()],
order_inner: b_generated::order::OrderInner::InnerAnother,
order_inners: vec![
b_generated::order::OrderInner::InnerAnother,
b_generated::order::OrderInner::InnerAnother2,
],
something: Some(b_generated::order::Something::AlsoUuid(get_uuid())),
}
}

fn generated_order() -> generated::Order {
generated::Order {
gender: generated::Gender::Female as i32,
genders: vec![
generated::Gender::Female as i32,
generated::Gender::Other as i32,
],
currency: Some(generated::Currency {
c: Some(generated::currency::C::Amount(get_amount())),
}),
o_currency: None,
currencies: vec![
generated::Currency {
c: Some(generated::currency::C::Custom(get_custom())),
},
generated::Currency {
c: Some(generated::currency::C::Amount(get_amount())),
},
],
uuid: get_uuid().to_string(),
no_uuid: get_no_uuid(),
repeated_uuids: vec![get_uuid().to_string(), get_uuid().to_string()],
no_uuids: vec![get_custom()],
order_inner: generated::order::OrderInner::InnerAnother as i32,
order_inners: vec![
generated::order::OrderInner::InnerAnother as i32,
generated::order::OrderInner::InnerAnother2 as i32,
],
something: Some(generated::order::Something::AlsoUuid(
get_uuid().to_string(),
)),
}
}

#[test]
fn it_works() {
let result = 2 + 2;
assert_eq!(result, 4);
fn equal() {
let g = b_generated_order();
let b = generated_order();
let g = g.encode_buffer().unwrap();
let b = b.encode_buffer().unwrap();

assert_eq!(g, b);
assert_eq!(g.encoded_len(), b.encoded_len());

// Check if encoding works
b_generated::Order::decode(b.as_slice()).unwrap();
generated::Order::decode(g.as_slice()).unwrap();
}

fn check_order(order: generated::Order) {
let b = order.encode_buffer().unwrap();

b_generated::Order::decode(b.as_slice()).unwrap();
}

macro_rules! write_invalid_test {
($method_name: ident, $order: ident, $change: tt) => {
#[test]
#[should_panic]
fn $method_name() {
let mut $order = generated_order();

$change;

check_order($order);
}
};
}

write_invalid_test!(
invalid_gender_zero,
order,
({
order.gender = 0;
})
);
write_invalid_test!(
invalid_gender_over_max,
order,
({
order.gender = 999;
})
);
write_invalid_test!(
invalid_uuids,
order,
({
order.repeated_uuids = vec![get_uuid().to_string(), get_no_uuid().to_string()];
})
);
write_invalid_test!(
empty_currency,
order,
({
order.currency = None;
})
);
write_invalid_test!(
invalid_uuid,
order,
({
order.uuid = order.uuid[1..].to_string();
})
);

write_invalid_test!(
invalid_inner_zero,
order,
({
order.order_inner = 0;
})
);
write_invalid_test!(
invalid_inner_over_max,
order,
({
order.order_inner = 999;
})
);
write_invalid_test!(
empty_something,
order,
({
order.something = None;
})
);
write_invalid_test!(
something_invalid_uuid,
order,
({
order.something = Some(generated::order::Something::AlsoUuid(get_no_uuid()));
})
);
}
159 changes: 112 additions & 47 deletions prost-build/src/code_generator.rs
Original file line number Diff line number Diff line change
@@ -18,7 +18,7 @@ use crate::ast::{Comments, Method, Service};
use crate::extern_paths::ExternPaths;
use crate::ident::{to_snake, to_upper_camel};
use crate::message_graph::MessageGraph;
use crate::{BytesType, Config, MapType};
use crate::{BytesType, Config, CustomType, MapType};

#[derive(PartialEq)]
enum Syntax {
@@ -311,7 +311,7 @@ impl<'a> CodeGenerator<'a> {

self.push_indent();
self.buf.push_str("#[prost(");
let type_tag = self.field_type_tag(&field);
let type_tag = self.field_type_tag(&field, fq_message_name);
self.buf.push_str(&type_tag);

if type_ == Type::Bytes {
@@ -377,6 +377,7 @@ impl<'a> CodeGenerator<'a> {
} else {
&enum_value
};

self.buf.push_str(stripped_prefix);
} else {
self.buf.push_str(&default.escape_default().to_string());
@@ -385,6 +386,32 @@ impl<'a> CodeGenerator<'a> {

self.buf.push_str("\")]\n");
self.append_field_attributes(fq_message_name, field.name());

if self.config.strict_messages {
match field.r#type() {
Type::Message => match field.label() {
Label::Optional => {
if let Some(ref s) = field.name {
if !s.starts_with("o_") {
self.buf.push_str("#[prost(strict)]\n");
}
}
}
_ => {}
},
Type::Enum => {
if !self.config.inline_enums {
if self.config.inline_enums {
self.buf.push_str("#[prost(inlined)]\n");
} else {
self.buf.push_str("#[prost(strict)]\n");
}
}
}
_ => {}
}
}

self.push_indent();
self.buf.push_str("pub ");
self.buf.push_str(&to_snake(field.name()));
@@ -433,8 +460,8 @@ impl<'a> CodeGenerator<'a> {
.get_first_field(fq_message_name, field.name())
.copied()
.unwrap_or_default();
let key_tag = self.field_type_tag(key);
let value_tag = self.map_value_type_tag(value);
let key_tag = self.field_type_tag(key, fq_message_name);
let value_tag = self.map_value_type_tag(value, fq_message_name);

self.buf.push_str(&format!(
"#[prost({}=\"{}, {}\", tag=\"{}\")]\n",
@@ -476,6 +503,9 @@ impl<'a> CodeGenerator<'a> {
.map(|&(ref field, _)| field.number())
.join(", ")
));
if self.config.strict_messages && !oneof.name.clone().unwrap().starts_with("o_") {
self.buf.push_str("#[prost(strict)]\n");
}
self.append_field_attributes(fq_message_name, oneof.name());
self.push_indent();
self.buf.push_str(&format!(
@@ -518,7 +548,7 @@ impl<'a> CodeGenerator<'a> {
self.path.pop();

self.push_indent();
let ty_tag = self.field_type_tag(&field);
let ty_tag = self.field_type_tag(&field, fq_message_name);
self.buf.push_str(&format!(
"#[prost({}, tag=\"{}\")]\n",
ty_tag,
@@ -746,24 +776,36 @@ impl<'a> CodeGenerator<'a> {
}

fn resolve_type(&self, field: &FieldDescriptorProto, fq_message_name: &str) -> String {
match field.r#type() {
Type::Float => String::from("f32"),
Type::Double => String::from("f64"),
Type::Uint32 | Type::Fixed32 => String::from("u32"),
Type::Uint64 | Type::Fixed64 => String::from("u64"),
Type::Int32 | Type::Sfixed32 | Type::Sint32 | Type::Enum => String::from("i32"),
Type::Int64 | Type::Sfixed64 | Type::Sint64 => String::from("i64"),
Type::Bool => String::from("bool"),
Type::String => String::from("::prost::alloc::string::String"),
Type::Bytes => self
.config
.bytes_type
.get_first_field(fq_message_name, field.name())
.copied()
.unwrap_or_default()
.rust_type()
.to_owned(),
Type::Group | Type::Message => self.resolve_ident(field.type_name()),
if let Some(the_type) = self
.config
.custom_type
.get_first_field(fq_message_name, field.name())
{
match the_type {
CustomType::Uuid => "uuid::Uuid".to_string(),
}
} else if self.config.inline_enums && matches!(field.r#type(), Type::Enum) {
self.resolve_ident(field.type_name())
} else {
match field.r#type() {
Type::Float => String::from("f32"),
Type::Double => String::from("f64"),
Type::Uint32 | Type::Fixed32 => String::from("u32"),
Type::Uint64 | Type::Fixed64 => String::from("u64"),
Type::Int32 | Type::Sfixed32 | Type::Sint32 | Type::Enum => String::from("i32"),
Type::Int64 | Type::Sfixed64 | Type::Sint64 => String::from("i64"),
Type::Bool => String::from("bool"),
Type::String => String::from("::prost::alloc::string::String"),
Type::Bytes => self
.config
.bytes_type
.get_first_field(fq_message_name, field.name())
.copied()
.unwrap_or_default()
.rust_type()
.to_owned(),
Type::Group | Type::Message => self.resolve_ident(field.type_name()),
}
}
}

@@ -801,39 +843,62 @@ impl<'a> CodeGenerator<'a> {
.join("::")
}

fn field_type_tag(&self, field: &FieldDescriptorProto) -> Cow<'static, str> {
match field.r#type() {
Type::Float => Cow::Borrowed("float"),
Type::Double => Cow::Borrowed("double"),
Type::Int32 => Cow::Borrowed("int32"),
Type::Int64 => Cow::Borrowed("int64"),
Type::Uint32 => Cow::Borrowed("uint32"),
Type::Uint64 => Cow::Borrowed("uint64"),
Type::Sint32 => Cow::Borrowed("sint32"),
Type::Sint64 => Cow::Borrowed("sint64"),
Type::Fixed32 => Cow::Borrowed("fixed32"),
Type::Fixed64 => Cow::Borrowed("fixed64"),
Type::Sfixed32 => Cow::Borrowed("sfixed32"),
Type::Sfixed64 => Cow::Borrowed("sfixed64"),
Type::Bool => Cow::Borrowed("bool"),
Type::String => Cow::Borrowed("string"),
Type::Bytes => Cow::Borrowed("bytes"),
Type::Group => Cow::Borrowed("group"),
Type::Message => Cow::Borrowed("message"),
Type::Enum => Cow::Owned(format!(
"enumeration={:?}",
fn field_type_tag(
&self,
field: &FieldDescriptorProto,
fq_message_name: &str,
) -> Cow<'static, str> {
if let Some(the_type) = self
.config
.custom_type
.get_first_field(fq_message_name, field.name())
{
match the_type {
CustomType::Uuid => Cow::Borrowed("uuid"),
}
} else if self.config.inline_enums && matches!(field.r#type(), Type::Enum) {
Cow::Owned(format!(
"inlined_enum={:?}",
self.resolve_ident(field.type_name())
)),
))
} else {
match field.r#type() {
Type::Float => Cow::Borrowed("float"),
Type::Double => Cow::Borrowed("double"),
Type::Int32 => Cow::Borrowed("int32"),
Type::Int64 => Cow::Borrowed("int64"),
Type::Uint32 => Cow::Borrowed("uint32"),
Type::Uint64 => Cow::Borrowed("uint64"),
Type::Sint32 => Cow::Borrowed("sint32"),
Type::Sint64 => Cow::Borrowed("sint64"),
Type::Fixed32 => Cow::Borrowed("fixed32"),
Type::Fixed64 => Cow::Borrowed("fixed64"),
Type::Sfixed32 => Cow::Borrowed("sfixed32"),
Type::Sfixed64 => Cow::Borrowed("sfixed64"),
Type::Bool => Cow::Borrowed("bool"),
Type::String => Cow::Borrowed("string"),
Type::Bytes => Cow::Borrowed("bytes"),
Type::Group => Cow::Borrowed("group"),
Type::Message => Cow::Borrowed("message"),
Type::Enum => Cow::Owned(format!(
"enumeration={:?}",
self.resolve_ident(field.type_name())
)),
}
}
}

fn map_value_type_tag(&self, field: &FieldDescriptorProto) -> Cow<'static, str> {
fn map_value_type_tag(
&self,
field: &FieldDescriptorProto,
fq_message_name: &str,
) -> Cow<'static, str> {
match field.r#type() {
Type::Enum => Cow::Owned(format!(
"enumeration({})",
self.resolve_ident(field.type_name())
)),
_ => self.field_type_tag(field),
_ => self.field_type_tag(field, fq_message_name),
}
}

83 changes: 82 additions & 1 deletion prost-build/src/lib.rs
Original file line number Diff line number Diff line change
@@ -226,10 +226,16 @@ impl Default for BytesType {
}
}

#[derive(Clone, Copy, PartialEq)]
pub enum CustomType {
Uuid,
}

/// Configuration options for Protobuf code generation.
///
/// This configuration builder can be used to set non-default code generation options.
pub struct Config {
start_file_with: Vec<String>,
file_descriptor_set_path: Option<PathBuf>,
service_generator: Option<Box<dyn ServiceGenerator>>,
map_type: PathMap<MapType>,
@@ -243,6 +249,9 @@ pub struct Config {
default_package_filename: String,
protoc_args: Vec<OsString>,
disable_comments: PathMap<()>,
custom_type: PathMap<CustomType>,
strict_messages: bool,
inline_enums: bool,
skip_protoc_run: bool,
include_file: Option<PathBuf>,
}
@@ -253,6 +262,12 @@ impl Config {
Config::default()
}

pub fn add_start_to_file<T: ToString>(&mut self, s: T) -> &mut Self {
self.start_file_with.push(s.to_string());

self
}

/// Configure the code generator to generate Rust [`BTreeMap`][1] fields for Protobuf
/// [`map`][2] type fields.
///
@@ -408,6 +423,18 @@ impl Config {
self
}

pub fn fields_attribute<P, A>(&mut self, paths: &[P], attribute: A) -> &mut Self
where
P: AsRef<str>,
A: AsRef<str>,
{
for path in paths.iter() {
self.field_attribute(path, &attribute);
}

self
}

/// Add additional attribute to matched messages, enums and one-ofs.
///
/// # Arguments
@@ -457,6 +484,19 @@ impl Config {
self
}

pub fn types_attribute<P, A>(&mut self, paths: &[P], attribute: A) -> &mut Self
where
P: AsRef<str>,
A: AsRef<str>,
{
for path in paths.iter() {
self.type_attributes
.insert(path.as_ref().to_string(), attribute.as_ref().to_string());
}

self
}

/// Configures the code generator to use the provided service generator.
pub fn service_generator(&mut self, service_generator: Box<dyn ServiceGenerator>) -> &mut Self {
self.service_generator = Some(service_generator);
@@ -754,6 +794,36 @@ impl Config {
self
}

pub fn add_type_mapping<M>(&mut self, to_match: M, custom_type: CustomType) -> &mut Self
where
M: ToString,
{
self.custom_type.insert(to_match.to_string(), custom_type);

self
}

pub fn add_types_mapping<M>(&mut self, to_match: Vec<M>, custom_type: CustomType) -> &mut Self
where
M: ToString,
{
for to_match in to_match.into_iter() {
self.add_type_mapping(to_match, custom_type);
}

self
}

pub fn strict_messages(&mut self) -> &mut Self {
self.strict_messages = true;
self
}

pub fn inline_enums(&mut self) -> &mut Self {
self.inline_enums = true;
self
}

/// Compile `.proto` files into Rust files during a Cargo build with additional code generator
/// configuration options.
///
@@ -894,7 +964,14 @@ impl Config {
trace!("unchanged: {:?}", file_name);
} else {
trace!("writing: {:?}", file_name);
fs::write(output_path, content)?;

let mut file = std::fs::File::create(output_path)?;

for i in &self.start_file_with {
writeln!(file, "{}", i)?;
}

writeln!(file, "{}", content)?;
}
}

@@ -1017,6 +1094,7 @@ impl Config {
impl default::Default for Config {
fn default() -> Config {
Config {
start_file_with: vec![],
file_descriptor_set_path: None,
service_generator: None,
map_type: PathMap::default(),
@@ -1030,6 +1108,9 @@ impl default::Default for Config {
default_package_filename: "_".to_string(),
protoc_args: Vec::new(),
disable_comments: PathMap::default(),
custom_type: PathMap::default(),
strict_messages: false,
inline_enums: false,
skip_protoc_run: false,
include_file: None,
}
16 changes: 13 additions & 3 deletions prost-build/src/path.rs
Original file line number Diff line number Diff line change
@@ -3,17 +3,27 @@
use std::iter;

/// Maps a fully-qualified Protobuf path to a value using path matchers.
#[derive(Debug, Default)]
#[derive(Debug)]
pub(crate) struct PathMap<T> {
// insertion order might actually matter (to avoid warning about legacy-derive-helpers)
// see: https://doc.rust-lang.org/rustc/lints/listing/warn-by-default.html#legacy-derive-helpers
pub(crate) matchers: Vec<(String, T)>,
}

impl<T> PathMap<T> {
impl<T> Default for PathMap<T> {
fn default() -> Self {
Self {
matchers: Default::default(),
}
}
}

impl<T: Clone + PartialEq> PathMap<T> {
/// Inserts a new matcher and associated value to the path map.
pub(crate) fn insert(&mut self, matcher: String, value: T) {
self.matchers.push((matcher, value));
if !self.matchers.contains(&(matcher.clone(), value.clone())) {
self.matchers.push((matcher, value));
}
}

/// Returns a iterator over all the value matching the given fd_path and associated suffix/prefix path
1 change: 1 addition & 0 deletions prost-derive/Cargo.toml
Original file line number Diff line number Diff line change
@@ -21,3 +21,4 @@ itertools = "0.10"
proc-macro2 = "1"
quote = "1"
syn = { version = "1", features = [ "extra-traits" ] }
uuid = "1"
1 change: 1 addition & 0 deletions prost-derive/src/field/map.rs
Original file line number Diff line number Diff line change
@@ -41,6 +41,7 @@ fn fake_scalar(ty: scalar::Ty) -> scalar::Field {
ty,
kind,
tag: 0, // Not used here
strict: false,
}
}

5 changes: 5 additions & 0 deletions prost-derive/src/field/message.rs
Original file line number Diff line number Diff line change
@@ -9,20 +9,24 @@ use crate::field::{set_bool, set_option, tag_attr, word_attr, Label};
pub struct Field {
pub label: Label,
pub tag: u32,
pub strict: bool,
}

impl Field {
pub fn new(attrs: &[Meta], inferred_tag: Option<u32>) -> Result<Option<Field>, Error> {
let mut message = false;
let mut label = None;
let mut tag = None;
let mut strict = false;
let mut boxed = false;

let mut unknown_attrs = Vec::new();

for attr in attrs {
if word_attr("message", attr) {
set_bool(&mut message, "duplicate message attribute")?;
} else if word_attr("strict", attr) {
set_bool(&mut strict, "duplicate strict attribute")?;
} else if word_attr("boxed", attr) {
set_bool(&mut boxed, "duplicate boxed attribute")?;
} else if let Some(t) = tag_attr(attr)? {
@@ -55,6 +59,7 @@ impl Field {
Ok(Some(Field {
label: label.unwrap_or(Label::Optional),
tag,
strict,
}))
}

57 changes: 56 additions & 1 deletion prost-derive/src/field/mod.rs
Original file line number Diff line number Diff line change
@@ -2,11 +2,12 @@ mod group;
mod map;
mod message;
mod oneof;
mod scalar;
pub mod scalar;

use std::fmt;
use std::slice;

pub use crate::field::scalar::{Kind, Ty};
use anyhow::{bail, Error};
use proc_macro2::TokenStream;
use quote::quote;
@@ -53,6 +54,60 @@ impl Field {
Ok(Some(field))
}

pub fn validate(&self, ident: &Ident) -> TokenStream {
let empty = quote! {};
let expect_non_nil = quote! {
if self.#ident.is_none() {
debug_assert!(false, "Unexpected nil value for {}", stringify!(self.#ident));

return Err(::prost::ValidateError::new("Empty non-nil message"))
}
};
let field = match self {
Field::Scalar(s) => s,
Field::Message(f) => return if f.strict { expect_non_nil } else { empty },
Field::Oneof(f) => return if f.strict { expect_non_nil } else { empty },
_ => return empty,
};

match field.kind {
Kind::Plain(_) => {
// Continue
}
_ => return empty,
};
match field.ty {
Ty::Uuid => {
quote! {
if self.#ident == uuid::Uuid::nil() {
debug_assert!(false, "Uuid was nil");

return Err(::prost::ValidateError::new("Uuid was nil"))
}
}
}
Ty::Enumeration(ref path) if field.strict => {
quote! {
if self.#ident == 0 || !#path::is_valid(self.#ident) {
debug_assert!(false, "Invalid case: {}", self.#ident);

return Err(::prost::ValidateError::new("Illegal case found"))
}
}
}
Ty::InlinedEnum(_) => {
quote! {
if self.#ident as i32 == 0 || self.#ident == Default::default() {
debug_assert!(false, "Invalid case");

return Err(::prost::ValidateError::new("Illegal case found"))
}
}
}
_ => empty,
}
}

/// Creates a new oneof `Field` from an iterator of field attributes.
///
/// If the meta items are invalid, an error will be returned.
8 changes: 6 additions & 2 deletions prost-derive/src/field/oneof.rs
Original file line number Diff line number Diff line change
@@ -3,18 +3,20 @@ use proc_macro2::TokenStream;
use quote::quote;
use syn::{parse_str, Lit, Meta, MetaNameValue, NestedMeta, Path};

use crate::field::{set_option, tags_attr};
use crate::field::{set_bool, set_option, tags_attr, word_attr};

#[derive(Clone)]
pub struct Field {
pub ty: Path,
pub tags: Vec<u32>,
pub strict: bool,
}

impl Field {
pub fn new(attrs: &[Meta]) -> Result<Option<Field>, Error> {
let mut ty = None;
let mut tags = None;
let mut strict = false;
let mut unknown_attrs = Vec::new();

for attr in attrs {
@@ -39,6 +41,8 @@ impl Field {
_ => bail!("invalid oneof attribute: {:?}", attr),
};
set_option(&mut ty, t, "duplicate oneof attribute")?;
} else if word_attr("strict", attr) {
set_bool(&mut strict, "duplicate strict attribute")?;
} else if let Some(t) = tags_attr(attr)? {
set_option(&mut tags, t, "duplicate tags attributes")?;
} else {
@@ -65,7 +69,7 @@ impl Field {
None => bail!("oneof field is missing a tags attribute"),
};

Ok(Some(Field { ty, tags }))
Ok(Some(Field { ty, tags, strict }))
}

/// Returns a statement which encodes the oneof field.
114 changes: 109 additions & 5 deletions prost-derive/src/field/scalar.rs
Original file line number Diff line number Diff line change
@@ -6,14 +6,15 @@ use proc_macro2::{Span, TokenStream};
use quote::{quote, ToTokens, TokenStreamExt};
use syn::{parse_str, Ident, Lit, LitByteStr, Meta, MetaList, MetaNameValue, NestedMeta, Path};

use crate::field::{bool_attr, set_option, tag_attr, Label};
use crate::field::{bool_attr, set_bool, set_option, tag_attr, word_attr, Label};

/// A scalar protobuf field.
#[derive(Clone)]
pub struct Field {
pub ty: Ty,
pub kind: Kind,
pub tag: u32,
pub strict: bool,
}

impl Field {
@@ -23,12 +24,15 @@ impl Field {
let mut packed = None;
let mut default = None;
let mut tag = None;
let mut strict = false;

let mut unknown_attrs = Vec::new();

for attr in attrs {
if let Some(t) = Ty::from_attr(attr)? {
set_option(&mut ty, t, "duplicate type attributes")?;
} else if word_attr("strict", attr) {
set_bool(&mut strict, "duplicate strict attribute")?;
} else if let Some(p) = bool_attr("packed", attr)? {
set_option(&mut packed, p, "duplicate packed attributes")?;
} else if let Some(t) = tag_attr(attr)? {
@@ -86,7 +90,12 @@ impl Field {
(Some(Label::Repeated), _, false) => Kind::Repeated,
};

Ok(Some(Field { ty, kind, tag }))
Ok(Some(Field {
ty,
kind,
tag,
strict,
}))
}

pub fn new_oneof(attrs: &[Meta]) -> Result<Option<Field>, Error> {
@@ -106,14 +115,38 @@ impl Field {
}

pub fn encode(&self, ident: TokenStream) -> TokenStream {
let tag = self.tag;
if matches!(self.ty, Ty::InlinedEnum(_)) {
match self.kind {
Kind::Plain(_) => {
return quote! {
if #ident != Default::default() {
::prost::encoding::int32::encode(#tag, &(#ident as i32), buf);
}
}
}
Kind::Packed => {
// This is a vec
return quote! {
::prost::encoding::int32::encode_packed(#tag, #ident.iter().map(|i| (*i as i32)).collect::<Vec<_>>().as_slice(), buf);
};
}
Kind::Required(_) => {
// This is inside a oneof
return quote! {
::prost::encoding::int32::encode(#tag, &(*value as i32), buf);
};
}
_ => panic!("Encode not supported"),
}
}
let module = self.ty.module();
let encode_fn = match self.kind {
Kind::Plain(..) | Kind::Optional(..) | Kind::Required(..) => quote!(encode),
Kind::Repeated => quote!(encode_repeated),
Kind::Packed => quote!(encode_packed),
};
let encode_fn = quote!(::prost::encoding::#module::#encode_fn);
let tag = self.tag;

match self.kind {
Kind::Plain(ref default) => {
@@ -169,6 +202,32 @@ impl Field {
let encoded_len_fn = quote!(::prost::encoding::#module::#encoded_len_fn);
let tag = self.tag;

if matches!(self.ty, Ty::InlinedEnum(_)) {
return match self.kind {
Kind::Plain(ref default) => {
let default = default.typed();
quote! {
if #ident != #default {
::prost::encoding::int32::encoded_len(#tag, &(#ident as i32))
} else {
0
}
}
}
Kind::Packed => {
quote! {
::prost::encoding::int32::encoded_len_packed(#tag, #ident.iter().map(|i| (*i as i32)).collect::<Vec<_>>().as_slice())
}
}
Kind::Required(_) => {
quote! {
::prost::encoding::int32::encoded_len(#tag, &(*value as i32))
}
}
_ => panic!("Encoded len not supported"),
};
}

match self.kind {
Kind::Plain(ref default) => {
let default = default.typed();
@@ -381,6 +440,8 @@ pub enum Ty {
Sfixed64,
Bool,
String,
Uuid,
InlinedEnum(Path),
Bytes(BytesTy),
Enumeration(Path),
}
@@ -425,6 +486,7 @@ impl Ty {
Meta::Path(ref name) if name.is_ident("sfixed64") => Ty::Sfixed64,
Meta::Path(ref name) if name.is_ident("bool") => Ty::Bool,
Meta::Path(ref name) if name.is_ident("string") => Ty::String,
Meta::Path(ref name) if name.is_ident("uuid") => Ty::Uuid,
Meta::Path(ref name) if name.is_ident("bytes") => Ty::Bytes(BytesTy::Vec),
Meta::NameValue(MetaNameValue {
ref path,
@@ -452,6 +514,27 @@ impl Ty {
bail!("invalid enumeration attribute: only a single identifier is supported");
}
}
Meta::NameValue(MetaNameValue {
ref path,
lit: Lit::Str(ref l),
..
}) if path.is_ident("inlined_enum") => Ty::InlinedEnum(parse_str::<Path>(&l.value())?),
Meta::List(MetaList {
ref path,
ref nested,
..
}) if path.is_ident("inlined_enum") => {
// TODO(rustlang/rust#23121): slice pattern matching would make this much nicer.
if nested.len() == 1 {
if let NestedMeta::Meta(Meta::Path(ref path)) = nested[0] {
Ty::InlinedEnum(path.clone())
} else {
bail!("invalid enumeration attribute: item must be an identifier");
}
} else {
bail!("invalid enumeration attribute: only a single identifier is supported");
}
}
_ => return Ok(None),
};
Ok(Some(ty))
@@ -511,6 +594,8 @@ impl Ty {
Ty::Sfixed64 => "sfixed64",
Ty::Bool => "bool",
Ty::String => "string",
Ty::Uuid => "uuid",
Ty::InlinedEnum(_) => "inlined_enum",
Ty::Bytes(..) => "bytes",
Ty::Enumeration(..) => "enum",
}
@@ -521,6 +606,9 @@ impl Ty {
match self {
Ty::String => quote!(::prost::alloc::string::String),
Ty::Bytes(ty) => ty.rust_type(),
Ty::InlinedEnum(path) => quote! {
#path
},
_ => self.rust_ref_type(),
}
}
@@ -542,21 +630,23 @@ impl Ty {
Ty::Sfixed64 => quote!(i64),
Ty::Bool => quote!(bool),
Ty::String => quote!(&str),
Ty::Uuid => quote!(uuid::Uuid),
Ty::InlinedEnum(_) => quote!(i32),
Ty::Bytes(..) => quote!(&[u8]),
Ty::Enumeration(..) => quote!(i32),
}
}

pub fn module(&self) -> Ident {
match *self {
Ty::Enumeration(..) => Ident::new("int32", Span::call_site()),
Ty::Enumeration(..) | Ty::InlinedEnum(_) => Ident::new("int32", Span::call_site()),
_ => Ident::new(self.as_str(), Span::call_site()),
}
}

/// Returns false if the scalar type is length delimited (i.e., `string` or `bytes`).
pub fn is_numeric(&self) -> bool {
!matches!(self, Ty::String | Ty::Bytes(..))
!matches!(self, Ty::String | Ty::Bytes(..) | Ty::Uuid)
}
}

@@ -598,6 +688,8 @@ pub enum DefaultValue {
U64(u64),
Bool(bool),
String(String),
Uuid,
InlinedEnum,
Bytes(Vec<u8>),
Enumeration(TokenStream),
Path(Path),
@@ -758,6 +850,8 @@ impl DefaultValue {

Ty::Bool => DefaultValue::Bool(false),
Ty::String => DefaultValue::String(String::new()),
Ty::Uuid => DefaultValue::Uuid,
Ty::InlinedEnum(_) => DefaultValue::InlinedEnum,
Ty::Bytes(..) => DefaultValue::Bytes(Vec::new()),
Ty::Enumeration(ref path) => DefaultValue::Enumeration(quote!(#path::default())),
}
@@ -801,6 +895,16 @@ impl ToTokens for DefaultValue {
DefaultValue::U64(value) => value.to_tokens(tokens),
DefaultValue::Bool(value) => value.to_tokens(tokens),
DefaultValue::String(ref value) => value.to_tokens(tokens),
DefaultValue::Uuid => {
tokens.append_all(quote! {
uuid::Uuid::nil()
});
}
DefaultValue::InlinedEnum => {
tokens.append_all(quote! {
Default::default()
});
}
DefaultValue::Bytes(ref value) => {
let byte_str = LitByteStr::new(value, Span::call_site());
tokens.append_all(quote!(#byte_str as &[u8]));
131 changes: 128 additions & 3 deletions prost-derive/src/lib.rs
Original file line number Diff line number Diff line change
@@ -16,8 +16,7 @@ use syn::{
};

mod field;
use crate::field::Field;

use crate::field::{Field, Kind, Ty};
fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
let input: DeriveInput = syn::parse(input)?;

@@ -103,8 +102,69 @@ fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
let merge = fields.iter().map(|&(ref field_ident, ref field)| {
let merge = field.merge(quote!(value));
let tags = field.tags().into_iter().map(|tag| quote!(#tag));
let tags = Itertools::intersperse(tags, quote!(|));
let tags = Itertools::intersperse(tags,quote!(|));

let field: &Field = field;
match field {
Field::Scalar(s) => {
match &s.ty {
Ty::InlinedEnum(p) => {
match s.kind {
Kind::Plain(_) => {
assert_eq!(1, field.tags().len());
return quote! {
#(#tags)* => {
let mut owned = Default::default();
let mut value = &mut owned;

#merge.map_err(|mut error| {
error.push(STRUCT_NAME, stringify!(#field_ident));
error
})?;

match #p::from_i32(owned) {
Some(p) => self.#field_ident = p,
None => return Err(::prost::DecodeError::new("Invalid enum case"))
}

Ok(())
},
}
}
Kind::Packed => {
return quote! {
#(#tags)* => {
let mut owned = Default::default();
let mut value = &mut owned;

#merge.map_err(|mut error| {
error.push(STRUCT_NAME, stringify!(#field_ident));
error
})?;

let owned_len = owned.len();

self.#field_ident = owned.into_iter().filter_map(|n| #p::from_i32(n)).collect();

if owned_len == self.#field_ident.len() {
Ok(())
} else {
Err(::prost::DecodeError::new("Mismatch in decoded enum len"))
}
},
}
}
_ => panic!("Merge not supported")
}
},
_ => {

}
}
},
_ => {
}
}
quote! {
#(#tags)* => {
let mut value = &mut self.#field_ident;
@@ -171,6 +231,11 @@ fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
quote!(f.debug_tuple(stringify!(#ident)))
};

// Validations fields:
// - Uuids are not default
// - Enum is not set to the first case
let validate = fields.iter().map(|(ident, field)| field.validate(&ident));

let expanded = quote! {
impl #impl_generics ::prost::Message for #ident #ty_generics #where_clause {
#[allow(unused_variables)]
@@ -199,6 +264,12 @@ fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
0 #(+ #encoded_len)*
}

fn validate(&self) -> Result<(), ::prost::ValidateError> {
#(#validate)*

Ok(())
}

fn clear(&mut self) {
#(#clear;)*
}
@@ -391,6 +462,60 @@ fn try_oneof(input: TokenStream) -> Result<TokenStream, Error> {
let merge = fields.iter().map(|&(ref variant_ident, ref field)| {
let tag = field.tags()[0];
let merge = field.merge(quote!(value));

let field: &Field = field;
match field {
Field::Scalar(s) => {
match &s.ty {
Ty::InlinedEnum(p) => {
return quote! {
#tag => {
match field {
::core::option::Option::Some(#ident::#variant_ident(ref mut value)) => {
let mut v = Default::default();

::prost::encoding::int32::merge(wire_type, &mut v, buf, ctx)?;

match #p::from_i32(v) {
Some(v) => {
*value = v;
}
None => {
return Err(::prost::DecodeError::new("Unknown enum"))
}
}

Ok(())
},
_ => {
let mut owned_value = ::core::default::Default::default();

::prost::encoding::int32::merge(wire_type, &mut owned_value, buf, ctx)?;

match #p::from_i32(owned_value) {
Some(v) => {
*field = ::core::option::Option::Some(#ident::#variant_ident(v))
}
None => {
return Err(::prost::DecodeError::new("Unknown enum"))
}
}

Ok(())
},
}
}
}
}
_ => {}
}
},
_ => {

}
}


quote! {
#tag => {
match field {
84 changes: 84 additions & 0 deletions src/encoding.rs
Original file line number Diff line number Diff line change
@@ -873,6 +873,90 @@ pub mod string {
}
}

pub mod uuid {
use super::*;
use crate::alloc::str::FromStr;
use crate::alloc::string::ToString;

pub fn encode<B>(tag: u32, value: &::uuid::Uuid, buf: &mut B)
where
B: BufMut,
{
super::string::encode(tag, &value.to_string(), buf)
}
pub fn merge<B>(
wire_type: WireType,
value: &mut ::uuid::Uuid,
buf: &mut B,
ctx: DecodeContext,
) -> Result<(), DecodeError>
where
B: Buf,
{
let mut to_merge = String::with_capacity(36);

super::string::merge(wire_type, &mut to_merge, buf, ctx)?;

let uuid = string_to_uuid(&to_merge)?;

*value = uuid;

Ok(())
}

fn string_to_uuid(s: &str) -> Result<::uuid::Uuid, DecodeError> {
// Check if the merged string is an actual uuid
match ::uuid::Uuid::from_str(s) {
Ok(uuid) => Ok(uuid),
Err(err) => Err(DecodeError::new(format!(
"invalid Uuid value: {}, error: {}",
s, err
))),
}
}

encode_repeated!(::uuid::Uuid);

pub fn merge_repeated<B>(
wire_type: WireType,
values: &mut Vec<::uuid::Uuid>,
buf: &mut B,
ctx: DecodeContext,
) -> Result<(), DecodeError>
where
B: Buf,
{
let mut strings = alloc::vec::Vec::new();

super::string::merge_repeated(wire_type, &mut strings, buf, ctx)?;

for string in strings {
let uuid = string_to_uuid(&string)?;

values.push(uuid);
}

Ok(())
}

#[inline]
pub fn encoded_len(tag: u32, value: &::uuid::Uuid) -> usize {
super::string::encoded_len(tag, &value.to_string())
}

#[inline]
pub fn encoded_len_repeated(tag: u32, values: &[::uuid::Uuid]) -> usize {
super::string::encoded_len_repeated(
tag,
values
.iter()
.map(|u| u.to_string())
.collect::<Vec<_>>()
.as_slice(),
)
}
}

pub trait BytesAdapter: sealed::BytesAdapter {}

mod sealed {
31 changes: 30 additions & 1 deletion src/error.rs
Original file line number Diff line number Diff line change
@@ -92,7 +92,7 @@ pub struct EncodeError {

impl EncodeError {
/// Creates a new `EncodeError`.
pub(crate) fn new(required: usize, remaining: usize) -> EncodeError {
pub fn new(required: usize, remaining: usize) -> EncodeError {
EncodeError {
required,
remaining,
@@ -120,6 +120,35 @@ impl fmt::Display for EncodeError {
}
}

#[derive(Clone, PartialEq, Eq)]
pub struct ValidateError {
inner: Box<Inner>,
}

impl ValidateError {
pub fn new(description: impl Into<Cow<'static, str>>) -> Self {
Self {
inner: Box::new(Inner {
description: description.into(),
stack: Vec::new(),
}),
}
}
}

impl From<ValidateError> for DecodeError {
fn from(s: ValidateError) -> Self {
DecodeError { inner: s.inner }
}
}

impl From<ValidateError> for EncodeError {
fn from(_: ValidateError) -> Self {
// This is ugly but whatever
EncodeError::new(0, 0)
}
}

#[cfg(feature = "std")]
impl std::error::Error for EncodeError {}

2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
@@ -16,7 +16,7 @@ mod types;
#[doc(hidden)]
pub mod encoding;

pub use crate::error::{DecodeError, EncodeError};
pub use crate::error::{DecodeError, EncodeError, ValidateError};
pub use crate::message::Message;

use bytes::{Buf, BufMut};
23 changes: 22 additions & 1 deletion src/message.rs
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@ use bytes::{Buf, BufMut};
use crate::encoding::{
decode_key, encode_varint, encoded_len_varint, message, DecodeContext, WireType,
};
use crate::error::ValidateError;
use crate::DecodeError;
use crate::EncodeError;

@@ -43,6 +44,17 @@ pub trait Message: Debug + Send + Sync {
/// Returns the encoded length of the message without a length delimiter.
fn encoded_len(&self) -> usize;

fn encode_buffer(&self) -> Result<alloc::vec::Vec<u8>, EncodeError>
where
Self: Sized,
{
let mut buffer = alloc::vec::Vec::with_capacity(self.encoded_len());

self.encode(&mut buffer)?;

Ok(buffer)
}

/// Encodes the message to a buffer.
///
/// An error will be returned if the buffer does not have sufficient capacity.
@@ -58,6 +70,11 @@ pub trait Message: Debug + Send + Sync {
}

self.encode_raw(buf);
self.validate()?;
Ok(())
}

fn validate(&self) -> Result<(), ValidateError> {
Ok(())
}

@@ -113,7 +130,11 @@ pub trait Message: Debug + Send + Sync {
Self: Default,
{
let mut message = Self::default();
Self::merge(&mut message, &mut buf).map(|_| message)
let msg = Self::merge(&mut message, &mut buf).map(|_| message)?;

msg.validate()?;

Ok(msg)
}

/// Decodes a length-delimited instance of the message from the buffer.
3 changes: 3 additions & 0 deletions tests/Cargo.toml
Original file line number Diff line number Diff line change
@@ -21,11 +21,14 @@ cfg-if = "1"
prost = { path = ".." }
prost-types = { path = "../prost-types" }
protobuf = { path = "../protobuf" }
uuid = "*"

[dev-dependencies]
diff = "0.1"
prost-build = { path = "../prost-build" }
tempfile = "3"
remove_dir_all = "0.6"
uuid = "1"

[build-dependencies]
cfg-if = "1"
11 changes: 9 additions & 2 deletions tests/src/bootstrap.rs
Original file line number Diff line number Diff line change
@@ -87,6 +87,13 @@ fn bootstrap() {
.unwrap();
}

assert_eq!(protobuf, bootstrapped_protobuf);
assert_eq!(compiler, bootstrapped_compiler);
// Remove the weird newlines
assert_eq!(
protobuf.replace('\n', ""),
bootstrapped_protobuf.replace('\n', "")
);
assert_eq!(
compiler.replace('\n', ""),
bootstrapped_compiler.replace('\n', "")
);
}