Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(relay): flush relayed connection once idle #3765

Merged
merged 6 commits into from
Apr 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions protocols/relay/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
## 0.15.2 - unreleased

- As a relay, when forwarding data between relay-connection-source and -destination and vice versa, flush write side when read currently has no more data available.
See [PR 3765].

[PR 3765]: https://github.com/libp2p/rust-libp2p/pull/3765

## 0.15.1

- Migrate from `prost` to `quick-protobuf`. This removes `protoc` dependency. See [PR 3312].
Expand Down
2 changes: 1 addition & 1 deletion protocols/relay/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "libp2p-relay"
edition = "2021"
rust-version = "1.62.0"
description = "Communications relaying for libp2p"
version = "0.15.1"
version = "0.15.2"
authors = ["Parity Technologies <admin@parity.io>", "Max Inden <mail@max-inden.de>"]
license = "MIT"
repository = "https://github.com/libp2p/rust-libp2p"
Expand Down
277 changes: 202 additions & 75 deletions protocols/relay/src/copy_future.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,14 @@ fn forward_data<S: AsyncBufRead + Unpin, D: AsyncWrite + Unpin>(
mut dst: &mut D,
cx: &mut Context<'_>,
) -> Poll<io::Result<u64>> {
let buffer = ready!(Pin::new(&mut src).poll_fill_buf(cx))?;
let buffer = match Pin::new(&mut src).poll_fill_buf(cx)? {
Poll::Ready(buffer) => buffer,
Poll::Pending => {
let _ = Pin::new(&mut dst).poll_flush(cx)?;
return Poll::Pending;
}
};

if buffer.is_empty() {
ready!(Pin::new(&mut dst).poll_flush(cx))?;
ready!(Pin::new(&mut dst).poll_close(cx))?;
Expand All @@ -150,95 +157,59 @@ fn forward_data<S: AsyncBufRead + Unpin, D: AsyncWrite + Unpin>(

#[cfg(test)]
mod tests {
use super::CopyFuture;
use super::*;
use futures::executor::block_on;
use futures::io::{AsyncRead, AsyncWrite};
use futures::io::{AsyncRead, AsyncWrite, BufReader, BufWriter};
use quickcheck::QuickCheck;
use std::io::ErrorKind;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;

struct Connection {
Copy link
Member Author

@mxinden mxinden Apr 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make confusing diff easier to read:

  • Connection is only used in quickckeck test, thus moved into quickcheck test.
  • PendingConnection is only used in max_circuit_duration test, thus moved into max_circuit_duration test.

read: Vec<u8>,
write: Vec<u8>,
}

impl AsyncWrite for Connection {
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
Pin::new(&mut self.write).poll_write(cx, buf)
}

fn poll_flush(
mut self: std::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.write).poll_flush(cx)
}

fn poll_close(
mut self: std::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.write).poll_close(cx)
}
}

impl AsyncRead for Connection {
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
let n = std::cmp::min(self.read.len(), buf.len());
buf[0..n].copy_from_slice(&self.read[0..n]);
self.read = self.read.split_off(n);
Poll::Ready(Ok(n))
#[test]
fn quickcheck() {
struct Connection {
read: Vec<u8>,
write: Vec<u8>,
}
}

struct PendingConnection {}

impl AsyncWrite for PendingConnection {
fn poll_write(
self: std::pin::Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &[u8],
) -> Poll<std::io::Result<usize>> {
Poll::Pending
}
impl AsyncWrite for Connection {
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
Pin::new(&mut self.write).poll_write(cx, buf)
}

fn poll_flush(
self: std::pin::Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<std::io::Result<()>> {
Poll::Pending
}
fn poll_flush(
mut self: std::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.write).poll_flush(cx)
}

fn poll_close(
self: std::pin::Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<std::io::Result<()>> {
Poll::Pending
fn poll_close(
mut self: std::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.write).poll_close(cx)
}
}
}

impl AsyncRead for PendingConnection {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
Poll::Pending
impl AsyncRead for Connection {
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
let n = std::cmp::min(self.read.len(), buf.len());
buf[0..n].copy_from_slice(&self.read[0..n]);
self.read = self.read.split_off(n);
Poll::Ready(Ok(n))
}
}
}

#[test]
fn quickcheck() {
fn prop(a: Vec<u8>, b: Vec<u8>, max_circuit_bytes: u64) {
let connection_a = Connection {
read: a.clone(),
Expand Down Expand Up @@ -275,6 +246,42 @@ mod tests {

#[test]
fn max_circuit_duration() {
struct PendingConnection {}

impl AsyncWrite for PendingConnection {
fn poll_write(
self: std::pin::Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &[u8],
) -> Poll<std::io::Result<usize>> {
Poll::Pending
}

fn poll_flush(
self: std::pin::Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<std::io::Result<()>> {
Poll::Pending
}

fn poll_close(
self: std::pin::Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<std::io::Result<()>> {
Poll::Pending
}
}

impl AsyncRead for PendingConnection {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
Poll::Pending
}
}

let copy_future = CopyFuture::new(
PendingConnection {},
PendingConnection {},
Expand All @@ -288,4 +295,124 @@ mod tests {
block_on(copy_future).expect_err("Expect maximum circuit duration to be reached.");
assert_eq!(error.kind(), ErrorKind::TimedOut);
}

#[test]
fn forward_data_should_flush_on_pending_source() {
struct NeverEndingSource {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
struct NeverEndingSource {
struct NeverEndingStory {

read: Vec<u8>,
}

impl AsyncRead for NeverEndingSource {
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
if let Some(b) = self.read.pop() {
buf[0] = b;
return Poll::Ready(Ok(1));
}

Poll::Pending
}
}

struct RecordingDestination {
method_calls: Vec<Method>,
}

#[derive(Debug, PartialEq)]
enum Method {
Write(Vec<u8>),
Flush,
Close,
}

impl AsyncWrite for RecordingDestination {
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
self.method_calls.push(Method::Write(buf.to_vec()));
Poll::Ready(Ok(buf.len()))
}

fn poll_flush(
mut self: std::pin::Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<std::io::Result<()>> {
self.method_calls.push(Method::Flush);
Poll::Ready(Ok(()))
}

fn poll_close(
mut self: std::pin::Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<std::io::Result<()>> {
self.method_calls.push(Method::Close);
Poll::Ready(Ok(()))
}
}

// The source has two reads available, handing them out on `AsyncRead::poll_read` one by one.
let mut source = BufReader::new(NeverEndingSource { read: vec![1, 2] });

// The destination is wrapped by a `BufWriter` with a capacity of `3`, i.e. one larger than
// the available reads of the source. Without an explicit `AsyncWrite::poll_flush` the two
// reads would thus never make it to the destination, but instead be stuck in the buffer of
// the `BufWrite`.
let mut destination = BufWriter::with_capacity(
3,
RecordingDestination {
method_calls: vec![],
},
);

let mut cx = Context::from_waker(futures::task::noop_waker_ref());

assert!(
matches!(
forward_data(&mut source, &mut destination, &mut cx),
Poll::Ready(Ok(1)),
),
"Expect `forward_data` to forward one read from the source to the wrapped destination."
);
assert_eq!(
destination.get_ref().method_calls.as_slice(), &[],
"Given that destination is wrapped with a `BufWrite`, the write doesn't (yet) make it to \
the destination. The source might have more data available, thus `forward_data` has not \
yet flushed.",
Comment on lines +383 to +385
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"Given that destination is wrapped with a `BufWrite`, the write doesn't (yet) make it to \
the destination. The source might have more data available, thus `forward_data` has not \
yet flushed.",
"Expect BufWriter to still have space and thus not forward the poll_write call",

);

assert!(
matches!(
forward_data(&mut source, &mut destination, &mut cx),
Poll::Ready(Ok(1)),
),
"Expect `forward_data` to forward one read from the source to the wrapped destination."
);
assert_eq!(
destination.get_ref().method_calls.as_slice(), &[],
"Given that destination is wrapped with a `BufWrite`, the write doesn't (yet) make it to \
the destination. The source might have more data available, thus `forward_data` has not \
yet flushed.",
);

assert!(
matches!(
forward_data(&mut source, &mut destination, &mut cx),
Poll::Pending,
),
"The source has no more reads available, but does not close i.e. does not return \
`Poll::Ready(Ok(1))` but instead `Poll::Pending`. Thus `forward_data` returns \
`Poll::Pending` as well."
);
assert_eq!(
destination.get_ref().method_calls.as_slice(),
&[Method::Write(vec![2, 1]), Method::Flush],
"Given that source had no more reads, `forward_data` calls flush, thus instructing the \
`BufWriter` to flush the two buffered writes down to the destination."
);
}
}