diff --git a/decode.go b/decode.go index 5986d061..0b44124d 100644 --- a/decode.go +++ b/decode.go @@ -446,6 +446,26 @@ func (fnbsm FieldNameByteStringMode) valid() bool { return fnbsm >= 0 && fnbsm < maxFieldNameByteStringMode } +// UnrecognizedTagToAnyMode specifies how to decode unrecognized CBOR tag into an empty interface (any). +// Currently, recognized CBOR tag numbers are 0, 1, 2, 3, or registered by TagSet. +type UnrecognizedTagToAnyMode int + +const ( + // UnrecognizedTagNumAndContentToAny decodes CBOR tag number and tag content to cbor.Tag + // when decoding unrecognized CBOR tag into an empty interface. + UnrecognizedTagNumAndContentToAny UnrecognizedTagToAnyMode = iota + + // UnrecognizedTagContentToAny decodes only CBOR tag content (into its default type) + // when decoding unrecognized CBOR tag into an empty interface. + UnrecognizedTagContentToAny + + maxUnrecognizedTagToAny +) + +func (uttam UnrecognizedTagToAnyMode) valid() bool { + return uttam >= 0 && uttam < maxUnrecognizedTagToAny +} + // DecOptions specifies decoding options. type DecOptions struct { // DupMapKey specifies whether to enforce duplicate map key. @@ -514,6 +534,10 @@ type DecOptions struct { // FieldNameByteString specifies the behavior when decoding a CBOR byte string map key as a // Go struct field name. FieldNameByteString FieldNameByteStringMode + + // UnrecognizedTagToAny specifies how to decode unrecognized CBOR tag into an empty interface. + // Currently, recognized CBOR tag numbers are 0, 1, 2, 3, or registered by TagSet. + UnrecognizedTagToAny UnrecognizedTagToAnyMode } // DecMode returns DecMode with immutable options and no tags (safe for concurrency). @@ -637,6 +661,9 @@ func (opts DecOptions) decMode() (*decMode, error) { if !opts.FieldNameByteString.valid() { return nil, errors.New("cbor: invalid FieldNameByteString " + strconv.Itoa(int(opts.FieldNameByteString))) } + if !opts.UnrecognizedTagToAny.valid() { + return nil, errors.New("cbor: invalid UnrecognizedTagToAnyMode " + strconv.Itoa(int(opts.UnrecognizedTagToAny))) + } dm := decMode{ dupMapKey: opts.DupMapKey, timeTag: opts.TimeTag, @@ -655,6 +682,7 @@ func (opts DecOptions) decMode() (*decMode, error) { defaultByteStringType: opts.DefaultByteStringType, byteStringToString: opts.ByteStringToString, fieldNameByteString: opts.FieldNameByteString, + unrecognizedTagToAny: opts.UnrecognizedTagToAny, } return &dm, nil } @@ -724,6 +752,7 @@ type decMode struct { defaultByteStringType reflect.Type byteStringToString ByteStringToStringMode fieldNameByteString FieldNameByteStringMode + unrecognizedTagToAny UnrecognizedTagToAnyMode } var defaultDecMode, _ = DecOptions{}.decMode() @@ -748,6 +777,7 @@ func (dm *decMode) DecOptions() DecOptions { DefaultByteStringType: dm.defaultByteStringType, ByteStringToString: dm.byteStringToString, FieldNameByteString: dm.fieldNameByteString, + UnrecognizedTagToAny: dm.unrecognizedTagToAny, } } @@ -1427,6 +1457,9 @@ func (d *decoder) parse(skipSelfDescribedTag bool) (interface{}, error) { //noli if err != nil { return nil, err } + if d.dm.unrecognizedTagToAny == UnrecognizedTagContentToAny { + return content, nil + } return Tag{tagNum, content}, nil case cborTypePrimitives: _, ai, val := d.getHead() diff --git a/decode_test.go b/decode_test.go index 92cb675f..8afa9139 100644 --- a/decode_test.go +++ b/decode_test.go @@ -8380,6 +8380,192 @@ func TestUnmarshalFieldNameByteString(t *testing.T) { } } +func TestDecModeInvalidReturnTypeForEmptyInterface(t *testing.T) { + for _, tc := range []struct { + name string + opts DecOptions + wantErrorMsg string + }{ + { + name: "below range of valid modes", + opts: DecOptions{UnrecognizedTagToAny: -1}, + wantErrorMsg: "cbor: invalid UnrecognizedTagToAnyMode -1", + }, + { + name: "above range of valid modes", + opts: DecOptions{UnrecognizedTagToAny: 101}, + wantErrorMsg: "cbor: invalid UnrecognizedTagToAnyMode 101", + }, + } { + t.Run(tc.name, func(t *testing.T) { + _, err := tc.opts.DecMode() + if err == nil { + t.Errorf("DecMode() didn't return an error") + } else if err.Error() != tc.wantErrorMsg { + t.Errorf("DecMode() returned error %q, want %q", err.Error(), tc.wantErrorMsg) + } + }) + } +} + +func TestUnmarshalWithUnrecognizedTagToAnyMode(t *testing.T) { + for _, tc := range []struct { + name string + opts DecOptions + in []byte + want interface{} + }{ + { + name: "default to value of type Tag", + opts: DecOptions{}, + in: hexDecode("d8ff00"), + want: Tag{Number: uint64(255), Content: uint64(0)}, + }, + { + name: "Tag's content", + opts: DecOptions{UnrecognizedTagToAny: UnrecognizedTagContentToAny}, + in: hexDecode("d8ff00"), + want: uint64(0), + }, + } { + t.Run(tc.name, func(t *testing.T) { + dm, err := tc.opts.DecMode() + if err != nil { + t.Fatal(err) + } + + var got interface{} + if err := dm.Unmarshal(tc.in, &got); err != nil { + t.Errorf("unexpected error: %v", err) + } + + if tc.want != got { + t.Errorf("got %s, want %s", got, tc.want) + } + }) + } +} + +func TestUnmarshalWithUnrecognizedTagToAnyModeForSupportedTags(t *testing.T) { + for _, tc := range []struct { + name string + opts DecOptions + in []byte + want interface{} + }{ + { + name: "Unmarshal with tag number 0 when UnrecognizedTagContentToAny option is not set", + opts: DecOptions{}, + in: hexDecode("c074323031332d30332d32315432303a30343a30305a"), + want: time.Date(2013, 3, 21, 20, 4, 0, 0, time.UTC), + }, + { + name: "Unmarshal with tag number 0 when UnrecognizedTagContentToAny option is set", + opts: DecOptions{UnrecognizedTagToAny: UnrecognizedTagContentToAny}, + in: hexDecode("c074323031332d30332d32315432303a30343a30305a"), + want: time.Date(2013, 3, 21, 20, 4, 0, 0, time.UTC), + }, + { + name: "Unmarshal with tag number 1 when UnrecognizedTagContentToAny option is not set", + opts: DecOptions{}, + in: hexDecode("c11a514b67b0"), + want: time.Date(2013, 3, 21, 20, 4, 0, 0, time.UTC), + }, + { + name: "Unmarshal with tag number 1 when UnrecognizedTagContentToAny option is set", + opts: DecOptions{UnrecognizedTagToAny: UnrecognizedTagContentToAny}, + in: hexDecode("c11a514b67b0"), + want: time.Date(2013, 3, 21, 20, 4, 0, 0, time.UTC), + }, + { + name: "Unmarshal with tag number 2 when UnrecognizedTagContentToAny option is not set", + opts: DecOptions{}, + in: hexDecode("c249010000000000000000"), + want: bigIntOrPanic("18446744073709551616"), + }, + { + name: "Unmarshal with tag number 2 when UnrecognizedTagContentToAny option is set", + opts: DecOptions{UnrecognizedTagToAny: UnrecognizedTagContentToAny}, + in: hexDecode("c249010000000000000000"), + want: bigIntOrPanic("18446744073709551616"), + }, + { + name: "Unmarshal with tag number 3 when UnrecognizedTagContentToAny option is not set", + opts: DecOptions{}, + in: hexDecode("c349010000000000000000"), + want: bigIntOrPanic("-18446744073709551617"), + }, + { + name: "Unmarshal with tag number 3 when UnrecognizedTagContentToAny option is set", + opts: DecOptions{UnrecognizedTagToAny: UnrecognizedTagContentToAny}, + in: hexDecode("c349010000000000000000"), + want: bigIntOrPanic("-18446744073709551617"), + }, + } { + t.Run(tc.name, func(t *testing.T) { + dm, err := tc.opts.DecMode() + if err != nil { + t.Fatal(err) + } + + var got interface{} + if err := dm.Unmarshal(tc.in, &got); err != nil { + t.Errorf("unexpected error: %v", err) + } + + compareNonFloats(t, tc.in, got, tc.want) + + }) + } +} + +func TestUnmarshalWithUnrecognizedTagToAnyModeForSharedTag(t *testing.T) { + + type myInt int + typ := reflect.TypeOf(myInt(0)) + + tags := NewTagSet() + if err := tags.Add(TagOptions{EncTag: EncTagRequired, DecTag: DecTagRequired}, typ, 125); err != nil { + t.Fatalf("TagSet.Add(%s, %v) returned error %v", typ, 125, err) + } + + for _, tc := range []struct { + name string + opts DecOptions + in []byte + want interface{} + }{ + { + name: "Unmarshal with a shared tag when UnrecognizedTagContentToAny option is not set", + opts: DecOptions{}, + in: hexDecode("d9d9f7d87d01"), // 55799(125(1)) + want: myInt(1), + }, + { + name: "Unmarshal with a shared tag when UnrecognizedTagContentToAny option is set", + opts: DecOptions{UnrecognizedTagToAny: UnrecognizedTagContentToAny}, + in: hexDecode("d9d9f7d87d01"), // 55799(125(1)) + want: myInt(1), + }, + } { + t.Run(tc.name, func(t *testing.T) { + dm, err := tc.opts.DecModeWithTags(tags) + if err != nil { + t.Fatal(err) + } + + var got interface{} + + if err := dm.Unmarshal(tc.in, &got); err != nil { + t.Errorf("unexpected error: %v", err) + } + + compareNonFloats(t, tc.in, got, tc.want) + + }) + } +} + func isCBORNil(data []byte) bool { return len(data) > 0 && (data[0] == 0xf6 || data[0] == 0xf7) }