From c68d9bfb59f0ebec480371c946eead592eb92e83 Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Wed, 29 Nov 2023 18:11:47 +0100 Subject: [PATCH] feat(server): make it possible to disable WS ping --- server/src/future.rs | 45 ++++++++++++++++++++++++++++++++++++-- server/src/server.rs | 27 +++++++++++++---------- server/src/tests/ws.rs | 2 +- server/src/transport/ws.rs | 38 ++++++++++++++------------------ 4 files changed, 76 insertions(+), 36 deletions(-) diff --git a/server/src/future.rs b/server/src/future.rs index affdf44744..a229a1430b 100644 --- a/server/src/future.rs +++ b/server/src/future.rs @@ -26,9 +26,18 @@ //! Utilities for handling async code. +use futures_util::{Stream, StreamExt}; use jsonrpsee_core::Error; -use std::sync::Arc; -use tokio::sync::{watch, OwnedSemaphorePermit, Semaphore, TryAcquireError}; +use pin_project::pin_project; +use std::{ + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; +use tokio::{ + sync::{watch, OwnedSemaphorePermit, Semaphore, TryAcquireError}, + time::interval, +}; /// Create channel to determine whether /// the server shall continue to run or not. @@ -119,3 +128,35 @@ impl ConnectionGuard { /// Connection permit. pub type ConnectionPermit = OwnedSemaphorePermit; + +#[pin_project] +pub(crate) struct IntervalStream(#[pin] Option); + +impl IntervalStream { + /// Creates a stream which never returns any elements. + pub(crate) fn pending() -> Self { + Self(None) + } + + /// Creates a stream which produces elements with `period`. + pub(crate) async fn new(period: std::time::Duration) -> Self { + let mut interval = interval(period); + interval.tick().await; + + Self(Some(tokio_stream::wrappers::IntervalStream::new(interval))) + } +} + +impl Stream for IntervalStream { + type Item = tokio::time::Instant; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if let Some(mut stream) = self.project().0.as_pin_mut() { + stream.poll_next_unpin(cx) + } else { + // NOTE: this will not be woken up again and it's by design + // to be a pending stream that never returns. + Poll::Pending + } + } +} diff --git a/server/src/server.rs b/server/src/server.rs index 0590c9a89c..9212308f1b 100644 --- a/server/src/server.rs +++ b/server/src/server.rs @@ -33,7 +33,7 @@ use std::sync::Arc; use std::task::Poll; use std::time::Duration; -use crate::future::{ConnectionGuard, ServerHandle, StopHandle}; +use crate::future::{ConnectionGuard, IntervalStream, ServerHandle, StopHandle}; use crate::middleware::rpc::{RpcService, RpcServiceBuilder, RpcServiceCfg, RpcServiceT}; use crate::transport::ws::BackgroundTaskParams; use crate::transport::{http, ws}; @@ -286,28 +286,31 @@ impl ConnectionState { pub enum PingConfig { /// The server pings the connected clients continuously at the configured interval but /// doesn't disconnect them if no pongs are received from the client. - WithoutInactivityCheck(Duration), + OnlyPing(Duration), /// The server pings the connected clients continuously at the configured interval /// and terminates the connection if no websocket messages received from client /// after the max limit is exceeded. - WithInactivityCheck { + Ping { /// Time interval between consequent pings from server ping_interval: Duration, /// Max allowed time for connection to stay idle inactive_limit: Duration, }, + /// Pings are disabled. + Disabled, } impl PingConfig { - pub(crate) fn ping_interval(&self) -> Duration { + pub(crate) async fn ping_interval(&self) -> IntervalStream { match self { - Self::WithoutInactivityCheck(ping_interval) => *ping_interval, - Self::WithInactivityCheck { ping_interval, .. } => *ping_interval, + Self::OnlyPing(ping_interval) => IntervalStream::new(*ping_interval).await, + Self::Ping { ping_interval, .. } => IntervalStream::new(*ping_interval).await, + Self::Disabled => IntervalStream::pending(), } } pub(crate) fn inactive_limit(&self) -> Option { - if let Self::WithInactivityCheck { inactive_limit, .. } = self { + if let Self::Ping { inactive_limit, .. } = self { Some(*inactive_limit) } else { None @@ -317,7 +320,7 @@ impl PingConfig { impl Default for PingConfig { fn default() -> Self { - Self::WithoutInactivityCheck(Duration::from_secs(60)) + Self::OnlyPing(Duration::from_secs(60)) } } @@ -333,7 +336,7 @@ impl Default for ServerConfig { enable_http: true, enable_ws: true, message_buffer_capacity: 1024, - ping_config: PingConfig::WithoutInactivityCheck(Duration::from_secs(60)), + ping_config: PingConfig::OnlyPing(Duration::from_secs(60)), id_provider: Arc::new(RandomIntegerIdProvider), } } @@ -423,7 +426,7 @@ impl ServerConfigBuilder { /// See [`Builder::ping_interval`] for documentation. pub fn ping_interval(mut self, config: PingConfig) -> Result { - if let PingConfig::WithInactivityCheck { ping_interval, inactive_limit } = config { + if let PingConfig::Ping { ping_interval, inactive_limit } = config { if ping_interval >= inactive_limit { return Err(Error::Custom("`inactive_limit` must be bigger than `ping_interval` to work".into())); } @@ -646,10 +649,10 @@ impl Builder { /// use jsonrpsee_server::{ServerBuilder, PingConfig}; /// /// // Set the ping interval to 10 seconds but terminate the connection if a client is inactive for more than 2 minutes - /// let builder = ServerBuilder::default().ping_interval(PingConfig::WithInactivityCheck { ping_interval: Duration::from_secs(10), inactive_limit: Duration::from_secs(2 * 60) }).unwrap(); + /// let builder = ServerBuilder::default().ping_interval(PingConfig::Ping { ping_interval: Duration::from_secs(10), inactive_limit: Duration::from_secs(2 * 60) }).unwrap(); /// ``` pub fn ping_interval(mut self, config: PingConfig) -> Result { - if let PingConfig::WithInactivityCheck { ping_interval, inactive_limit } = config { + if let PingConfig::Ping { ping_interval, inactive_limit } = config { if ping_interval >= inactive_limit { return Err(Error::Custom("`inactive_limit` must be bigger than `ping_interval` to work".into())); } diff --git a/server/src/tests/ws.rs b/server/src/tests/ws.rs index 32a9f0a829..14d3f85ed5 100644 --- a/server/src/tests/ws.rs +++ b/server/src/tests/ws.rs @@ -880,7 +880,7 @@ async fn server_with_infinite_call( ) -> (crate::ServerHandle, std::net::SocketAddr) { let server = ServerBuilder::default() // Make sure that the ping_interval doesn't force the connection to be closed - .ping_interval(crate::server::PingConfig::WithoutInactivityCheck(timeout)) + .ping_interval(crate::server::PingConfig::OnlyPing(timeout)) .unwrap() .build("127.0.0.1:0") .with_default_timeout() diff --git a/server/src/transport/ws.rs b/server/src/transport/ws.rs index 8da0cbccf5..e71182daf1 100644 --- a/server/src/transport/ws.rs +++ b/server/src/transport/ws.rs @@ -1,13 +1,14 @@ use std::sync::Arc; -use std::time::{Duration, Instant}; +use std::time::Instant; +use crate::future::IntervalStream; use crate::middleware::rpc::{RpcService, RpcServiceBuilder, RpcServiceCfg, RpcServiceT}; use crate::server::{handle_rpc_call, ConnectionState, ServerConfig}; use crate::PingConfig; -use futures_util::future::{self, Either, Fuse}; +use futures_util::future::{self, Either}; use futures_util::io::{BufReader, BufWriter}; -use futures_util::{Future, FutureExt, StreamExt, TryStreamExt}; +use futures_util::{Future, StreamExt, TryStreamExt}; use hyper::upgrade::Upgraded; use jsonrpsee_core::server::helpers::MethodSink; use jsonrpsee_core::server::{BoundedSubscriptions, Methods}; @@ -18,7 +19,7 @@ use soketto::connection::Error as SokettoError; use soketto::data::ByteSlice125; use tokio::sync::{mpsc, oneshot}; -use tokio_stream::wrappers::{IntervalStream, ReceiverStream}; +use tokio_stream::wrappers::ReceiverStream; use tokio_util::compat::{Compat, TokioAsyncReadCompatExt}; pub(crate) type Sender = soketto::Sender>>>; @@ -77,7 +78,7 @@ where let (conn_tx, conn_rx) = oneshot::channel(); // Spawn another task that sends out the responses on the Websocket. - let send_task_handle = tokio::spawn(send_task(rx, ws_sender, ping_config.ping_interval(), conn_rx)); + let send_task_handle = tokio::spawn(send_task(rx, ws_sender, ping_config, conn_rx)); let stopped = conn.stop_handle.clone().shutdown(); let rpc_service = Arc::new(rpc_service); @@ -174,15 +175,10 @@ where async fn send_task( rx: mpsc::Receiver, mut ws_sender: Sender, - ping_interval: Duration, + ping_config: PingConfig, stop: oneshot::Receiver<()>, ) { - // Interval to send out continuously `pings`. - let mut ping_interval = tokio::time::interval(ping_interval); - // This returns immediately so make sure it doesn't resolve before the ping_interval has been elapsed. - ping_interval.tick().await; - - let ping_interval = IntervalStream::new(ping_interval); + let ping_interval = ping_config.ping_interval().await; let rx = ReceiverStream::new(rx); tokio::pin!(ping_interval, rx, stop); @@ -250,10 +246,14 @@ where T: StreamExt> + Unpin, { let mut last_active = Instant::now(); + let inactivity_check = match ping_config.inactive_limit() { + Some(period) => IntervalStream::new(period).await, + None => IntervalStream::pending(), + }; - let inactivity_check = - Box::pin(ping_config.inactive_limit().map(|d| tokio::time::sleep(d).fuse()).unwrap_or_else(Fuse::terminated)); - let mut futs = futures_util::future::select(ws_stream.next(), inactivity_check); + tokio::pin!(inactivity_check); + + let mut futs = futures_util::future::select(ws_stream.next(), inactivity_check.next()); loop { match futures_util::future::select(futs, stopped).await { @@ -279,12 +279,8 @@ where } stopped = s; - // use really large duration instead of Duration::MAX to - // solve the panic issue with interval initialization - let inactivity_check = Box::pin( - ping_config.inactive_limit().map(|d| tokio::time::sleep(d).fuse()).unwrap_or_else(Fuse::terminated), - ); - futs = futures_util::future::select(rcv, inactivity_check); + + futs = futures_util::future::select(rcv, inactivity_check.next()); } // Server has been stopped. Either::Right(_) => break Receive::Stopped,