diff --git a/transports/quic/CHANGELOG.md b/transports/quic/CHANGELOG.md index b6a47a93cf2..0db42d12c90 100644 --- a/transports/quic/CHANGELOG.md +++ b/transports/quic/CHANGELOG.md @@ -7,7 +7,10 @@ - Add opt-in support for the `/quic` codepoint, interpreted as QUIC version draft-29. See [PR 3151]. +- Wake the transport's task when a new dialer or listener is added. See [3342]. + [PR 3151]: https://github.com/libp2p/rust-libp2p/pull/3151 +[PR 3342]: https://github.com/libp2p/rust-libp2p/pull/3342 # 0.7.0-alpha diff --git a/transports/quic/src/transport.rs b/transports/quic/src/transport.rs index 9f66fe49724..30dba0909c3 100644 --- a/transports/quic/src/transport.rs +++ b/transports/quic/src/transport.rs @@ -71,6 +71,8 @@ pub struct GenTransport { listeners: SelectAll>, /// Dialer for each socket family if no matching listener exists. dialer: HashMap, + /// Waker to poll the transport again when a new dialer or listener is added. + waker: Option, } impl GenTransport

{ @@ -84,6 +86,7 @@ impl GenTransport

{ quinn_config, handshake_timeout, dialer: HashMap::new(), + waker: None, support_draft_29, } } @@ -108,6 +111,10 @@ impl Transport for GenTransport

{ )?; self.listeners.push(listener); + if let Some(waker) = self.waker.take() { + waker.wake(); + } + // Remove dialer endpoint so that the endpoint is dropped once the last // connection that uses it is closed. // New outbound connections will use the bidirectional (listener) endpoint. @@ -163,6 +170,9 @@ impl Transport for GenTransport

{ let dialer = match self.dialer.entry(socket_family) { Entry::Occupied(occupied) => occupied.into_mut(), Entry::Vacant(vacant) => { + if let Some(waker) = self.waker.take() { + waker.wake(); + } vacant.insert(Dialer::new::

(self.quinn_config.clone(), socket_family)?) } }; @@ -202,15 +212,19 @@ impl Transport for GenTransport

{ errored.push(*key); } } + for key in errored { // Endpoint driver of dialer crashed. // Drop dialer and all pending dials so that the connection receiver is notified. self.dialer.remove(&key); } - match self.listeners.poll_next_unpin(cx) { - Poll::Ready(Some(ev)) => Poll::Ready(ev), - _ => Poll::Pending, + + if let Poll::Ready(Some(ev)) = self.listeners.poll_next_unpin(cx) { + return Poll::Ready(ev); } + + self.waker = Some(cx.waker().clone()); + Poll::Pending } } diff --git a/transports/quic/tests/smoke.rs b/transports/quic/tests/smoke.rs index 7950bbdc2c5..41e7f6ac039 100644 --- a/transports/quic/tests/smoke.rs +++ b/transports/quic/tests/smoke.rs @@ -1,11 +1,14 @@ #![cfg(any(feature = "async-std", feature = "tokio"))] use futures::channel::{mpsc, oneshot}; +use futures::future::BoxFuture; use futures::future::{poll_fn, Either}; use futures::stream::StreamExt; use futures::{future, AsyncReadExt, AsyncWriteExt, FutureExt, SinkExt}; +use futures_timer::Delay; use libp2p_core::muxing::{StreamMuxerBox, StreamMuxerExt, SubstreamBox}; use libp2p_core::transport::{Boxed, OrTransport, TransportEvent}; +use libp2p_core::transport::{ListenerId, TransportError}; use libp2p_core::{multiaddr::Protocol, upgrade, Multiaddr, PeerId, Transport}; use libp2p_noise as noise; use libp2p_quic as quic; @@ -18,6 +21,10 @@ use std::io; use std::num::NonZeroU8; use std::task::Poll; use std::time::Duration; +use std::{ + pin::Pin, + sync::{Arc, Mutex}, +}; #[cfg(feature = "tokio")] #[tokio::test] @@ -89,6 +96,113 @@ async fn ipv4_dial_ipv6() { assert_eq!(b_connected, a_peer_id); } +/// Tests that a [`Transport::dial`] wakes up the task previously polling [`Transport::poll`]. +/// +/// See https://github.com/libp2p/rust-libp2p/pull/3306 for context. +#[cfg(feature = "async-std")] +#[async_std::test] +async fn wrapped_with_delay() { + let _ = env_logger::try_init(); + + struct DialDelay(Arc>>); + + impl Transport for DialDelay { + type Output = (PeerId, StreamMuxerBox); + type Error = std::io::Error; + type ListenerUpgrade = Pin> + Send>>; + type Dial = BoxFuture<'static, Result>; + + fn listen_on( + &mut self, + addr: Multiaddr, + ) -> Result> { + self.0.lock().unwrap().listen_on(addr) + } + + fn remove_listener(&mut self, id: ListenerId) -> bool { + self.0.lock().unwrap().remove_listener(id) + } + + fn address_translation( + &self, + listen: &Multiaddr, + observed: &Multiaddr, + ) -> Option { + self.0.lock().unwrap().address_translation(listen, observed) + } + + /// Delayed dial, i.e. calling [`Transport::dial`] on the inner [`Transport`] not within the + /// synchronous [`Transport::dial`] method, but within the [`Future`] returned by the outer + /// [`Transport::dial`]. + fn dial(&mut self, addr: Multiaddr) -> Result> { + let t = self.0.clone(); + Ok(async move { + // Simulate DNS lookup. Giving the `Transport::poll` the chance to return + // `Poll::Pending` and thus suspending its task, waiting for a wakeup from the dial + // on the inner transport below. + Delay::new(Duration::from_millis(100)).await; + + let dial = t.lock().unwrap().dial(addr).map_err(|e| match e { + TransportError::MultiaddrNotSupported(_) => { + panic!() + } + TransportError::Other(e) => e, + })?; + dial.await + } + .boxed()) + } + + fn dial_as_listener( + &mut self, + addr: Multiaddr, + ) -> Result> { + self.0.lock().unwrap().dial_as_listener(addr) + } + + fn poll( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + Pin::new(&mut *self.0.lock().unwrap()).poll(cx) + } + } + + let (a_peer_id, mut a_transport) = create_default_transport::(); + let (b_peer_id, mut b_transport) = { + let (id, transport) = create_default_transport::(); + (id, DialDelay(Arc::new(Mutex::new(transport))).boxed()) + }; + + // Spawn A + let a_addr = start_listening(&mut a_transport, "/ip6/::1/udp/0/quic-v1").await; + let listener = async_std::task::spawn(async move { + let (upgrade, _) = a_transport + .select_next_some() + .await + .into_incoming() + .unwrap(); + let (peer_id, _) = upgrade.await.unwrap(); + + peer_id + }); + + // Spawn B + // + // Note that the dial is spawned on a different task than the transport allowing the transport + // task to poll the transport once and then suspend, waiting for the wakeup from the dial. + let dial = async_std::task::spawn({ + let dial = b_transport.dial(a_addr).unwrap(); + async { dial.await.unwrap().0 } + }); + async_std::task::spawn(async move { b_transport.next().await }); + + let (a_connected, b_connected) = future::join(listener, dial).await; + + assert_eq!(a_connected, b_peer_id); + assert_eq!(b_connected, a_peer_id); +} + #[cfg(feature = "async-std")] #[async_std::test] #[ignore] // Transport currently does not validate PeerId. Enable once we make use of PeerId validation in rustls.