diff --git a/swarm/src/connection/pool.rs b/swarm/src/connection/pool.rs index bf21976f749..5a723bcd49f 100644 --- a/swarm/src/connection/pool.rs +++ b/swarm/src/connection/pool.rs @@ -33,12 +33,11 @@ use fnv::FnvHashMap; use futures::prelude::*; use futures::{ channel::{mpsc, oneshot}, - future::{poll_fn, BoxFuture, Either}, - ready, + future::{BoxFuture, Either}, stream::FuturesUnordered, }; use libp2p_core::connection::{ConnectionId, Endpoint, PendingPoint}; -use libp2p_core::muxing::{StreamMuxer, StreamMuxerBox}; +use libp2p_core::muxing::StreamMuxerBox; use std::{ collections::{hash_map, HashMap}, convert::TryFrom as _, @@ -98,10 +97,12 @@ where /// Sender distributed to pending tasks for reporting events back /// to the pool. - pending_connection_events_tx: mpsc::Sender>, + pending_connection_events_tx: + mpsc::Sender>, /// Receiver for events reported from pending tasks. - pending_connection_events_rx: mpsc::Receiver>, + pending_connection_events_rx: + mpsc::Receiver>, /// Sender distributed to established tasks for reporting events back /// to the pool. @@ -485,7 +486,7 @@ where dial_concurrency_factor_override: Option, ) -> Result where - TTrans: Clone + Send, + TTrans: Transport + Clone + Send, TTrans::Dial: Send + 'static, { if let Err(limit) = self.counters.check_max_pending_outgoing() { @@ -541,6 +542,7 @@ where info: IncomingInfo<'_>, ) -> Result where + TTrans: Transport, TFut: Future> + Send + 'static, { let endpoint = info.to_connected_point(); @@ -673,7 +675,9 @@ where match event { task::PendingConnectionEvent::ConnectionEstablished { id, - output: (obtained_peer_id, muxer), + // output: (obtained_peer_id, muxer), + obtained_peer_id, + response, outgoing, } => { let PendingConnectionInfo { @@ -759,20 +763,8 @@ where }); if let Err(error) = error { - self.spawn( - poll_fn(move |cx| { - if let Err(e) = ready!(muxer.close(cx)) { - log::debug!( - "Failed to close connection {:?} to peer {}: {:?}", - id, - obtained_peer_id, - e - ); - } - Poll::Ready(()) - }) - .boxed(), - ); + // send message to PendingConnection + let _ = response.send(task::PendingCommand::Close); match endpoint { ConnectedPoint::Dialer { .. } => { @@ -815,21 +807,22 @@ where }, ); - let connection = super::Connection::new( - muxer, - handler.into_handler(&obtained_peer_id, &endpoint), - self.substream_upgrade_protocol_override, - ); - self.spawn( - task::new_for_established_connection( + // Send message to upgrade pending connection to upgrade to a full connection + let cmd = task::PendingCommand::Upgrade { + handler: handler.into_handler(&obtained_peer_id, &endpoint), + substream_upgrade_protocol_override: self + .substream_upgrade_protocol_override, + command_receiver, + events: self.established_connection_events_tx.clone(), + }; + if response.send(cmd).is_err() { + // TODO: what else do we want to do if the task is gone? + log::debug!( + "Failed to upgrade connection {:?} to peer {}: Task is gone", id, obtained_peer_id, - connection, - command_receiver, - self.established_connection_events_tx.clone(), - ) - .boxed(), - ); + ); + } match self.get(id) { Some(PoolConnection::Established(connection)) => { diff --git a/swarm/src/connection/pool/task.rs b/swarm/src/connection/pool/task.rs index 866049e50da..01cefdfdf8e 100644 --- a/swarm/src/connection/pool/task.rs +++ b/swarm/src/connection/pool/task.rs @@ -24,7 +24,8 @@ use super::concurrent_dial::ConcurrentDial; use crate::{ connection::{ - self, ConnectionError, PendingInboundConnectionError, PendingOutboundConnectionError, + self, Connection, ConnectionError, PendingInboundConnectionError, + PendingOutboundConnectionError, }, transport::{Transport, TransportError}, ConnectionHandler, Multiaddr, PeerId, @@ -34,7 +35,7 @@ use futures::{ future::{poll_fn, Either, Future}, SinkExt, StreamExt, }; -use libp2p_core::connection::ConnectionId; +use libp2p_core::{connection::ConnectionId, muxing::StreamMuxerBox, upgrade, StreamMuxer}; use std::pin::Pin; use void::Void; @@ -48,14 +49,30 @@ pub enum Command { Close, } +/// Commands that can be sent to a task driving a pending connection. #[derive(Debug)] -pub enum PendingConnectionEvent +pub enum PendingCommand { + /// Upgrade from pending to established connection. + Upgrade { + handler: THandler, + substream_upgrade_protocol_override: Option, + command_receiver: mpsc::Receiver>, + events: mpsc::Sender>, + }, + /// Close the connection, due to an error, and terminate the task. + Close, +} + +#[derive(Debug)] +pub enum PendingConnectionEvent where TTrans: Transport, + THandler: ConnectionHandler, { ConnectionEstablished { id: ConnectionId, - output: TTrans::Output, + obtained_peer_id: PeerId, + response: oneshot::Sender>, /// [`Some`] when the new connection is an outgoing connection. /// Addresses are dialed in parallel. Contains the addresses and errors /// of dial attempts that failed before the one successful dial. @@ -97,13 +114,14 @@ pub enum EstablishedConnectionEvent { }, } -pub async fn new_for_pending_outgoing_connection( +pub async fn new_for_pending_outgoing_connection( connection_id: ConnectionId, dial: ConcurrentDial, abort_receiver: oneshot::Receiver, - mut events: mpsc::Sender>, + mut events: mpsc::Sender>, ) where - TTrans: Transport, + TTrans: Transport, + THandler: ConnectionHandler, { match futures::future::select(abort_receiver, Box::pin(dial)).await { Either::Left((Err(oneshot::Canceled), _)) => { @@ -115,14 +133,50 @@ pub async fn new_for_pending_outgoing_connection( .await; } Either::Left((Ok(v), _)) => void::unreachable(v), - Either::Right((Ok((address, output, errors)), _)) => { + Either::Right((Ok((address, (obtained_peer_id, muxer), errors)), _)) => { + let (response, receiver) = oneshot::channel(); let _ = events .send(PendingConnectionEvent::ConnectionEstablished { id: connection_id, - output, + obtained_peer_id, + response, outgoing: Some((address, errors)), }) .await; + + match receiver.await { + Ok(PendingCommand::Upgrade { + handler, + substream_upgrade_protocol_override, + command_receiver, + events, + }) => { + // Upgrade to Connection + let connection = + Connection::new(muxer, handler, substream_upgrade_protocol_override); + new_for_established_connection( + connection_id, + obtained_peer_id, + connection, + command_receiver, + events, + ) + .await + } + Ok(PendingCommand::Close) => { + if let Err(e) = poll_fn(move |cx| muxer.close(cx)).await { + log::debug!( + "Failed to close connection {:?} to peer {}: {:?}", + connection_id, + obtained_peer_id, + e + ); + } + } + Err(_) => { + // Shutting down, nothing we can do about this. + } + } } Either::Right((Err(e), _)) => { let _ = events @@ -135,14 +189,15 @@ pub async fn new_for_pending_outgoing_connection( } } -pub async fn new_for_pending_incoming_connection( +pub async fn new_for_pending_incoming_connection( connection_id: ConnectionId, future: TFut, abort_receiver: oneshot::Receiver, - mut events: mpsc::Sender>, + mut events: mpsc::Sender>, ) where - TTrans: Transport, + TTrans: Transport, TFut: Future> + Send + 'static, + THandler: ConnectionHandler, { match futures::future::select(abort_receiver, Box::pin(future)).await { Either::Left((Err(oneshot::Canceled), _)) => { @@ -154,14 +209,50 @@ pub async fn new_for_pending_incoming_connection( .await; } Either::Left((Ok(v), _)) => void::unreachable(v), - Either::Right((Ok(output), _)) => { + Either::Right((Ok((obtained_peer_id, muxer)), _)) => { + let (response, receiver) = oneshot::channel(); let _ = events .send(PendingConnectionEvent::ConnectionEstablished { id: connection_id, - output, + obtained_peer_id, + response, outgoing: None, }) .await; + + match receiver.await { + Ok(PendingCommand::Upgrade { + handler, + substream_upgrade_protocol_override, + command_receiver, + events, + }) => { + // Upgrade to Connection + let connection = + Connection::new(muxer, handler, substream_upgrade_protocol_override); + new_for_established_connection( + connection_id, + obtained_peer_id, + connection, + command_receiver, + events, + ) + .await + } + Ok(PendingCommand::Close) => { + if let Err(e) = poll_fn(move |cx| muxer.close(cx)).await { + log::debug!( + "Failed to close connection {:?} to peer {}: {:?}", + connection_id, + obtained_peer_id, + e + ); + } + } + Err(_) => { + // Shutting down, nothing we can do about this. + } + } } Either::Right((Err(e), _)) => { let _ = events @@ -176,10 +267,10 @@ pub async fn new_for_pending_incoming_connection( } } -pub async fn new_for_established_connection( +async fn new_for_established_connection( connection_id: ConnectionId, peer_id: PeerId, - mut connection: crate::connection::Connection, + mut connection: Connection, mut command_receiver: mpsc::Receiver>, mut events: mpsc::Sender>, ) where