diff --git a/src/config.rs b/src/config.rs index 34d67ca0..c2956021 100644 --- a/src/config.rs +++ b/src/config.rs @@ -30,7 +30,7 @@ use crate::{ }, transport::{ manager::limits::ConnectionLimitsConfig, tcp::config::Config as TcpConfig, - MAX_PARALLEL_DIALS, + KEEP_ALIVE_TIMEOUT, MAX_PARALLEL_DIALS, }, types::protocol::ProtocolName, PeerId, @@ -45,7 +45,7 @@ use crate::transport::websocket::config::Config as WebSocketConfig; use multiaddr::Multiaddr; -use std::{collections::HashMap, sync::Arc}; +use std::{collections::HashMap, sync::Arc, time::Duration}; /// Connection role. #[derive(Debug, Copy, Clone)] @@ -121,6 +121,9 @@ pub struct ConfigBuilder { /// Connection limits config. connection_limits: ConnectionLimitsConfig, + + /// Close the connection if no substreams are open within this time frame. + keep_alive_timeout: Duration, } impl Default for ConfigBuilder { @@ -153,6 +156,7 @@ impl ConfigBuilder { request_response_protocols: HashMap::new(), known_addresses: Vec::new(), connection_limits: ConnectionLimitsConfig::default(), + keep_alive_timeout: KEEP_ALIVE_TIMEOUT, } } @@ -268,6 +272,12 @@ impl ConfigBuilder { self } + /// Set keep alive timeout for connections. + pub fn with_keep_alive_timeout(mut self, timeout: Duration) -> Self { + self.keep_alive_timeout = timeout; + self + } + /// Build [`Litep2pConfig`]. pub fn build(mut self) -> Litep2pConfig { let keypair = match self.keypair { @@ -296,6 +306,7 @@ impl ConfigBuilder { request_response_protocols: self.request_response_protocols, known_addresses: self.known_addresses, connection_limits: self.connection_limits, + keep_alive_timeout: self.keep_alive_timeout, } } } @@ -355,4 +366,7 @@ pub struct Litep2pConfig { /// Connection limits config. pub(crate) connection_limits: ConnectionLimitsConfig, + + /// Close the connection if no substreams are open within this time frame. + pub(crate) keep_alive_timeout: Duration, } diff --git a/src/lib.rs b/src/lib.rs index c63e465e..6222c241 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -171,6 +171,7 @@ impl Litep2p { protocol, config.fallback_names.clone(), config.codec, + litep2p_config.keep_alive_timeout, ); let executor = Arc::clone(&litep2p_config.executor); litep2p_config.executor.run(Box::pin(async move { @@ -190,6 +191,7 @@ impl Litep2p { protocol, config.fallback_names.clone(), config.codec, + litep2p_config.keep_alive_timeout, ); litep2p_config.executor.run(Box::pin(async move { RequestResponseProtocol::new(service, config).run().await @@ -200,8 +202,12 @@ impl Litep2p { for (protocol_name, protocol) in litep2p_config.user_protocols.into_iter() { tracing::debug!(target: LOG_TARGET, protocol = ?protocol_name, "enable user protocol"); - let service = - transport_manager.register_protocol(protocol_name, Vec::new(), protocol.codec()); + let service = transport_manager.register_protocol( + protocol_name, + Vec::new(), + protocol.codec(), + litep2p_config.keep_alive_timeout, + ); litep2p_config.executor.run(Box::pin(async move { let _ = protocol.run(service).await; })); @@ -219,6 +225,7 @@ impl Litep2p { ping_config.protocol.clone(), Vec::new(), ping_config.codec, + litep2p_config.keep_alive_timeout, ); litep2p_config.executor.run(Box::pin(async move { Ping::new(service, ping_config).run().await @@ -241,6 +248,7 @@ impl Litep2p { main_protocol.clone(), fallback_names, kademlia_config.codec, + litep2p_config.keep_alive_timeout, ); litep2p_config.executor.run(Box::pin(async move { let _ = Kademlia::new(service, kademlia_config).run().await; @@ -261,6 +269,7 @@ impl Litep2p { identify_config.protocol.clone(), Vec::new(), identify_config.codec, + litep2p_config.keep_alive_timeout, ); identify_config.public = Some(litep2p_config.keypair.public().into()); @@ -280,6 +289,7 @@ impl Litep2p { bitswap_config.protocol.clone(), Vec::new(), bitswap_config.codec, + litep2p_config.keep_alive_timeout, ); litep2p_config.executor.run(Box::pin(async move { Bitswap::new(service, bitswap_config).run().await diff --git a/src/protocol/libp2p/kademlia/mod.rs b/src/protocol/libp2p/kademlia/mod.rs index f4b3288c..9dc2c347 100644 --- a/src/protocol/libp2p/kademlia/mod.rs +++ b/src/protocol/libp2p/kademlia/mod.rs @@ -874,7 +874,10 @@ mod tests { use crate::{ codec::ProtocolCodec, crypto::ed25519::Keypair, - transport::manager::{limits::ConnectionLimitsConfig, TransportManager}, + transport::{ + manager::{limits::ConnectionLimitsConfig, TransportManager}, + KEEP_ALIVE_TIMEOUT, + }, types::protocol::ProtocolName, BandwidthSink, }; @@ -902,6 +905,7 @@ mod tests { Vec::new(), Default::default(), handle, + KEEP_ALIVE_TIMEOUT, ); let (event_tx, event_rx) = channel(64); let (_cmd_tx, cmd_rx) = channel(64); diff --git a/src/protocol/notification/tests/mod.rs b/src/protocol/notification/tests/mod.rs index 0b275502..4aa48aa4 100644 --- a/src/protocol/notification/tests/mod.rs +++ b/src/protocol/notification/tests/mod.rs @@ -29,7 +29,10 @@ use crate::{ }, InnerTransportEvent, ProtocolCommand, TransportService, }, - transport::manager::{limits::ConnectionLimitsConfig, TransportManager}, + transport::{ + manager::{limits::ConnectionLimitsConfig, TransportManager}, + KEEP_ALIVE_TIMEOUT, + }, types::protocol::ProtocolName, BandwidthSink, PeerId, }; @@ -63,6 +66,7 @@ fn make_notification_protocol() -> ( Vec::new(), std::sync::Arc::new(Default::default()), handle, + KEEP_ALIVE_TIMEOUT, ); let (config, handle) = NotificationConfig::new( ProtocolName::from("/notif/1"), diff --git a/src/protocol/request_response/tests.rs b/src/protocol/request_response/tests.rs index 7c57b4f9..9cb842f6 100644 --- a/src/protocol/request_response/tests.rs +++ b/src/protocol/request_response/tests.rs @@ -29,7 +29,10 @@ use crate::{ InnerTransportEvent, TransportService, }, substream::Substream, - transport::manager::{limits::ConnectionLimitsConfig, TransportManager}, + transport::{ + manager::{limits::ConnectionLimitsConfig, TransportManager}, + KEEP_ALIVE_TIMEOUT, + }, types::{RequestId, SubstreamId}, BandwidthSink, Error, PeerId, ProtocolName, }; @@ -61,6 +64,7 @@ fn protocol() -> ( Vec::new(), std::sync::Arc::new(Default::default()), handle, + KEEP_ALIVE_TIMEOUT, ); let (config, handle) = ConfigBuilder::new(ProtocolName::from("/req/1")).with_max_size(1024).build(); diff --git a/src/protocol/transport_service.rs b/src/protocol/transport_service.rs index 9302f42a..63d5a117 100644 --- a/src/protocol/transport_service.rs +++ b/src/protocol/transport_service.rs @@ -122,8 +122,11 @@ pub struct TransportService { /// Next substream ID. next_substream_id: Arc, + /// Close the connection if no substreams are open within this time frame. + keep_alive_timeout: Duration, + /// Pending keep-alive timeouts. - keep_alive_timeouts: FuturesUnordered>, + pending_keep_alive_timeouts: FuturesUnordered>, } impl TransportService { @@ -134,6 +137,7 @@ impl TransportService { fallback_names: Vec, next_substream_id: Arc, transport_handle: TransportManagerHandle, + keep_alive_timeout: Duration, ) -> (Self, Sender) { let (tx, rx) = channel(DEFAULT_CHANNEL_SIZE); @@ -146,7 +150,8 @@ impl TransportService { transport_handle, next_substream_id, connections: HashMap::new(), - keep_alive_timeouts: FuturesUnordered::new(), + keep_alive_timeout: keep_alive_timeout, + pending_keep_alive_timeouts: FuturesUnordered::new(), }, tx, ) @@ -168,6 +173,7 @@ impl TransportService { ?connection_id, "connection established", ); + let keep_alive_timeout = self.keep_alive_timeout; match self.connections.get_mut(&peer) { Some(context) => match context.secondary { @@ -182,8 +188,8 @@ impl TransportService { None } None => { - self.keep_alive_timeouts.push(Box::pin(async move { - tokio::time::sleep(Duration::from_secs(5)).await; + self.pending_keep_alive_timeouts.push(Box::pin(async move { + tokio::time::sleep(keep_alive_timeout).await; (peer, connection_id) })); context.secondary = Some(handle); @@ -193,8 +199,8 @@ impl TransportService { }, None => { self.connections.insert(peer, ConnectionContext::new(handle)); - self.keep_alive_timeouts.push(Box::pin(async move { - tokio::time::sleep(Duration::from_secs(5)).await; + self.pending_keep_alive_timeouts.push(Box::pin(async move { + tokio::time::sleep(keep_alive_timeout).await; (peer, connection_id) })); @@ -387,7 +393,7 @@ impl Stream for TransportService { } while let Poll::Ready(Some((peer, connection_id))) = - self.keep_alive_timeouts.poll_next_unpin(cx) + self.pending_keep_alive_timeouts.poll_next_unpin(cx) { if let Some(context) = self.connections.get_mut(&peer) { tracing::trace!( @@ -410,7 +416,10 @@ mod tests { use super::*; use crate::{ protocol::TransportService, - transport::manager::{handle::InnerTransportManagerCommand, TransportManagerHandle}, + transport::{ + manager::{handle::InnerTransportManagerCommand, TransportManagerHandle}, + KEEP_ALIVE_TIMEOUT, + }, }; use futures::StreamExt; use parking_lot::RwLock; @@ -439,6 +448,7 @@ mod tests { Vec::new(), Arc::new(AtomicUsize::new(0usize)), handle, + KEEP_ALIVE_TIMEOUT, ); (service, sender, cmd_rx) @@ -780,7 +790,7 @@ mod tests { }; // verify the first connection state is correct - assert_eq!(service.keep_alive_timeouts.len(), 1); + assert_eq!(service.pending_keep_alive_timeouts.len(), 1); match service.connections.get(&peer) { Some(context) => { assert_eq!( @@ -815,7 +825,7 @@ mod tests { // doesn't exist anymore // // the peer is removed because there is no connection to them - assert_eq!(service.keep_alive_timeouts.len(), 1); + assert_eq!(service.pending_keep_alive_timeouts.len(), 1); assert!(service.connections.get(&peer).is_none()); // register new primary connection but verify that there are now two pending keep-alive @@ -843,7 +853,7 @@ mod tests { }; // verify the first connection state is correct - assert_eq!(service.keep_alive_timeouts.len(), 2); + assert_eq!(service.pending_keep_alive_timeouts.len(), 2); match service.connections.get(&peer) { Some(context) => { assert_eq!( diff --git a/src/transport/manager/mod.rs b/src/transport/manager/mod.rs index 4689097c..33cc8f5b 100644 --- a/src/transport/manager/mod.rs +++ b/src/transport/manager/mod.rs @@ -51,6 +51,7 @@ use std::{ Arc, }, task::{Context, Poll}, + time::Duration, }; pub use handle::{TransportHandle, TransportManagerHandle}; @@ -322,6 +323,7 @@ impl TransportManager { protocol: ProtocolName, fallback_names: Vec, codec: ProtocolCodec, + keep_alive_timeout: Duration, ) -> TransportService { assert!(!self.protocol_names.contains(&protocol)); @@ -337,6 +339,7 @@ impl TransportManager { fallback_names.clone(), self.next_substream_id.clone(), self.transport_manager_handle.clone(), + keep_alive_timeout, ); self.protocols.insert( @@ -1756,7 +1759,9 @@ mod tests { use super::*; use crate::{ - crypto::ed25519::Keypair, executor::DefaultExecutor, transport::dummy::DummyTransport, + crypto::ed25519::Keypair, + executor::DefaultExecutor, + transport::{dummy::DummyTransport, KEEP_ALIVE_TIMEOUT}, }; use std::{ net::{Ipv4Addr, Ipv6Addr}, @@ -1793,11 +1798,13 @@ mod tests { ProtocolName::from("/notif/1"), Vec::new(), ProtocolCodec::UnsignedVarint(None), + KEEP_ALIVE_TIMEOUT, ); manager.register_protocol( ProtocolName::from("/notif/1"), Vec::new(), ProtocolCodec::UnsignedVarint(None), + KEEP_ALIVE_TIMEOUT, ); } @@ -1818,6 +1825,7 @@ mod tests { ProtocolName::from("/notif/1"), Vec::new(), ProtocolCodec::UnsignedVarint(None), + KEEP_ALIVE_TIMEOUT, ); manager.register_protocol( ProtocolName::from("/notif/2"), @@ -1826,6 +1834,7 @@ mod tests { ProtocolName::from("/notif/1"), ], ProtocolCodec::UnsignedVarint(None), + KEEP_ALIVE_TIMEOUT, ); } @@ -1849,6 +1858,7 @@ mod tests { ProtocolName::from("/notif/1"), ], ProtocolCodec::UnsignedVarint(None), + KEEP_ALIVE_TIMEOUT, ); manager.register_protocol( ProtocolName::from("/notif/2"), @@ -1857,6 +1867,7 @@ mod tests { ProtocolName::from("/notif/1/new"), ], ProtocolCodec::UnsignedVarint(None), + KEEP_ALIVE_TIMEOUT, ); } diff --git a/src/transport/mod.rs b/src/transport/mod.rs index 0746b9e7..792508cc 100644 --- a/src/transport/mod.rs +++ b/src/transport/mod.rs @@ -47,6 +47,9 @@ pub(crate) const CONNECTION_OPEN_TIMEOUT: Duration = Duration::from_secs(10); /// Timeout for opening a substream. pub(crate) const SUBSTREAM_OPEN_TIMEOUT: Duration = Duration::from_secs(5); +/// Timeout for connection waiting new substreams. +pub(crate) const KEEP_ALIVE_TIMEOUT: Duration = Duration::from_secs(5); + /// Maximum number of parallel dial attempts. pub(crate) const MAX_PARALLEL_DIALS: usize = 8;