Skip to content

Commit

Permalink
protoparse: message-set extensions must be optional (#489)
Browse files Browse the repository at this point in the history
  • Loading branch information
jhump authored Feb 5, 2022
1 parent 4ced19e commit 65315df
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 53 deletions.
58 changes: 57 additions & 1 deletion desc/protoparse/linker.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func (l *linker) linkFiles() (map[string]*desc.FileDescriptor, error) {
}
// we should now have any message_set_wire_format options parsed
// and can do further validation on tag ranges
if err := checkExtensionsInFile(fd, r); err != nil {
if err := l.checkExtensionsInFile(fd, r); err != nil {
return nil, err
}
}
Expand Down Expand Up @@ -999,3 +999,59 @@ func (l *linker) checkForUnusedImports(filename string) {
}
}
}

func (l *linker) checkExtensionsInFile(fd *desc.FileDescriptor, res *parseResult) error {
for _, fld := range fd.GetExtensions() {
if err := l.checkExtension(fld, res); err != nil {
return err
}
}
for _, md := range fd.GetMessageTypes() {
if err := l.checkExtensionsInMessage(md, res); err != nil {
return err
}
}
return nil
}

func (l *linker) checkExtensionsInMessage(md *desc.MessageDescriptor, res *parseResult) error {
for _, fld := range md.GetNestedExtensions() {
if err := l.checkExtension(fld, res); err != nil {
return err
}
}
for _, nmd := range md.GetNestedMessageTypes() {
if err := l.checkExtensionsInMessage(nmd, res); err != nil {
return err
}
}
return nil
}

func (l *linker) checkExtension(fld *desc.FieldDescriptor, res *parseResult) error {
// NB: It's a little gross that we don't enforce these in validateBasic().
// But requires some minimal linking to resolve the extendee, so we can
// interrogate its descriptor.
if fld.GetOwner().GetMessageOptions().GetMessageSetWireFormat() {
// Message set wire format requires that all extensions be messages
// themselves (no scalar extensions)
if fld.GetType() != dpb.FieldDescriptorProto_TYPE_MESSAGE {
pos := res.getFieldNode(fld.AsFieldDescriptorProto()).FieldType().Start()
return l.errs.handleErrorWithPos(pos, "messages with message-set wire format cannot contain scalar extensions, only messages")
}
if fld.IsRepeated() {
pos := res.getFieldNode(fld.AsFieldDescriptorProto()).FieldLabel().Start()
return l.errs.handleErrorWithPos(pos, "messages with message-set wire format cannot contain repeated extensions, only optional")
}
} else {
// In validateBasic() we just made sure these were within bounds for any message. But
// now that things are linked, we can check if the extendee is messageset wire format
// and, if not, enforce tighter limit.
if fld.GetNumber() > internal.MaxNormalTag {
pos := res.getFieldNode(fld.AsFieldDescriptorProto()).FieldTag().Start()
return l.errs.handleErrorWithPos(pos, "tag number %d is higher than max allowed tag number (%d)", fld.GetNumber(), internal.MaxNormalTag)
}
}

return nil
}
6 changes: 6 additions & 0 deletions desc/protoparse/linker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,12 @@ func TestLinkerValidation(t *testing.T) {
},
"", // should succeed
},
{
map[string]string{
"foo.proto": "message Foo { option message_set_wire_format = true; extensions 1 to 100; } extend Foo { repeated Foo bar = 1; }",
},
"foo.proto:1:90: messages with message-set wire format cannot contain repeated extensions, only optional",
},
{
map[string]string{
"foo.proto": "message Foo { extensions 1 to max; } extend Foo { optional int32 bar = 536870912; }",
Expand Down
52 changes: 0 additions & 52 deletions desc/protoparse/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -817,58 +817,6 @@ func checkTag(pos *SourcePos, v uint64, maxTag int32) error {
return nil
}

func checkExtensionsInFile(fd *desc.FileDescriptor, res *parseResult) error {
for _, fld := range fd.GetExtensions() {
if err := checkExtension(fld, res); err != nil {
return err
}
}
for _, md := range fd.GetMessageTypes() {
if err := checkExtensionsInMessage(md, res); err != nil {
return err
}
}
return nil
}

func checkExtensionsInMessage(md *desc.MessageDescriptor, res *parseResult) error {
for _, fld := range md.GetNestedExtensions() {
if err := checkExtension(fld, res); err != nil {
return err
}
}
for _, nmd := range md.GetNestedMessageTypes() {
if err := checkExtensionsInMessage(nmd, res); err != nil {
return err
}
}
return nil
}

func checkExtension(fld *desc.FieldDescriptor, res *parseResult) error {
// NB: It's a little gross that we don't enforce these in validateBasic().
// But requires some minimal linking to resolve the extendee, so we can
// interrogate its descriptor.
if fld.GetOwner().GetMessageOptions().GetMessageSetWireFormat() {
// Message set wire format requires that all extensions be messages
// themselves (no scalar extensions)
if fld.GetType() != dpb.FieldDescriptorProto_TYPE_MESSAGE {
pos := res.getFieldNode(fld.AsFieldDescriptorProto()).FieldType().Start()
return errorWithPos(pos, "messages with message-set wire format cannot contain scalar extensions, only messages")
}
} else {
// In validateBasic() we just made sure these were within bounds for any message. But
// now that things are linked, we can check if the extendee is messageset wire format
// and, if not, enforce tighter limit.
if fld.GetNumber() > internal.MaxNormalTag {
pos := res.getFieldNode(fld.AsFieldDescriptorProto()).FieldTag().Start()
return errorWithPos(pos, "tag number %d is higher than max allowed tag number (%d)", fld.GetNumber(), internal.MaxNormalTag)
}
}

return nil
}

func aggToString(agg []*ast.MessageFieldNode, buf *bytes.Buffer) {
buf.WriteString("{")
for _, a := range agg {
Expand Down

0 comments on commit 65315df

Please sign in to comment.