Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

better checking of tag duplicates, avoid discarding invalid variant errs #951

Merged
merged 11 commits into from
May 24, 2024
Merged
4 changes: 2 additions & 2 deletions prost-derive/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ proc_macro = true

[dependencies]
anyhow = "1.0.1"
itertools = { version = ">=0.10, <=0.12", default-features = false, features = ["use_alloc"] }
proc-macro2 = "1"
itertools = ">=0.10.1, <=0.12"
caspermeijn marked this conversation as resolved.
Show resolved Hide resolved
proc-macro2 = ">=1.0.60, <2"
mumbleskates marked this conversation as resolved.
Show resolved Hide resolved
quote = "1"
syn = { version = "2", features = ["extra-traits"] }
9 changes: 5 additions & 4 deletions prost-derive/src/field/group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,11 @@ impl Field {
return Ok(None);
}

match unknown_attrs.len() {
0 => (),
1 => bail!("unknown attribute for group field: {:?}", unknown_attrs[0]),
_ => bail!("unknown attributes for group field: {:?}", unknown_attrs),
if !unknown_attrs.is_empty() {
bail!(
"unknown attribute(s) for group field: #[prost({})]",
quote!(#(#unknown_attrs),*)
);
}

let tag = match tag.or(inferred_tag) {
Expand Down
12 changes: 5 additions & 7 deletions prost-derive/src/field/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,11 @@ impl Field {
return Ok(None);
}

match unknown_attrs.len() {
0 => (),
1 => bail!(
"unknown attribute for message field: {:?}",
unknown_attrs[0]
),
_ => bail!("unknown attributes for message field: {:?}", unknown_attrs),
if !unknown_attrs.is_empty() {
bail!(
"unknown attribute(s) for message field: #[prost({})]",
quote!(#(#unknown_attrs),*)
);
}

let tag = match tag.or(inferred_tag) {
Expand Down
12 changes: 5 additions & 7 deletions prost-derive/src/field/oneof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,11 @@ impl Field {
None => return Ok(None),
};

match unknown_attrs.len() {
0 => (),
1 => bail!(
"unknown attribute for message field: {:?}",
unknown_attrs[0]
),
_ => bail!("unknown attributes for message field: {:?}", unknown_attrs),
if !unknown_attrs.is_empty() {
bail!(
"unknown attribute(s) for message field: #[prost({})]",
quote!(#(#unknown_attrs),*)
);
}

let tags = match tags {
Expand Down
9 changes: 5 additions & 4 deletions prost-derive/src/field/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,11 @@ impl Field {
None => return Ok(None),
};

match unknown_attrs.len() {
0 => (),
1 => bail!("unknown attribute: {:?}", unknown_attrs[0]),
_ => bail!("unknown attributes: {:?}", unknown_attrs),
if !unknown_attrs.is_empty() {
bail!(
"unknown attribute(s): #[prost({})]",
quote!(#(#unknown_attrs),*)
);
}

let tag = match tag.or(inferred_tag) {
Expand Down
172 changes: 134 additions & 38 deletions prost-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ extern crate proc_macro;

use anyhow::{bail, Error};
use itertools::Itertools;
use proc_macro::TokenStream;
use proc_macro2::Span;
use proc_macro2::{Span, TokenStream};
caspermeijn marked this conversation as resolved.
Show resolved Hide resolved
use quote::quote;
use syn::{
punctuated::Punctuated, Data, DataEnum, DataStruct, DeriveInput, Expr, Fields, FieldsNamed,
Expand All @@ -19,7 +18,7 @@ mod field;
use crate::field::Field;

fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
let input: DeriveInput = syn::parse(input)?;
let input: DeriveInput = syn::parse2(input)?;

let ident = input.ident;

Expand Down Expand Up @@ -91,16 +90,19 @@ fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
fields.sort_by_key(|(_, field)| field.tags().into_iter().min().unwrap());
let fields = fields;

let mut tags = fields
if let Some((duplicate_tag, _)) = fields
.iter()
.flat_map(|(_, field)| field.tags())
.collect::<Vec<_>>();
let num_tags = tags.len();
tags.sort_unstable();
tags.dedup();
if tags.len() != num_tags {
bail!("message {} has fields with duplicate tags", ident);
}
.sorted_unstable()
.tuple_windows()
.find(|(a, b)| a == b)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would kinda be preferable to write this as .duplicates().next() using Itertools, but that method unfortunately requires use_std.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I understand prost_derive will run on the host, so it is not subject to no_std rules. The code does also use std::fmt;, so std must be available. So I think it is safe to use_std for Itertools. Do you agree?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah from what i've seen since i published this that is totally possible. i'll update it

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems you updated the other case, but not this one. Was that intentional?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it was not.

{
bail!(
"message {} has multiple fields with tag {}",
caspermeijn marked this conversation as resolved.
Show resolved Hide resolved
ident,
duplicate_tag
)
};

let encoded_len = fields
.iter()
Expand Down Expand Up @@ -251,16 +253,16 @@ fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
#methods
};

Ok(expanded.into())
Ok(expanded)
}

#[proc_macro_derive(Message, attributes(prost))]
pub fn message(input: TokenStream) -> TokenStream {
try_message(input).unwrap()
pub fn message(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
try_message(input.into()).unwrap().into()
}

fn try_enumeration(input: TokenStream) -> Result<TokenStream, Error> {
let input: DeriveInput = syn::parse(input)?;
let input: DeriveInput = syn::parse2(input)?;
let ident = input.ident;

let generics = &input.generics;
Expand Down Expand Up @@ -359,16 +361,16 @@ fn try_enumeration(input: TokenStream) -> Result<TokenStream, Error> {
}
};

Ok(expanded.into())
Ok(expanded)
}

#[proc_macro_derive(Enumeration, attributes(prost))]
pub fn enumeration(input: TokenStream) -> TokenStream {
try_enumeration(input).unwrap()
pub fn enumeration(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
try_enumeration(input.into()).unwrap().into()
}

fn try_oneof(input: TokenStream) -> Result<TokenStream, Error> {
let input: DeriveInput = syn::parse(input)?;
let input: DeriveInput = syn::parse2(input)?;

let ident = input.ident;

Expand Down Expand Up @@ -412,23 +414,28 @@ fn try_oneof(input: TokenStream) -> Result<TokenStream, Error> {
}
}

let mut tags = fields
for (variant_ident, field) in &fields {
// Not clear if this condition is reachable since multiple "tag" attributes are already
// rejected, but good to be safe
caspermeijn marked this conversation as resolved.
Show resolved Hide resolved
if field.tags().len() > 1 {
bail!(
"invalid oneof variant {}::{}: oneof variants may only have a single tag",
ident,
variant_ident
);
}
}
if let Some(duplicate_tag) = fields
.iter()
.flat_map(|(variant_ident, field)| -> Result<u32, Error> {
if field.tags().len() > 1 {
bail!(
"invalid oneof variant {}::{}: oneof variants may only have a single tag",
ident,
variant_ident
);
caspermeijn marked this conversation as resolved.
Show resolved Hide resolved
}
Ok(field.tags()[0])
})
.collect::<Vec<_>>();
tags.sort_unstable();
tags.dedup();
if tags.len() != fields.len() {
panic!("invalid oneof {}: variants have duplicate tags", ident);
.flat_map(|(_, field)| field.tags())
.duplicates()
.next()
{
bail!(
"invalid oneof {}: multiple variants have tag {}",
ident,
duplicate_tag
);
}

let encode = fields.iter().map(|(variant_ident, field)| {
Expand Down Expand Up @@ -519,10 +526,99 @@ fn try_oneof(input: TokenStream) -> Result<TokenStream, Error> {
}
};

Ok(expanded.into())
Ok(expanded)
}

#[proc_macro_derive(Oneof, attributes(prost))]
pub fn oneof(input: TokenStream) -> TokenStream {
try_oneof(input).unwrap()
pub fn oneof(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
try_oneof(input.into()).unwrap().into()
}

#[cfg(test)]
mod test {
use crate::{try_message, try_oneof};
use quote::quote;

#[test]
fn test_rejects_colliding_message_fields() {
let output = try_message(quote!(
struct Invalid {
#[prost(bool, tag = "1")]
a: bool,
#[prost(oneof = "super::Whatever", tags = "4, 5, 1")]
b: Option<super::Whatever>,
}
));
assert_eq!(
output
.expect_err("did not reject colliding message fields")
.to_string(),
"message Invalid has multiple fields with tag 1"
);
}

#[test]
fn test_rejects_colliding_oneof_variants() {
let output = try_oneof(quote!(
pub enum Invalid {
#[prost(bool, tag = "1")]
A(bool),
#[prost(bool, tag = "3")]
B(bool),
#[prost(bool, tag = "1")]
C(bool),
}
));
assert_eq!(
output
.expect_err("did not reject colliding oneof variants")
.to_string(),
"invalid oneof Invalid: multiple variants have tag 1"
);
caspermeijn marked this conversation as resolved.
Show resolved Hide resolved
}

#[test]
fn test_rejects_multiple_tags_oneof_variant() {
let output = try_oneof(quote!(
enum What {
#[prost(bool, tag = "1", tag = "2")]
A(bool),
}
));
assert_eq!(
output
.expect_err("did not reject multiple tags on oneof variant")
.to_string(),
"duplicate tag attributes: 1 and 2"
);

let output = try_oneof(quote!(
enum What {
#[prost(bool, tag = "3")]
#[prost(tag = "4")]
A(bool),
}
));
assert!(output.is_err());
assert_eq!(
output
.expect_err("did not reject multiple tags on oneof variant")
.to_string(),
"duplicate tag attributes: 3 and 4"
);

let output = try_oneof(quote!(
enum What {
#[prost(bool, tags = "5,6")]
A(bool),
}
));
assert!(output.is_err());
assert_eq!(
output
.expect_err("did not reject multiple tags on oneof variant")
.to_string(),
"unknown attribute(s): #[prost(tags = \"5,6\")]"
);
}
}
Loading