Skip to content

Commit

Permalink
Better transaction support in cluster mode
Browse files Browse the repository at this point in the history
- if all the commands of a transaction are not executed on the same node, the transaction will fail cleanly. The test is done in Rustis, before actually sending the commands to the Redis cluster
  • Loading branch information
mstyura authored Dec 28, 2023
1 parent a7aafba commit e5a1e0c
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 41 deletions.
109 changes: 74 additions & 35 deletions src/network/cluster_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,15 +152,7 @@ impl ClusterConnection {
)));
};

let command_name = command_info.name.to_string();

let request_policy = command_info.command_tips.iter().find_map(|tip| {
if let CommandTip::RequestPolicy(request_policy) = tip {
Some(request_policy)
} else {
None
}
});
let command_name = command_info.name.clone();

let node_idx = self.get_random_node_index();
let keys = self
Expand All @@ -171,6 +163,14 @@ impl ClusterConnection {

debug!("[{}] keys: {keys:?}, slots: {slots:?}", self.tag);

let request_policy = command_info.command_tips.iter().find_map(|tip| {
if let CommandTip::RequestPolicy(request_policy) = tip {
Some(request_policy)
} else {
None
}
});

if let Some(request_policy) = request_policy {
match request_policy {
RequestPolicy::AllNodes => {
Expand Down Expand Up @@ -205,7 +205,7 @@ impl ClusterConnection {

pub async fn write_batch(
&mut self,
commands: impl Iterator<Item = &mut Command>,
commands: SmallVec<[&mut Command; 10]>,
retry_reasons: &[RetryReason],
) -> Result<()> {
if retry_reasons.iter().any(|r| {
Expand All @@ -231,8 +231,39 @@ impl ClusterConnection {
})
.collect::<Vec<_>>();

for command in commands {
self.internal_write(command, &ask_reasons).await?;
if commands.len() > 1 && commands[0].name == "MULTI" {
let node_idx = self.get_random_node_index();
let keys = self
.command_info_manager
.extract_keys(commands[1], &mut self.nodes[node_idx].connection)
.await?;
let slots = Self::hash_slots(&keys);
if slots.is_empty() || !slots.windows(2).all(|s| s[0] == s[1]) {
return Err(Error::Client(format!(
"[{}] Cannot execute transaction with mismatched key slots",
self.tag
)));
}
let ref_slot = slots[0];

for command in commands {
let keys = self
.command_info_manager
.extract_keys(command, &mut self.nodes[node_idx].connection)
.await?;
self.no_request_policy(
command,
command.name.to_string(),
keys,
SmallVec::from_slice(&[ref_slot]),
&ask_reasons,
)
.await?;
}
} else {
for command in commands {
self.internal_write(command, &ask_reasons).await?;
}
}

Ok(())
Expand Down Expand Up @@ -308,7 +339,7 @@ impl ClusterConnection {
Ok(())
}

/// The client should execute the command on several shards.
/// The client should execute the command on multiple shards.
/// The shards that execute the command are determined by the hash slots of its input key name arguments.
/// Examples for such commands include MSET, MGET and DEL.
/// However, note that SUNIONSTORE isn't considered as multi_shard because all of its keys must belong to the same hash slot.
Expand Down Expand Up @@ -486,24 +517,28 @@ impl ClusterConnection {

let node_id = &self.nodes[node_idx].id;

let Some((req_idx, sub_req_idx)) = self
.pending_requests
.iter()
.enumerate()
.find_map(|(req_idx, req)| {
let sub_req_idx = req
.sub_requests
.iter()
.position(|sr| sr.node_id == *node_id && sr.result.is_none())?;
Some((req_idx, sub_req_idx))
}) else {
log::error!("[{}] Received unexpected message: {result:?} from {}",
self.tag, self.nodes[node_idx].connection.tag());
return Some(Err(Error::Client(format!(
"[{}] Received unexpected message",
self.tag
))));
};
let Some((req_idx, sub_req_idx)) =
self.pending_requests
.iter()
.enumerate()
.find_map(|(req_idx, req)| {
let sub_req_idx = req
.sub_requests
.iter()
.position(|sr| sr.node_id == *node_id && sr.result.is_none())?;
Some((req_idx, sub_req_idx))
})
else {
log::error!(
"[{}] Received unexpected message: {result:?} from {}",
self.tag,
self.nodes[node_idx].connection.tag()
);
return Some(Err(Error::Client(format!(
"[{}] Received unexpected message",
self.tag
))));
};

self.pending_requests[req_idx].sub_requests[sub_req_idx].result = Some(result);
trace!(
Expand Down Expand Up @@ -768,7 +803,8 @@ impl ClusterConnection {
let mut deserializer = RespDeserializer::new(resp_buf);
let Ok(chunks) = deserializer.array_chunks() else {
return Some(Err(Error::Client(format!(
"[{}] Unexpected result {sub_result:?}", self.tag
"[{}] Unexpected result {sub_result:?}",
self.tag
))));
};

Expand All @@ -795,7 +831,8 @@ impl ClusterConnection {
let mut deserializer = RespDeserializer::new(resp_buf);
let Ok(chunks) = deserializer.array_chunks() else {
return Some(Err(Error::Client(format!(
"[{}] Unexpected result {sub_result:?}", self.tag
"[{}] Unexpected result {sub_result:?}",
self.tag
))));
};

Expand Down Expand Up @@ -903,7 +940,8 @@ impl ClusterConnection {
let mut slot_ranges = Vec::<SlotRange>::new();

for shard_info in shard_info_list.into_iter() {
let Some(master_info) = shard_info.nodes.into_iter().find(|n| n.role == "master") else {
let Some(master_info) = shard_info.nodes.into_iter().find(|n| n.role == "master")
else {
return Err(Error::Client("Cluster misconfiguration".to_owned()));
};
let master_id: NodeId = master_info.id.as_str().into();
Expand Down Expand Up @@ -1015,7 +1053,8 @@ impl ClusterConnection {
for mut shard_info in shard_info_list {
// ensure that the first node is master
if shard_info.nodes[0].role != "master" {
let Some(master_idx) = shard_info.nodes.iter().position(|n| n.role == "master") else {
let Some(master_idx) = shard_info.nodes.iter().position(|n| n.role == "master")
else {
return Err(Error::Client("Cluster misconfiguration".to_owned()));
};

Expand Down
3 changes: 2 additions & 1 deletion src/network/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::{
StandaloneConnection,
};
use serde::de::DeserializeOwned;
use smallvec::SmallVec;
use std::future::IntoFuture;

pub enum Connection {
Expand Down Expand Up @@ -42,7 +43,7 @@ impl Connection {
#[inline]
pub async fn write_batch(
&mut self,
commands: impl Iterator<Item = &mut Command>,
commands: SmallVec::<[&mut Command; 10]>,
retry_reasons: &[RetryReason],
) -> Result<()> {
match self {
Expand Down
2 changes: 1 addition & 1 deletion src/network/network_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ impl NetworkHandler {

if let Err(e) = self
.connection
.write_batch(commands_to_write.into_iter(), &retry_reasons)
.write_batch(commands_to_write, &retry_reasons)
.await
{
error!("[{}] Error while writing batch: {e}", self.tag);
Expand Down
3 changes: 2 additions & 1 deletion src/network/sentinel_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::{
sleep, Error, Result, RetryReason, StandaloneConnection,
};
use log::debug;
use smallvec::SmallVec;

pub struct SentinelConnection {
sentinel_config: SentinelConfig,
Expand All @@ -21,7 +22,7 @@ impl SentinelConnection {
#[inline]
pub async fn write_batch(
&mut self,
commands: impl Iterator<Item = &mut Command>,
commands: SmallVec::<[&mut Command; 10]>,
retry_reasons: &[RetryReason],
) -> Result<()> {
self.inner_connection
Expand Down
3 changes: 2 additions & 1 deletion src/network/standalone_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use bytes::BytesMut;
use futures_util::{SinkExt, StreamExt};
use log::{debug, log_enabled, Level};
use serde::de::DeserializeOwned;
use smallvec::SmallVec;
use std::future::IntoFuture;
use tokio::io::AsyncWriteExt;
use tokio_util::codec::{Encoder, FramedRead, FramedWrite};
Expand Down Expand Up @@ -99,7 +100,7 @@ impl StandaloneConnection {

pub async fn write_batch(
&mut self,
commands: impl Iterator<Item = &mut Command>,
commands: SmallVec::<[&mut Command; 10]>,
_retry_reasons: &[RetryReason],
) -> Result<()> {
self.buffer.clear();
Expand Down
23 changes: 22 additions & 1 deletion src/tests/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::{
client::BatchPreparedCommand,
commands::{FlushingMode, ServerCommands, StringCommands},
resp::{cmd, Value},
tests::get_test_client,
tests::{get_test_client, get_cluster_test_client},
Result,
};
use serial_test::serial;
Expand Down Expand Up @@ -46,3 +46,24 @@ async fn error() -> Result<()> {

Ok(())
}


#[cfg_attr(feature = "tokio-runtime", tokio::test)]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
#[serial]
async fn pipeline_on_cluster() -> Result<()> {
let client = get_cluster_test_client().await?;
client.flushall(FlushingMode::Sync).await?;

let mut pipeline = client.create_pipeline();
pipeline.set("key1", "value1").forget();
pipeline.set("key2", "value2").forget();
pipeline.get::<_, ()>("key1").queue();
pipeline.get::<_, ()>("key2").queue();

let (value1, value2): (String, String) = pipeline.execute().await?;
assert_eq!("value1", value1);
assert_eq!("value2", value2);

Ok(())
}
39 changes: 38 additions & 1 deletion src/tests/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::{
client::BatchPreparedCommand,
commands::{FlushingMode, ListCommands, ServerCommands, StringCommands, TransactionCommands},
resp::cmd,
tests::get_test_client,
tests::{get_test_client, get_cluster_test_client},
Error, RedisError, RedisErrorKind, Result,
};
use serial_test::serial;
Expand Down Expand Up @@ -160,3 +160,40 @@ async fn transaction_discard() -> Result<()> {

Ok(())
}

#[cfg_attr(feature = "tokio-runtime", tokio::test)]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
#[serial]
async fn transaction_on_cluster_connection_with_keys_with_same_slot() -> Result<()> {
let client = get_cluster_test_client().await?;
client.flushall(FlushingMode::Sync).await?;

let mut transaction = client.create_transaction();

transaction.mset([("{hash}key1", "value1"), ("{hash}key2", "value2")]).queue();
transaction.get::<_, String>("{hash}key1").queue();
transaction.get::<_, String>("{hash}key2").queue();
let ((), val1, val2): ((), String, String) = transaction.execute().await.unwrap();
assert_eq!("value1", val1);
assert_eq!("value2", val2);

Ok(())
}

#[cfg_attr(feature = "tokio-runtime", tokio::test)]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
#[serial]
async fn transaction_on_cluster_connection_with_keys_with_different_slots() -> Result<()> {
let client = get_cluster_test_client().await?;
client.flushall(FlushingMode::Sync).await?;

let mut transaction = client.create_transaction();

transaction.mset([("key1", "value1"), ("key2", "value2")]).queue();
transaction.get::<_, String>("key1").queue();
transaction.get::<_, String>("key2").queue();
let result: Result<((), String, String)> = transaction.execute().await;
assert!(result.is_err());

Ok(())
}

0 comments on commit e5a1e0c

Please sign in to comment.