From 7665e74cdb100255feaf5f69b4c9b294dbf0855d Mon Sep 17 00:00:00 2001 From: Max Inden Date: Wed, 11 Jan 2023 15:41:34 +0100 Subject: [PATCH] fix(quic): Trigger Quic as Transport wakeup on dial (#3306) Scenario: rust-libp2p node A dials rust-libp2p node B. B listens on a QUIC address. A dials B via the `libp2p-quic` `Transport` wrapped in a `libp2p-dns` `Transport`. Note that `libp2p-dns` in itself is not relevant here. Only the fact that `libp2p-dns` delays a dial is relevant, i.e. that it first does other async stuff (DNS lookup) before creating the QUIC dial. In fact, dialing an IP address through the DNS `Transport` where no DNS resolution is needed triggers the below just fine. 1. A calls `Swarm::dial` which creates a `libp2p-dns` dial. 2. That dial is spawned onto the connection `Pool`, thus starting the DNS resolution. 3. A continuously calls `Swarm::poll`. 4. `libp2p-quic` `Transport::poll` is called, finding no dialers in `self.dialer` given that the spawned dial is still only resolving the DNS address. 5. On the spawned connection task: 1. The DNS resolution finishes. 2. Thus calling `Transport::dial` on `libp1p-quic` (note that the DNS dial has a clone of the QUIC `Transport` wrapped in an `Arc>`). 3. That adds a dialer to `self.dialer`. Note that there are no listeners, i.e. `Swarm::listen_on` was never called. 4. `DialerState::new_dial` is called which adds a message to `self.pending_dials` and wakes `self.waker`. Given that on the last `Transport::poll` there was no `self.dialer`, that waker is empty. Result: The message is stuck in the `DialerState::pending_dials`. The message is never send to the endpoint driver. The dial never succeeds. This commit fixes the above, waking the `:poll` method. --- transports/quic/src/transport.rs | 28 +++++--- transports/quic/tests/smoke.rs | 114 +++++++++++++++++++++++++++++++ 2 files changed, 131 insertions(+), 11 deletions(-) diff --git a/transports/quic/src/transport.rs b/transports/quic/src/transport.rs index 9f66fe49724..dea01c74685 100644 --- a/transports/quic/src/transport.rs +++ b/transports/quic/src/transport.rs @@ -71,6 +71,7 @@ pub struct GenTransport { listeners: SelectAll>, /// Dialer for each socket family if no matching listener exists. dialer: HashMap, + dialer_waker: Option, } impl GenTransport

{ @@ -84,6 +85,7 @@ impl GenTransport

{ quinn_config, handshake_timeout, dialer: HashMap::new(), + dialer_waker: None, support_draft_29, } } @@ -178,6 +180,12 @@ impl Transport for GenTransport

{ &mut listeners[index].dialer_state } }; + + // Wakeup the task polling [`Transport::poll`] to drive the new dial. + if let Some(waker) = self.dialer_waker.take() { + waker.wake(); + } + Ok(dialer_state.new_dial(socket_addr, self.handshake_timeout, version)) } @@ -207,10 +215,14 @@ impl Transport for GenTransport

{ // Drop dialer and all pending dials so that the connection receiver is notified. self.dialer.remove(&key); } - match self.listeners.poll_next_unpin(cx) { - Poll::Ready(Some(ev)) => Poll::Ready(ev), - _ => Poll::Pending, + + if let Poll::Ready(Some(ev)) = self.listeners.poll_next_unpin(cx) { + return Poll::Ready(ev); } + + self.dialer_waker = Some(cx.waker().clone()); + + Poll::Pending } } @@ -254,12 +266,11 @@ impl Drop for Dialer { } } -/// Pending dials to be sent to the endpoint was the [`endpoint::Channel`] -/// has capacity +/// Pending dials to be sent to the endpoint once the [`endpoint::Channel`] +/// has capacity. #[derive(Default, Debug)] struct DialerState { pending_dials: VecDeque, - waker: Option, } impl DialerState { @@ -279,10 +290,6 @@ impl DialerState { self.pending_dials.push_back(message); - if let Some(waker) = self.waker.take() { - waker.wake(); - } - async move { // Our oneshot getting dropped means the message didn't make it to the endpoint driver. let connection = tx.await.map_err(|_| Error::EndpointDriverCrashed)??; @@ -307,7 +314,6 @@ impl DialerState { Err(endpoint::Disconnected {}) => return Poll::Ready(Error::EndpointDriverCrashed), } } - self.waker = Some(cx.waker().clone()); Poll::Pending } } diff --git a/transports/quic/tests/smoke.rs b/transports/quic/tests/smoke.rs index a147864528c..649aca09b26 100644 --- a/transports/quic/tests/smoke.rs +++ b/transports/quic/tests/smoke.rs @@ -1,12 +1,15 @@ #![cfg(any(feature = "async-std", feature = "tokio"))] use futures::channel::{mpsc, oneshot}; +use futures::future::BoxFuture; use futures::future::{poll_fn, Either}; use futures::stream::StreamExt; use futures::{future, AsyncReadExt, AsyncWriteExt, FutureExt, SinkExt}; +use futures_timer::Delay; use libp2p_core::either::EitherOutput; use libp2p_core::muxing::{StreamMuxerBox, StreamMuxerExt, SubstreamBox}; use libp2p_core::transport::{Boxed, OrTransport, TransportEvent}; +use libp2p_core::transport::{ListenerId, TransportError}; use libp2p_core::{multiaddr::Protocol, upgrade, Multiaddr, PeerId, Transport}; use libp2p_noise as noise; use libp2p_quic as quic; @@ -19,6 +22,10 @@ use std::io; use std::num::NonZeroU8; use std::task::Poll; use std::time::Duration; +use std::{ + pin::Pin, + sync::{Arc, Mutex}, +}; #[cfg(feature = "tokio")] #[tokio::test] @@ -90,6 +97,113 @@ async fn ipv4_dial_ipv6() { assert_eq!(b_connected, a_peer_id); } +/// Tests that a [`Transport::dial`] wakes up the task previously polling [`Transport::poll`]. +/// +/// See https://github.com/libp2p/rust-libp2p/pull/3306 for context. +#[cfg(feature = "async-std")] +#[async_std::test] +async fn wrapped_with_dns() { + let _ = env_logger::try_init(); + + struct DialDelay(Arc>>); + + impl Transport for DialDelay { + type Output = (PeerId, StreamMuxerBox); + type Error = std::io::Error; + type ListenerUpgrade = Pin> + Send>>; + type Dial = BoxFuture<'static, Result>; + + fn listen_on( + &mut self, + addr: Multiaddr, + ) -> Result> { + self.0.lock().unwrap().listen_on(addr) + } + + fn remove_listener(&mut self, id: ListenerId) -> bool { + self.0.lock().unwrap().remove_listener(id) + } + + fn address_translation( + &self, + listen: &Multiaddr, + observed: &Multiaddr, + ) -> Option { + self.0.lock().unwrap().address_translation(listen, observed) + } + + /// Delayed dial, i.e. calling [`Transport::dial`] on the inner [`Transport`] not within the + /// synchronous [`Transport::dial`] method, but within the [`Future`] returned by the outer + /// [`Transport::dial`]. + fn dial(&mut self, addr: Multiaddr) -> Result> { + let t = self.0.clone(); + Ok(async move { + // Simulate DNS lookup. Giving the `Transport::poll` the chance to return + // `Poll::Pending` and thus suspending its task, waiting for a wakeup from the dial + // on the inner transport below. + Delay::new(Duration::from_millis(100)).await; + + let dial = t.lock().unwrap().dial(addr).map_err(|e| match e { + TransportError::MultiaddrNotSupported(_) => { + panic!() + } + TransportError::Other(e) => e, + })?; + dial.await + } + .boxed()) + } + + fn dial_as_listener( + &mut self, + addr: Multiaddr, + ) -> Result> { + self.0.lock().unwrap().dial_as_listener(addr) + } + + fn poll( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + Pin::new(&mut *self.0.lock().unwrap()).poll(cx) + } + } + + let (a_peer_id, mut a_transport) = create_default_transport::(); + let (b_peer_id, mut b_transport) = { + let (id, transport) = create_default_transport::(); + (id, DialDelay(Arc::new(Mutex::new(transport))).boxed()) + }; + + // Spawn a + let a_addr = start_listening(&mut a_transport, "/ip6/::1/udp/0/quic-v1").await; + let listener = async_std::task::spawn(async move { + let (upgrade, _) = a_transport + .select_next_some() + .await + .into_incoming() + .unwrap(); + let (peer_id, _) = upgrade.await.unwrap(); + + peer_id + }); + + // Spawn b + // + // Note that the dial is spawned on a different task than the transport allowing the transport + // task to poll the transport once and then suspend, waiting for the wakeup from the dial. + let dial = async_std::task::spawn({ + let dial = b_transport.dial(a_addr).unwrap(); + async { dial.await.unwrap().0 } + }); + async_std::task::spawn(async move { b_transport.next().await }); + + let (a_connected, b_connected) = future::join(listener, dial).await; + + assert_eq!(a_connected, b_peer_id); + assert_eq!(b_connected, a_peer_id); +} + #[cfg(feature = "async-std")] #[async_std::test] #[ignore] // Transport currently does not validate PeerId. Enable once we make use of PeerId validation in rustls.