From 40a47356dc339041206eac30eb3b54ced98a8501 Mon Sep 17 00:00:00 2001 From: vmenge Date: Wed, 29 Jan 2025 16:44:14 +0100 Subject: [PATCH] fix(relay-client): lack of proper connection attempt backoff --- Cargo.lock | 5 ++- Cargo.toml | 1 + relay-client/Cargo.toml | 1 + relay-client/src/actor.rs | 45 +++++++++++++---------- relay-client/src/flume_receiver_stream.rs | 21 +++++++++-- relay-client/src/lib.rs | 2 + relay-client/tests/connect_attempts.rs | 37 ++++++++++++++++++- 7 files changed, 86 insertions(+), 26 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 780fae6..29c6204 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -747,6 +747,7 @@ dependencies = [ "secrecy", "tokio", "tokio-stream", + "tokio-util", "tonic", "tracing", ] @@ -1340,9 +1341,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.12" +version = "0.7.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61e7c3654c13bcd040d4a03abee2c75b1d14a37b423cf5a813ceae1cc903ec6a" +checksum = "d7fcaa8d55a2bdd6b83ace262b016eca0d79ee02818c5c1bcdf0305114081078" dependencies = [ "bytes", "futures-core", diff --git a/Cargo.toml b/Cargo.toml index d0351f6..dda4473 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,7 @@ clap = { version = "4.5", features = ["derive"] } color-eyre = "0.6.2" tokio = { version = "1", features = ["full"] } tokio-stream = "0.1.15" +tokio-util = "0.7.13" uuid = "1.11.0" derive_more = { version = "0.99" } tracing = "0.1" diff --git a/relay-client/Cargo.toml b/relay-client/Cargo.toml index 50f9485..50a7d68 100644 --- a/relay-client/Cargo.toml +++ b/relay-client/Cargo.toml @@ -12,6 +12,7 @@ rust-version.workspace = true [dependencies] orb-relay-messages = { workspace = true, features = ["client"] } tokio = { workspace = true, fetures = ["full"] } +tokio-util.workspace = true tonic = { workspace = true, features = ["tls-roots"] } derive_more.workspace = true color-eyre.workspace = true diff --git a/relay-client/src/actor.rs b/relay-client/src/actor.rs index f2d2b83..48b8de2 100644 --- a/relay-client/src/actor.rs +++ b/relay-client/src/actor.rs @@ -12,6 +12,7 @@ use secrecy::ExposeSecret; use std::collections::HashMap; use tokio::{task, time}; use tokio_stream::StreamExt; +use tokio_util::sync::CancellationToken; use tonic::{ transport::{ClientTlsConfig, Endpoint}, Streaming, @@ -78,31 +79,31 @@ pub fn run(props: Props) -> (flume::Sender, task::JoinHandle { + .await; + + if let Err(e) = result { + cancellation_token.cancel(); + if let Err::StopRequest = e { return Err(Err::StopRequest); - } + } else if props.opts.max_connection_attempts <= conn_attempts { + return Err(e); + } else { + error!( + "RelayClient errored out {e:?}. Retrying in {}s", + props.opts.connection_timeout.as_secs() + ); - Err(e) => { - if props.opts.max_connection_attempts <= conn_attempts { - return Err(e); - } else { - error!( - "RelayClient errored out {e:?}. Retrying in {}s", - props.opts.connection_timeout.as_secs() - ); - } + time::sleep(props.opts.connection_backoff).await; } - - Ok(()) => (), - } + }; } }); @@ -114,10 +115,11 @@ async fn main_loop( props: &Props, relay_actor_tx: flume::Sender, relay_actor_rx: flume::Receiver, + cancellation_token: CancellationToken, ) -> Result<(), Err> { let mut response_stream = time::timeout( props.opts.connection_timeout, - connect(props, &relay_actor_tx), + connect(props, &relay_actor_tx, cancellation_token), ) .await .wrap_err("Timed out trying to establish a connection")??; @@ -361,6 +363,7 @@ fn handle_ack(state: &mut State, seq: Seq) { async fn connect( props: &Props, relay_actor_tx: &flume::Sender, + cancellation_token: CancellationToken, ) -> Result, Err> { let Props { opts, @@ -407,7 +410,11 @@ async fn connect( .wrap_err("Failed to send RelayConnectRequest")?; let mut response_stream: Streaming = relay_client - .relay_connect(flume_receiver_stream::new(tonic_rx.clone(), 4)) + .relay_connect(flume_receiver_stream::new( + tonic_rx.clone(), + 4, + cancellation_token, + )) .await? .into_inner(); diff --git a/relay-client/src/flume_receiver_stream.rs b/relay-client/src/flume_receiver_stream.rs index a2151ad..80a7dae 100644 --- a/relay-client/src/flume_receiver_stream.rs +++ b/relay-client/src/flume_receiver_stream.rs @@ -3,6 +3,7 @@ use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; +use tokio_util::sync::CancellationToken; /// Creates a `tokio_stream::wrappers::ReceiverStream` from a `flume::Receiver<_>` /// ## example @@ -13,15 +14,29 @@ use tokio_stream::wrappers::ReceiverStream; pub fn new( flume_rx: flume::Receiver, tokio_mpsc_receiver_buffer: usize, + cancellation_token: CancellationToken, ) -> ReceiverStream { let (tx, rx) = mpsc::channel(tokio_mpsc_receiver_buffer); tokio::spawn(async move { - while let Ok(msg) = flume_rx.recv_async().await { - if tx.send(msg).await.is_err() { - break; + loop { + tokio::select! { + biased; + + _ = cancellation_token.cancelled() => { + break; + } + + + msg = flume_rx.recv_async() => { + if tx.send(msg?).await.is_err() { + break; + } + } } } + + Ok::<_,flume::RecvError>(()) }); ReceiverStream::new(rx) diff --git a/relay-client/src/lib.rs b/relay-client/src/lib.rs index b35cac8..f9e1a68 100644 --- a/relay-client/src/lib.rs +++ b/relay-client/src/lib.rs @@ -243,6 +243,8 @@ pub struct ClientOpts { auth: Auth, #[builder(default = Duration::from_secs(20))] connection_timeout: Duration, + #[builder(default = Duration::from_secs(20))] + connection_backoff: Duration, #[builder(default = Amount::Infinite)] max_connection_attempts: Amount, #[builder(default = Duration::from_secs(20))] diff --git a/relay-client/tests/connect_attempts.rs b/relay-client/tests/connect_attempts.rs index 246faee..0df1401 100644 --- a/relay-client/tests/connect_attempts.rs +++ b/relay-client/tests/connect_attempts.rs @@ -7,7 +7,7 @@ use orb_relay_client::{Amount, Auth, Client, ClientOpts}; use orb_relay_messages::relay::{ entity::EntityType, relay_connect_request::Msg, ConnectRequest, ConnectResponse, }; -use std::time::Duration; +use std::time::{Duration, Instant}; use test_server::{IntoRes, TestServer}; use tokio::time; @@ -53,7 +53,7 @@ async fn connects() { #[tokio::test] async fn tries_to_connect_the_expected_number_of_times_then_gives_up() { // Arrange - let expected_attempts = 2; + let expected_attempts = 3; let sv = TestServer::new(0, |attempts, _conn_req, _| { *attempts += 1; ConnectResponse { @@ -72,6 +72,7 @@ async fn tries_to_connect_the_expected_number_of_times_then_gives_up() { .auth(Auth::Token(Default::default())) .max_connection_attempts(Amount::Val(expected_attempts)) .connection_timeout(Duration::from_millis(10)) + .connection_backoff(Duration::ZERO) .build(); // Act @@ -83,3 +84,35 @@ async fn tries_to_connect_the_expected_number_of_times_then_gives_up() { let actual_attempts = sv.state().await; assert_eq!(*actual_attempts, expected_attempts); } + +#[tokio::test] +async fn sleeps_for_backoff_period_between_connection_attempts() { + // Arrange + let sv = TestServer::new((0, Instant::now()), |attempts, _conn_req, _| { + attempts.0 += 1; + ConnectResponse { + client_id: "doesntmatter".to_string(), + success: false, + error: "nothing".to_string(), + } + .into_res() + }) + .await; + + let opts = ClientOpts::entity(EntityType::App) + .id("foo") + .namespace("bar") + .endpoint(format!("http://{}", sv.addr())) + .auth(Auth::Token(Default::default())) + .max_connection_attempts(Amount::Infinite) + .connection_backoff(Duration::from_millis(50)) + .build(); + + // Act + let (_client, _handle) = Client::connect(opts); + + // Assert + time::sleep(Duration::from_millis(150)).await; + let actual_attempts = sv.state().await; + assert_eq!(actual_attempts.0, 3); +}