Skip to content

Commit

Permalink
Packet encoding optimization (#343)
Browse files Browse the repository at this point in the history
* Dynamically allocate buffer for writes if needed

* Remove unused net.Buffer

* Return bytes written to buffer instead of conn

* Dynamic write buffer

* Reduce double write of pk.Payload

* Use memory pool for packet encode

* Pool doesn't guarantee value between Put and Get

* Add benchmark for bufpool

* Fix issue #346

* Change default pool not to have size cap

---------

Co-authored-by: JB <28275108+mochi-co@users.noreply.github.com>
  • Loading branch information
thedevop and mochi-co authored Dec 21, 2023
1 parent 4c68238 commit c6c7c29
Show file tree
Hide file tree
Showing 3 changed files with 227 additions and 26 deletions.
81 changes: 81 additions & 0 deletions mempool/bufpool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package mempool

import (
"bytes"
"sync"
)

var bufPool = NewBuffer(0)

// GetBuffer takes a Buffer from the default buffer pool
func GetBuffer() *bytes.Buffer { return bufPool.Get() }

// PutBuffer returns Buffer to the default buffer pool
func PutBuffer(x *bytes.Buffer) { bufPool.Put(x) }

type BufferPool interface {
Get() *bytes.Buffer
Put(x *bytes.Buffer)
}

// NewBuffer returns a buffer pool. The max specify the max capacity of the Buffer the pool will
// return. If the Buffer becoomes large than max, it will no longer be returned to the pool. If
// max <= 0, no limit will be enforced.
func NewBuffer(max int) BufferPool {
if max > 0 {
return newBufferWithCap(max)
}

return newBuffer()
}

// Buffer is a Buffer pool.
type Buffer struct {
pool *sync.Pool
}

func newBuffer() *Buffer {
return &Buffer{
pool: &sync.Pool{
New: func() any { return new(bytes.Buffer) },
},
}
}

// Get a Buffer from the pool.
func (b *Buffer) Get() *bytes.Buffer {
return b.pool.Get().(*bytes.Buffer)
}

// Put the Buffer back into pool. It resets the Buffer for reuse.
func (b *Buffer) Put(x *bytes.Buffer) {
x.Reset()
b.pool.Put(x)
}

// BufferWithCap is a Buffer pool that
type BufferWithCap struct {
bp *Buffer
max int
}

func newBufferWithCap(max int) *BufferWithCap {
return &BufferWithCap{
bp: newBuffer(),
max: max,
}
}

// Get a Buffer from the pool.
func (b *BufferWithCap) Get() *bytes.Buffer {
return b.bp.Get()
}

// Put the Buffer back into the pool if the capacity doesn't exceed the limit. It resets the Buffer
// for reuse.
func (b *BufferWithCap) Put(x *bytes.Buffer) {
if x.Cap() > b.max {
return
}
b.bp.Put(x)
}
96 changes: 96 additions & 0 deletions mempool/bufpool_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package mempool

import (
"bytes"
"reflect"
"runtime/debug"
"testing"

"github.com/stretchr/testify/require"
)

func TestNewBuffer(t *testing.T) {
defer debug.SetGCPercent(debug.SetGCPercent(-1))
bp := NewBuffer(1000)
require.Equal(t, "*mempool.BufferWithCap", reflect.TypeOf(bp).String())

bp = NewBuffer(0)
require.Equal(t, "*mempool.Buffer", reflect.TypeOf(bp).String())

bp = NewBuffer(-1)
require.Equal(t, "*mempool.Buffer", reflect.TypeOf(bp).String())
}

func TestBuffer(t *testing.T) {
defer debug.SetGCPercent(debug.SetGCPercent(-1))
Size := 101

bp := NewBuffer(0)
buf := bp.Get()

for i := 0; i < Size; i++ {
buf.WriteByte('a')
}

bp.Put(buf)
buf = bp.Get()
require.Equal(t, 0, buf.Len())
}

func TestBufferWithCap(t *testing.T) {
defer debug.SetGCPercent(debug.SetGCPercent(-1))
Size := 101
bp := NewBuffer(100)
buf := bp.Get()

for i := 0; i < Size; i++ {
buf.WriteByte('a')
}

bp.Put(buf)
buf = bp.Get()
require.Equal(t, 0, buf.Len())
require.Equal(t, 0, buf.Cap())
}

func BenchmarkBufferPool(b *testing.B) {
bp := NewBuffer(0)

b.ResetTimer()
for i := 0; i < b.N; i++ {
b := bp.Get()
b.WriteString("this is a test")
bp.Put(b)
}
}

func BenchmarkBufferPoolWithCapLarger(b *testing.B) {
bp := NewBuffer(64 * 1024)

b.ResetTimer()
for i := 0; i < b.N; i++ {
b := bp.Get()
b.WriteString("this is a test")
bp.Put(b)
}
}

func BenchmarkBufferPoolWithCapLesser(b *testing.B) {
bp := NewBuffer(10)

b.ResetTimer()
for i := 0; i < b.N; i++ {
b := bp.Get()
b.WriteString("this is a test")
bp.Put(b)
}
}

func BenchmarkBufferWithoutPool(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
b := new(bytes.Buffer)
b.WriteString("this is a test")
_ = b
}
}
76 changes: 50 additions & 26 deletions packets/packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"strconv"
"strings"
"sync"

"github.com/mochi-mqtt/server/v2/mempool"
)

// All valid packet types and their packet identifiers.
Expand Down Expand Up @@ -298,7 +300,8 @@ func (s *Subscription) decode(b byte) {

// ConnectEncode encodes a connect packet.
func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error {
nb := bytes.NewBuffer([]byte{})
nb := mempool.GetBuffer()
defer mempool.PutBuffer(nb)
nb.Write(encodeBytes(pk.Connect.ProtocolName))
nb.WriteByte(pk.ProtocolVersion)

Expand All @@ -315,7 +318,8 @@ func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error {
nb.Write(encodeUint16(pk.Connect.Keepalive))

if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pb := mempool.GetBuffer()
defer mempool.PutBuffer(pb)
(&pk.Properties).Encode(pk.FixedHeader.Type, pk.Mods, pb, 0)
nb.Write(pb.Bytes())
}
Expand All @@ -324,7 +328,8 @@ func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error {

if pk.Connect.WillFlag {
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pb := mempool.GetBuffer()
defer mempool.PutBuffer(pb)
(&pk.Connect).WillProperties.Encode(WillProperties, pk.Mods, pb, 0)
nb.Write(pb.Bytes())
}
Expand Down Expand Up @@ -493,12 +498,14 @@ func (pk *Packet) ConnectValidate() Code {

// ConnackEncode encodes a Connack packet.
func (pk *Packet) ConnackEncode(buf *bytes.Buffer) error {
nb := bytes.NewBuffer([]byte{})
nb := mempool.GetBuffer()
defer mempool.PutBuffer(nb)
nb.WriteByte(encodeBool(pk.SessionPresent))
nb.WriteByte(pk.ReasonCode)

if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pb := mempool.GetBuffer()
defer mempool.PutBuffer(pb)
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+2) // +SessionPresent +ReasonCode
nb.Write(pb.Bytes())
}
Expand Down Expand Up @@ -536,12 +543,14 @@ func (pk *Packet) ConnackDecode(buf []byte) error {

// DisconnectEncode encodes a Disconnect packet.
func (pk *Packet) DisconnectEncode(buf *bytes.Buffer) error {
nb := bytes.NewBuffer([]byte{})
nb := mempool.GetBuffer()
defer mempool.PutBuffer(nb)

if pk.ProtocolVersion == 5 {
nb.WriteByte(pk.ReasonCode)

pb := bytes.NewBuffer([]byte{})
pb := mempool.GetBuffer()
defer mempool.PutBuffer(pb)
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
nb.Write(pb.Bytes())
}
Expand Down Expand Up @@ -598,7 +607,8 @@ func (pk *Packet) PingrespDecode(buf []byte) error {

// PublishEncode encodes a Publish packet.
func (pk *Packet) PublishEncode(buf *bytes.Buffer) error {
nb := bytes.NewBuffer([]byte{})
nb := mempool.GetBuffer()
defer mempool.PutBuffer(nb)

nb.Write(encodeString(pk.TopicName)) // [MQTT-3.3.2-1]

Expand All @@ -610,16 +620,16 @@ func (pk *Packet) PublishEncode(buf *bytes.Buffer) error {
}

if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pb := mempool.GetBuffer()
defer mempool.PutBuffer(pb)
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+len(pk.Payload))
nb.Write(pb.Bytes())
}

nb.Write(pk.Payload)

pk.FixedHeader.Remaining = nb.Len()
pk.FixedHeader.Remaining = nb.Len() + len(pk.Payload)
pk.FixedHeader.Encode(buf)
_, _ = nb.WriteTo(buf)
buf.Write(pk.Payload)

return nil
}
Expand Down Expand Up @@ -690,11 +700,13 @@ func (pk *Packet) PublishValidate(topicAliasMaximum uint16) Code {

// encodePubAckRelRecComp encodes a Puback, Pubrel, Pubrec, or Pubcomp packet.
func (pk *Packet) encodePubAckRelRecComp(buf *bytes.Buffer) error {
nb := bytes.NewBuffer([]byte{})
nb := mempool.GetBuffer()
defer mempool.PutBuffer(nb)
nb.Write(encodeUint16(pk.PacketID))

if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pb := mempool.GetBuffer()
defer mempool.PutBuffer(pb)
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
if pk.ReasonCode >= ErrUnspecifiedError.Code || pb.Len() > 1 {
nb.WriteByte(pk.ReasonCode)
Expand Down Expand Up @@ -831,11 +843,13 @@ func (pk *Packet) ReasonCodeValid() bool {

// SubackEncode encodes a Suback packet.
func (pk *Packet) SubackEncode(buf *bytes.Buffer) error {
nb := bytes.NewBuffer([]byte{})
nb := mempool.GetBuffer()
defer mempool.PutBuffer(nb)
nb.Write(encodeUint16(pk.PacketID))

if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pb := mempool.GetBuffer()
defer mempool.PutBuffer(pb)
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+len(pk.ReasonCodes))
nb.Write(pb.Bytes())
}
Expand Down Expand Up @@ -878,10 +892,12 @@ func (pk *Packet) SubscribeEncode(buf *bytes.Buffer) error {
return ErrProtocolViolationNoPacketID
}

nb := bytes.NewBuffer([]byte{})
nb := mempool.GetBuffer()
defer mempool.PutBuffer(nb)
nb.Write(encodeUint16(pk.PacketID))

xb := bytes.NewBuffer([]byte{}) // capture and write filters after length checks
xb := mempool.GetBuffer() // capture and write filters after length checks
defer mempool.PutBuffer(xb)
for _, opts := range pk.Filters {
xb.Write(encodeString(opts.Filter)) // [MQTT-3.8.3-1]
if pk.ProtocolVersion == 5 {
Expand All @@ -892,7 +908,8 @@ func (pk *Packet) SubscribeEncode(buf *bytes.Buffer) error {
}

if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pb := mempool.GetBuffer()
defer mempool.PutBuffer(pb)
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+xb.Len())
nb.Write(pb.Bytes())
}
Expand Down Expand Up @@ -983,11 +1000,13 @@ func (pk *Packet) SubscribeValidate() Code {

// UnsubackEncode encodes an Unsuback packet.
func (pk *Packet) UnsubackEncode(buf *bytes.Buffer) error {
nb := bytes.NewBuffer([]byte{})
nb := mempool.GetBuffer()
defer mempool.PutBuffer(nb)
nb.Write(encodeUint16(pk.PacketID))

if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pb := mempool.GetBuffer()
defer mempool.PutBuffer(pb)
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
nb.Write(pb.Bytes())
}
Expand Down Expand Up @@ -1031,16 +1050,19 @@ func (pk *Packet) UnsubscribeEncode(buf *bytes.Buffer) error {
return ErrProtocolViolationNoPacketID
}

nb := bytes.NewBuffer([]byte{})
nb := mempool.GetBuffer()
defer mempool.PutBuffer(nb)
nb.Write(encodeUint16(pk.PacketID))

xb := bytes.NewBuffer([]byte{}) // capture filters and write after length checks
xb := mempool.GetBuffer() // capture filters and write after length checks
defer mempool.PutBuffer(xb)
for _, sub := range pk.Filters {
xb.Write(encodeString(sub.Filter)) // [MQTT-3.10.3-1]
}

if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pb := mempool.GetBuffer()
defer mempool.PutBuffer(pb)
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+xb.Len())
nb.Write(pb.Bytes())
}
Expand Down Expand Up @@ -1100,10 +1122,12 @@ func (pk *Packet) UnsubscribeValidate() Code {

// AuthEncode encodes an Auth packet.
func (pk *Packet) AuthEncode(buf *bytes.Buffer) error {
nb := bytes.NewBuffer([]byte{})
nb := mempool.GetBuffer()
defer mempool.PutBuffer(nb)
nb.WriteByte(pk.ReasonCode)

pb := bytes.NewBuffer([]byte{})
pb := mempool.GetBuffer()
defer mempool.PutBuffer(pb)
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
nb.Write(pb.Bytes())

Expand Down

0 comments on commit c6c7c29

Please sign in to comment.