Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
elias-orijtech committed Nov 30, 2023
1 parent dbf8abb commit a538009
Show file tree
Hide file tree
Showing 5 changed files with 351 additions and 239 deletions.
198 changes: 125 additions & 73 deletions features/zeropb/zeropb.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,134 +43,191 @@ func (g zeropbFeature) generateMessage(f *protogen.File, m *protogen.Message) {

func (g zeropbFeature) generateMarshal(m *protogen.Message) {
g.gen.P("func (x *", m.GoIdent, ") MarshalZeroPB(buf []byte) (n int, err error) {")
g.gen.P("defer func() {")
g.gen.P(" if e := recover(); e != nil {")
g.gen.P(" err = ", errorsPackage.Ident("New"), `("buffer overflow")`)
g.gen.P(" }")
g.gen.P("}()")
g.gen.P(" defer func() {")
g.gen.P(" if e := recover(); e != nil {")
g.gen.P(" err = ", errorsPackage.Ident("New"), `("buffer overflow")`)
g.gen.P(" }")
g.gen.P(" }()")
g.gen.P(" b := ", runtimePackage.Ident("NewBuffer"), "(buf)")
g.gen.P(" x.marshalZeroPB(b, b.Alloc(", structSize(m), "))")
g.gen.P(" return int(b.Allocated()), nil")
g.gen.P("}")
g.gen.P()
g.gen.P("func (x *", m.GoIdent, ") marshalZeroPB(b *", runtimePackage.Ident("Buffer"), ", buf ", runtimePackage.Ident("Allocation"), ") {")
g.gen.P(" var n uint16")
g.gen.P(" _ = n")
for _, f := range m.Fields {
g.generateMarshalField(f)
}
g.gen.P("return n, nil")
g.gen.P("}")
}

const (
sliceSize = 2 * 2
segmentHeaderSize = 1 + 1 + 2
)

func structSize(m *protogen.Message) int {
n := 0
for _, f := range m.Fields {
d := f.Desc
switch {
case d.IsList(), d.IsMap():
n += sliceSize
default:
n += fieldSize(f)
}
}
return n
}

func fieldSize(f *protogen.Field) int {
d := f.Desc
switch d.Kind() {
case protoreflect.FloatKind:
return 4
case protoreflect.DoubleKind:
return 8
case protoreflect.Sfixed32Kind, protoreflect.Fixed32Kind, protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Uint32Kind, protoreflect.EnumKind:
return 4
case protoreflect.Sfixed64Kind, protoreflect.Fixed64Kind, protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Uint64Kind:
return 8
case protoreflect.BoolKind:
return 4
case protoreflect.StringKind, protoreflect.BytesKind:
return sliceSize
case protoreflect.MessageKind:
return structSize(f.Message)
}
return 0
}

func (g zeropbFeature) generateMarshalField(f *protogen.Field) {
d := f.Desc
switch {
case d.IsList():
g.gen.P("len_", d.Index(), " := uint16(len(x.", f.GoName, "))")
g.gen.P("if len(x.", f.GoName, ") != int(len_", d.Index(), ") {")
g.gen.P(" return n, ", errorsPackage.Ident("New"), `("field `, f.GoName, ` is too long")`)
g.gen.P("}")
g.gen.P(binaryPackage.Ident("LittleEndian"), ".PutUint16(buf[n:], len_", d.Index(), ")")
g.gen.P("n += 2")
g.gen.P("for _, e := range x.", f.GoName, " {")
g.generateMarshalPrimitive(d, "e")
g.gen.P("buf_", d.Index(), " := b.AllocRel(len(x.", f.GoName, ")*", fieldSize(f), " + ", segmentHeaderSize, ", buf, n, uint16(len(x.", f.GoName, ")))")
g.gen.P("n += 4")
g.gen.P("{")
g.gen.P(" var n uint16")
g.gen.P(" buf := buf_", d.Index())
// Write a segment header.
g.gen.P(" buf.Buf[0] = byte(len(x.", f.GoName, "))")
g.gen.P(" buf.Buf[1] = byte(len(x.", f.GoName, "))")
g.gen.P(binaryPackage.Ident("LittleEndian"), ".PutUint16(buf.Buf[2:], 0)")
g.gen.P(" n += 4")
g.gen.P(" for _, e := range x.", f.GoName, " {")
g.generateMarshalPrimitive(f, "e")
g.gen.P(" }")
g.gen.P("}")
case d.IsMap():
g.gen.P("len_", d.Index(), " := uint16(len(x.", f.GoName, "))")
g.gen.P("if len(x.", f.GoName, ") != int(len_", d.Index(), ") {")
g.gen.P(" return n, ", errorsPackage.Ident("New"), `("field `, f.GoName, ` is too long")`)
g.gen.P("}")
g.gen.P("binary.LittleEndian.PutUint16(buf[n:], len_", d.Index(), ")")
g.gen.P("n += 2")
g.gen.P("for k, v := range x.", f.GoName, " {")
g.generateMarshalPrimitive(d.MapKey(), "k")
g.generateMarshalPrimitive(d.MapValue(), "v")
sz := fieldSize(f.Message.Fields[0]) + fieldSize(f.Message.Fields[1])
g.gen.P("buf_", d.Index(), " := b.AllocRel(len(x.", f.GoName, ")*", sz, ", buf, n, uint16(len(x."+f.GoName+")))")
g.gen.P("n += 4")
g.gen.P("{")
g.gen.P(" var n uint16")
g.gen.P(" buf := buf_", d.Index())
g.gen.P(" for k, v := range x.", f.GoName, " {")
g.generateMarshalPrimitive(f.Message.Fields[0], "k")
g.generateMarshalPrimitive(f.Message.Fields[1], "v")
g.gen.P(" }")
g.gen.P("}")
case d.ContainingOneof() != nil:
g.gen.P("// TODO: field ", f.GoName)
return
default:
g.generateMarshalPrimitive(d, "x."+f.GoName)
g.generateMarshalPrimitive(f, "x."+f.GoName)
}
}

func (g zeropbFeature) generateMarshalPrimitive(d protoreflect.FieldDescriptor, name string) {
func (g zeropbFeature) generateMarshalPrimitive(f *protogen.Field, name string) {
d := f.Desc
switch d.Kind() {
case protoreflect.FloatKind:
g.gen.P("binary.LittleEndian.PutUint32(buf[n:], ", mathPackage.Ident("Float32bits"), "(", name, "))")
g.gen.P("binary.LittleEndian.PutUint32(buf.Buf[n:], ", mathPackage.Ident("Float32bits"), "(", name, "))")
g.gen.P("n += 4")
case protoreflect.DoubleKind:
g.gen.P("binary.LittleEndian.PutUint64(buf[n:], ", mathPackage.Ident("Float64bits"), "(", name, "))")
g.gen.P("binary.LittleEndian.PutUint64(buf.Buf[n:], ", mathPackage.Ident("Float64bits"), "(", name, "))")
g.gen.P("n += 8")
case protoreflect.Sfixed32Kind, protoreflect.Fixed32Kind, protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Uint32Kind, protoreflect.EnumKind:
g.gen.P("binary.LittleEndian.PutUint32(buf[n:], uint32(", name, "))")
g.gen.P("binary.LittleEndian.PutUint32(buf.Buf[n:], uint32(", name, "))")
g.gen.P("n += 4")
case protoreflect.Sfixed64Kind, protoreflect.Fixed64Kind, protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Uint64Kind:
g.gen.P("binary.LittleEndian.PutUint64(buf[n:], uint64(", name, "))")
g.gen.P("binary.LittleEndian.PutUint64(buf.Buf[n:], uint64(", name, "))")
g.gen.P("n += 8")
case protoreflect.BoolKind:
g.gen.P("bool_", d.Index(), " := uint32(0)")
g.gen.P("if ", name, " {")
g.gen.P(" bool_", d.Index(), " = 1")
g.gen.P("}")
g.gen.P("binary.LittleEndian.PutUint32(buf[n:], bool_", d.Index(), ")")
g.gen.P("binary.LittleEndian.PutUint32(buf.Buf[n:], bool_", d.Index(), ")")
g.gen.P("n += 4")
case protoreflect.StringKind, protoreflect.BytesKind:
g.gen.P("len_", d.Index(), " := uint16(len(", name, "))")
g.gen.P("if len(", name, ") != int(len_", d.Index(), ") {")
g.gen.P(" return n, ", errorsPackage.Ident("New"), `("field `, name, ` is too long")`)
g.gen.P("}")
g.gen.P("binary.LittleEndian.PutUint16(buf[n:], len_", d.Index(), ")")
g.gen.P("n += 2")
// Reslice buf to convert a truncated write into a buffer overflow error.
g.gen.P("copy(buf[n:n+len(", name, ")], ", name, ")")
g.gen.P("n += len(", name, ")")
g.gen.P("buf_", d.Index(), " := b.AllocRel(len(", name, "), buf, n, uint16(len(", name, ")))")
g.gen.P("n += 4")
g.gen.P("copy(buf_", d.Index(), ".Buf, ", name, ")")
case protoreflect.MessageKind:
g.gen.P("n_", d.Index(), ", err := ", name, ".MarshalZeroPB(buf[n:])")
g.gen.P("n += n_", d.Index())
g.gen.P("if err != nil {")
g.gen.P(" return n, err")
g.gen.P("}")
g.gen.P(name, ".marshalZeroPB(b, buf.Slice(n))")
g.gen.P("n += ", structSize(f.Message))
default:
g.gen.P("// TODO: field ", name)
g.gen.P("_ = ", name)
}
}

func (g zeropbFeature) generateUnmarshal(m *protogen.Message) {
g.gen.P("func (x *", m.GoIdent, ") UnmarshalZeroPB(buf []byte) (n int, err error) {")
g.gen.P("defer func() {")
g.gen.P(" if e := recover(); e != nil {")
g.gen.P(" err = ", errorsPackage.Ident("New"), `("buffer underflow")`)
g.gen.P(" }")
g.gen.P("}()")
g.gen.P("func (x *", m.GoIdent, ") UnmarshalZeroPB(buf []byte) (err error) {")
g.gen.P(" defer func() {")
g.gen.P(" if e := recover(); e != nil {")
g.gen.P(" err = ", errorsPackage.Ident("New"), `("buffer underflow")`)
g.gen.P(" }")
g.gen.P(" }()")
g.gen.P(" x.unmarshalZeroPB(buf, 0)")
g.gen.P(" return nil")
g.gen.P("}")
g.gen.P()
g.gen.P("func (x *", m.GoIdent, ") unmarshalZeroPB(buf []byte, n uint16) {")
for _, f := range m.Fields {
g.generateUnmarshalField(f)
}
g.gen.P("return n, nil")
g.gen.P("}")
}

func (g zeropbFeature) generateUnmarshalField(f *protogen.Field) {
d := f.Desc
switch {
case d.IsList():
g.gen.P("len_", d.Index(), " := int(binary.LittleEndian.Uint16(buf[n:]))")
g.gen.P("n += 2")
g.gen.P("n_", d.Index(), ", len_", d.Index(), " := ", runtimePackage.Ident("ReadSlice"), "(buf, n)")
g.gen.P("n += 4")
typ, pointer := protoc.FieldGoType(g.gen, f)
if pointer {
typ = "*" + typ
}
g.gen.P("x.", f.GoName, " = make(", typ, ", len_", d.Index(), ")")
g.gen.P("for i := range x.", f.GoName, "{")
// Skip segment header.
g.gen.P("n +=", segmentHeaderSize)
g.gen.P("{")
g.gen.P(" n := n_", d.Index())
g.gen.P(" for i := range x.", f.GoName, "{")
g.generateUnmarshalPrimitive(f, "x."+f.GoName+"[i]")
g.gen.P(" }")
g.gen.P("}")
case d.IsMap():
g.gen.P("len_", d.Index(), " := int(", binaryPackage.Ident("LittleEndian"), ".Uint16(buf[n:]))")
g.gen.P("n += 2")
g.gen.P("n_", d.Index(), ", len_", d.Index(), " := ", runtimePackage.Ident("ReadSlice"), "(buf, n)")
g.gen.P("n += 4")
typ, _ := protoc.FieldGoType(g.gen, f)
g.gen.P("x.", f.GoName, " = make(", typ, ", len_", d.Index(), ")")
keyType, _ := protoc.FieldGoType(g.gen, f.Message.Fields[0])
valType, _ := protoc.FieldGoType(g.gen, f.Message.Fields[1])
g.gen.P("for i := 0; i < len_", d.Index(), "; i++ {")
g.gen.P("var k ", keyType)
g.gen.P("var v ", valType)
g.gen.P("{")
g.gen.P(" n := n_", d.Index())
g.gen.P(" for i := uint16(0); i < len_", d.Index(), "; i++ {")
g.gen.P(" var k ", keyType)
g.gen.P(" var v ", valType)
g.generateUnmarshalPrimitive(f.Message.Fields[0], "k")
g.generateUnmarshalPrimitive(f.Message.Fields[1], "v")
g.gen.P(" x.", f.GoName, "[k] = v")
g.gen.P(" x.", f.GoName, "[k] = v")
g.gen.P(" }")
g.gen.P("}")
case d.ContainingOneof() != nil:
g.gen.P("// TODO: field ", f.GoName)
Expand Down Expand Up @@ -211,23 +268,18 @@ func (g zeropbFeature) generateUnmarshalPrimitive(f *protogen.Field, name string
g.gen.P("}")
g.gen.P("n += 4")
case protoreflect.StringKind:
g.gen.P("len_", d.Index(), " := int(", binaryPackage.Ident("LittleEndian"), ".Uint16(buf[n:]))")
g.gen.P("n += 2")
g.gen.P(name, " = string(buf[n:n+len_", d.Index(), "])")
g.gen.P("n += len_", d.Index())
g.gen.P("n_", d.Index(), ", len_", d.Index(), " := ", runtimePackage.Ident("ReadSlice"), "(buf, n)")
g.gen.P("n += 4")
g.gen.P(name, " = string(buf[n_", d.Index(), ":n_", d.Index(), "+len_", d.Index(), "])")
case protoreflect.BytesKind:
g.gen.P("len_", d.Index(), " := int(", binaryPackage.Ident("LittleEndian"), ".Uint16(buf[n:]))")
g.gen.P("n += 2")
g.gen.P(name, " = append([]byte{}, buf[n:n+len_", d.Index(), "]...)")
g.gen.P("n += len_", d.Index())
g.gen.P("n_", d.Index(), ", len_", d.Index(), " := ", runtimePackage.Ident("ReadSlice"), "(buf, n)")
g.gen.P("n += 4")
g.gen.P(name, " = append([]byte{}, buf[n_", d.Index(), ":n_", d.Index(), "+len_", d.Index(), "]...)")
case protoreflect.MessageKind:
typ := g.gen.QualifiedGoIdent(f.Message.GoIdent)
g.gen.P(name, " = new(", typ, ")")
g.gen.P("n_", d.Index(), ", err := ", name, ".UnmarshalZeroPB(buf[n:])")
g.gen.P("n += n_", d.Index())
g.gen.P("if err != nil {")
g.gen.P(" return n, err")
g.gen.P("}")
g.gen.P(name, ".unmarshalZeroPB(buf, n)")
g.gen.P("n += ", structSize(f.Message))
default:
g.gen.P("// TODO: field ", name)
g.gen.P("_ = ", name)
Expand Down
67 changes: 64 additions & 3 deletions runtime/zeropb/zeropb.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,67 @@
package zeropb

type Slice[T any] struct {
offset int16
length uint16
import (
"encoding/binary"
"fmt"
"math"
)

type Buffer struct {
buf []byte
allocated uint16
}

type Allocation struct {
Buf []byte
offset uint16
}

func NewBuffer(b []byte) *Buffer {
if len(b) > math.MaxUint16 {
b = b[:math.MaxUint16]
}
return &Buffer{
buf: b,
}
}

func (b *Buffer) Alloc(n int) Allocation {
n16 := uint16(n)
if int(n16) != n {
panic(fmt.Errorf("allocation %d too large", n))
}
a := Allocation{
Buf: b.buf[:n16],
offset: b.allocated,
}
b.buf = b.buf[n16:]
b.allocated += n16
return a
}

func (b *Buffer) AllocRel(n int, dst Allocation, offset, len uint16) Allocation {
a := b.Alloc(n)
bo := binary.LittleEndian
// Write relative offset and len.
bo.PutUint16(dst.Buf[offset:], a.offset-dst.offset-offset)
bo.PutUint16(dst.Buf[offset+2:], len)
return a
}

func (b *Buffer) Allocated() uint16 {
return b.allocated
}

func (a Allocation) Slice(offset uint16) Allocation {
return Allocation{
Buf: a.Buf[offset:],
offset: a.offset + offset,
}
}

func ReadSlice(buf []byte, offset uint16) (off, len uint16) {
bo := binary.LittleEndian
off = offset + bo.Uint16(buf[offset:])
len = bo.Uint16(buf[offset+2:])
return
}
Loading

0 comments on commit a538009

Please sign in to comment.