Skip to content

Commit

Permalink
feat(server): make it possible to disable WS ping
Browse files Browse the repository at this point in the history
  • Loading branch information
niklasad1 committed Nov 29, 2023
1 parent 98675a0 commit c68d9bf
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 36 deletions.
45 changes: 43 additions & 2 deletions server/src/future.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,18 @@

//! Utilities for handling async code.
use futures_util::{Stream, StreamExt};
use jsonrpsee_core::Error;
use std::sync::Arc;
use tokio::sync::{watch, OwnedSemaphorePermit, Semaphore, TryAcquireError};
use pin_project::pin_project;
use std::{
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use tokio::{
sync::{watch, OwnedSemaphorePermit, Semaphore, TryAcquireError},
time::interval,
};

/// Create channel to determine whether
/// the server shall continue to run or not.
Expand Down Expand Up @@ -119,3 +128,35 @@ impl ConnectionGuard {

/// Connection permit.
pub type ConnectionPermit = OwnedSemaphorePermit;

#[pin_project]
pub(crate) struct IntervalStream(#[pin] Option<tokio_stream::wrappers::IntervalStream>);

impl IntervalStream {
/// Creates a stream which never returns any elements.
pub(crate) fn pending() -> Self {
Self(None)
}

/// Creates a stream which produces elements with `period`.
pub(crate) async fn new(period: std::time::Duration) -> Self {
let mut interval = interval(period);
interval.tick().await;

Self(Some(tokio_stream::wrappers::IntervalStream::new(interval)))
}
}

impl Stream for IntervalStream {
type Item = tokio::time::Instant;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if let Some(mut stream) = self.project().0.as_pin_mut() {
stream.poll_next_unpin(cx)
} else {
// NOTE: this will not be woken up again and it's by design
// to be a pending stream that never returns.
Poll::Pending
}
}
}
27 changes: 15 additions & 12 deletions server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use std::sync::Arc;
use std::task::Poll;
use std::time::Duration;

use crate::future::{ConnectionGuard, ServerHandle, StopHandle};
use crate::future::{ConnectionGuard, IntervalStream, ServerHandle, StopHandle};
use crate::middleware::rpc::{RpcService, RpcServiceBuilder, RpcServiceCfg, RpcServiceT};
use crate::transport::ws::BackgroundTaskParams;
use crate::transport::{http, ws};
Expand Down Expand Up @@ -286,28 +286,31 @@ impl ConnectionState {
pub enum PingConfig {
/// The server pings the connected clients continuously at the configured interval but
/// doesn't disconnect them if no pongs are received from the client.
WithoutInactivityCheck(Duration),
OnlyPing(Duration),
/// The server pings the connected clients continuously at the configured interval
/// and terminates the connection if no websocket messages received from client
/// after the max limit is exceeded.
WithInactivityCheck {
Ping {
/// Time interval between consequent pings from server
ping_interval: Duration,
/// Max allowed time for connection to stay idle
inactive_limit: Duration,
},
/// Pings are disabled.
Disabled,
}

impl PingConfig {
pub(crate) fn ping_interval(&self) -> Duration {
pub(crate) async fn ping_interval(&self) -> IntervalStream {
match self {
Self::WithoutInactivityCheck(ping_interval) => *ping_interval,
Self::WithInactivityCheck { ping_interval, .. } => *ping_interval,
Self::OnlyPing(ping_interval) => IntervalStream::new(*ping_interval).await,
Self::Ping { ping_interval, .. } => IntervalStream::new(*ping_interval).await,
Self::Disabled => IntervalStream::pending(),
}
}

pub(crate) fn inactive_limit(&self) -> Option<Duration> {
if let Self::WithInactivityCheck { inactive_limit, .. } = self {
if let Self::Ping { inactive_limit, .. } = self {
Some(*inactive_limit)
} else {
None
Expand All @@ -317,7 +320,7 @@ impl PingConfig {

impl Default for PingConfig {
fn default() -> Self {
Self::WithoutInactivityCheck(Duration::from_secs(60))
Self::OnlyPing(Duration::from_secs(60))
}
}

Expand All @@ -333,7 +336,7 @@ impl Default for ServerConfig {
enable_http: true,
enable_ws: true,
message_buffer_capacity: 1024,
ping_config: PingConfig::WithoutInactivityCheck(Duration::from_secs(60)),
ping_config: PingConfig::OnlyPing(Duration::from_secs(60)),
id_provider: Arc::new(RandomIntegerIdProvider),
}
}
Expand Down Expand Up @@ -423,7 +426,7 @@ impl ServerConfigBuilder {

/// See [`Builder::ping_interval`] for documentation.
pub fn ping_interval(mut self, config: PingConfig) -> Result<Self, Error> {
if let PingConfig::WithInactivityCheck { ping_interval, inactive_limit } = config {
if let PingConfig::Ping { ping_interval, inactive_limit } = config {
if ping_interval >= inactive_limit {
return Err(Error::Custom("`inactive_limit` must be bigger than `ping_interval` to work".into()));
}
Expand Down Expand Up @@ -646,10 +649,10 @@ impl<HttpMiddleware, RpcMiddleware> Builder<HttpMiddleware, RpcMiddleware> {
/// use jsonrpsee_server::{ServerBuilder, PingConfig};
///
/// // Set the ping interval to 10 seconds but terminate the connection if a client is inactive for more than 2 minutes
/// let builder = ServerBuilder::default().ping_interval(PingConfig::WithInactivityCheck { ping_interval: Duration::from_secs(10), inactive_limit: Duration::from_secs(2 * 60) }).unwrap();
/// let builder = ServerBuilder::default().ping_interval(PingConfig::Ping { ping_interval: Duration::from_secs(10), inactive_limit: Duration::from_secs(2 * 60) }).unwrap();
/// ```
pub fn ping_interval(mut self, config: PingConfig) -> Result<Self, Error> {
if let PingConfig::WithInactivityCheck { ping_interval, inactive_limit } = config {
if let PingConfig::Ping { ping_interval, inactive_limit } = config {
if ping_interval >= inactive_limit {
return Err(Error::Custom("`inactive_limit` must be bigger than `ping_interval` to work".into()));
}
Expand Down
2 changes: 1 addition & 1 deletion server/src/tests/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -880,7 +880,7 @@ async fn server_with_infinite_call(
) -> (crate::ServerHandle, std::net::SocketAddr) {
let server = ServerBuilder::default()
// Make sure that the ping_interval doesn't force the connection to be closed
.ping_interval(crate::server::PingConfig::WithoutInactivityCheck(timeout))
.ping_interval(crate::server::PingConfig::OnlyPing(timeout))
.unwrap()
.build("127.0.0.1:0")
.with_default_timeout()
Expand Down
38 changes: 17 additions & 21 deletions server/src/transport/ws.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use std::sync::Arc;
use std::time::{Duration, Instant};
use std::time::Instant;

use crate::future::IntervalStream;
use crate::middleware::rpc::{RpcService, RpcServiceBuilder, RpcServiceCfg, RpcServiceT};
use crate::server::{handle_rpc_call, ConnectionState, ServerConfig};
use crate::PingConfig;

use futures_util::future::{self, Either, Fuse};
use futures_util::future::{self, Either};
use futures_util::io::{BufReader, BufWriter};
use futures_util::{Future, FutureExt, StreamExt, TryStreamExt};
use futures_util::{Future, StreamExt, TryStreamExt};
use hyper::upgrade::Upgraded;
use jsonrpsee_core::server::helpers::MethodSink;
use jsonrpsee_core::server::{BoundedSubscriptions, Methods};
Expand All @@ -18,7 +19,7 @@ use soketto::connection::Error as SokettoError;
use soketto::data::ByteSlice125;

use tokio::sync::{mpsc, oneshot};
use tokio_stream::wrappers::{IntervalStream, ReceiverStream};
use tokio_stream::wrappers::ReceiverStream;
use tokio_util::compat::{Compat, TokioAsyncReadCompatExt};

pub(crate) type Sender = soketto::Sender<BufReader<BufWriter<Compat<Upgraded>>>>;
Expand Down Expand Up @@ -77,7 +78,7 @@ where
let (conn_tx, conn_rx) = oneshot::channel();

// Spawn another task that sends out the responses on the Websocket.
let send_task_handle = tokio::spawn(send_task(rx, ws_sender, ping_config.ping_interval(), conn_rx));
let send_task_handle = tokio::spawn(send_task(rx, ws_sender, ping_config, conn_rx));

let stopped = conn.stop_handle.clone().shutdown();
let rpc_service = Arc::new(rpc_service);
Expand Down Expand Up @@ -174,15 +175,10 @@ where
async fn send_task(
rx: mpsc::Receiver<String>,
mut ws_sender: Sender,
ping_interval: Duration,
ping_config: PingConfig,
stop: oneshot::Receiver<()>,
) {
// Interval to send out continuously `pings`.
let mut ping_interval = tokio::time::interval(ping_interval);
// This returns immediately so make sure it doesn't resolve before the ping_interval has been elapsed.
ping_interval.tick().await;

let ping_interval = IntervalStream::new(ping_interval);
let ping_interval = ping_config.ping_interval().await;
let rx = ReceiverStream::new(rx);

tokio::pin!(ping_interval, rx, stop);
Expand Down Expand Up @@ -250,10 +246,14 @@ where
T: StreamExt<Item = Result<Incoming, SokettoError>> + Unpin,
{
let mut last_active = Instant::now();
let inactivity_check = match ping_config.inactive_limit() {
Some(period) => IntervalStream::new(period).await,
None => IntervalStream::pending(),
};

let inactivity_check =
Box::pin(ping_config.inactive_limit().map(|d| tokio::time::sleep(d).fuse()).unwrap_or_else(Fuse::terminated));
let mut futs = futures_util::future::select(ws_stream.next(), inactivity_check);
tokio::pin!(inactivity_check);

let mut futs = futures_util::future::select(ws_stream.next(), inactivity_check.next());

loop {
match futures_util::future::select(futs, stopped).await {
Expand All @@ -279,12 +279,8 @@ where
}

stopped = s;
// use really large duration instead of Duration::MAX to
// solve the panic issue with interval initialization
let inactivity_check = Box::pin(
ping_config.inactive_limit().map(|d| tokio::time::sleep(d).fuse()).unwrap_or_else(Fuse::terminated),
);
futs = futures_util::future::select(rcv, inactivity_check);

futs = futures_util::future::select(rcv, inactivity_check.next());
}
// Server has been stopped.
Either::Right(_) => break Receive::Stopped,
Expand Down

0 comments on commit c68d9bf

Please sign in to comment.