From dbd5111f5aaed8cb35cc3c1702e2a8bb6b900d37 Mon Sep 17 00:00:00 2001 From: Kent Ross Date: Sat, 25 Nov 2023 13:06:14 -0800 Subject: [PATCH 01/11] better checking of tag duplicates, avoid discarding invalid variant errors --- prost-derive/src/lib.rs | 59 ++++++++++++++++++++++------------------- 1 file changed, 32 insertions(+), 27 deletions(-) diff --git a/prost-derive/src/lib.rs b/prost-derive/src/lib.rs index 58952fbdc..bd1e8fc25 100644 --- a/prost-derive/src/lib.rs +++ b/prost-derive/src/lib.rs @@ -91,16 +91,19 @@ fn try_message(input: TokenStream) -> Result { fields.sort_by_key(|(_, field)| field.tags().into_iter().min().unwrap()); let fields = fields; - let mut tags = fields + if let Some((duplicate_tag, _)) = fields .iter() .flat_map(|(_, field)| field.tags()) - .collect::>(); - let num_tags = tags.len(); - tags.sort_unstable(); - tags.dedup(); - if tags.len() != num_tags { - bail!("message {} has fields with duplicate tags", ident); - } + .sorted_unstable() + .tuple_windows() + .find(|(a, b)| a == b) + { + bail!( + "message {} has multiple fields with tag {}", + ident, + duplicate_tag + ) + }; let encoded_len = fields .iter() @@ -251,7 +254,7 @@ fn try_message(input: TokenStream) -> Result { #methods }; - Ok(expanded.into()) + Ok(expanded) } #[proc_macro_derive(Message, attributes(prost))] @@ -359,7 +362,7 @@ fn try_enumeration(input: TokenStream) -> Result { } }; - Ok(expanded.into()) + Ok(expanded) } #[proc_macro_derive(Enumeration, attributes(prost))] @@ -412,23 +415,25 @@ fn try_oneof(input: TokenStream) -> Result { } } - let mut tags = fields + if let Some((invalid_variant, _)) = fields.iter().find(|(_, field)| field.tags().len() > 1) { + bail!( + "invalid oneof variant {}::{}: oneof variants may only have a single tag", + ident, + invalid_variant + ); + } + if let Some((duplicate_tag, _)) = fields .iter() - .flat_map(|(variant_ident, field)| -> Result { - if field.tags().len() > 1 { - bail!( - "invalid oneof variant {}::{}: oneof variants may only have a single tag", - ident, - variant_ident - ); - } - Ok(field.tags()[0]) - }) - .collect::>(); - tags.sort_unstable(); - tags.dedup(); - if tags.len() != fields.len() { - panic!("invalid oneof {}: variants have duplicate tags", ident); + .flat_map(|(_, field)| field.tags()) + .sorted_unstable() + .tuple_windows() + .find(|(a, b)| a == b) + { + bail!( + "invalid oneof {}: multiple variants have tag {}", + ident, + duplicate_tag + ); } let encode = fields.iter().map(|(variant_ident, field)| { @@ -519,7 +524,7 @@ fn try_oneof(input: TokenStream) -> Result { } }; - Ok(expanded.into()) + Ok(expanded) } #[proc_macro_derive(Oneof, attributes(prost))] From 4c00d4ef36c912d7528d39792ad78fecb334e74f Mon Sep 17 00:00:00 2001 From: Kent Ross Date: Sat, 25 Nov 2023 14:30:34 -0800 Subject: [PATCH 02/11] add some simple derive tests :) --- prost-derive/src/lib.rs | 71 +++++++++++++++++++++++++++++++---------- 1 file changed, 54 insertions(+), 17 deletions(-) diff --git a/prost-derive/src/lib.rs b/prost-derive/src/lib.rs index bd1e8fc25..e7c7d8d5c 100644 --- a/prost-derive/src/lib.rs +++ b/prost-derive/src/lib.rs @@ -7,8 +7,7 @@ extern crate proc_macro; use anyhow::{bail, Error}; use itertools::Itertools; -use proc_macro::TokenStream; -use proc_macro2::Span; +use proc_macro2::{Span, TokenStream}; use quote::quote; use syn::{ punctuated::Punctuated, Data, DataEnum, DataStruct, DeriveInput, Expr, Fields, FieldsNamed, @@ -19,7 +18,7 @@ mod field; use crate::field::Field; fn try_message(input: TokenStream) -> Result { - let input: DeriveInput = syn::parse(input)?; + let input: DeriveInput = syn::parse2(input)?; let ident = input.ident; @@ -258,12 +257,12 @@ fn try_message(input: TokenStream) -> Result { } #[proc_macro_derive(Message, attributes(prost))] -pub fn message(input: TokenStream) -> TokenStream { - try_message(input).unwrap() +pub fn message(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + try_message(input.into()).unwrap().into() } fn try_enumeration(input: TokenStream) -> Result { - let input: DeriveInput = syn::parse(input)?; + let input: DeriveInput = syn::parse2(input)?; let ident = input.ident; let generics = &input.generics; @@ -366,12 +365,12 @@ fn try_enumeration(input: TokenStream) -> Result { } #[proc_macro_derive(Enumeration, attributes(prost))] -pub fn enumeration(input: TokenStream) -> TokenStream { - try_enumeration(input).unwrap() +pub fn enumeration(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + try_enumeration(input.into()).unwrap().into() } fn try_oneof(input: TokenStream) -> Result { - let input: DeriveInput = syn::parse(input)?; + let input: DeriveInput = syn::parse2(input)?; let ident = input.ident; @@ -415,12 +414,8 @@ fn try_oneof(input: TokenStream) -> Result { } } - if let Some((invalid_variant, _)) = fields.iter().find(|(_, field)| field.tags().len() > 1) { - bail!( - "invalid oneof variant {}::{}: oneof variants may only have a single tag", - ident, - invalid_variant - ); + if fields.iter().any(|(_, field)| field.tags().len() > 1) { + panic!("variant with multiple tags"); // Not clear if this is possible, but good to be safe } if let Some((duplicate_tag, _)) = fields .iter() @@ -528,6 +523,48 @@ fn try_oneof(input: TokenStream) -> Result { } #[proc_macro_derive(Oneof, attributes(prost))] -pub fn oneof(input: TokenStream) -> TokenStream { - try_oneof(input).unwrap() +pub fn oneof(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + try_oneof(input.into()).unwrap().into() +} + +#[cfg(test)] +mod test { + use crate::{try_message, try_oneof}; + use quote::quote; + + #[test] + fn test_rejects_colliding_message_fields() { + let output = try_message(quote!( + struct Invalid { + #[prost(bool, tag = "1")] + a: bool, + #[prost(oneof = "super::Whatever", tags = "4, 5, 1")] + b: Option, + } + )); + assert!(output.is_err()); + assert_eq!( + output.unwrap_err().to_string(), + "message Invalid has multiple fields with tag 1" + ); + } + + #[test] + fn test_rejects_colliding_oneof_variants() { + let output = try_oneof(quote!( + pub enum Invalid { + #[prost(bool, tag = "1")] + A(bool), + #[prost(bool, tag = "3")] + B(bool), + #[prost(bool, tag = "1")] + C(bool), + } + )); + assert!(output.is_err()); + assert_eq!( + output.unwrap_err().to_string(), + "invalid oneof Invalid: multiple variants have tag 1" + ); + } } From 12e65f34fec8f73f61a243fdd6832067f710af5e Mon Sep 17 00:00:00 2001 From: Kent Ross Date: Wed, 22 May 2024 16:30:04 -0700 Subject: [PATCH 03/11] use itertools duplicates() --- prost-derive/Cargo.toml | 2 +- prost-derive/src/lib.rs | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/prost-derive/Cargo.toml b/prost-derive/Cargo.toml index fbff78b54..87ee65987 100644 --- a/prost-derive/Cargo.toml +++ b/prost-derive/Cargo.toml @@ -20,7 +20,7 @@ proc_macro = true [dependencies] anyhow = "1.0.1" -itertools = { version = ">=0.10, <=0.12", default-features = false, features = ["use_alloc"] } +itertools = ">=0.10, <=0.12" proc-macro2 = "1" quote = "1" syn = { version = "2", features = ["extra-traits"] } diff --git a/prost-derive/src/lib.rs b/prost-derive/src/lib.rs index e7c7d8d5c..d665533ef 100644 --- a/prost-derive/src/lib.rs +++ b/prost-derive/src/lib.rs @@ -417,12 +417,11 @@ fn try_oneof(input: TokenStream) -> Result { if fields.iter().any(|(_, field)| field.tags().len() > 1) { panic!("variant with multiple tags"); // Not clear if this is possible, but good to be safe } - if let Some((duplicate_tag, _)) = fields + if let Some(duplicate_tag) = fields .iter() .flat_map(|(_, field)| field.tags()) - .sorted_unstable() - .tuple_windows() - .find(|(a, b)| a == b) + .duplicates() + .next() { bail!( "invalid oneof {}: multiple variants have tag {}", From 0d0a1dd11afbec378fb95fd778354743e4d58ef5 Mon Sep 17 00:00:00 2001 From: Kent Ross Date: Wed, 22 May 2024 16:30:45 -0700 Subject: [PATCH 04/11] don't print out unreadable syn junk when encountering unknown attributes --- prost-derive/src/field/group.rs | 6 ++---- prost-derive/src/field/message.rs | 9 ++------- prost-derive/src/field/oneof.rs | 9 ++------- prost-derive/src/field/scalar.rs | 6 ++---- 4 files changed, 8 insertions(+), 22 deletions(-) diff --git a/prost-derive/src/field/group.rs b/prost-derive/src/field/group.rs index 076b577d7..4c30668bb 100644 --- a/prost-derive/src/field/group.rs +++ b/prost-derive/src/field/group.rs @@ -38,10 +38,8 @@ impl Field { return Ok(None); } - match unknown_attrs.len() { - 0 => (), - 1 => bail!("unknown attribute for group field: {:?}", unknown_attrs[0]), - _ => bail!("unknown attributes for group field: {:?}", unknown_attrs), + if !unknown_attrs.is_empty() { + bail!("unknown attribute(s) for group field: {}", quote!(#(#unknown_attrs),*)); } let tag = match tag.or(inferred_tag) { diff --git a/prost-derive/src/field/message.rs b/prost-derive/src/field/message.rs index 3bcdddfb1..317fb1530 100644 --- a/prost-derive/src/field/message.rs +++ b/prost-derive/src/field/message.rs @@ -38,13 +38,8 @@ impl Field { return Ok(None); } - match unknown_attrs.len() { - 0 => (), - 1 => bail!( - "unknown attribute for message field: {:?}", - unknown_attrs[0] - ), - _ => bail!("unknown attributes for message field: {:?}", unknown_attrs), + if !unknown_attrs.is_empty() { + bail!("unknown attribute(s) for message field: {}", quote!(#(#unknown_attrs),*)); } let tag = match tag.or(inferred_tag) { diff --git a/prost-derive/src/field/oneof.rs b/prost-derive/src/field/oneof.rs index 78c77eeb1..7b241ac4d 100644 --- a/prost-derive/src/field/oneof.rs +++ b/prost-derive/src/field/oneof.rs @@ -44,13 +44,8 @@ impl Field { None => return Ok(None), }; - match unknown_attrs.len() { - 0 => (), - 1 => bail!( - "unknown attribute for message field: {:?}", - unknown_attrs[0] - ), - _ => bail!("unknown attributes for message field: {:?}", unknown_attrs), + if !unknown_attrs.is_empty() { + bail!("unknown attribute(s) for message field: {}", quote!(#(#unknown_attrs),*)); } let tags = match tags { diff --git a/prost-derive/src/field/scalar.rs b/prost-derive/src/field/scalar.rs index 6be16cd70..4cb07045d 100644 --- a/prost-derive/src/field/scalar.rs +++ b/prost-derive/src/field/scalar.rs @@ -46,10 +46,8 @@ impl Field { None => return Ok(None), }; - match unknown_attrs.len() { - 0 => (), - 1 => bail!("unknown attribute: {:?}", unknown_attrs[0]), - _ => bail!("unknown attributes: {:?}", unknown_attrs), + if !unknown_attrs.is_empty() { + bail!("unknown attribute(s): {}", quote!(#(#unknown_attrs),*)); } let tag = match tag.or(inferred_tag) { From 0b3216a55838452153e10afd90cbe31fba902e67 Mon Sep 17 00:00:00 2001 From: Kent Ross Date: Wed, 22 May 2024 16:31:23 -0700 Subject: [PATCH 05/11] clarify and test backstop for the "multiple tags in a oneof variant" condition --- prost-derive/src/lib.rs | 43 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/prost-derive/src/lib.rs b/prost-derive/src/lib.rs index d665533ef..e5bc55659 100644 --- a/prost-derive/src/lib.rs +++ b/prost-derive/src/lib.rs @@ -414,8 +414,16 @@ fn try_oneof(input: TokenStream) -> Result { } } - if fields.iter().any(|(_, field)| field.tags().len() > 1) { - panic!("variant with multiple tags"); // Not clear if this is possible, but good to be safe + for (variant_ident, field) in &fields { + // Not clear if this condition is reachable since multiple "tag" attributes are already + // rejected, but good to be safe + if field.tags().len() > 1 { + bail!( + "invalid oneof variant {}::{}: oneof variants may only have a single tag", + ident, + variant_ident + ); + } } if let Some(duplicate_tag) = fields .iter() @@ -566,4 +574,35 @@ mod test { "invalid oneof Invalid: multiple variants have tag 1" ); } + + #[test] + fn test_rejects_multiple_tags_oneof_variant() { + let output = try_oneof(quote!( + enum What { + #[prost(bool, tag = "1", tag = "2")] + A(bool), + } + )); + assert!(output.is_err()); + assert_eq!(output.unwrap_err().to_string(), "duplicate tag attributes: 1 and 2"); + + let output = try_oneof(quote!( + enum What { + #[prost(bool, tag = "3")] + #[prost(tag = "4")] + A(bool), + } + )); + assert!(output.is_err()); + assert_eq!(output.unwrap_err().to_string(), "duplicate tag attributes: 3 and 4"); + + let output = try_oneof(quote!( + enum What { + #[prost(bool, tags = "5,6")] + A(bool), + } + )); + assert!(output.is_err()); + assert_eq!(output.unwrap_err().to_string(), "unknown attribute(s): tags = \"5,6\""); + } } From 6cfd80f808a1ce3a522d2e7cc2929fedd1031c0e Mon Sep 17 00:00:00 2001 From: Kent Ross Date: Wed, 22 May 2024 16:32:59 -0700 Subject: [PATCH 06/11] use expect_err --- prost-derive/src/lib.rs | 32 ++++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/prost-derive/src/lib.rs b/prost-derive/src/lib.rs index e5bc55659..068904c73 100644 --- a/prost-derive/src/lib.rs +++ b/prost-derive/src/lib.rs @@ -549,9 +549,10 @@ mod test { b: Option, } )); - assert!(output.is_err()); assert_eq!( - output.unwrap_err().to_string(), + output + .expect_err("did not reject colliding message fields") + .to_string(), "message Invalid has multiple fields with tag 1" ); } @@ -568,9 +569,10 @@ mod test { C(bool), } )); - assert!(output.is_err()); assert_eq!( - output.unwrap_err().to_string(), + output + .expect_err("did not reject colliding oneof variants") + .to_string(), "invalid oneof Invalid: multiple variants have tag 1" ); } @@ -583,8 +585,12 @@ mod test { A(bool), } )); - assert!(output.is_err()); - assert_eq!(output.unwrap_err().to_string(), "duplicate tag attributes: 1 and 2"); + assert_eq!( + output + .expect_err("did not reject multiple tags on oneof variant") + .to_string(), + "duplicate tag attributes: 1 and 2" + ); let output = try_oneof(quote!( enum What { @@ -594,7 +600,12 @@ mod test { } )); assert!(output.is_err()); - assert_eq!(output.unwrap_err().to_string(), "duplicate tag attributes: 3 and 4"); + assert_eq!( + output + .expect_err("did not reject multiple tags on oneof variant") + .to_string(), + "duplicate tag attributes: 3 and 4" + ); let output = try_oneof(quote!( enum What { @@ -603,6 +614,11 @@ mod test { } )); assert!(output.is_err()); - assert_eq!(output.unwrap_err().to_string(), "unknown attribute(s): tags = \"5,6\""); + assert_eq!( + output + .expect_err("did not reject multiple tags on oneof variant") + .to_string(), + "unknown attribute(s): tags = \"5,6\"" + ); } } From 02bc13232e5a3d9aa041e8657456ea8590992863 Mon Sep 17 00:00:00 2001 From: Kent Ross Date: Wed, 22 May 2024 16:52:32 -0700 Subject: [PATCH 07/11] nicer framing around the unknown attribute tokens --- prost-derive/src/field/group.rs | 5 ++++- prost-derive/src/field/message.rs | 5 ++++- prost-derive/src/field/oneof.rs | 5 ++++- prost-derive/src/field/scalar.rs | 5 ++++- prost-derive/src/lib.rs | 2 +- 5 files changed, 17 insertions(+), 5 deletions(-) diff --git a/prost-derive/src/field/group.rs b/prost-derive/src/field/group.rs index 4c30668bb..485ecfc1b 100644 --- a/prost-derive/src/field/group.rs +++ b/prost-derive/src/field/group.rs @@ -39,7 +39,10 @@ impl Field { } if !unknown_attrs.is_empty() { - bail!("unknown attribute(s) for group field: {}", quote!(#(#unknown_attrs),*)); + bail!( + "unknown attribute(s) for group field: #[prost({})]", + quote!(#(#unknown_attrs),*) + ); } let tag = match tag.or(inferred_tag) { diff --git a/prost-derive/src/field/message.rs b/prost-derive/src/field/message.rs index 317fb1530..f6ac391e7 100644 --- a/prost-derive/src/field/message.rs +++ b/prost-derive/src/field/message.rs @@ -39,7 +39,10 @@ impl Field { } if !unknown_attrs.is_empty() { - bail!("unknown attribute(s) for message field: {}", quote!(#(#unknown_attrs),*)); + bail!( + "unknown attribute(s) for message field: #[prost({})]", + quote!(#(#unknown_attrs),*) + ); } let tag = match tag.or(inferred_tag) { diff --git a/prost-derive/src/field/oneof.rs b/prost-derive/src/field/oneof.rs index 7b241ac4d..ad1e32f19 100644 --- a/prost-derive/src/field/oneof.rs +++ b/prost-derive/src/field/oneof.rs @@ -45,7 +45,10 @@ impl Field { }; if !unknown_attrs.is_empty() { - bail!("unknown attribute(s) for message field: {}", quote!(#(#unknown_attrs),*)); + bail!( + "unknown attribute(s) for message field: #[prost({})]", + quote!(#(#unknown_attrs),*) + ); } let tags = match tags { diff --git a/prost-derive/src/field/scalar.rs b/prost-derive/src/field/scalar.rs index 4cb07045d..c2e870524 100644 --- a/prost-derive/src/field/scalar.rs +++ b/prost-derive/src/field/scalar.rs @@ -47,7 +47,10 @@ impl Field { }; if !unknown_attrs.is_empty() { - bail!("unknown attribute(s): {}", quote!(#(#unknown_attrs),*)); + bail!( + "unknown attribute(s): #[prost({})]", + quote!(#(#unknown_attrs),*) + ); } let tag = match tag.or(inferred_tag) { diff --git a/prost-derive/src/lib.rs b/prost-derive/src/lib.rs index 068904c73..8b2f1e717 100644 --- a/prost-derive/src/lib.rs +++ b/prost-derive/src/lib.rs @@ -618,7 +618,7 @@ mod test { output .expect_err("did not reject multiple tags on oneof variant") .to_string(), - "unknown attribute(s): tags = \"5,6\"" + "unknown attribute(s): #[prost(tags = \"5,6\")]" ); } } From be6f1400c3288ef289bc7de7a2f4df38c5f3540f Mon Sep 17 00:00:00 2001 From: Kent Ross Date: Wed, 22 May 2024 17:01:59 -0700 Subject: [PATCH 08/11] express higher minimal versions of itertools and proc-macro2 for cargo hack check --- prost-derive/Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/prost-derive/Cargo.toml b/prost-derive/Cargo.toml index 87ee65987..d67d4f044 100644 --- a/prost-derive/Cargo.toml +++ b/prost-derive/Cargo.toml @@ -20,7 +20,7 @@ proc_macro = true [dependencies] anyhow = "1.0.1" -itertools = ">=0.10, <=0.12" -proc-macro2 = "1" +itertools = ">=0.10.1, <=0.12" +proc-macro2 = ">=1.0.60, <2" quote = "1" syn = { version = "2", features = ["extra-traits"] } From 25455424f8388bd3d9344331e2a3d19a17bdcb49 Mon Sep 17 00:00:00 2001 From: Kent Ross Date: Thu, 23 May 2024 21:51:17 -0700 Subject: [PATCH 09/11] simplify bounds for proc-macro2 Co-authored-by: Casper Meijn --- prost-derive/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prost-derive/Cargo.toml b/prost-derive/Cargo.toml index d67d4f044..19c44a7c2 100644 --- a/prost-derive/Cargo.toml +++ b/prost-derive/Cargo.toml @@ -21,6 +21,6 @@ proc_macro = true [dependencies] anyhow = "1.0.1" itertools = ">=0.10.1, <=0.12" -proc-macro2 = ">=1.0.60, <2" +proc-macro2 = "1.0.60" quote = "1" syn = { version = "2", features = ["extra-traits"] } From aa45a1487da3b33457eb737ad90680548c3929cd Mon Sep 17 00:00:00 2001 From: Kent Ross Date: Thu, 23 May 2024 21:55:36 -0700 Subject: [PATCH 10/11] update the other instance of the .tuple_windows() trick to use .duplicates() --- prost-derive/src/lib.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/prost-derive/src/lib.rs b/prost-derive/src/lib.rs index 8b2f1e717..d90559468 100644 --- a/prost-derive/src/lib.rs +++ b/prost-derive/src/lib.rs @@ -90,12 +90,11 @@ fn try_message(input: TokenStream) -> Result { fields.sort_by_key(|(_, field)| field.tags().into_iter().min().unwrap()); let fields = fields; - if let Some((duplicate_tag, _)) = fields + if let Some(duplicate_tag) = fields .iter() .flat_map(|(_, field)| field.tags()) - .sorted_unstable() - .tuple_windows() - .find(|(a, b)| a == b) + .duplicates() + .next() { bail!( "message {} has multiple fields with tag {}", From af24da173081657df73175d5d8cc3d5cddd10a13 Mon Sep 17 00:00:00 2001 From: Kent Ross Date: Thu, 23 May 2024 22:14:30 -0700 Subject: [PATCH 11/11] clarify & shorten assertion --- prost-derive/src/lib.rs | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/prost-derive/src/lib.rs b/prost-derive/src/lib.rs index d90559468..0cd49c57d 100644 --- a/prost-derive/src/lib.rs +++ b/prost-derive/src/lib.rs @@ -413,17 +413,10 @@ fn try_oneof(input: TokenStream) -> Result { } } - for (variant_ident, field) in &fields { - // Not clear if this condition is reachable since multiple "tag" attributes are already - // rejected, but good to be safe - if field.tags().len() > 1 { - bail!( - "invalid oneof variant {}::{}: oneof variants may only have a single tag", - ident, - variant_ident - ); - } - } + // Oneof variants cannot be oneofs themselves, so it's impossible to have a field with multiple + // tags. + assert!(fields.iter().all(|(_, field)| field.tags().len() == 1)); + if let Some(duplicate_tag) = fields .iter() .flat_map(|(_, field)| field.tags())