From a441c7c1fa29aed6e8b92a3365a3322d0935cc11 Mon Sep 17 00:00:00 2001
From: Josh Humphries <2035234+jhump@users.noreply.github.com>
Date: Tue, 14 Nov 2023 08:03:31 -0500
Subject: [PATCH] buf curl: don't allow unrecognized fields in JSON request;
 print count of unrecognized bytes in responses with -v

---
 private/buf/bufcurl/invoker.go                | 40 ++++++++++++++++++-
 private/pkg/protoencoding/json_unmarshaler.go | 16 +++++---
 private/pkg/protoencoding/protoencoding.go    | 14 ++++++-
 3 files changed, 60 insertions(+), 10 deletions(-)

diff --git a/private/buf/bufcurl/invoker.go b/private/buf/bufcurl/invoker.go
index 53f05eebe4..20649a9291 100644
--- a/private/buf/bufcurl/invoker.go
+++ b/private/buf/bufcurl/invoker.go
@@ -290,6 +290,11 @@ func (inv *invoker) handleResponse(data []byte, msg *dynamicpb.Message) error {
 			protoencoding.JSONMarshalerWithEmitUnpopulated(),
 		)
 	}
+	unrecognized := countUnrecognized(msg.ProtoReflect())
+	if unrecognized > 0 {
+		inv.printer.Printf("Response message (%s) contained %d bytes of unrecognized fields.",
+			msg.ProtoReflect().Descriptor().FullName(), unrecognized)
+	}
 	outputBytes, err := protoencoding.NewJSONMarshaler(inv.res, jsonMarshalerOptions...).Marshal(msg)
 	if err != nil {
 		return err
@@ -437,6 +442,37 @@ func (s *streamMessageProvider) next(msg proto.Message) error {
 		}
 		return fmt.Errorf("%s at offset %d: %w", s.name, s.dec.InputOffset(), err)
 	}
-	proto.Reset(msg)
-	return protoencoding.NewJSONUnmarshaler(s.res).Unmarshal(jsonData, msg)
+	return protoencoding.NewJSONUnmarshaler(
+		s.res, protoencoding.JSONUnmarshalerWithDisallowUnknown(),
+	).Unmarshal(jsonData, msg)
+}
+
+func countUnrecognized(msg protoreflect.Message) int {
+	var count int
+	msg.Range(func(field protoreflect.FieldDescriptor, val protoreflect.Value) bool {
+		switch {
+		case field.IsMap() && isMessageKind(field.MapValue().Kind()):
+			// Note: Technically, each message entry could have had unrecognized field
+			// bytes, but they are discarded by the runtime. So we can only look at
+			// unrecognized fields in message values inside the map.
+			mapVal := val.Map()
+			mapVal.Range(func(_ protoreflect.MapKey, v protoreflect.Value) bool {
+				count += countUnrecognized(v.Message())
+				return true
+			})
+		case field.IsList() && isMessageKind(field.Kind()):
+			listVal := val.List()
+			for i, length := 0, listVal.Len(); i < length; i++ {
+				count += countUnrecognized(listVal.Get(i).Message())
+			}
+		case isMessageKind(field.Kind()):
+			count += countUnrecognized(val.Message())
+		}
+		return true
+	})
+	return count + len(msg.GetUnknown())
+}
+
+func isMessageKind(k protoreflect.Kind) bool {
+	return k == protoreflect.MessageKind || k == protoreflect.GroupKind
 }
diff --git a/private/pkg/protoencoding/json_unmarshaler.go b/private/pkg/protoencoding/json_unmarshaler.go
index 75feedef2c..0daf59e6e6 100644
--- a/private/pkg/protoencoding/json_unmarshaler.go
+++ b/private/pkg/protoencoding/json_unmarshaler.go
@@ -20,20 +20,24 @@ import (
 )
 
 type jsonUnmarshaler struct {
-	resolver Resolver
+	resolver        Resolver
+	disallowUnknown bool
 }
 
-func newJSONUnmarshaler(resolver Resolver) Unmarshaler {
-	return &jsonUnmarshaler{
+func newJSONUnmarshaler(resolver Resolver, options ...JSONUnmarshalerOption) Unmarshaler {
+	jsonUnmarshaler := &jsonUnmarshaler{
 		resolver: resolver,
 	}
+	for _, option := range options {
+		option(jsonUnmarshaler)
+	}
+	return jsonUnmarshaler
 }
 
 func (m *jsonUnmarshaler) Unmarshal(data []byte, message proto.Message) error {
 	options := protojson.UnmarshalOptions{
-		Resolver: m.resolver,
-		// TODO: make this an option
-		DiscardUnknown: true,
+		Resolver:       m.resolver,
+		DiscardUnknown: !m.disallowUnknown,
 	}
 	return options.Unmarshal(data, message)
 }
diff --git a/private/pkg/protoencoding/protoencoding.go b/private/pkg/protoencoding/protoencoding.go
index e1a76a1c9d..d15a872b5f 100644
--- a/private/pkg/protoencoding/protoencoding.go
+++ b/private/pkg/protoencoding/protoencoding.go
@@ -117,8 +117,18 @@ func NewWireUnmarshaler(resolver Resolver) Unmarshaler {
 // NewJSONUnmarshaler returns a new Unmarshaler for json.
 //
 // resolver can be nil if unknown and are only needed for extensions.
-func NewJSONUnmarshaler(resolver Resolver) Unmarshaler {
-	return newJSONUnmarshaler(resolver)
+func NewJSONUnmarshaler(resolver Resolver, options ...JSONUnmarshalerOption) Unmarshaler {
+	return newJSONUnmarshaler(resolver, options...)
+}
+
+// JSONUnmarshalerOption is an option for a new JSONUnmarshaler.
+type JSONUnmarshalerOption func(*jsonUnmarshaler)
+
+// JSONUnmarshalerWithDisallowUnknown says to disallow unrecognized fields.
+func JSONUnmarshalerWithDisallowUnknown() JSONUnmarshalerOption {
+	return func(jsonUnmarshaler *jsonUnmarshaler) {
+		jsonUnmarshaler.disallowUnknown = true
+	}
 }
 
 // NewTxtpbUnmarshaler returns a new Unmarshaler for txtpb.