Skip to content

Commit

Permalink
switch to io.Writer
Browse files Browse the repository at this point in the history
  • Loading branch information
aaronc committed Sep 8, 2022
1 parent 8c0619a commit fb2cba3
Show file tree
Hide file tree
Showing 10 changed files with 108 additions and 62 deletions.
4 changes: 2 additions & 2 deletions codec/v2/stablejson/duration.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
package stablejson

import (
"strings"
io "io"

"google.golang.org/protobuf/reflect/protoreflect"
)

const ()

func marshalDuration(writer *strings.Builder, message protoreflect.Message) error {
func marshalDuration(writer io.Writer, message protoreflect.Message) error {
return nil
}
3 changes: 2 additions & 1 deletion codec/v2/stablejson/fieldmask.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package stablejson

import (
"fmt"
io "io"
"strings"

"google.golang.org/protobuf/reflect/protoreflect"
Expand All @@ -11,7 +12,7 @@ const (
pathsName protoreflect.Name = "paths"
)

func marshalFieldMask(writer *strings.Builder, value protoreflect.Message) error {
func marshalFieldMask(writer io.Writer, value protoreflect.Message) error {
field := value.Descriptor().Fields().ByName(pathsName)
if field == nil {
return fmt.Errorf("expected to find field %s", pathsName)
Expand Down
14 changes: 8 additions & 6 deletions codec/v2/stablejson/float.go
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
package stablejson

import (
"io"
"math"
"strconv"
"strings"
)

func marshalFloat(writer *strings.Builder, x float64) {
func marshalFloat(writer io.Writer, x float64) error {
// PROTO3 SPEC:
// JSON value will be a number or one of the special string values "NaN", "Infinity", and "-Infinity".
// Either numbers or strings are accepted. Exponent notation is also accepted.
// -0 is considered equivalent to 0.
var err error
if math.IsInf(x, -1) {
writer.WriteString("-Infinity")
_, err = writer.Write([]byte("-Infinity"))
} else if math.IsInf(x, 1) {
writer.WriteString("Infinity")
_, err = writer.Write([]byte("Infinity"))
} else if math.IsNaN(x) {
writer.WriteString("NaN")
_, err = writer.Write([]byte("NaN"))
} else {
writer.WriteString(strconv.FormatFloat(x, 'f', -1, 64))
_, err = writer.Write([]byte(strconv.FormatFloat(x, 'f', -1, 64)))
}
return err
}
57 changes: 42 additions & 15 deletions codec/v2/stablejson/marshal.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package stablejson

import (
"bytes"
"fmt"
"strings"
"io"

"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protopath"
Expand All @@ -22,10 +23,15 @@ type MarshalOptions struct {
}

func (opts MarshalOptions) Marshal(message proto.Message) ([]byte, error) {
writer := &strings.Builder{}
buf := &bytes.Buffer{}
err := opts.MarshalTo(message, buf)
return buf.Bytes(), err
}

func (opts MarshalOptions) MarshalTo(message proto.Message, writer io.Writer) error {
firstStack := []bool{true}
skipNext := false
err := protorange.Options{
return protorange.Options{
Stable: true,
}.Range(message.ProtoReflect(),
// push
Expand All @@ -37,7 +43,10 @@ func (opts MarshalOptions) Marshal(message proto.Message) ([]byte, error) {

// Starting printing the value.
if !firstStack[len(firstStack)-1] {
writer.WriteString(",")
_, err := writer.Write([]byte(","))
if err != nil {
return err
}
}
firstStack[len(firstStack)-1] = false

Expand All @@ -48,21 +57,27 @@ func (opts MarshalOptions) Marshal(message proto.Message) ([]byte, error) {
switch last.Step.Kind() {
case protopath.FieldAccessStep:
fd = last.Step.FieldDescriptor()
_, _ = fmt.Fprintf(writer, "%q:", fd.Name())
_, err := fmt.Fprintf(writer, "%q:", fd.Name())
if err != nil {
return err
}

case protopath.ListIndexStep:
fd = beforeLast.Step.FieldDescriptor() // lists always appear in the context of a repeated field

case protopath.MapIndexStep:
fd = beforeLast.Step.FieldDescriptor() // maps always appear in the context of a repeated field
_, _ = fmt.Fprintf(writer, "%q:", last.Step.MapIndex().String())
_, err := fmt.Fprintf(writer, "%q:", last.Step.MapIndex().String())
if err != nil {
return err
}

case protopath.AnyExpandStep:
_, _ = fmt.Fprintf(writer, `"@type":%q`, last.Value.Message().Descriptor().FullName())
return nil
_, err := fmt.Fprintf(writer, `"@type":%q`, last.Value.Message().Descriptor().FullName())
return err

case protopath.UnknownAccessStep:
writer.WriteString("?:")
return fmt.Errorf("unexpected %s", protopath.UnknownAccessStep)
}

switch value := last.Value.Interface().(type) {
Expand All @@ -79,23 +94,36 @@ func (opts MarshalOptions) Marshal(message proto.Message) ([]byte, error) {

firstStack = append(firstStack, true)
case protoreflect.List:
writer.WriteString("[")
_, err := writer.Write([]byte("["))
if err != nil {
return err
}
firstStack = append(firstStack, true)
case protoreflect.Map:
_, _ = fmt.Fprintf(writer, "{")
_, err := fmt.Fprintf(writer, "{")
if err != nil {
return err
}
firstStack = append(firstStack, true)
case protoreflect.EnumNumber:
var ev protoreflect.EnumValueDescriptor
if fd != nil {
ev = fd.Enum().Values().ByNumber(value)
}
var err error
if ev != nil {
_, _ = fmt.Fprintf(writer, "%q", ev.Name())
_, err = fmt.Fprintf(writer, "%q", ev.Name())
} else {
_, _ = fmt.Fprintf(writer, "%v", value)
_, err = fmt.Fprintf(writer, "%v", value)
}
if err != nil {
return err
}
case string:
_, _ = fmt.Fprintf(writer, "%q", value)
_, err := fmt.Fprintf(writer, "%q", value)
if err != nil {
return err
}
default:
return opts.marshalScalar(writer, value)
}
Expand All @@ -119,5 +147,4 @@ func (opts MarshalOptions) Marshal(message proto.Message) ([]byte, error) {
return nil
},
)
return []byte(writer.String()), err
}
13 changes: 8 additions & 5 deletions codec/v2/stablejson/marshal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,17 @@ func TestStableJSON(t *testing.T) {
Uint32Value: &wrapperspb.UInt32Value{Value: 37492},
Uint64Value: &wrapperspb.UInt64Value{Value: 1892409137358391},
FieldMask: &fieldmaskpb.FieldMask{Paths: []string{"a.b", "a.c", "b"}},
ListValue: &structpb.ListValue{Values: nil},
Value: &structpb.Value{},
NullValue: structpb.NullValue_NULL_VALUE,
Empty: &emptypb.Empty{},
ListValue: &structpb.ListValue{Values: []*structpb.Value{
structpb.NewNumberValue(1.1),
structpb.NewStringValue("qrs"),
}},
Value: &structpb.Value{},
NullValue: structpb.NullValue_NULL_VALUE,
Empty: &emptypb.Empty{},
}
bz, err := stablejson.Marshal(msg)
assert.NilError(t, err)
assert.Equal(t,
`{}`,
`{"message":{"foo":"test"},"enum":"ONE","str_map":{"bar":"def","foo":"abc"},"int32_map":{"-3":"xyz","0":"abc","10":"qrs"},"bool_map":{"false":"F","true":"T"},"repeated":[3,-7,2,6,4],"str":"abcxyz\"foo\"def","bool":true,"bytes":"AAECAw==","i32":-15,"f32":1001,"u32":1200,"si32":-376,"sf32":-1000,"i64":"14578294827584932","f64":"9572348124213523654","u64":"4759492485","si64":"-59268425823934","sf64":"-659101379604211154","float":1,"double":5235.2941,"any":{"@type":"testpb.ABitOfEverything","str":"abc","i32":10},"timestamp":},"duration":},"struct":{"bool":true,"nested struct":{"a":"abc"},"null":null,"num":3.76,"str":"abc","struct list":["xyz",false,-9]}}},"bool_value":true},"bytes_value":"AAECAw=="},"double_value":1.324},"float_value":-1},"int32_value":10},"int64_value":"-376923457"},"string_value":"gfedcba"},"uint32_value":37492},"uint64_value":"1892409137358391"},"field_mask":"a.b,a.c,b"]},"list_value":[1.1,"qrs"]]},"value":}}}`,
string(bz))
}
12 changes: 6 additions & 6 deletions codec/v2/stablejson/message.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package stablejson

import (
"strings"
"io"

"google.golang.org/protobuf/reflect/protoreflect"
)

func (opts MarshalOptions) marshalMessage(writer *strings.Builder, value protoreflect.Message) (continueRange bool, error error) {
func (opts MarshalOptions) marshalMessage(writer io.Writer, value protoreflect.Message) (continueRange bool, error error) {
switch value.Descriptor().FullName() {
case timestampFullName:
return false, marshalTimestamp(writer, value)
Expand All @@ -19,17 +19,17 @@ func (opts MarshalOptions) marshalMessage(writer *strings.Builder, value protore
case valueFullName:
return false, marshalValue(writer, value)
case nullValueFullName:
writer.WriteString("null")
return false, nil
_, err := writer.Write([]byte("null"))
return false, err
case boolValueFullName, int32ValueFullName, int64ValueFullName, uint32ValueFullName, uint64ValueFullName,
stringValueFullName, bytesValueFullName, floatValueFullName, doubleValueFullName:
return false, opts.marshalWrapper(writer, value)
case fieldMaskFullName:
return false, marshalFieldMask(writer, value)
}

writer.WriteString("{")
return true, nil
_, err := writer.Write([]byte("{"))
return true, err
}

const (
Expand Down
10 changes: 5 additions & 5 deletions codec/v2/stablejson/scalar.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,21 @@ package stablejson
import (
"encoding/base64"
"fmt"
"strings"
"io"
)

func (opts MarshalOptions) marshalScalar(writer *strings.Builder, value interface{}) error {
func (opts MarshalOptions) marshalScalar(writer io.Writer, value interface{}) error {
switch value := value.(type) {
case string:
_, _ = fmt.Fprintf(writer, "%q", value)
case []byte:
writer.WriteString(`"`)
_, _ = writer.Write([]byte(`"`))
if opts.HexBytes {
_, _ = fmt.Fprintf(writer, "%X", value)
} else {
writer.WriteString(base64.StdEncoding.EncodeToString(value))
_, _ = writer.Write([]byte(base64.StdEncoding.EncodeToString(value)))
}
writer.WriteString(`"`)
_, _ = writer.Write([]byte(`"`))
case bool:
_, _ = fmt.Fprintf(writer, "%t", value)
case int32:
Expand Down
49 changes: 31 additions & 18 deletions codec/v2/stablejson/struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@ package stablejson

import (
"fmt"
io "io"
"sort"
"strconv"
"strings"

"golang.org/x/exp/maps"
"google.golang.org/protobuf/reflect/protoreflect"
Expand All @@ -22,11 +21,14 @@ const (
listValueField protoreflect.Name = "list_value"
)

func marshalStruct(writer *strings.Builder, value protoreflect.Message) error {
func marshalStruct(writer io.Writer, value protoreflect.Message) error {
field := value.Descriptor().Fields().ByName(fieldsField)
m1 := value.Get(field).Map()

writer.WriteString("{")
_, err := writer.Write([]byte("{"))
if err != nil {
return err
}

m2 := map[string]protoreflect.Message{}
m1.Range(func(key protoreflect.MapKey, value protoreflect.Value) bool {
Expand All @@ -39,7 +41,10 @@ func marshalStruct(writer *strings.Builder, value protoreflect.Message) error {
first := true
for _, k := range keys {
if !first {
writer.WriteString(",")
_, err := writer.Write([]byte(","))
if err != nil {
return err
}
}

first = false
Expand All @@ -51,20 +56,27 @@ func marshalStruct(writer *strings.Builder, value protoreflect.Message) error {
}
}

writer.WriteString("}")
return nil
_, err = writer.Write([]byte("}"))
return err
}

func marshalListValue(writer *strings.Builder, value protoreflect.Message) error {
func marshalListValue(writer io.Writer, value protoreflect.Message) error {
field := value.Descriptor().Fields().ByName(valuesField)
list := value.Get(field).List()
n := list.Len()

writer.WriteString("[")
_, err := writer.Write([]byte("["))
if err != nil {
return err
}

first := true
for i := 0; i < n; i++ {
if !first {
writer.WriteString(",")
_, err = writer.Write([]byte(","))
if err != nil {
return err
}
}
first = false

Expand All @@ -73,32 +85,33 @@ func marshalListValue(writer *strings.Builder, value protoreflect.Message) error
return err
}
}
writer.WriteString("]")

return nil
_, err = writer.Write([]byte("]"))
return err
}

func marshalValue(writer *strings.Builder, value protoreflect.Message) error {
func marshalValue(writer io.Writer, value protoreflect.Message) error {
field := value.WhichOneof(value.Descriptor().Oneofs().ByName(kindOneOf))
if field == nil {
return nil
}

var err error
switch field.Name() {
case nullValueField:
writer.WriteString("null")
_, err = writer.Write([]byte("null"))
case numberValueField:
marshalFloat(writer, value.Get(field).Float())
err = marshalFloat(writer, value.Get(field).Float())
case stringValueField:
_, _ = fmt.Fprintf(writer, "%q", value.Get(field).String())
_, err = fmt.Fprintf(writer, "%q", value.Get(field).String())
case boolValueField:
writer.WriteString(strconv.FormatBool(value.Get(field).Bool()))
_, err = fmt.Fprintf(writer, "%t", value.Get(field).Bool())
case structValueField:
return marshalStruct(writer, value.Get(field).Message())
case listValueField:
return marshalListValue(writer, value.Get(field).Message())
default:
return fmt.Errorf("unexpected field in google.protobuf.Value: %v", field)
}
return nil
return err
}
Loading

0 comments on commit fb2cba3

Please sign in to comment.