Skip to content

Commit

Permalink
feat(quick-protobuf-codec): reduce allocations during encoding
Browse files Browse the repository at this point in the history
We can reduce the number of allocations during the encoding of messages by not depending on the `Encoder` implementation of `unsigned-varint` but doing the encoding ourselves.

This allows us to directly plug the `BytesMut` into the `Writer` and directly encode into the target buffer. Previously, we were allocating each message as an additional buffer that got immediately thrown-away again.

Related: #4781.

Pull-Request: #4782.
thomaseizinger authored Nov 24, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 3b6b74d commit d851d1b
Showing 10 changed files with 339 additions and 24 deletions.
5 changes: 4 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -117,7 +117,7 @@ multiaddr = "0.18.1"
multihash = "0.19.1"
multistream-select = { version = "0.13.0", path = "misc/multistream-select" }
prometheus-client = "0.22.0"
quick-protobuf-codec = { version = "0.3.0", path = "misc/quick-protobuf-codec" }
quick-protobuf-codec = { version = "0.3.1", path = "misc/quick-protobuf-codec" }
quickcheck = { package = "quickcheck-ext", path = "misc/quickcheck-ext" }
rw-stream-sink = { version = "0.4.0", path = "misc/rw-stream-sink" }
unsigned-varint = { version = "0.8.0" }
5 changes: 5 additions & 0 deletions misc/quick-protobuf-codec/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
## 0.3.1

- Reduce allocations during encoding.
See [PR 4782](https://github.com/libp2p/rust-libp2p/pull/4782).

## 0.3.0

- Update to `asynchronous-codec` `v0.7.0`.
13 changes: 11 additions & 2 deletions misc/quick-protobuf-codec/Cargo.toml
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@ name = "quick-protobuf-codec"
edition = "2021"
rust-version = { workspace = true }
description = "Asynchronous de-/encoding of Protobuf structs using asynchronous-codec, unsigned-varint and quick-protobuf."
version = "0.3.0"
version = "0.3.1"
authors = ["Max Inden <mail@max-inden.de>"]
license = "MIT"
repository = "https://github.com/libp2p/rust-libp2p"
@@ -14,9 +14,18 @@ categories = ["asynchronous"]
asynchronous-codec = { workspace = true }
bytes = { version = "1" }
thiserror = "1.0"
unsigned-varint = { workspace = true, features = ["asynchronous_codec"] }
unsigned-varint = { workspace = true, features = ["std"] }
quick-protobuf = "0.8"

[dev-dependencies]
criterion = "0.5.1"
futures = "0.3.28"
quickcheck = { workspace = true }

[[bench]]
name = "codec"
harness = false

# Passing arguments to the docsrs builder in order to properly document cfg's.
# More information: https://docs.rs/about/builds#cross-compiling
[package.metadata.docs.rs]
28 changes: 28 additions & 0 deletions misc/quick-protobuf-codec/benches/codec.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
use asynchronous_codec::Encoder;
use bytes::BytesMut;
use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion};
use quick_protobuf_codec::{proto, Codec};

pub fn benchmark(c: &mut Criterion) {
for size in [1000, 10_000, 100_000, 1_000_000, 10_000_000] {
c.bench_with_input(BenchmarkId::new("encode", size), &size, |b, i| {
b.iter_batched(
|| {
let mut out = BytesMut::new();
out.reserve(i + 100);
let codec = Codec::<proto::Message>::new(i + 100);
let msg = proto::Message {
data: vec![0; size],
};

(codec, out, msg)
},
|(mut codec, mut out, msg)| codec.encode(msg, &mut out).unwrap(),
BatchSize::SmallInput,
);
});
}
}

criterion_group!(benches, benchmark);
criterion_main!(benches);
2 changes: 2 additions & 0 deletions misc/quick-protobuf-codec/src/generated/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
// Automatically generated mod.rs
pub mod test;
7 changes: 7 additions & 0 deletions misc/quick-protobuf-codec/src/generated/test.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
syntax = "proto3";

package test;

message Message {
bytes data = 1;
}
47 changes: 47 additions & 0 deletions misc/quick-protobuf-codec/src/generated/test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Automatically generated rust module for 'test.proto' file

#![allow(non_snake_case)]
#![allow(non_upper_case_globals)]
#![allow(non_camel_case_types)]
#![allow(unused_imports)]
#![allow(unknown_lints)]
#![allow(clippy::all)]
#![cfg_attr(rustfmt, rustfmt_skip)]


use quick_protobuf::{MessageInfo, MessageRead, MessageWrite, BytesReader, Writer, WriterBackend, Result};
use quick_protobuf::sizeofs::*;
use super::*;

#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Debug, Default, PartialEq, Clone)]
pub struct Message {
pub data: Vec<u8>,
}

impl<'a> MessageRead<'a> for Message {
fn from_reader(r: &mut BytesReader, bytes: &'a [u8]) -> Result<Self> {
let mut msg = Self::default();
while !r.is_eof() {
match r.next_tag(bytes) {
Ok(10) => msg.data = r.read_bytes(bytes)?.to_owned(),
Ok(t) => { r.read_unknown(bytes, t)?; }
Err(e) => return Err(e),
}
}
Ok(msg)
}
}

impl MessageWrite for Message {
fn get_size(&self) -> usize {
0
+ if self.data.is_empty() { 0 } else { 1 + sizeof_len((&self.data).len()) }
}

fn write_message<W: WriterBackend>(&self, w: &mut Writer<W>) -> Result<()> {
if !self.data.is_empty() { w.write_with_tag(10, |w| w.write_bytes(&**&self.data))?; }
Ok(())
}
}

238 changes: 218 additions & 20 deletions misc/quick-protobuf-codec/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]

use asynchronous_codec::{Decoder, Encoder};
use bytes::{Bytes, BytesMut};
use quick_protobuf::{BytesReader, MessageRead, MessageWrite, Writer};
use bytes::{Buf, BufMut, BytesMut};
use quick_protobuf::{BytesReader, MessageRead, MessageWrite, Writer, WriterBackend};
use std::io;
use std::marker::PhantomData;
use unsigned_varint::codec::UviBytes;

mod generated;

#[doc(hidden)] // NOT public API. Do not use.
pub use generated::test as proto;

/// [`Codec`] implements [`Encoder`] and [`Decoder`], uses [`unsigned_varint`]
/// to prefix messages with their length and uses [`quick_protobuf`] and a provided
/// `struct` implementing [`MessageRead`] and [`MessageWrite`] to do the encoding.
pub struct Codec<In, Out = In> {
uvi: UviBytes,
max_message_len_bytes: usize,
phantom: PhantomData<(In, Out)>,
}

@@ -21,10 +26,8 @@ impl<In, Out> Codec<In, Out> {
/// Protobuf message. The limit does not include the bytes needed for the
/// [`unsigned_varint`].
pub fn new(max_message_len_bytes: usize) -> Self {
let mut uvi = UviBytes::default();
uvi.set_max_len(max_message_len_bytes);
Self {
uvi,
max_message_len_bytes,
phantom: PhantomData,
}
}
@@ -35,16 +38,32 @@ impl<In: MessageWrite, Out> Encoder for Codec<In, Out> {
type Error = Error;

fn encode(&mut self, item: Self::Item<'_>, dst: &mut BytesMut) -> Result<(), Self::Error> {
let mut encoded_msg = Vec::new();
let mut writer = Writer::new(&mut encoded_msg);
item.write_message(&mut writer)
.expect("Encoding to succeed");
self.uvi.encode(Bytes::from(encoded_msg), dst)?;
write_length(&item, dst);
write_message(&item, dst)?;

Ok(())
}
}

/// Write the message's length (i.e. `size`) to `dst` as a variable-length integer.
fn write_length(message: &impl MessageWrite, dst: &mut BytesMut) {
let message_length = message.get_size();

let mut uvi_buf = unsigned_varint::encode::usize_buffer();
let encoded_length = unsigned_varint::encode::usize(message_length, &mut uvi_buf);

dst.extend_from_slice(encoded_length);
}

/// Write the message itself to `dst`.
fn write_message(item: &impl MessageWrite, dst: &mut BytesMut) -> io::Result<()> {
let mut writer = Writer::new(BytesMutWriterBackend::new(dst));
item.write_message(&mut writer)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;

Ok(())
}

impl<In, Out> Decoder for Codec<In, Out>
where
Out: for<'a> MessageRead<'a>,
@@ -53,24 +72,203 @@ where
type Error = Error;

fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
let msg = match self.uvi.decode(src)? {
None => return Ok(None),
Some(msg) => msg,
let (message_length, remaining) = match unsigned_varint::decode::usize(src) {
Ok((len, remaining)) => (len, remaining),
Err(unsigned_varint::decode::Error::Insufficient) => return Ok(None),
Err(e) => return Err(Error(io::Error::new(io::ErrorKind::InvalidData, e))),
};

let mut reader = BytesReader::from_bytes(&msg);
let message = Self::Item::from_reader(&mut reader, &msg)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
if message_length > self.max_message_len_bytes {
return Err(Error(io::Error::new(
io::ErrorKind::PermissionDenied,
format!(
"message with {message_length}b exceeds maximum of {}b",
self.max_message_len_bytes
),
)));
}

// Compute how many bytes the varint itself consumed.
let varint_length = src.len() - remaining.len();

// Ensure we can read an entire message.
if src.len() < (message_length + varint_length) {
return Ok(None);
}

// Safe to advance buffer now.
src.advance(varint_length);

let message = src.split_to(message_length);

let mut reader = BytesReader::from_bytes(&message);
let message = Self::Item::from_reader(&mut reader, &message)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;

Ok(Some(message))
}
}

struct BytesMutWriterBackend<'a> {
dst: &'a mut BytesMut,
}

impl<'a> BytesMutWriterBackend<'a> {
fn new(dst: &'a mut BytesMut) -> Self {
Self { dst }
}
}

impl<'a> WriterBackend for BytesMutWriterBackend<'a> {
fn pb_write_u8(&mut self, x: u8) -> quick_protobuf::Result<()> {
self.dst.put_u8(x);

Ok(())
}

fn pb_write_u32(&mut self, x: u32) -> quick_protobuf::Result<()> {
self.dst.put_u32_le(x);

Ok(())
}

fn pb_write_i32(&mut self, x: i32) -> quick_protobuf::Result<()> {
self.dst.put_i32_le(x);

Ok(())
}

fn pb_write_f32(&mut self, x: f32) -> quick_protobuf::Result<()> {
self.dst.put_f32_le(x);

Ok(())
}

fn pb_write_u64(&mut self, x: u64) -> quick_protobuf::Result<()> {
self.dst.put_u64_le(x);

Ok(())
}

fn pb_write_i64(&mut self, x: i64) -> quick_protobuf::Result<()> {
self.dst.put_i64_le(x);

Ok(())
}

fn pb_write_f64(&mut self, x: f64) -> quick_protobuf::Result<()> {
self.dst.put_f64_le(x);

Ok(())
}

fn pb_write_all(&mut self, buf: &[u8]) -> quick_protobuf::Result<()> {
self.dst.put_slice(buf);

Ok(())
}
}

#[derive(thiserror::Error, Debug)]
#[error("Failed to encode/decode message")]
pub struct Error(#[from] std::io::Error);
pub struct Error(#[from] io::Error);

impl From<Error> for std::io::Error {
impl From<Error> for io::Error {
fn from(e: Error) -> Self {
e.0
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::proto;
use asynchronous_codec::FramedRead;
use futures::io::Cursor;
use futures::{FutureExt, StreamExt};
use quickcheck::{Arbitrary, Gen, QuickCheck};
use std::error::Error;

#[test]
fn honors_max_message_length() {
let codec = Codec::<Dummy>::new(1);
let mut src = varint_zeroes(100);

let mut read = FramedRead::new(Cursor::new(&mut src), codec);
let err = read.next().now_or_never().unwrap().unwrap().unwrap_err();

assert_eq!(
err.source().unwrap().to_string(),
"message with 100b exceeds maximum of 1b"
)
}

#[test]
fn empty_bytes_mut_does_not_panic() {
let mut codec = Codec::<Dummy>::new(100);

let mut src = varint_zeroes(100);
src.truncate(50);

let result = codec.decode(&mut src);

assert!(result.unwrap().is_none());
assert_eq!(
src.len(),
50,
"to not modify `src` if we cannot read a full message"
)
}

#[test]
fn only_partial_message_in_bytes_mut_does_not_panic() {
let mut codec = Codec::<Dummy>::new(100);

let result = codec.decode(&mut BytesMut::new());

assert!(result.unwrap().is_none());
}

#[test]
fn handles_arbitrary_initial_capacity() {
fn prop(message: proto::Message, initial_capacity: u16) {
let mut buffer = BytesMut::with_capacity(initial_capacity as usize);
let mut codec = Codec::<proto::Message>::new(u32::MAX as usize);

codec.encode(message.clone(), &mut buffer).unwrap();
let decoded = codec.decode(&mut buffer).unwrap().unwrap();

assert_eq!(message, decoded);
}

QuickCheck::new().quickcheck(prop as fn(_, _) -> _)
}

/// Constructs a [`BytesMut`] of the provided length where the message is all zeros.
fn varint_zeroes(length: usize) -> BytesMut {
let mut buf = unsigned_varint::encode::usize_buffer();
let encoded_length = unsigned_varint::encode::usize(length, &mut buf);

let mut src = BytesMut::new();
src.extend_from_slice(encoded_length);
src.extend(std::iter::repeat(0).take(length));
src
}

impl Arbitrary for proto::Message {
fn arbitrary(g: &mut Gen) -> Self {
Self {
data: Vec::arbitrary(g),
}
}
}

#[derive(Debug)]
struct Dummy;

impl<'a> MessageRead<'a> for Dummy {
fn from_reader(_: &mut BytesReader, _: &'a [u8]) -> quick_protobuf::Result<Self> {
todo!()
}
}
}
16 changes: 16 additions & 0 deletions misc/quick-protobuf-codec/tests/large_message.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
use asynchronous_codec::Encoder;
use bytes::BytesMut;
use quick_protobuf_codec::proto;
use quick_protobuf_codec::Codec;

#[test]
fn encode_large_message() {
let mut codec = Codec::<proto::Message>::new(1_001_000);
let mut dst = BytesMut::new();
dst.reserve(1_001_000);
let message = proto::Message {
data: vec![0; 1_000_000],
};

codec.encode(message, &mut dst).unwrap();
}

0 comments on commit d851d1b

Please sign in to comment.