From e5a1e0c0173a772cd9aa71571dbc6cad2a629ff2 Mon Sep 17 00:00:00 2001 From: Yury Yarashevich Date: Fri, 29 Dec 2023 00:11:04 +0100 Subject: [PATCH] Better transaction support in cluster mode - if all the commands of a transaction are not executed on the same node, the transaction will fail cleanly. The test is done in Rustis, before actually sending the commands to the Redis cluster --- src/network/cluster_connection.rs | 109 ++++++++++++++++++--------- src/network/connection.rs | 3 +- src/network/network_handler.rs | 2 +- src/network/sentinel_connection.rs | 3 +- src/network/standalone_connection.rs | 3 +- src/tests/pipeline.rs | 23 +++++- src/tests/transaction.rs | 39 +++++++++- 7 files changed, 141 insertions(+), 41 deletions(-) diff --git a/src/network/cluster_connection.rs b/src/network/cluster_connection.rs index ef6fe18..ad25e5e 100644 --- a/src/network/cluster_connection.rs +++ b/src/network/cluster_connection.rs @@ -152,15 +152,7 @@ impl ClusterConnection { ))); }; - let command_name = command_info.name.to_string(); - - let request_policy = command_info.command_tips.iter().find_map(|tip| { - if let CommandTip::RequestPolicy(request_policy) = tip { - Some(request_policy) - } else { - None - } - }); + let command_name = command_info.name.clone(); let node_idx = self.get_random_node_index(); let keys = self @@ -171,6 +163,14 @@ impl ClusterConnection { debug!("[{}] keys: {keys:?}, slots: {slots:?}", self.tag); + let request_policy = command_info.command_tips.iter().find_map(|tip| { + if let CommandTip::RequestPolicy(request_policy) = tip { + Some(request_policy) + } else { + None + } + }); + if let Some(request_policy) = request_policy { match request_policy { RequestPolicy::AllNodes => { @@ -205,7 +205,7 @@ impl ClusterConnection { pub async fn write_batch( &mut self, - commands: impl Iterator, + commands: SmallVec<[&mut Command; 10]>, retry_reasons: &[RetryReason], ) -> Result<()> { if retry_reasons.iter().any(|r| { @@ -231,8 +231,39 @@ impl ClusterConnection { }) .collect::>(); - for command in commands { - self.internal_write(command, &ask_reasons).await?; + if commands.len() > 1 && commands[0].name == "MULTI" { + let node_idx = self.get_random_node_index(); + let keys = self + .command_info_manager + .extract_keys(commands[1], &mut self.nodes[node_idx].connection) + .await?; + let slots = Self::hash_slots(&keys); + if slots.is_empty() || !slots.windows(2).all(|s| s[0] == s[1]) { + return Err(Error::Client(format!( + "[{}] Cannot execute transaction with mismatched key slots", + self.tag + ))); + } + let ref_slot = slots[0]; + + for command in commands { + let keys = self + .command_info_manager + .extract_keys(command, &mut self.nodes[node_idx].connection) + .await?; + self.no_request_policy( + command, + command.name.to_string(), + keys, + SmallVec::from_slice(&[ref_slot]), + &ask_reasons, + ) + .await?; + } + } else { + for command in commands { + self.internal_write(command, &ask_reasons).await?; + } } Ok(()) @@ -308,7 +339,7 @@ impl ClusterConnection { Ok(()) } - /// The client should execute the command on several shards. + /// The client should execute the command on multiple shards. /// The shards that execute the command are determined by the hash slots of its input key name arguments. /// Examples for such commands include MSET, MGET and DEL. /// However, note that SUNIONSTORE isn't considered as multi_shard because all of its keys must belong to the same hash slot. @@ -486,24 +517,28 @@ impl ClusterConnection { let node_id = &self.nodes[node_idx].id; - let Some((req_idx, sub_req_idx)) = self - .pending_requests - .iter() - .enumerate() - .find_map(|(req_idx, req)| { - let sub_req_idx = req - .sub_requests - .iter() - .position(|sr| sr.node_id == *node_id && sr.result.is_none())?; - Some((req_idx, sub_req_idx)) - }) else { - log::error!("[{}] Received unexpected message: {result:?} from {}", - self.tag, self.nodes[node_idx].connection.tag()); - return Some(Err(Error::Client(format!( - "[{}] Received unexpected message", - self.tag - )))); - }; + let Some((req_idx, sub_req_idx)) = + self.pending_requests + .iter() + .enumerate() + .find_map(|(req_idx, req)| { + let sub_req_idx = req + .sub_requests + .iter() + .position(|sr| sr.node_id == *node_id && sr.result.is_none())?; + Some((req_idx, sub_req_idx)) + }) + else { + log::error!( + "[{}] Received unexpected message: {result:?} from {}", + self.tag, + self.nodes[node_idx].connection.tag() + ); + return Some(Err(Error::Client(format!( + "[{}] Received unexpected message", + self.tag + )))); + }; self.pending_requests[req_idx].sub_requests[sub_req_idx].result = Some(result); trace!( @@ -768,7 +803,8 @@ impl ClusterConnection { let mut deserializer = RespDeserializer::new(resp_buf); let Ok(chunks) = deserializer.array_chunks() else { return Some(Err(Error::Client(format!( - "[{}] Unexpected result {sub_result:?}", self.tag + "[{}] Unexpected result {sub_result:?}", + self.tag )))); }; @@ -795,7 +831,8 @@ impl ClusterConnection { let mut deserializer = RespDeserializer::new(resp_buf); let Ok(chunks) = deserializer.array_chunks() else { return Some(Err(Error::Client(format!( - "[{}] Unexpected result {sub_result:?}", self.tag + "[{}] Unexpected result {sub_result:?}", + self.tag )))); }; @@ -903,7 +940,8 @@ impl ClusterConnection { let mut slot_ranges = Vec::::new(); for shard_info in shard_info_list.into_iter() { - let Some(master_info) = shard_info.nodes.into_iter().find(|n| n.role == "master") else { + let Some(master_info) = shard_info.nodes.into_iter().find(|n| n.role == "master") + else { return Err(Error::Client("Cluster misconfiguration".to_owned())); }; let master_id: NodeId = master_info.id.as_str().into(); @@ -1015,7 +1053,8 @@ impl ClusterConnection { for mut shard_info in shard_info_list { // ensure that the first node is master if shard_info.nodes[0].role != "master" { - let Some(master_idx) = shard_info.nodes.iter().position(|n| n.role == "master") else { + let Some(master_idx) = shard_info.nodes.iter().position(|n| n.role == "master") + else { return Err(Error::Client("Cluster misconfiguration".to_owned())); }; diff --git a/src/network/connection.rs b/src/network/connection.rs index c1040c5..d836b12 100644 --- a/src/network/connection.rs +++ b/src/network/connection.rs @@ -6,6 +6,7 @@ use crate::{ StandaloneConnection, }; use serde::de::DeserializeOwned; +use smallvec::SmallVec; use std::future::IntoFuture; pub enum Connection { @@ -42,7 +43,7 @@ impl Connection { #[inline] pub async fn write_batch( &mut self, - commands: impl Iterator, + commands: SmallVec::<[&mut Command; 10]>, retry_reasons: &[RetryReason], ) -> Result<()> { match self { diff --git a/src/network/network_handler.rs b/src/network/network_handler.rs index 9369746..7953bc7 100644 --- a/src/network/network_handler.rs +++ b/src/network/network_handler.rs @@ -328,7 +328,7 @@ impl NetworkHandler { if let Err(e) = self .connection - .write_batch(commands_to_write.into_iter(), &retry_reasons) + .write_batch(commands_to_write, &retry_reasons) .await { error!("[{}] Error while writing batch: {e}", self.tag); diff --git a/src/network/sentinel_connection.rs b/src/network/sentinel_connection.rs index 1013a73..bc399f3 100644 --- a/src/network/sentinel_connection.rs +++ b/src/network/sentinel_connection.rs @@ -5,6 +5,7 @@ use crate::{ sleep, Error, Result, RetryReason, StandaloneConnection, }; use log::debug; +use smallvec::SmallVec; pub struct SentinelConnection { sentinel_config: SentinelConfig, @@ -21,7 +22,7 @@ impl SentinelConnection { #[inline] pub async fn write_batch( &mut self, - commands: impl Iterator, + commands: SmallVec::<[&mut Command; 10]>, retry_reasons: &[RetryReason], ) -> Result<()> { self.inner_connection diff --git a/src/network/standalone_connection.rs b/src/network/standalone_connection.rs index 0853efb..1f3bcb0 100644 --- a/src/network/standalone_connection.rs +++ b/src/network/standalone_connection.rs @@ -12,6 +12,7 @@ use bytes::BytesMut; use futures_util::{SinkExt, StreamExt}; use log::{debug, log_enabled, Level}; use serde::de::DeserializeOwned; +use smallvec::SmallVec; use std::future::IntoFuture; use tokio::io::AsyncWriteExt; use tokio_util::codec::{Encoder, FramedRead, FramedWrite}; @@ -99,7 +100,7 @@ impl StandaloneConnection { pub async fn write_batch( &mut self, - commands: impl Iterator, + commands: SmallVec::<[&mut Command; 10]>, _retry_reasons: &[RetryReason], ) -> Result<()> { self.buffer.clear(); diff --git a/src/tests/pipeline.rs b/src/tests/pipeline.rs index bc9e7a9..d2e12ed 100644 --- a/src/tests/pipeline.rs +++ b/src/tests/pipeline.rs @@ -2,7 +2,7 @@ use crate::{ client::BatchPreparedCommand, commands::{FlushingMode, ServerCommands, StringCommands}, resp::{cmd, Value}, - tests::get_test_client, + tests::{get_test_client, get_cluster_test_client}, Result, }; use serial_test::serial; @@ -46,3 +46,24 @@ async fn error() -> Result<()> { Ok(()) } + + +#[cfg_attr(feature = "tokio-runtime", tokio::test)] +#[cfg_attr(feature = "async-std-runtime", async_std::test)] +#[serial] +async fn pipeline_on_cluster() -> Result<()> { + let client = get_cluster_test_client().await?; + client.flushall(FlushingMode::Sync).await?; + + let mut pipeline = client.create_pipeline(); + pipeline.set("key1", "value1").forget(); + pipeline.set("key2", "value2").forget(); + pipeline.get::<_, ()>("key1").queue(); + pipeline.get::<_, ()>("key2").queue(); + + let (value1, value2): (String, String) = pipeline.execute().await?; + assert_eq!("value1", value1); + assert_eq!("value2", value2); + + Ok(()) +} \ No newline at end of file diff --git a/src/tests/transaction.rs b/src/tests/transaction.rs index 48cca7e..bb0d29d 100644 --- a/src/tests/transaction.rs +++ b/src/tests/transaction.rs @@ -2,7 +2,7 @@ use crate::{ client::BatchPreparedCommand, commands::{FlushingMode, ListCommands, ServerCommands, StringCommands, TransactionCommands}, resp::cmd, - tests::get_test_client, + tests::{get_test_client, get_cluster_test_client}, Error, RedisError, RedisErrorKind, Result, }; use serial_test::serial; @@ -160,3 +160,40 @@ async fn transaction_discard() -> Result<()> { Ok(()) } + +#[cfg_attr(feature = "tokio-runtime", tokio::test)] +#[cfg_attr(feature = "async-std-runtime", async_std::test)] +#[serial] +async fn transaction_on_cluster_connection_with_keys_with_same_slot() -> Result<()> { + let client = get_cluster_test_client().await?; + client.flushall(FlushingMode::Sync).await?; + + let mut transaction = client.create_transaction(); + + transaction.mset([("{hash}key1", "value1"), ("{hash}key2", "value2")]).queue(); + transaction.get::<_, String>("{hash}key1").queue(); + transaction.get::<_, String>("{hash}key2").queue(); + let ((), val1, val2): ((), String, String) = transaction.execute().await.unwrap(); + assert_eq!("value1", val1); + assert_eq!("value2", val2); + + Ok(()) +} + +#[cfg_attr(feature = "tokio-runtime", tokio::test)] +#[cfg_attr(feature = "async-std-runtime", async_std::test)] +#[serial] +async fn transaction_on_cluster_connection_with_keys_with_different_slots() -> Result<()> { + let client = get_cluster_test_client().await?; + client.flushall(FlushingMode::Sync).await?; + + let mut transaction = client.create_transaction(); + + transaction.mset([("key1", "value1"), ("key2", "value2")]).queue(); + transaction.get::<_, String>("key1").queue(); + transaction.get::<_, String>("key2").queue(); + let result: Result<((), String, String)> = transaction.execute().await; + assert!(result.is_err()); + + Ok(()) +}