diff --git a/prost-derive/src/lib.rs b/prost-derive/src/lib.rs index 642ad5825..3a78c4aa1 100644 --- a/prost-derive/src/lib.rs +++ b/prost-derive/src/lib.rs @@ -7,8 +7,7 @@ extern crate proc_macro; use anyhow::{bail, Error}; use itertools::Itertools; -use proc_macro::TokenStream; -use proc_macro2::Span; +use proc_macro2::{Span, TokenStream}; use quote::quote; use syn::{ punctuated::Punctuated, Data, DataEnum, DataStruct, DeriveInput, Expr, Fields, FieldsNamed, @@ -19,7 +18,7 @@ mod field; use crate::field::Field; fn try_message(input: TokenStream) -> Result { - let input: DeriveInput = syn::parse(input)?; + let input: DeriveInput = syn::parse2(input)?; let ident = input.ident; @@ -253,12 +252,12 @@ fn try_message(input: TokenStream) -> Result { } #[proc_macro_derive(Message, attributes(prost))] -pub fn message(input: TokenStream) -> TokenStream { - try_message(input).unwrap() +pub fn message(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + try_message(input.into()).unwrap().into() } fn try_enumeration(input: TokenStream) -> Result { - let input: DeriveInput = syn::parse(input)?; + let input: DeriveInput = syn::parse2(input)?; let ident = input.ident; let generics = &input.generics; @@ -363,12 +362,12 @@ fn try_enumeration(input: TokenStream) -> Result { } #[proc_macro_derive(Enumeration, attributes(prost))] -pub fn enumeration(input: TokenStream) -> TokenStream { - try_enumeration(input).unwrap() +pub fn enumeration(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + try_enumeration(input.into()).unwrap().into() } fn try_oneof(input: TokenStream) -> Result { - let input: DeriveInput = syn::parse(input)?; + let input: DeriveInput = syn::parse2(input)?; let ident = input.ident; @@ -412,12 +411,8 @@ fn try_oneof(input: TokenStream) -> Result { } } - if let Some((invalid_variant, _)) = fields.iter().find(|(_, field)| field.tags().len() > 1) { - bail!( - "invalid oneof variant {}::{}: oneof variants may only have a single tag", - ident, - invalid_variant - ); + if fields.iter().any(|(_, field)| field.tags().len() > 1) { + panic!("variant with multiple tags"); // Not clear if this is possible, but good to be safe } if let Some((duplicate_tag, _)) = fields .iter() @@ -520,6 +515,42 @@ fn try_oneof(input: TokenStream) -> Result { } #[proc_macro_derive(Oneof, attributes(prost))] -pub fn oneof(input: TokenStream) -> TokenStream { - try_oneof(input).unwrap() +pub fn oneof(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + try_oneof(input.into()).unwrap().into() +} + +#[cfg(test)] +mod test { + use crate::{try_message, try_oneof}; + use quote::quote; + + #[test] + fn test_rejects_colliding_message_fields() { + let output = try_message(quote!( + struct Invalid { + #[prost(bool, tag = "1")] + a: bool, + #[prost(oneof = "super::Whatever", tags = "4, 5, 1")] + b: Option, + } + ).into()); + assert!(output.is_err()); + assert_eq!(output.unwrap_err().to_string(), + "message Invalid has duplicate tag 1"); + } + + #[test] + fn test_rejects_colliding_oneof_variants() { + let output = try_oneof(quote!( + pub enum Invalid { + #[prost(bool, tag = "1")] + A(bool), + #[prost(bool, tag = "1")] + B(bool), + } + ).into()); + assert!(output.is_err()); + assert_eq!(output.unwrap_err().to_string(), + "invalid oneof Invalid: multiple variants have tag 1"); + } }