diff --git a/src/network/network_handler.rs b/src/network/network_handler.rs index 94bb8f8..2f59ecb 100644 --- a/src/network/network_handler.rs +++ b/src/network/network_handler.rs @@ -3,7 +3,7 @@ use crate::{ client::{Commands, Config, Message}, commands::InternalPubSubCommands, resp::{cmd, Command, RespBuf}, - sleep, spawn, Connection, Error, JoinHandle, ReconnectionState, Result, RetryReason, + spawn, timeout, Connection, Error, JoinHandle, ReconnectionState, Result, RetryReason, }; use futures_channel::{mpsc, oneshot}; use futures_util::{select, FutureExt, SinkExt, StreamExt}; @@ -13,7 +13,7 @@ use std::{ collections::{HashMap, VecDeque}, time::Duration, }; -use tokio::sync::broadcast; +use tokio::{sync::broadcast, time::Instant}; pub(crate) type MsgSender = mpsc::UnboundedSender; pub(crate) type MsgReceiver = mpsc::UnboundedReceiver; @@ -28,7 +28,7 @@ pub(crate) type PushReceiver = mpsc::UnboundedReceiver>; pub(crate) type ReconnectSender = broadcast::Sender<()>; pub(crate) type ReconnectReceiver = broadcast::Receiver<()>; -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] enum Status { Disconnected, Connected, @@ -152,8 +152,8 @@ impl NetworkHandler { msg = self.msg_receiver.next().fuse() => { if !self.handle_message(msg).await { break; } } , - value = self.connection.read().fuse() => { - if !self.handle_result(value).await { break; } + result = self.connection.read().fuse() => { + if !self.handle_result(result).await { break; } } } } @@ -242,11 +242,22 @@ impl NetworkHandler { self.messages_to_send.push_back(MessageToSend::new(msg)); } Status::Disconnected => { - debug!( - "[{}] network disconnected, queuing command: {:?}", - self.tag, msg.commands - ); - self.messages_to_send.push_back(MessageToSend::new(msg)); + if msg.retry_on_error { + debug!( + "[{}] network disconnected, queuing command: {:?}", + self.tag, msg.commands + ); + self.messages_to_send.push_back(MessageToSend::new(msg)); + } else { + debug!( + "[{}] network disconnected, ending command in error: {:?}", + self.tag, msg.commands + ); + msg.commands.send_error( + &self.tag, + Error::Client("Disconnected from server".to_string()), + ); + } } Status::EnteringMonitor => { self.messages_to_send.push_back(MessageToSend::new(msg)) @@ -278,8 +289,7 @@ impl NetworkHandler { } } - if let Status::Disconnected = self.status { - } else { + if self.status != Status::Disconnected { self.send_messages().await } @@ -719,7 +729,22 @@ impl NetworkHandler { loop { if let Some(delay) = self.reconnection_state.next_delay() { debug!("[{}] Waiting {delay} ms before reconnection", self.tag); - sleep(Duration::from_millis(delay)).await; + + // keep on receiving new message during the delay + let start = Instant::now(); + let end = start.checked_add(Duration::from_millis(delay)).unwrap(); + loop { + let delay = end.duration_since(Instant::now()); + let result = timeout(delay, self.msg_receiver.next().fuse()).await; + if let Ok(msg) = result { + if !self.handle_message(msg).await { + return false; + } + } else { + // delay has expired + break; + } + } } else { warn!("[{}] Max reconnection attempts reached", self.tag); while let Some(message_to_receive) = self.messages_to_receive.pop_front() { diff --git a/src/tests/client.rs b/src/tests/client.rs index 1721e5a..ae4845a 100644 --- a/src/tests/client.rs +++ b/src/tests/client.rs @@ -57,7 +57,7 @@ async fn on_reconnect() -> Result<()> { .await?; // send command to be sure that the reconnection has been done - client1.set("key", "value").await?; + client1.set("key", "value").retry_on_error(true).await?; let result = receiver.try_recv(); assert!(result.is_ok()); diff --git a/src/tests/config.rs b/src/tests/config.rs index ab38e88..865fa2a 100644 --- a/src/tests/config.rs +++ b/src/tests/config.rs @@ -48,6 +48,7 @@ async fn password() -> Result<()> { #[cfg_attr(feature = "async-std-runtime", async_std::test)] #[serial] async fn reconnection() -> Result<()> { + log_try_init(); let uri = format!( "redis://{}:{}/1", get_default_host(), @@ -62,7 +63,7 @@ async fn reconnection() -> Result<()> { .client_kill(ClientKillOptions::default().id(client_id)) .await?; - let client_info = client.client_info().await?; + let client_info = client.client_info().retry_on_error(true).await?; assert_eq!(1, client_info.db); Ok(()) diff --git a/src/tests/error.rs b/src/tests/error.rs index 5a4b9a0..e85850e 100644 --- a/src/tests/error.rs +++ b/src/tests/error.rs @@ -1,4 +1,9 @@ -use crate::{resp::cmd, tests::get_test_client, Error, RedisError, RedisErrorKind, Result}; +use crate::{ + commands::{ClientKillOptions, ConnectionCommands, StringCommands}, + resp::cmd, + tests::{get_default_config, get_test_client, get_test_client_with_config}, + Error, RedisError, RedisErrorKind, Result, +}; use serial_test::serial; use std::str::FromStr; @@ -48,6 +53,29 @@ fn ask_error() { )); } +#[cfg_attr(feature = "tokio-runtime", tokio::test)] +#[cfg_attr(feature = "async-std-runtime", async_std::test)] +#[serial] +async fn reconnection() -> Result<()> { + let mut config = get_default_config()?; + config.connection_name = "regular".to_string(); + let regular_client = get_test_client_with_config(config).await?; + + let mut config = get_default_config()?; + config.connection_name = "killer".to_string(); + let killer_client = get_test_client_with_config(config).await?; + + let client_id = regular_client.client_id().await?; + killer_client + .client_kill(ClientKillOptions::default().id(client_id)) + .await?; + + let result = regular_client.set("key", "value").await; + assert!(result.is_err()); + + Ok(()) +} + // #[cfg_attr(feature = "tokio-runtime", tokio::test)] // #[cfg_attr(feature = "async-std-runtime", async_std::test)] // #[serial] @@ -153,48 +181,52 @@ fn ask_error() { // Ok(()) // } -// #[cfg(debug_assertions)] -// #[cfg_attr(feature = "tokio-runtime", tokio::test)] -// #[cfg_attr(feature = "async-std-runtime", async_std::test)] -// #[serial] -// async fn kill_on_write() -> Result<()> { -// let client = get_test_client().await?; - -// // 3 reconnections -// let result = client -// .send( -// cmd("SET") -// .arg("key1") -// .arg("value1") -// .kill_connection_on_write(3), -// Some(true), -// ) -// .await; -// assert!(result.is_err()); - -// // 2 reconnections -// let result = client -// .send( -// cmd("SET") -// .arg("key2") -// .arg("value2") -// .kill_connection_on_write(2), -// Some(true), -// ) -// .await; -// assert!(result.is_ok()); - -// // 2 reconnections / no retry -// let result = client -// .send( -// cmd("SET") -// .arg("key3") -// .arg("value3") -// .kill_connection_on_write(2), -// Some(false), -// ) -// .await; -// assert!(result.is_err()); +#[cfg(debug_assertions)] +#[cfg_attr(feature = "tokio-runtime", tokio::test)] +#[cfg_attr(feature = "async-std-runtime", async_std::test)] +#[serial] +async fn kill_on_write() -> Result<()> { + use crate::client::ReconnectionConfig; + + let mut config = get_default_config()?; + config.reconnection = ReconnectionConfig::new_constant(0, 100); + let client = get_test_client_with_config(config).await?; + + // 3 reconnections + let result = client + .send( + cmd("SET") + .arg("key1") + .arg("value1") + .kill_connection_on_write(3), + Some(true), + ) + .await; + assert!(result.is_ok()); + + // 2 reconnections + let result = client + .send( + cmd("SET") + .arg("key2") + .arg("value2") + .kill_connection_on_write(2), + Some(true), + ) + .await; + assert!(result.is_ok()); + + // 2 reconnections / no retry + let result = client + .send( + cmd("SET") + .arg("key3") + .arg("value3") + .kill_connection_on_write(2), + Some(false), + ) + .await; + assert!(result.is_err()); -// Ok(()) -// } + Ok(()) +} diff --git a/src/tests/pub_sub_commands.rs b/src/tests/pub_sub_commands.rs index e826259..184eea4 100644 --- a/src/tests/pub_sub_commands.rs +++ b/src/tests/pub_sub_commands.rs @@ -1,11 +1,14 @@ use crate::{ - client::{Client, IntoConfig}, + client::{Client, IntoConfig, ReconnectionConfig}, commands::{ ClientKillOptions, ClusterCommands, ClusterShardResult, ConnectionCommands, FlushingMode, ListCommands, PubSubChannelsOptions, PubSubCommands, ServerCommands, StringCommands, }, spawn, - tests::{get_cluster_test_client, get_default_addr, get_test_client, log_try_init}, + tests::{ + get_cluster_test_client, get_default_addr, get_default_config, get_test_client, + get_test_client_with_config, log_try_init, + }, Result, }; use futures_util::{FutureExt, StreamExt, TryStreamExt}; @@ -535,7 +538,9 @@ async fn additional_sub() -> Result<()> { #[cfg_attr(feature = "async-std-runtime", async_std::test)] #[serial] async fn auto_resubscribe() -> Result<()> { - let pub_sub_client = get_test_client().await?; + let mut config = get_default_config()?; + config.reconnection = ReconnectionConfig::new_constant(0, 100); + let pub_sub_client = get_test_client_with_config(config).await?; let regular_client = get_test_client().await?; let pub_sub_client_id = pub_sub_client.client_id().await?; diff --git a/src/tests/server_commands.rs b/src/tests/server_commands.rs index d9e1056..7c75ad5 100644 --- a/src/tests/server_commands.rs +++ b/src/tests/server_commands.rs @@ -1,5 +1,5 @@ use crate::{ - client::Client, + client::{Client, ReconnectionConfig}, commands::{ AclCatOptions, AclDryRunOptions, AclGenPassOptions, AclLogOptions, BlockingCommands, ClientInfo, ClientKillOptions, CommandDoc, CommandHistogram, CommandListOptions, @@ -9,7 +9,7 @@ use crate::{ }, resp::{cmd, Value}, spawn, - tests::{get_sentinel_test_client, get_test_client}, + tests::{get_default_config, get_sentinel_test_client, get_test_client, get_test_client_with_config}, Error, RedisError, RedisErrorKind, Result, }; use futures_util::StreamExt; @@ -966,7 +966,9 @@ async fn monitor() -> Result<()> { #[cfg_attr(feature = "async-std-runtime", async_std::test)] #[serial] async fn auto_remonitor() -> Result<()> { - let client = get_test_client().await?; + let mut config = get_default_config()?; + config.reconnection = ReconnectionConfig::new_constant(0, 100); + let client = get_test_client_with_config(config).await?; client.flushdb(FlushingMode::Sync).await?; let client2 = get_test_client().await?;