diff --git a/protocols/dcutr/src/handler/relayed.rs b/protocols/dcutr/src/handler/relayed.rs index ff22f2b18e1..939650f1cb7 100644 --- a/protocols/dcutr/src/handler/relayed.rs +++ b/protocols/dcutr/src/handler/relayed.rs @@ -257,10 +257,6 @@ impl ConnectionHandler for Handler { return KeepAlive::Yes; } - if self.inbound_connect.is_some() { - return KeepAlive::Yes; - } - if self.attempts < MAX_NUMBER_OF_UPGRADE_ATTEMPTS { return KeepAlive::Yes; } diff --git a/protocols/gossipsub/src/behaviour.rs b/protocols/gossipsub/src/behaviour.rs index 69fa36b002f..08e41edc21a 100644 --- a/protocols/gossipsub/src/behaviour.rs +++ b/protocols/gossipsub/src/behaviour.rs @@ -3265,7 +3265,6 @@ where type ConnectionHandler = Handler; type ToSwarm = Event; - #[allow(deprecated)] fn handle_established_inbound_connection( &mut self, _: ConnectionId, @@ -3276,7 +3275,6 @@ where Ok(Handler::new(self.config.protocol_config())) } - #[allow(deprecated)] fn handle_established_outbound_connection( &mut self, _: ConnectionId, diff --git a/protocols/gossipsub/src/handler.rs b/protocols/gossipsub/src/handler.rs index 1a50ef88fd5..2e3e986e29f 100644 --- a/protocols/gossipsub/src/handler.rs +++ b/protocols/gossipsub/src/handler.rs @@ -431,15 +431,6 @@ impl ConnectionHandler for Handler { return KeepAlive::Yes; } - if let Some( - OutboundSubstreamState::PendingSend(_, _) - | OutboundSubstreamState::PendingFlush(_), - ) = handler.outbound_substream - { - return KeepAlive::Yes; - } - - #[allow(deprecated)] KeepAlive::No } Handler::Disabled(_) => KeepAlive::No, diff --git a/protocols/identify/src/handler.rs b/protocols/identify/src/handler.rs index b71f98e8509..33bc686595b 100644 --- a/protocols/identify/src/handler.rs +++ b/protocols/identify/src/handler.rs @@ -315,10 +315,6 @@ impl ConnectionHandler for Handler { } fn connection_keep_alive(&self) -> KeepAlive { - if !self.active_streams.is_empty() { - return KeepAlive::Yes; - } - KeepAlive::No } diff --git a/protocols/kad/src/behaviour/test.rs b/protocols/kad/src/behaviour/test.rs index 7987e4833f1..debb9e567bc 100644 --- a/protocols/kad/src/behaviour/test.rs +++ b/protocols/kad/src/behaviour/test.rs @@ -73,7 +73,7 @@ fn build_node_with_config(cfg: Config) -> (Multiaddr, TestSwarm) { behaviour, local_id, swarm::Config::with_async_std_executor() - .with_idle_connection_timeout(Duration::from_secs(5)), + .with_idle_connection_timeout(Duration::from_secs(10)), ); let address: Multiaddr = Protocol::Memory(random::()).into(); diff --git a/protocols/kad/src/handler.rs b/protocols/kad/src/handler.rs index f994c45a6f5..9c961946824 100644 --- a/protocols/kad/src/handler.rs +++ b/protocols/kad/src/handler.rs @@ -703,11 +703,7 @@ impl ConnectionHandler for Handler { } fn connection_keep_alive(&self) -> KeepAlive { - if self.outbound_substreams.is_empty() && self.inbound_substreams.is_empty() { - return KeepAlive::No; - }; - - KeepAlive::Yes + KeepAlive::No } fn poll( diff --git a/protocols/ping/src/handler.rs b/protocols/ping/src/handler.rs index 522663196e6..73a801c6fec 100644 --- a/protocols/ping/src/handler.rs +++ b/protocols/ping/src/handler.rs @@ -261,9 +261,10 @@ impl ConnectionHandler for Handler { log::debug!("Inbound ping error: {:?}", e); self.inbound = None; } - Poll::Ready(Ok(stream)) => { + Poll::Ready(Ok(mut stream)) => { log::trace!("answered inbound ping from {}", self.peer); + stream.no_keep_alive(); // A ping from a remote peer has been answered, wait for the next. self.inbound = Some(protocol::recv_ping(stream).boxed()); } @@ -294,9 +295,10 @@ impl ConnectionHandler for Handler { self.outbound = Some(OutboundState::Ping(ping)); break; } - Poll::Ready(Ok((stream, rtt))) => { + Poll::Ready(Ok((mut stream, rtt))) => { log::debug!("latency to {} is {}ms", self.peer, rtt.as_millis()); + stream.no_keep_alive(); self.failures = 0; self.interval.reset(self.config.interval); self.outbound = Some(OutboundState::Idle(stream)); @@ -307,12 +309,14 @@ impl ConnectionHandler for Handler { self.pending_errors.push_front(e); } }, - Some(OutboundState::Idle(stream)) => match self.interval.poll_unpin(cx) { + Some(OutboundState::Idle(mut stream)) => match self.interval.poll_unpin(cx) { Poll::Pending => { + stream.no_keep_alive(); self.outbound = Some(OutboundState::Idle(stream)); break; } Poll::Ready(()) => { + stream.no_keep_alive(); self.outbound = Some(OutboundState::Ping( send_ping(stream, self.config.timeout).boxed(), )); diff --git a/protocols/relay/src/behaviour/handler.rs b/protocols/relay/src/behaviour/handler.rs index 6fb0a834d2f..f5ea541ba32 100644 --- a/protocols/relay/src/behaviour/handler.rs +++ b/protocols/relay/src/behaviour/handler.rs @@ -376,10 +376,6 @@ pub struct Handler { /// /// Contains a [`futures::future::Future`] for each lend out substream that /// resolves once the substream is dropped. - /// - /// Once all substreams are dropped and this handler has no other work, - /// [`KeepAlive::Until`] can be set, allowing the connection to be closed - /// eventually. alive_lend_out_substreams: FuturesUnordered>, /// Futures relaying data for circuit between two peers. circuits: Futures<(CircuitId, PeerId, Result<(), std::io::Error>)>, @@ -881,13 +877,7 @@ impl ConnectionHandler for Handler { {} // Check keep alive status. - if self.reservation_request_future.is_none() - && self.circuit_accept_futures.is_empty() - && self.circuit_deny_futures.is_empty() - && self.alive_lend_out_substreams.is_empty() - && self.circuits.is_empty() - && self.active_reservation.is_none() - { + if self.active_reservation.is_none() { if self.idle_at.is_none() { self.idle_at = Some(Instant::now()); } diff --git a/protocols/relay/src/priv_client/handler.rs b/protocols/relay/src/priv_client/handler.rs index b2effdbde56..0a0ded63d2e 100644 --- a/protocols/relay/src/priv_client/handler.rs +++ b/protocols/relay/src/priv_client/handler.rs @@ -324,22 +324,6 @@ impl ConnectionHandler for Handler { return KeepAlive::Yes; } - if !self.alive_lend_out_substreams.is_empty() { - return KeepAlive::Yes; - } - - if !self.circuit_deny_futs.is_empty() { - return KeepAlive::Yes; - } - - if !self.open_circuit_futs.is_empty() { - return KeepAlive::Yes; - } - - if !self.outbound_circuits.is_empty() { - return KeepAlive::Yes; - } - KeepAlive::No } diff --git a/protocols/request-response/src/handler.rs b/protocols/request-response/src/handler.rs index 69929b77873..9103c346fd6 100644 --- a/protocols/request-response/src/handler.rs +++ b/protocols/request-response/src/handler.rs @@ -59,8 +59,6 @@ where /// The timeout for inbound and outbound substreams (i.e. request /// and response processing). substream_timeout: Duration, - /// The current connection keep-alive. - keep_alive: KeepAlive, /// Queue of events to emit in `poll()`. pending_events: VecDeque>, /// Outbound upgrades waiting to be emitted as an `OutboundSubstreamRequest`. @@ -94,7 +92,6 @@ where Self { inbound_protocols, codec, - keep_alive: KeepAlive::Yes, substream_timeout, outbound: VecDeque::new(), inbound: FuturesUnordered::new(), @@ -274,12 +271,11 @@ where } fn on_behaviour_event(&mut self, request: Self::FromBehaviour) { - self.keep_alive = KeepAlive::Yes; self.outbound.push_back(request); } fn connection_keep_alive(&self) -> KeepAlive { - self.keep_alive + KeepAlive::No } fn poll( @@ -300,7 +296,6 @@ where match result { Ok(((id, rq), rs_sender)) => { // We received an inbound request. - self.keep_alive = KeepAlive::Yes; return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(Event::Request { request_id: id, request: rq, @@ -330,13 +325,6 @@ where self.outbound.shrink_to_fit(); } - if self.inbound.is_empty() && self.keep_alive.is_yes() { - // No new inbound or outbound requests. We already check - // there is no active streams exist in swarm connection, - // so we can set keep-alive to no directly. - self.keep_alive = KeepAlive::No; - } - Poll::Pending } diff --git a/swarm/CHANGELOG.md b/swarm/CHANGELOG.md index 53a6220f8d4..e1fac13bcba 100644 --- a/swarm/CHANGELOG.md +++ b/swarm/CHANGELOG.md @@ -6,6 +6,8 @@ See [PR 4225](https://github.com/libp2p/rust-libp2p/pull/4225). - Remove deprecated `keep_alive_timeout` in `OneShotHandlerConfig`. See [PR 4677](https://github.com/libp2p/rust-libp2p/pull/4677). +- Add `ConnectionHandler::connection_keep_alive` default implementation that returns `KeepAlive::No`. + See [PR 4703](https://github.com/libp2p/rust-libp2p/pull/4703). ## 0.43.6 diff --git a/swarm/src/connection.rs b/swarm/src/connection.rs index a9c56c80d63..0f25d5d6d25 100644 --- a/swarm/src/connection.rs +++ b/swarm/src/connection.rs @@ -52,17 +52,20 @@ use libp2p_core::upgrade; use libp2p_core::upgrade::{NegotiationError, ProtocolError}; use libp2p_core::Endpoint; use libp2p_identity::PeerId; -use std::cmp::max; use std::collections::HashSet; use std::fmt::{Display, Formatter}; use std::future::Future; use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; use std::task::Waker; use std::time::Duration; use std::{fmt, io, mem, pin::Pin, task::Context, task::Poll}; static NEXT_CONNECTION_ID: AtomicUsize = AtomicUsize::new(1); +/// Counter of the number of active streams on a connection +type ActiveStreamCounter = Arc<()>; + /// Connection identifier. #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] pub struct ConnectionId(usize); @@ -157,6 +160,8 @@ where local_supported_protocols: HashSet, remote_supported_protocols: HashSet, idle_timeout: Duration, + /// The counter of active streams + stream_counter: ActiveStreamCounter, } impl fmt::Debug for Connection @@ -205,6 +210,7 @@ where local_supported_protocols: initial_protocols, remote_supported_protocols: Default::default(), idle_timeout, + stream_counter: Arc::new(()), } } @@ -237,6 +243,7 @@ where local_supported_protocols: supported_protocols, remote_supported_protocols, idle_timeout, + stream_counter, } = self.get_mut(); loop { @@ -344,19 +351,19 @@ where } } - // Compute new shutdown - if let Some(new_shutdown) = - compute_new_shutdown(handler.connection_keep_alive(), shutdown, *idle_timeout) - { - *shutdown = new_shutdown; - } - // Check if the connection (and handler) should be shut down. - // As long as we're still negotiating substreams, shutdown is always postponed. + // As long as we're still negotiating substreams or have any active streams shutdown is always postponed. if negotiating_in.is_empty() && negotiating_out.is_empty() && requested_substreams.is_empty() + && Arc::strong_count(stream_counter) == 1 { + if let Some(new_timeout) = + compute_new_shutdown(handler.connection_keep_alive(), shutdown, *idle_timeout) + { + *shutdown = new_timeout; + } + match shutdown { Shutdown::None => {} Shutdown::Asap => return Poll::Ready(Err(ConnectionError::KeepAliveTimeout)), @@ -391,6 +398,7 @@ where timeout, upgrade, *substream_upgrade_protocol_override, + stream_counter.clone(), )); continue; // Go back to the top, handler can potentially make progress again. @@ -404,7 +412,11 @@ where Poll::Ready(substream) => { let protocol = handler.listen_protocol(); - negotiating_in.push(StreamUpgrade::new_inbound(substream, protocol)); + negotiating_in.push(StreamUpgrade::new_inbound( + substream, + protocol, + stream_counter.clone(), + )); continue; // Go back to the top, handler can potentially make progress again. } @@ -450,44 +462,9 @@ fn compute_new_shutdown( current_shutdown: &Shutdown, idle_timeout: Duration, ) -> Option { - #[allow(deprecated)] match (current_shutdown, handler_keep_alive) { - (Shutdown::Later(_, deadline), KeepAlive::Until(t)) => { - let now = Instant::now(); - - if *deadline != t { - let deadline = t; - if let Some(new_duration) = deadline.checked_duration_since(Instant::now()) { - let effective_keep_alive = max(new_duration, idle_timeout); - - let safe_keep_alive = checked_add_fraction(now, effective_keep_alive); - return Some(Shutdown::Later(Delay::new(safe_keep_alive), deadline)); - } - } - None - } - (_, KeepAlive::Until(earliest_shutdown)) => { - let now = Instant::now(); - - if let Some(requested) = earliest_shutdown.checked_duration_since(now) { - let effective_keep_alive = max(requested, idle_timeout); - - let safe_keep_alive = checked_add_fraction(now, effective_keep_alive); - - // Important: We store the _original_ `Instant` given by the `ConnectionHandler` in the `Later` instance to ensure we can compare it in the above branch. - // This is quite subtle but will hopefully become simpler soon once `KeepAlive::Until` is fully deprecated. See / - return Some(Shutdown::Later( - Delay::new(safe_keep_alive), - earliest_shutdown, - )); - } - None - } (_, KeepAlive::No) if idle_timeout == Duration::ZERO => Some(Shutdown::Asap), - (Shutdown::Later(_, _), KeepAlive::No) => { - // Do nothing, i.e. let the shutdown timer continue to tick. - None - } + (Shutdown::Later(_, _), KeepAlive::No) => None, // Do nothing, i.e. let the shutdown timer continue to tick. (_, KeepAlive::No) => { let now = Instant::now(); let safe_keep_alive = checked_add_fraction(now, idle_timeout); @@ -547,6 +524,7 @@ impl StreamUpgrade { timeout: Delay, upgrade: Upgrade, version_override: Option, + counter: Arc<()>, ) -> Self where Upgrade: OutboundUpgradeSend, @@ -578,7 +556,7 @@ impl StreamUpgrade { .map_err(to_stream_upgrade_error)?; let output = upgrade - .upgrade_outbound(Stream::new(stream), info) + .upgrade_outbound(Stream::new(stream, counter), info) .await .map_err(StreamUpgradeError::Apply)?; @@ -592,6 +570,7 @@ impl StreamUpgrade { fn new_inbound( substream: SubstreamBox, protocol: SubstreamProtocol, + counter: Arc<()>, ) -> Self where Upgrade: InboundUpgradeSend, @@ -610,7 +589,7 @@ impl StreamUpgrade { .map_err(to_stream_upgrade_error)?; let output = upgrade - .upgrade_inbound(Stream::new(stream), info) + .upgrade_inbound(Stream::new(stream, counter), info) .await .map_err(StreamUpgradeError::Apply)?; @@ -933,68 +912,6 @@ mod tests { )); } - #[tokio::test] - async fn idle_timeout_with_keep_alive_until_greater_than_idle_timeout() { - let idle_timeout = Duration::from_millis(100); - - let mut connection = Connection::new( - StreamMuxerBox::new(PendingStreamMuxer), - KeepAliveUntilConnectionHandler { - until: Instant::now() + idle_timeout * 2, - }, - None, - 0, - idle_timeout, - ); - - assert!(connection.poll_noop_waker().is_pending()); - - tokio::time::sleep(idle_timeout).await; - - assert!( - connection.poll_noop_waker().is_pending(), - "`KeepAlive::Until` is greater than idle-timeout, continue sleeping" - ); - - tokio::time::sleep(idle_timeout).await; - - assert!(matches!( - connection.poll_noop_waker(), - Poll::Ready(Err(ConnectionError::KeepAliveTimeout)) - )); - } - - #[tokio::test] - async fn idle_timeout_with_keep_alive_until_less_than_idle_timeout() { - let idle_timeout = Duration::from_millis(100); - - let mut connection = Connection::new( - StreamMuxerBox::new(PendingStreamMuxer), - KeepAliveUntilConnectionHandler { - until: Instant::now() + idle_timeout / 2, - }, - None, - 0, - idle_timeout, - ); - - assert!(connection.poll_noop_waker().is_pending()); - - tokio::time::sleep(idle_timeout / 2).await; - - assert!( - connection.poll_noop_waker().is_pending(), - "`KeepAlive::Until` is less than idle-timeout, honor idle-timeout" - ); - - tokio::time::sleep(idle_timeout / 2).await; - - assert!(matches!( - connection.poll_noop_waker(), - Poll::Ready(Err(ConnectionError::KeepAliveTimeout)) - )); - } - #[test] fn checked_add_fraction_can_add_u64_max() { let _ = env_logger::try_init(); @@ -1058,58 +975,6 @@ mod tests { QuickCheck::new().quickcheck(prop as fn(_, _, _)); } - struct KeepAliveUntilConnectionHandler { - until: Instant, - } - - impl ConnectionHandler for KeepAliveUntilConnectionHandler { - type FromBehaviour = Void; - type ToBehaviour = Void; - type Error = Void; - type InboundProtocol = DeniedUpgrade; - type OutboundProtocol = DeniedUpgrade; - type InboundOpenInfo = (); - type OutboundOpenInfo = Void; - - fn listen_protocol( - &self, - ) -> SubstreamProtocol { - SubstreamProtocol::new(DeniedUpgrade, ()) - } - - fn connection_keep_alive(&self) -> KeepAlive { - #[allow(deprecated)] - KeepAlive::Until(self.until) - } - - fn poll( - &mut self, - _: &mut Context<'_>, - ) -> Poll< - ConnectionHandlerEvent< - Self::OutboundProtocol, - Self::OutboundOpenInfo, - Self::ToBehaviour, - Self::Error, - >, - > { - Poll::Pending - } - - fn on_behaviour_event(&mut self, _: Self::FromBehaviour) {} - - fn on_connection_event( - &mut self, - _: ConnectionEvent< - Self::InboundProtocol, - Self::OutboundProtocol, - Self::InboundOpenInfo, - Self::OutboundOpenInfo, - >, - ) { - } - } - struct DummyStreamMuxer { counter: Arc<()>, } diff --git a/swarm/src/handler.rs b/swarm/src/handler.rs index 02eb9f83935..8337ecdbe3b 100644 --- a/swarm/src/handler.rs +++ b/swarm/src/handler.rs @@ -55,7 +55,6 @@ pub use select::ConnectionHandlerSelect; use crate::StreamProtocol; use ::either::Either; -use instant::Instant; use libp2p_core::Multiaddr; use once_cell::sync::Lazy; use smallvec::SmallVec; @@ -125,17 +124,13 @@ pub trait ConnectionHandler: Send + 'static { /// Returns until when the connection should be kept alive. /// - /// This method is called by the `Swarm` after each invocation of - /// [`ConnectionHandler::poll`] to determine if the connection and the associated - /// [`ConnectionHandler`]s should be kept alive as far as this handler is concerned - /// and if so, for how long. + /// This method is an optional implementation and can be called by the `Swarm` after + /// each invocation of [`ConnectionHandler::poll`] to determine if the connection + /// and the associated [`ConnectionHandler`]s should be kept alive. /// /// Returning [`KeepAlive::No`] indicates that the connection should be /// closed and this handler destroyed immediately. /// - /// Returning [`KeepAlive::Until`] indicates that the connection may be closed - /// and this handler destroyed after the specified `Instant`. - /// /// Returning [`KeepAlive::Yes`] indicates that the connection should /// be kept alive until the next call to this method. /// @@ -143,7 +138,9 @@ pub trait ConnectionHandler: Send + 'static { /// > when [`ConnectionHandler::poll`] returns an error. Furthermore, the /// > connection may be closed for reasons outside of the control /// > of the handler. - fn connection_keep_alive(&self) -> KeepAlive; + fn connection_keep_alive(&self) -> KeepAlive { + KeepAlive::No + } /// Should behave like `Stream::poll()`. fn poll( @@ -727,11 +724,6 @@ where /// How long the connection should be kept alive. #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum KeepAlive { - /// If nothing new happens, the connection should be closed at the given `Instant`. - #[deprecated( - note = "Use `swarm::Config::with_idle_connection_timeout` instead. See for details." - )] - Until(Instant), /// Keep the connection alive. Yes, /// Close the connection as soon as possible. @@ -751,16 +743,14 @@ impl PartialOrd for KeepAlive { } } -#[allow(deprecated)] impl Ord for KeepAlive { fn cmp(&self, other: &KeepAlive) -> Ordering { use self::KeepAlive::*; match (self, other) { (No, No) | (Yes, Yes) => Ordering::Equal, - (No, _) | (_, Yes) => Ordering::Less, - (_, No) | (Yes, _) => Ordering::Greater, - (Until(t1), Until(t2)) => t1.cmp(t2), + (Yes, No) => Ordering::Less, + (No, Yes) => Ordering::Greater, } } } @@ -768,18 +758,9 @@ impl Ord for KeepAlive { #[cfg(test)] impl quickcheck::Arbitrary for KeepAlive { fn arbitrary(g: &mut quickcheck::Gen) -> Self { - match quickcheck::GenRange::gen_range(g, 1u8..4) { - 1 => - { - #[allow(deprecated)] - KeepAlive::Until( - Instant::now() - .checked_add(Duration::arbitrary(g)) - .unwrap_or(Instant::now()), - ) - } - 2 => KeepAlive::Yes, - 3 => KeepAlive::No, + match quickcheck::GenRange::gen_range(g, 1u8..3) { + 1 => KeepAlive::Yes, + 2 => KeepAlive::No, _ => unreachable!(), } } diff --git a/swarm/src/handler/one_shot.rs b/swarm/src/handler/one_shot.rs index 7f422cfa7d0..660e09a192e 100644 --- a/swarm/src/handler/one_shot.rs +++ b/swarm/src/handler/one_shot.rs @@ -43,8 +43,6 @@ where dial_queue: SmallVec<[TOutbound; 4]>, /// Current number of concurrent outbound substreams being opened. dial_negotiated: u32, - /// Value to return from `connection_keep_alive`. - keep_alive: KeepAlive, /// The configuration container for the handler config: OneShotHandlerConfig, } @@ -64,7 +62,6 @@ where events_out: SmallVec::new(), dial_queue: SmallVec::new(), dial_negotiated: 0, - keep_alive: KeepAlive::Yes, config, } } @@ -92,7 +89,6 @@ where /// Opens an outbound substream with `upgrade`. pub fn send_request(&mut self, upgrade: TOutbound) { - self.keep_alive = KeepAlive::Yes; self.dial_queue.push(upgrade); } } @@ -137,7 +133,7 @@ where } fn connection_keep_alive(&self) -> KeepAlive { - self.keep_alive + KeepAlive::No } fn poll( @@ -174,10 +170,6 @@ where } } else { self.dial_queue.shrink_to_fit(); - - if self.dial_negotiated == 0 && self.keep_alive.is_yes() { - self.keep_alive = KeepAlive::No; - } } Poll::Pending @@ -209,7 +201,6 @@ where ConnectionEvent::DialUpgradeError(DialUpgradeError { error, .. }) => { if self.pending_error.is_none() { log::debug!("DialUpgradeError: {error}"); - self.keep_alive = KeepAlive::No; } } ConnectionEvent::AddressChange(_) @@ -230,7 +221,6 @@ pub struct OneShotHandlerConfig { } impl Default for OneShotHandlerConfig { - #[allow(deprecated)] fn default() -> Self { OneShotHandlerConfig { outbound_substream_timeout: Duration::from_secs(10), @@ -249,7 +239,6 @@ mod tests { use void::Void; #[test] - #[allow(deprecated)] fn do_not_keep_idle_connection_alive() { let mut handler: OneShotHandler<_, DeniedUpgrade, Void> = OneShotHandler::new( SubstreamProtocol::new(DeniedUpgrade {}, ()), diff --git a/swarm/src/stream.rs b/swarm/src/stream.rs index 3c4c52afc33..13e6588128f 100644 --- a/swarm/src/stream.rs +++ b/swarm/src/stream.rs @@ -1,16 +1,43 @@ use futures::{AsyncRead, AsyncWrite}; use libp2p_core::muxing::SubstreamBox; use libp2p_core::Negotiated; -use std::io::{IoSlice, IoSliceMut}; -use std::pin::Pin; -use std::task::{Context, Poll}; +use std::{ + io::{IoSlice, IoSliceMut}, + pin::Pin, + sync::{Arc, Weak}, + task::{Context, Poll}, +}; #[derive(Debug)] -pub struct Stream(Negotiated); +pub struct Stream { + stream: Negotiated, + counter: StreamCounter, +} + +#[derive(Debug)] +enum StreamCounter { + Arc(Arc<()>), + Weak(Weak<()>), +} impl Stream { - pub(crate) fn new(stream: Negotiated) -> Self { - Self(stream) + pub(crate) fn new(stream: Negotiated, counter: Arc<()>) -> Self { + let counter = StreamCounter::Arc(counter); + Self { stream, counter } + } + + /// Opt-out this stream from the [Swarm](crate::Swarm)s connection keep alive algorithm. + /// + /// By default, any active stream keeps a connection alive. For most protocols, + /// this is a good default as it ensures that the protocol is completed before + /// a connection is shut down. + /// Some protocols like libp2p's [ping](https://github.com/libp2p/specs/blob/master/ping/ping.md) + /// for example never complete and are of an auxiliary nature. + /// These protocols should opt-out of the keep alive algorithm using this method. + pub fn no_keep_alive(&mut self) { + if let StreamCounter::Arc(arc_counter) = &self.counter { + self.counter = StreamCounter::Weak(Arc::downgrade(arc_counter)); + } } } @@ -20,7 +47,7 @@ impl AsyncRead for Stream { cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll> { - Pin::new(&mut self.get_mut().0).poll_read(cx, buf) + Pin::new(&mut self.get_mut().stream).poll_read(cx, buf) } fn poll_read_vectored( @@ -28,7 +55,7 @@ impl AsyncRead for Stream { cx: &mut Context<'_>, bufs: &mut [IoSliceMut<'_>], ) -> Poll> { - Pin::new(&mut self.get_mut().0).poll_read_vectored(cx, bufs) + Pin::new(&mut self.get_mut().stream).poll_read_vectored(cx, bufs) } } @@ -38,7 +65,7 @@ impl AsyncWrite for Stream { cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - Pin::new(&mut self.get_mut().0).poll_write(cx, buf) + Pin::new(&mut self.get_mut().stream).poll_write(cx, buf) } fn poll_write_vectored( @@ -46,14 +73,14 @@ impl AsyncWrite for Stream { cx: &mut Context<'_>, bufs: &[IoSlice<'_>], ) -> Poll> { - Pin::new(&mut self.get_mut().0).poll_write_vectored(cx, bufs) + Pin::new(&mut self.get_mut().stream).poll_write_vectored(cx, bufs) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.get_mut().0).poll_flush(cx) + Pin::new(&mut self.get_mut().stream).poll_flush(cx) } fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.get_mut().0).poll_close(cx) + Pin::new(&mut self.get_mut().stream).poll_close(cx) } }