Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ws server): fix shutdown on connection closed #1103

Merged
merged 18 commits into from
Apr 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ impl<L: Logger> hyper::service::Service<hyper::Request<hyper::Body>> for TowerSe
ws_builder.set_max_message_size(data.max_request_body_size as usize);
let (sender, receiver) = ws_builder.finish();

let _ = ws::background_task::<L>(sender, receiver, data).await;
ws::background_task::<L>(sender, receiver, data).await;
}
.in_current_span(),
);
Expand Down
86 changes: 69 additions & 17 deletions server/src/tests/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.

use std::time::Duration;

use crate::server::BatchRequestConfig;
use crate::tests::helpers::{deser_call, init_logger, server_with_context};
use crate::types::SubscriptionId;
Expand Down Expand Up @@ -815,37 +817,87 @@ async fn notif_is_ignored() {
}

#[tokio::test]
async fn drop_client_with_pending_calls_works() {
async fn close_client_with_pending_calls_works() {
const MAX_TIMEOUT: Duration = Duration::from_secs(60);
const CONCURRENT_CALLS: usize = 10;
init_logger();

let (handle, addr) = {
let server = ServerBuilder::default().build("127.0.0.1:0").with_default_timeout().await.unwrap().unwrap();

let mut module = RpcModule::new(());

module
.register_async_method("infinite_call", |_, _| async move {
futures_util::future::pending::<()>().await;
"ok"
})
.unwrap();
let addr = server.local_addr().unwrap();

(server.start(module).unwrap(), addr)
};
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();

let (handle, addr) = server_with_infinite_call(MAX_TIMEOUT.checked_mul(10).unwrap(), tx).await;
let mut client = WebSocketTestClient::new(addr).with_default_timeout().await.unwrap().unwrap();

for _ in 0..10 {
let req = r#"{"jsonrpc":"2.0","method":"infinite_call","id":1}"#;
client.send(req).with_default_timeout().await.unwrap().unwrap();
}

// Assert that the server has received the calls.
for _ in 0..CONCURRENT_CALLS {
assert!(rx.recv().await.is_some());
}

client.close().await.unwrap();
assert!(client.receive().await.is_err());

// Stop the server and ensure that the server doesn't wait for futures to complete
// when the connection has already been closed.
handle.stop().unwrap();
assert!(handle.stopped().with_default_timeout().await.is_ok());
assert!(handle.stopped().with_timeout(MAX_TIMEOUT).await.is_ok());
}

#[tokio::test]
async fn drop_client_with_pending_calls_works() {
const MAX_TIMEOUT: Duration = Duration::from_secs(60);
const CONCURRENT_CALLS: usize = 10;
init_logger();

let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
let (handle, addr) = server_with_infinite_call(MAX_TIMEOUT.checked_mul(10).unwrap(), tx).await;

{
let mut client = WebSocketTestClient::new(addr).with_default_timeout().await.unwrap().unwrap();

for _ in 0..CONCURRENT_CALLS {
let req = r#"{"jsonrpc":"2.0","method":"infinite_call","id":1}"#;
client.send(req).with_default_timeout().await.unwrap().unwrap();
}
// Assert that the server has received the calls.
for _ in 0..CONCURRENT_CALLS {
assert!(rx.recv().await.is_some());
}
}

// Stop the server and ensure that the server doesn't wait for futures to complete
// when the connection has already been closed.
handle.stop().unwrap();
assert!(handle.stopped().with_timeout(MAX_TIMEOUT).await.is_ok());
}

async fn server_with_infinite_call(
timeout: Duration,
tx: tokio::sync::mpsc::UnboundedSender<()>,
) -> (crate::ServerHandle, std::net::SocketAddr) {
let server = ServerBuilder::default()
// Make sure that the ping_interval doesn't force the connection to be closed
.ping_interval(timeout)
.build("127.0.0.1:0")
.with_default_timeout()
.await
.unwrap()
.unwrap();

let mut module = RpcModule::new(tx);

module
.register_async_method("infinite_call", |_, mut ctx| async move {
let tx = std::sync::Arc::make_mut(&mut ctx);
tx.send(()).unwrap();
futures_util::future::pending::<()>().await;
"ok"
})
.unwrap();
let addr = server.local_addr().unwrap();

(server.start(module).unwrap(), addr)
}
98 changes: 71 additions & 27 deletions server/src/transport/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,11 +226,7 @@ pub(crate) async fn execute_call<'a, L: Logger>(req: Request<'a>, call: CallData
response
}

pub(crate) async fn background_task<L: Logger>(
sender: Sender,
mut receiver: Receiver,
svc: ServiceData<L>,
) -> Result<(), Error> {
pub(crate) async fn background_task<L: Logger>(sender: Sender, mut receiver: Receiver, svc: ServiceData<L>) {
let ServiceData {
methods,
max_request_body_size,
Expand All @@ -250,17 +246,17 @@ pub(crate) async fn background_task<L: Logger>(
} = svc;

let (tx, rx) = mpsc::channel::<String>(message_buffer_capacity as usize);
let (mut conn_tx, conn_rx) = oneshot::channel();
let (conn_tx, conn_rx) = oneshot::channel();
let sink = MethodSink::new_with_limit(tx, max_response_body_size, max_log_length);
let bounded_subscriptions = BoundedSubscriptions::new(max_subscriptions_per_connection);
let pending_calls = FuturesUnordered::new();

// Spawn another task that sends out the responses on the Websocket.
tokio::spawn(send_task(rx, sender, ping_interval, conn_rx));
let send_task_handle = tokio::spawn(send_task(rx, sender, ping_interval, conn_rx));

// Buffer for incoming data.
let mut data = Vec::with_capacity(100);
let stopped = stop_handle.shutdown();
let stopped = stop_handle.clone().shutdown();

tokio::pin!(stopped);

Expand All @@ -272,11 +268,11 @@ pub(crate) async fn background_task<L: Logger>(
stopped = stop;
permit
}
None => break Ok(()),
None => break Ok(Shutdown::ConnectionClosed),
Copy link
Collaborator

@jsdw jsdw Apr 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I right in thinking that:

  • Stopped means the user has manually stopped the server, so we want to gracefully close the eonnction
  • ConnectionClosed means the connection was closed for some other reason (eg network issue or whatever)

Copy link
Member Author

@niklasad1 niklasad1 Apr 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this is only fails if the send_task has been completed and the receiver of the bounded channel has been dropped.

this can only occur if the send_ping fails I think i.e, connection closed

};

match try_recv(&mut receiver, &mut data, stopped).await {
Receive::Shutdown => break Ok(()),
Receive::Shutdown => break Ok(Shutdown::Stopped),
Receive::Ok(stop) => {
stopped = stop;
}
Expand All @@ -286,7 +282,7 @@ pub(crate) async fn background_task<L: Logger>(
match err {
SokettoError::Closed => {
tracing::debug!("WS transport: remote peer terminated the connection: {}", conn_id);
break Ok(());
break Ok(Shutdown::ConnectionClosed);
}
SokettoError::MessageTooLarge { current, maximum } => {
tracing::debug!(
Expand All @@ -300,7 +296,7 @@ pub(crate) async fn background_task<L: Logger>(
}
err => {
tracing::debug!("WS transport error: {}; terminate connection: {}", err, conn_id);
break Err(err.into());
break Err(err);
}
};
}
Expand All @@ -326,22 +322,11 @@ pub(crate) async fn background_task<L: Logger>(
// Drive all running methods to completion.
// **NOTE** Do not return early in this function. This `await` needs to run to guarantee
// proper drop behaviour.
//
// This is not strictly not needed because `tokio::spawn` will drive these the completion
// but it's preferred that the `stop_handle.stopped()` should not return until all methods has been
// executed and the connection has been closed.
tokio::select! {
// All pending calls executed.
_ = pending_calls.for_each(|_| async {}) => {
_ = conn_tx.send(());
}
// The connection was closed, no point of waiting for the pending calls.
_ = conn_tx.closed() => {}
}
graceful_shutdown(result, pending_calls, receiver, data, conn_tx, send_task_handle).await;

logger.on_disconnect(remote_addr, TransportProtocol::WebSocket);
drop(conn);
result
drop(stop_handle);
}

/// A task that waits for new messages via the `rx channel` and sends them out on the `WebSocket`.
Expand All @@ -352,7 +337,11 @@ async fn send_task(
stop: oneshot::Receiver<()>,
) {
// Interval to send out continuously `pings`.
let ping_interval = IntervalStream::new(tokio::time::interval(ping_interval));
let mut ping_interval = tokio::time::interval(ping_interval);
// This returns immediately so make sure it doesn't resolve before the ping_interval has been elapsed.
ping_interval.tick().await;

let ping_interval = IntervalStream::new(ping_interval);
let rx = ReceiverStream::new(rx);

tokio::pin!(ping_interval, rx, stop);
Expand Down Expand Up @@ -384,15 +373,18 @@ async fn send_task(
}

// Handle timer intervals.
Either::Right((Either::Left((_, stop)), next_rx)) => {
Either::Right((Either::Left((Some(_instant), stop)), next_rx)) => {
if let Err(err) = send_ping(&mut ws_sender).await {
tracing::debug!("WS transport error: send ping failed: {}", err);
break;
}

rx_item = next_rx;
futs = future::select(ping_interval.next(), stop);
}

Either::Right((Either::Left((None, _)), _)) => unreachable!("IntervalStream never terminates"),

// Server is stopped.
Either::Right((Either::Right(_), _)) => {
break;
Expand Down Expand Up @@ -558,3 +550,55 @@ async fn execute_unchecked_call<L: Logger>(params: ExecuteCallParams<L>) {
}
};
}

#[derive(Debug, Copy, Clone)]
pub(crate) enum Shutdown {
Stopped,
ConnectionClosed,
}

/// Enforce a graceful shutdown.
///
/// This will return once the connection has been terminated or all pending calls have been executed.
async fn graceful_shutdown<F: Future>(
result: Result<Shutdown, SokettoError>,
pending_calls: FuturesUnordered<F>,
receiver: Receiver,
data: Vec<u8>,
mut conn_tx: oneshot::Sender<()>,
send_task_handle: tokio::task::JoinHandle<()>,
) {
match result {
Ok(Shutdown::ConnectionClosed) | Err(SokettoError::Closed) => (),
Ok(Shutdown::Stopped) | Err(_) => {
// Soketto doesn't have a way to signal when the connection is closed
// thus just throw away the data and terminate the stream once the connection has
// been terminated.
//
// The receiver is not cancel-safe such that it's used in a stream to enforce that.
let disconnect_stream = futures_util::stream::unfold((receiver, data), |(mut receiver, mut data)| async {
if let Err(SokettoError::Closed) = receiver.receive(&mut data).await {
None
} else {
Some(((), (receiver, data)))
}
});

let graceful_shutdown = pending_calls.for_each(|_| async {});
let disconnect = disconnect_stream.for_each(|_| async {});

// All pending calls has been finished or the connection closed.
// Fine to terminate
tokio::select! {
_ = graceful_shutdown => {}
_ = disconnect => {}
_ = conn_tx.closed() => {}
}
}
};

// Send a message to close down the "send task".
_ = conn_tx.send(());
// Ensure that send task has been closed.
_ = send_task_handle.await;
}