From b04ede15553344e28d61f45b8390abcab8360e09 Mon Sep 17 00:00:00 2001 From: "Masih H. Derkani" Date: Fri, 13 Dec 2024 13:32:11 +0000 Subject: [PATCH] Implement GPBFT message compression using zstd (#793) Add the ability to compress GPBFT messages controllable via manifest. Implement benchmarks to compare vanilla CBOR and ZSTD encoding. Basic local run: ``` BenchmarkCborEncoding-12 47173 25491 ns/op 135409 B/op 87 allocs/op BenchmarkCborDecoding-12 64550 18078 ns/op 61728 B/op 209 allocs/op BenchmarkZstdEncoding-12 29061 41489 ns/op 193455 B/op 88 allocs/op BenchmarkZstdDecoding-12 66172 17924 ns/op 176517 B/op 211 allocs/op ``` Fixes #786 --- f3_test.go | 26 +++++++ go.mod | 2 +- host.go | 35 +++++---- manifest/manifest.go | 16 +++++ msg_encoding.go | 75 +++++++++++++++++++ msg_encoding_test.go | 166 +++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 306 insertions(+), 14 deletions(-) create mode 100644 msg_encoding.go create mode 100644 msg_encoding_test.go diff --git a/f3_test.go b/f3_test.go index bb4d38d8..f47623af 100644 --- a/f3_test.go +++ b/f3_test.go @@ -235,6 +235,31 @@ func TestF3DynamicManifest_WithPauseAndRebootstrap(t *testing.T) { require.Equal(t, env.manifest.BootstrapEpoch-env.manifest.EC.Finality, cert0.ECChain.Base().Epoch) } +func TestF3DynamicManifest_RebootstrapWithCompression(t *testing.T) { + env := newTestEnvironment(t).withNodes(2).withDynamicManifest().start() + env.waitForInstanceNumber(10, 30*time.Second, true) + + env.manifest.Pause = true + env.updateManifest() + + env.waitForNodesStopped() + + env.manifest.BootstrapEpoch = 956 + env.manifest.PubSub.CompressionEnabled = true + env.manifest.Pause = false + env.updateManifest() + env.waitForManifest() + + env.clock.Add(1 * time.Minute) + + env.waitForInstanceNumber(3, 30*time.Second, true) + env.requireEqualManifests(true) + + cert0, err := env.nodes[0].f3.GetCert(env.testCtx, 0) + require.NoError(t, err) + require.Equal(t, env.manifest.BootstrapEpoch-env.manifest.EC.Finality, cert0.ECChain.Base().Epoch) +} + func TestF3LateBootstrap(t *testing.T) { env := newTestEnvironment(t).withNodes(2).start() @@ -286,6 +311,7 @@ var base = manifest.Manifest{ EC: manifest.DefaultEcConfig, CertificateExchange: manifest.DefaultCxConfig, CatchUpAlignment: manifest.DefaultCatchUpAlignment, + PubSub: manifest.DefaultPubSubConfig, } type testNode struct { diff --git a/go.mod b/go.mod index bb653822..0986ca06 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/ipfs/go-datastore v0.6.0 github.com/ipfs/go-ds-leveldb v0.5.0 github.com/ipfs/go-log/v2 v2.5.1 + github.com/klauspost/compress v1.17.11 github.com/libp2p/go-libp2p v0.37.2 github.com/libp2p/go-libp2p-pubsub v0.11.0 github.com/marcboeker/go-duckdb v1.8.2 @@ -67,7 +68,6 @@ require ( github.com/jackpal/go-nat-pmp v1.0.2 // indirect github.com/jbenet/go-temp-err-catcher v0.1.0 // indirect github.com/jbenet/goprocess v0.1.4 // indirect - github.com/klauspost/compress v1.17.11 // indirect github.com/klauspost/cpuid/v2 v2.2.8 // indirect github.com/koron/go-ssdp v0.0.4 // indirect github.com/libp2p/go-buffer-pool v0.1.0 // indirect diff --git a/host.go b/host.go index 586e20c6..c7f8ed60 100644 --- a/host.go +++ b/host.go @@ -1,7 +1,6 @@ package f3 import ( - "bytes" "context" "errors" "fmt" @@ -51,7 +50,8 @@ type gpbftRunner struct { msgsMutex sync.Mutex selfMessages map[uint64]map[roundPhase][]*gpbft.GMessage - inputs gpbftInputs + inputs gpbftInputs + msgEncoding gMessageEncoding } type roundPhase struct { @@ -132,6 +132,15 @@ func newRunner( return nil, fmt.Errorf("creating participant: %w", err) } runner.participant = p + + if runner.manifest.PubSub.CompressionEnabled { + runner.msgEncoding, err = newZstdGMessageEncoding() + if err != nil { + return nil, err + } + } else { + runner.msgEncoding = &cborGMessageEncoding{} + } return runner, nil } @@ -443,13 +452,12 @@ func (h *gpbftRunner) BroadcastMessage(ctx context.Context, msg *gpbft.GMessage) if h.topic == nil { return pubsub.ErrTopicClosed } - var bw bytes.Buffer - err = msg.MarshalCBOR(&bw) + encoded, err := h.msgEncoding.Encode(msg) if err != nil { - return fmt.Errorf("marshalling GMessage for broadcast: %w", err) + return fmt.Errorf("encoding GMessage for broadcast: %w", err) } - err = h.topic.Publish(ctx, bw.Bytes()) + err = h.topic.Publish(ctx, encoded) if err != nil { return fmt.Errorf("publishing message: %w", err) } @@ -464,11 +472,11 @@ func (h *gpbftRunner) rebroadcastMessage(msg *gpbft.GMessage) error { if h.topic == nil { return pubsub.ErrTopicClosed } - var bw bytes.Buffer - if err := msg.MarshalCBOR(&bw); err != nil { - return fmt.Errorf("marshalling GMessage for broadcast: %w", err) + encoded, err := h.msgEncoding.Encode(msg) + if err != nil { + return fmt.Errorf("encoding GMessage for broadcast: %w", err) } - if err := h.topic.Publish(h.runningCtx, bw.Bytes()); err != nil { + if err := h.topic.Publish(h.runningCtx, encoded); err != nil { return fmt.Errorf("publishing message: %w", err) } return nil @@ -481,12 +489,13 @@ func (h *gpbftRunner) validatePubsubMessage(ctx context.Context, _ peer.ID, msg recordValidationTime(ctx, start, _result) }(time.Now()) - var gmsg gpbft.GMessage - if err := gmsg.UnmarshalCBOR(bytes.NewReader(msg.Data)); err != nil { + gmsg, err := h.msgEncoding.Decode(msg.Data) + if err != nil { + log.Debugw("failed to decode message", "from", msg.GetFrom(), "err", err) return pubsub.ValidationReject } - switch validatedMessage, err := h.participant.ValidateMessage(&gmsg); { + switch validatedMessage, err := h.participant.ValidateMessage(gmsg); { case errors.Is(err, gpbft.ErrValidationInvalid): log.Debugf("validation error during validation: %+v", err) return pubsub.ValidationReject diff --git a/manifest/manifest.go b/manifest/manifest.go index d7f7a8a5..e5ab93c2 100644 --- a/manifest/manifest.go +++ b/manifest/manifest.go @@ -52,6 +52,10 @@ var ( MaximumPollInterval: 4 * DefaultEcConfig.Period, } + DefaultPubSubConfig = PubSubConfig{ + CompressionEnabled: false, + } + // Default instance alignment when catching up. DefaultCatchUpAlignment = DefaultEcConfig.Period / 2 ) @@ -194,6 +198,12 @@ func (e *EcConfig) Validate() error { return nil } +type PubSubConfig struct { + CompressionEnabled bool +} + +func (p *PubSubConfig) Validate() error { return nil } + // Manifest identifies the specific configuration for the F3 instance currently running. type Manifest struct { // Pause stops the participation in F3. @@ -227,6 +237,8 @@ type Manifest struct { EC EcConfig // Certificate Exchange specific parameters CertificateExchange CxConfig + // PubSubConfig specifies the pubsub related configuration. + PubSub PubSubConfig } func (m *Manifest) Equal(o *Manifest) bool { @@ -289,6 +301,9 @@ func (m *Manifest) Validate() error { if err := m.CertificateExchange.Validate(); err != nil { return fmt.Errorf("invalid manifest: invalid certificate exchange config: %w", err) } + if err := m.PubSub.Validate(); err != nil { + return fmt.Errorf("invalid manifest: invalid pubsub config: %w", err) + } return nil } @@ -305,6 +320,7 @@ func LocalDevnetManifest() *Manifest { Gpbft: DefaultGpbftConfig, CertificateExchange: DefaultCxConfig, CatchUpAlignment: DefaultCatchUpAlignment, + PubSub: DefaultPubSubConfig, } return m } diff --git a/msg_encoding.go b/msg_encoding.go new file mode 100644 index 00000000..65b903b0 --- /dev/null +++ b/msg_encoding.go @@ -0,0 +1,75 @@ +package f3 + +import ( + "bytes" + + "github.com/filecoin-project/go-f3/gpbft" + "github.com/klauspost/compress/zstd" +) + +var ( + _ gMessageEncoding = (*cborGMessageEncoding)(nil) + _ gMessageEncoding = (*zstdGMessageEncoding)(nil) +) + +type gMessageEncoding interface { + Encode(*gpbft.GMessage) ([]byte, error) + Decode([]byte) (*gpbft.GMessage, error) +} + +type cborGMessageEncoding struct{} + +func (c *cborGMessageEncoding) Encode(m *gpbft.GMessage) ([]byte, error) { + var buf bytes.Buffer + if err := m.MarshalCBOR(&buf); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +func (c *cborGMessageEncoding) Decode(v []byte) (*gpbft.GMessage, error) { + r := bytes.NewReader(v) + var msg gpbft.GMessage + if err := msg.UnmarshalCBOR(r); err != nil { + return nil, err + } + return &msg, nil +} + +type zstdGMessageEncoding struct { + cborEncoding cborGMessageEncoding + compressor *zstd.Encoder + decompressor *zstd.Decoder +} + +func newZstdGMessageEncoding() (*zstdGMessageEncoding, error) { + writer, err := zstd.NewWriter(nil) + if err != nil { + return nil, err + } + reader, err := zstd.NewReader(nil) + if err != nil { + return nil, err + } + return &zstdGMessageEncoding{ + compressor: writer, + decompressor: reader, + }, nil +} + +func (c *zstdGMessageEncoding) Encode(m *gpbft.GMessage) ([]byte, error) { + cborEncoded, err := c.cborEncoding.Encode(m) + if err != nil { + return nil, err + } + compressed := c.compressor.EncodeAll(cborEncoded, make([]byte, 0, len(cborEncoded))) + return compressed, err +} + +func (c *zstdGMessageEncoding) Decode(v []byte) (*gpbft.GMessage, error) { + cborEncoded, err := c.decompressor.DecodeAll(v, make([]byte, 0, len(v))) + if err != nil { + return nil, err + } + return c.cborEncoding.Decode(cborEncoded) +} diff --git a/msg_encoding_test.go b/msg_encoding_test.go new file mode 100644 index 00000000..fb83e2e2 --- /dev/null +++ b/msg_encoding_test.go @@ -0,0 +1,166 @@ +package f3 + +import ( + "math/rand" + "testing" + + "github.com/filecoin-project/go-bitfield" + "github.com/filecoin-project/go-f3/gpbft" + "github.com/ipfs/go-cid" + "github.com/multiformats/go-multihash" + "github.com/stretchr/testify/require" +) + +const seed = 1413 + +func BenchmarkCborEncoding(b *testing.B) { + rng := rand.New(rand.NewSource(seed)) + encoder := &cborGMessageEncoding{} + msg := generateRandomGMessage(b, rng) + + b.ResetTimer() + b.ReportAllocs() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if _, err := encoder.Encode(msg); err != nil { + require.NoError(b, err) + } + } + }) +} + +func BenchmarkCborDecoding(b *testing.B) { + rng := rand.New(rand.NewSource(seed)) + encoder := &cborGMessageEncoding{} + msg := generateRandomGMessage(b, rng) + data, err := encoder.Encode(msg) + require.NoError(b, err) + + b.ResetTimer() + b.ReportAllocs() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if got, err := encoder.Decode(data); err != nil { + require.NoError(b, err) + require.Equal(b, msg, got) + } + } + }) +} + +func BenchmarkZstdEncoding(b *testing.B) { + rng := rand.New(rand.NewSource(seed)) + encoder, err := newZstdGMessageEncoding() + require.NoError(b, err) + msg := generateRandomGMessage(b, rng) + + b.ResetTimer() + b.ReportAllocs() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if _, err := encoder.Encode(msg); err != nil { + require.NoError(b, err) + } + } + }) +} + +func BenchmarkZstdDecoding(b *testing.B) { + rng := rand.New(rand.NewSource(seed)) + encoder, err := newZstdGMessageEncoding() + require.NoError(b, err) + msg := generateRandomGMessage(b, rng) + data, err := encoder.Encode(msg) + require.NoError(b, err) + + b.ResetTimer() + b.ReportAllocs() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if got, err := encoder.Decode(data); err != nil { + require.NoError(b, err) + require.Equal(b, msg, got) + } + } + }) +} + +func generateRandomGMessage(b *testing.B, rng *rand.Rand) *gpbft.GMessage { + var maybeTicket []byte + if rng.Float64() < 0.5 { + generateRandomBytes(b, rng, 96) + } + + return &gpbft.GMessage{ + Sender: gpbft.ActorID(rng.Uint64()), + Vote: generateRandomPayload(b, rng), + Signature: generateRandomBytes(b, rng, 96), + Ticket: maybeTicket, + Justification: generateRandomJustification(b, rng), + } +} + +func generateRandomJustification(b *testing.B, rng *rand.Rand) *gpbft.Justification { + return &gpbft.Justification{ + Vote: generateRandomPayload(b, rng), + Signers: generateRandomBitfield(rng), + Signature: generateRandomBytes(b, rng, 96), + } +} + +func generateRandomBytes(b *testing.B, rng *rand.Rand, n int) []byte { + buf := make([]byte, n) + _, err := rng.Read(buf) + require.NoError(b, err) + return buf +} + +func generateRandomPayload(b *testing.B, rng *rand.Rand) gpbft.Payload { + return gpbft.Payload{ + Instance: rng.Uint64(), + Round: rng.Uint64(), + Phase: gpbft.Phase(rng.Intn(int(gpbft.COMMIT_PHASE)) + 1), + Value: generateRandomECChain(b, rng, rng.Intn(gpbft.ChainMaxLen)+1), + SupplementalData: gpbft.SupplementalData{ + PowerTable: generateRandomCID(b, rng), + }, + } +} + +func generateRandomBitfield(rng *rand.Rand) bitfield.BitField { + ids := make([]uint64, rng.Intn(2_000)+1) + for i := range ids { + ids[i] = rng.Uint64() + } + return bitfield.NewFromSet(ids) +} + +func generateRandomECChain(b *testing.B, rng *rand.Rand, length int) gpbft.ECChain { + chain := make(gpbft.ECChain, length) + epoch := int64(rng.Uint64()) + for i := range length { + chain[i] = generateRandomTipSet(b, rng, epoch+int64(i)) + } + return chain +} + +func generateRandomTipSet(b *testing.B, rng *rand.Rand, epoch int64) gpbft.TipSet { + return gpbft.TipSet{ + Epoch: epoch, + Key: generateRandomTipSetKey(b, rng), + PowerTable: generateRandomCID(b, rng), + } +} + +func generateRandomTipSetKey(b *testing.B, rng *rand.Rand) gpbft.TipSetKey { + key := make([]byte, rng.Intn(gpbft.TipsetKeyMaxLen)+1) + _, err := rng.Read(key) + require.NoError(b, err) + return key +} + +func generateRandomCID(b *testing.B, rng *rand.Rand) cid.Cid { + sum, err := multihash.Sum(generateRandomBytes(b, rng, 32), multihash.SHA2_256, -1) + require.NoError(b, err) + return cid.NewCidV1(cid.Raw, sum) +}