Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transport: Add accept_pending/reject_pending for inbound connections and introduce inbound limits #194

Merged
merged 14 commits into from
Aug 7, 2024
Merged
8 changes: 8 additions & 0 deletions src/transport/dummy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,14 @@ impl Transport for DummyTransport {
Ok(())
}

fn accept_pending(&mut self, _connection_id: ConnectionId) -> crate::Result<()> {
Ok(())
}

fn reject_pending(&mut self, _connection_id: ConnectionId) -> crate::Result<()> {
Ok(())
}

fn reject(&mut self, _: ConnectionId) -> crate::Result<()> {
Ok(())
}
Expand Down
25 changes: 16 additions & 9 deletions src/transport/manager/limits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,17 @@ impl ConnectionLimits {
Ok(usize::MAX)
}

/// Called before accepting a new incoming connection.
pub fn on_incoming(&mut self) -> Result<(), ConnectionLimitsError> {
if let Some(max_incoming_connections) = self.config.max_incoming_connections {
if self.incoming_connections.len() >= max_incoming_connections {
return Err(ConnectionLimitsError::MaxIncomingConnectionsExceeded);
}
}

Ok(())
}

/// Called when a new connection is established.
pub fn on_connection_established(
&mut self,
Expand All @@ -114,11 +125,9 @@ impl ConnectionLimits {
return Err(ConnectionLimitsError::MaxIncomingConnectionsExceeded);
}
}
} else {
if let Some(max_outgoing_connections) = self.config.max_outgoing_connections {
if self.outgoing_connections.len() >= max_outgoing_connections {
return Err(ConnectionLimitsError::MaxOutgoingConnectionsExceeded);
}
} else if let Some(max_outgoing_connections) = self.config.max_outgoing_connections {
if self.outgoing_connections.len() >= max_outgoing_connections {
return Err(ConnectionLimitsError::MaxOutgoingConnectionsExceeded);
}
}

Expand All @@ -127,10 +136,8 @@ impl ConnectionLimits {
if self.config.max_incoming_connections.is_some() {
self.incoming_connections.insert(connection_id);
}
} else {
if self.config.max_outgoing_connections.is_some() {
self.outgoing_connections.insert(connection_id);
}
} else if self.config.max_outgoing_connections.is_some() {
self.outgoing_connections.insert(connection_id);
}

Ok(())
Expand Down
44 changes: 36 additions & 8 deletions src/transport/manager/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -508,18 +508,14 @@ impl TransportManager {
record.set_connection_id(connection_id);

#[cfg(feature = "quic")]
if address.iter().find(|p| std::matches!(p, Protocol::QuicV1)).is_some() {
if address.iter().any(|p| std::matches!(&p, Protocol::QuicV1)) {
quic.push(address.clone());
transports.insert(SupportedTransport::Quic);
continue;
}

#[cfg(feature = "websocket")]
if address
.iter()
.find(|p| std::matches!(p, Protocol::Ws(_) | Protocol::Wss(_)))
.is_some()
{
if address.iter().any(|p| std::matches!(&p, Protocol::Ws(_) | Protocol::Wss(_))) {
websocket.push(address.clone());
transports.insert(SupportedTransport::WebSocket);
continue;
Expand Down Expand Up @@ -839,6 +835,11 @@ impl TransportManager {
}
}

fn on_pending_incoming_connection(&mut self) -> crate::Result<()> {
self.connection_limits.on_incoming()?;
Ok(())
}

/// Handle closed connection.
fn on_connection_closed(
&mut self,
Expand Down Expand Up @@ -1713,7 +1714,34 @@ impl TransportManager {
}
Ok(None) => {}
}
}
},
TransportEvent::PendingInboundConnection { connection_id } => {
if self.on_pending_incoming_connection().is_ok() {
tracing::trace!(
target: LOG_TARGET,
?connection_id,
"accept pending incoming connection",
);

let _ = self
.transports
.get_mut(&transport)
.expect("transport to exist")
.accept_pending(connection_id);
} else {
tracing::debug!(
target: LOG_TARGET,
?connection_id,
"reject pending incoming connection",
);

let _ = self
.transports
.get_mut(&transport)
.expect("transport to exist")
.reject_pending(connection_id);
}
},
event => panic!("event not supported: {event:?}"),
}
},
Expand Down Expand Up @@ -2563,7 +2591,7 @@ mod tests {

peer_context.state = PeerState::Connected {
record,
dial_record: dial_record,
dial_record,
};
}

Expand Down
13 changes: 13 additions & 0 deletions src/transport/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ pub mod websocket;
pub(crate) mod dummy;
pub(crate) mod manager;

pub use manager::limits::{ConnectionLimitsConfig, ConnectionLimitsError};

/// Timeout for opening a connection.
pub(crate) const CONNECTION_OPEN_TIMEOUT: Duration = Duration::from_secs(10);

Expand Down Expand Up @@ -121,6 +123,11 @@ pub(crate) enum TransportEvent {
endpoint: Endpoint,
},

PendingInboundConnection {
/// Connection ID.
connection_id: ConnectionId,
},

/// Connection opened to remote but not yet negotiated.
ConnectionOpened {
/// Connection ID.
Expand Down Expand Up @@ -176,6 +183,12 @@ pub(crate) trait Transport: Stream + Unpin + Send {
/// Accept negotiated connection.
fn accept(&mut self, connection_id: ConnectionId) -> crate::Result<()>;

/// Accept pending connection.
fn accept_pending(&mut self, connection_id: ConnectionId) -> crate::Result<()>;

/// Reject pending connection.
fn reject_pending(&mut self, connection_id: ConnectionId) -> crate::Result<()>;

/// Reject negotiated connection.
fn reject(&mut self, connection_id: ConnectionId) -> crate::Result<()>;

Expand Down
65 changes: 52 additions & 13 deletions src/transport/quic/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use crate::{

use futures::{future::BoxFuture, stream::FuturesUnordered, Stream, StreamExt};
use multiaddr::{Multiaddr, Protocol};
use quinn::{ClientConfig, Connection, Endpoint, IdleTimeout};
use quinn::{ClientConfig, Connecting, Connection, Endpoint, IdleTimeout};

use std::{
collections::{HashMap, HashSet},
Expand Down Expand Up @@ -80,6 +80,9 @@ pub(crate) struct QuicTransport {
/// Pending dials.
pending_dials: HashMap<ConnectionId, Multiaddr>,

/// Pending inbound connections.
pending_inbound_connections: HashMap<ConnectionId, Connecting>,

/// Pending connections.
pending_connections:
FuturesUnordered<BoxFuture<'static, (ConnectionId, Result<NegotiatedConnection, Error>)>>,
Expand Down Expand Up @@ -110,6 +113,22 @@ impl QuicTransport {
Some(p2p_cert.peer_id())
}

/// Handle inbound accepted connection.
fn on_inbound_connection(&mut self, connection_id: ConnectionId, connection: Connecting) {
self.pending_connections.push(Box::pin(async move {
let connection = match connection.await {
Ok(connection) => connection,
Err(error) => return (connection_id, Err(error.into())),
};

let Some(peer) = Self::extract_peer_id(&connection) else {
return (connection_id, Err(Error::InvalidCertificate));
};

(connection_id, Ok(NegotiatedConnection { peer, connection }))
}));
}

/// Handle established connection.
fn on_connection_established(
&mut self,
Expand Down Expand Up @@ -193,6 +212,7 @@ impl TransportBuilder for QuicTransport {
opened_raw: HashMap::new(),
pending_open: HashMap::new(),
pending_dials: HashMap::new(),
pending_inbound_connections: HashMap::new(),
pending_raw_connections: FuturesUnordered::new(),
pending_connections: FuturesUnordered::new(),
},
Expand Down Expand Up @@ -291,6 +311,23 @@ impl Transport for QuicTransport {
.map_or(Err(Error::ConnectionDoesntExist(connection_id)), |_| Ok(()))
}

fn accept_pending(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
let connection = self
.pending_inbound_connections
.remove(&connection_id)
.ok_or(Error::ConnectionDoesntExist(connection_id))?;

self.on_inbound_connection(connection_id, connection);

Ok(())
}

fn reject_pending(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
self.pending_inbound_connections
.remove(&connection_id)
.map_or(Err(Error::ConnectionDoesntExist(connection_id)), |_| Ok(()))
}

fn open(
&mut self,
connection_id: ConnectionId,
Expand Down Expand Up @@ -406,26 +443,19 @@ impl Stream for QuicTransport {
type Item = TransportEvent;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
while let Poll::Ready(Some(connection)) = self.listener.poll_next_unpin(cx) {
if let Poll::Ready(Some(connection)) = self.listener.poll_next_unpin(cx) {
let connection_id = self.context.next_connection_id();

tracing::trace!(
target: LOG_TARGET,
?connection_id,
"accept connection",
"pending inbound connection",
);

self.pending_connections.push(Box::pin(async move {
let connection = match connection.await {
Ok(connection) => connection,
Err(error) => return (connection_id, Err(error.into())),
};
self.pending_inbound_connections.insert(connection_id, connection);

let Some(peer) = Self::extract_peer_id(&connection) else {
return (connection_id, Err(Error::InvalidCertificate));
};

(connection_id, Ok(NegotiatedConnection { peer, connection }))
return Poll::Ready(Some(TransportEvent::PendingInboundConnection {
connection_id,
}));
}

Expand Down Expand Up @@ -545,6 +575,15 @@ mod tests {
));

transport2.dial(ConnectionId::new(), listen_address).unwrap();

let event = transport1.next().await.unwrap();
match event {
TransportEvent::PendingInboundConnection { connection_id } => {
transport1.accept_pending(connection_id).unwrap();
}
_ => panic!("unexpected event"),
}

let (res1, res2) = tokio::join!(transport1.next(), transport2.next());

assert!(std::matches!(
Expand Down
5 changes: 4 additions & 1 deletion src/transport/tcp/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,10 @@ impl TcpConnection {
})
.await
{
Err(_) => Err(Error::Timeout),
Err(_) => {
tracing::trace!(target: LOG_TARGET, ?connection_id, "connection timed out during negotiation");
Err(Error::Timeout)
}
Ok(result) => result,
}
}
Expand Down
Loading