Skip to content

Commit

Permalink
fix(quic): Trigger Quic as Transport wakeup on dial (#3306)
Browse files Browse the repository at this point in the history
Scenario: rust-libp2p node A dials rust-libp2p node B. B listens on a QUIC address. A dials B via the `libp2p-quic` `Transport` wrapped in a `libp2p-dns` `Transport`.

Note that `libp2p-dns` in itself is not relevant here. Only the fact that `libp2p-dns` delays a dial is relevant, i.e. that it first does other async stuff (DNS lookup) before creating the QUIC dial. In fact, dialing an IP address through the DNS `Transport` where no DNS resolution is needed triggers the below just fine.

1. A calls `Swarm::dial` which creates a `libp2p-dns` dial.
2. That dial is spawned onto the connection `Pool`, thus starting the DNS resolution.
3. A continuously calls `Swarm::poll`.
4. `libp2p-quic` `Transport::poll` is called, finding no dialers in `self.dialer` given that the spawned dial is still only resolving the DNS address.
5. On the spawned connection task:
1. The DNS resolution finishes.
2. Thus calling `Transport::dial` on `libp1p-quic` (note that the DNS dial has a clone of the QUIC `Transport` wrapped in an `Arc<Mutex<_>>`).
3. That adds a dialer to `self.dialer`. Note that there are no listeners, i.e. `Swarm::listen_on` was never called.
4. `DialerState::new_dial` is called which adds a message to `self.pending_dials` and wakes `self.waker`. Given that on the last `Transport::poll` there was no `self.dialer`, that waker is empty.

Result: The message is stuck in the `DialerState::pending_dials`. The message is never send to the endpoint driver. The dial never succeeds.

This commit fixes the above, waking the `<Quic as Transport>:poll` method.
  • Loading branch information
mxinden authored Jan 11, 2023
1 parent 1b6c915 commit 7665e74
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 11 deletions.
28 changes: 17 additions & 11 deletions transports/quic/src/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ pub struct GenTransport<P: Provider> {
listeners: SelectAll<Listener<P>>,
/// Dialer for each socket family if no matching listener exists.
dialer: HashMap<SocketFamily, Dialer>,
dialer_waker: Option<Waker>,
}

impl<P: Provider> GenTransport<P> {
Expand All @@ -84,6 +85,7 @@ impl<P: Provider> GenTransport<P> {
quinn_config,
handshake_timeout,
dialer: HashMap::new(),
dialer_waker: None,
support_draft_29,
}
}
Expand Down Expand Up @@ -178,6 +180,12 @@ impl<P: Provider> Transport for GenTransport<P> {
&mut listeners[index].dialer_state
}
};

// Wakeup the task polling [`Transport::poll`] to drive the new dial.
if let Some(waker) = self.dialer_waker.take() {
waker.wake();
}

Ok(dialer_state.new_dial(socket_addr, self.handshake_timeout, version))
}

Expand Down Expand Up @@ -207,10 +215,14 @@ impl<P: Provider> Transport for GenTransport<P> {
// Drop dialer and all pending dials so that the connection receiver is notified.
self.dialer.remove(&key);
}
match self.listeners.poll_next_unpin(cx) {
Poll::Ready(Some(ev)) => Poll::Ready(ev),
_ => Poll::Pending,

if let Poll::Ready(Some(ev)) = self.listeners.poll_next_unpin(cx) {
return Poll::Ready(ev);
}

self.dialer_waker = Some(cx.waker().clone());

Poll::Pending
}
}

Expand Down Expand Up @@ -254,12 +266,11 @@ impl Drop for Dialer {
}
}

/// Pending dials to be sent to the endpoint was the [`endpoint::Channel`]
/// has capacity
/// Pending dials to be sent to the endpoint once the [`endpoint::Channel`]
/// has capacity.
#[derive(Default, Debug)]
struct DialerState {
pending_dials: VecDeque<ToEndpoint>,
waker: Option<Waker>,
}

impl DialerState {
Expand All @@ -279,10 +290,6 @@ impl DialerState {

self.pending_dials.push_back(message);

if let Some(waker) = self.waker.take() {
waker.wake();
}

async move {
// Our oneshot getting dropped means the message didn't make it to the endpoint driver.
let connection = tx.await.map_err(|_| Error::EndpointDriverCrashed)??;
Expand All @@ -307,7 +314,6 @@ impl DialerState {
Err(endpoint::Disconnected {}) => return Poll::Ready(Error::EndpointDriverCrashed),
}
}
self.waker = Some(cx.waker().clone());
Poll::Pending
}
}
Expand Down
114 changes: 114 additions & 0 deletions transports/quic/tests/smoke.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
#![cfg(any(feature = "async-std", feature = "tokio"))]

use futures::channel::{mpsc, oneshot};
use futures::future::BoxFuture;
use futures::future::{poll_fn, Either};
use futures::stream::StreamExt;
use futures::{future, AsyncReadExt, AsyncWriteExt, FutureExt, SinkExt};
use futures_timer::Delay;
use libp2p_core::either::EitherOutput;
use libp2p_core::muxing::{StreamMuxerBox, StreamMuxerExt, SubstreamBox};
use libp2p_core::transport::{Boxed, OrTransport, TransportEvent};
use libp2p_core::transport::{ListenerId, TransportError};
use libp2p_core::{multiaddr::Protocol, upgrade, Multiaddr, PeerId, Transport};
use libp2p_noise as noise;
use libp2p_quic as quic;
Expand All @@ -19,6 +22,10 @@ use std::io;
use std::num::NonZeroU8;
use std::task::Poll;
use std::time::Duration;
use std::{
pin::Pin,
sync::{Arc, Mutex},
};

#[cfg(feature = "tokio")]
#[tokio::test]
Expand Down Expand Up @@ -90,6 +97,113 @@ async fn ipv4_dial_ipv6() {
assert_eq!(b_connected, a_peer_id);
}

/// Tests that a [`Transport::dial`] wakes up the task previously polling [`Transport::poll`].
///
/// See https://github.com/libp2p/rust-libp2p/pull/3306 for context.
#[cfg(feature = "async-std")]
#[async_std::test]
async fn wrapped_with_dns() {
let _ = env_logger::try_init();

struct DialDelay(Arc<Mutex<Boxed<(PeerId, StreamMuxerBox)>>>);

impl Transport for DialDelay {
type Output = (PeerId, StreamMuxerBox);
type Error = std::io::Error;
type ListenerUpgrade = Pin<Box<dyn Future<Output = io::Result<Self::Output>> + Send>>;
type Dial = BoxFuture<'static, Result<Self::Output, Self::Error>>;

fn listen_on(
&mut self,
addr: Multiaddr,
) -> Result<ListenerId, TransportError<Self::Error>> {
self.0.lock().unwrap().listen_on(addr)
}

fn remove_listener(&mut self, id: ListenerId) -> bool {
self.0.lock().unwrap().remove_listener(id)
}

fn address_translation(
&self,
listen: &Multiaddr,
observed: &Multiaddr,
) -> Option<Multiaddr> {
self.0.lock().unwrap().address_translation(listen, observed)
}

/// Delayed dial, i.e. calling [`Transport::dial`] on the inner [`Transport`] not within the
/// synchronous [`Transport::dial`] method, but within the [`Future`] returned by the outer
/// [`Transport::dial`].
fn dial(&mut self, addr: Multiaddr) -> Result<Self::Dial, TransportError<Self::Error>> {
let t = self.0.clone();
Ok(async move {
// Simulate DNS lookup. Giving the `Transport::poll` the chance to return
// `Poll::Pending` and thus suspending its task, waiting for a wakeup from the dial
// on the inner transport below.
Delay::new(Duration::from_millis(100)).await;

let dial = t.lock().unwrap().dial(addr).map_err(|e| match e {
TransportError::MultiaddrNotSupported(_) => {
panic!()
}
TransportError::Other(e) => e,
})?;
dial.await
}
.boxed())
}

fn dial_as_listener(
&mut self,
addr: Multiaddr,
) -> Result<Self::Dial, TransportError<Self::Error>> {
self.0.lock().unwrap().dial_as_listener(addr)
}

fn poll(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<TransportEvent<Self::ListenerUpgrade, Self::Error>> {
Pin::new(&mut *self.0.lock().unwrap()).poll(cx)
}
}

let (a_peer_id, mut a_transport) = create_default_transport::<quic::async_std::Provider>();
let (b_peer_id, mut b_transport) = {
let (id, transport) = create_default_transport::<quic::async_std::Provider>();
(id, DialDelay(Arc::new(Mutex::new(transport))).boxed())
};

// Spawn a
let a_addr = start_listening(&mut a_transport, "/ip6/::1/udp/0/quic-v1").await;
let listener = async_std::task::spawn(async move {
let (upgrade, _) = a_transport
.select_next_some()
.await
.into_incoming()
.unwrap();
let (peer_id, _) = upgrade.await.unwrap();

peer_id
});

// Spawn b
//
// Note that the dial is spawned on a different task than the transport allowing the transport
// task to poll the transport once and then suspend, waiting for the wakeup from the dial.
let dial = async_std::task::spawn({
let dial = b_transport.dial(a_addr).unwrap();
async { dial.await.unwrap().0 }
});
async_std::task::spawn(async move { b_transport.next().await });

let (a_connected, b_connected) = future::join(listener, dial).await;

assert_eq!(a_connected, b_peer_id);
assert_eq!(b_connected, a_peer_id);
}

#[cfg(feature = "async-std")]
#[async_std::test]
#[ignore] // Transport currently does not validate PeerId. Enable once we make use of PeerId validation in rustls.
Expand Down

0 comments on commit 7665e74

Please sign in to comment.