From a8b78c2cb48c4fbdec08bd4c458af3b76146b035 Mon Sep 17 00:00:00 2001 From: Bar Shaul <88437685+barshaul@users.noreply.github.com> Date: Wed, 24 Apr 2024 11:56:57 +0300 Subject: [PATCH] Fixed blocking command to be timed out based on the specified command argument (#1283) --- CHANGELOG.md | 1 + glide-core/src/client/mod.rs | 283 +++++++++++++++++- .../src/client/reconnecting_connection.rs | 2 +- glide-core/tests/test_client.rs | 96 +++++- python/python/tests/test_async_client.py | 12 +- 5 files changed, 369 insertions(+), 25 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 041e0c877c..ca47682d35 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ #### Fixes * Python: Fix typing error "‘type’ object is not subscriptable" ([#1203](https://github.com/aws/glide-for-redis/pull/1203)) +* Core: Fixed blocking commands to use the specified timeout from the command argument ([#1283](https://github.com/aws/glide-for-redis/pull/1283)) ## 0.3.3 (2024-03-28) diff --git a/glide-core/src/client/mod.rs b/glide-core/src/client/mod.rs index 645bef1118..c679ad3397 100644 --- a/glide-core/src/client/mod.rs +++ b/glide-core/src/client/mod.rs @@ -8,9 +8,9 @@ use futures::FutureExt; use logger_core::log_info; use redis::aio::ConnectionLike; use redis::cluster_async::ClusterConnection; -use redis::cluster_routing::{RoutingInfo, SingleNodeRoutingInfo}; -use redis::RedisResult; +use redis::cluster_routing::{Routable, RoutingInfo, SingleNodeRoutingInfo}; use redis::{Cmd, ErrorKind, Value}; +use redis::{RedisError, RedisResult}; pub use standalone_client::StandaloneClient; use std::io; use std::ops::Deref; @@ -95,13 +95,122 @@ pub struct Client { } async fn run_with_timeout( - timeout: Duration, + timeout: Option, future: impl futures::Future> + Send, ) -> redis::RedisResult { - tokio::time::timeout(timeout, future) - .await - .map_err(|_| io::Error::from(io::ErrorKind::TimedOut).into()) - .and_then(|res| res) + match timeout { + Some(duration) => tokio::time::timeout(duration, future) + .await + .map_err(|_| io::Error::from(io::ErrorKind::TimedOut).into()) + .and_then(|res| res), + None => future.await, + } +} + +/// Extension to the request timeout for blocking commands to ensure we won't return with timeout error before the server responded +const BLOCKING_CMD_TIMEOUT_EXTENSION: f64 = 0.5; // seconds + +enum TimeUnit { + Milliseconds = 1000, + Seconds = 1, +} + +/// Enumeration representing different request timeout options. +#[derive(Default, PartialEq, Debug)] +enum RequestTimeoutOption { + // Indicates no timeout should be set for the request. + NoTimeout, + // Indicates the request timeout should be based on the client's configured timeout. + #[default] + ClientConfig, + // Indicates the request timeout should be based on the timeout specified in the blocking command. + BlockingCommand(Duration), +} + +/// Helper function for parsing a timeout argument to f64. +/// Attempts to parse the argument found at `timeout_idx` from bytes into an f64. +fn parse_timeout_to_f64(cmd: &Cmd, timeout_idx: usize) -> RedisResult { + let create_err = |err_msg| { + RedisError::from(( + ErrorKind::ResponseError, + err_msg, + format!( + "Expected to find timeout value at index {:?} for command {:?}. Recieved command = {:?}", + timeout_idx, + std::str::from_utf8(&cmd.command().unwrap_or_default()), + std::str::from_utf8(&cmd.get_packed_command()) + ), + )) + }; + let timeout_bytes = cmd + .arg_idx(timeout_idx) + .ok_or(create_err("Couldn't find timeout index"))?; + let timeout_str = std::str::from_utf8(timeout_bytes) + .map_err(|_| create_err("Failed to parse the timeout argument to string"))?; + timeout_str + .parse::() + .map_err(|_| create_err("Failed to parse the timeout argument to f64")) +} + +/// Attempts to get the timeout duration from the command argument at `timeout_idx`. +/// If the argument can be parsed into a duration, it returns the duration in seconds with BlockingCmdTimeout. +/// If the timeout argument value is zero, NoTimeout will be returned. Otherwise, ClientConfigTimeout is returned. +fn get_timeout_from_cmd_arg( + cmd: &Cmd, + timeout_idx: usize, + time_unit: TimeUnit, +) -> RedisResult { + let timeout_secs = parse_timeout_to_f64(cmd, timeout_idx)? / ((time_unit as i32) as f64); + if timeout_secs < 0.0 { + // Timeout cannot be negative, return the client's configured request timeout + Err(RedisError::from(( + ErrorKind::ResponseError, + "Timeout cannot be negative", + format!("Recieved timeout={:?}", timeout_secs), + ))) + } else if timeout_secs == 0.0 { + // `0` means we should set no timeout + Ok(RequestTimeoutOption::NoTimeout) + } else { + // We limit the maximum timeout due to restrictions imposed by Redis and the Duration crate + if timeout_secs > u32::MAX as f64 { + Err(RedisError::from(( + ErrorKind::ResponseError, + "Timeout is out of range, max timeout is 2^32 - 1 (u32::MAX)", + format!("Recieved timeout={:?}", timeout_secs), + ))) + } else { + // Extend the request timeout to ensure we don't timeout before receiving a response from the server. + Ok(RequestTimeoutOption::BlockingCommand( + Duration::from_secs_f64( + (timeout_secs + BLOCKING_CMD_TIMEOUT_EXTENSION).min(u32::MAX as f64), + ), + )) + } + } +} + +fn get_request_timeout(cmd: &Cmd, default_timeout: Duration) -> RedisResult> { + let command = cmd.command().unwrap_or_default(); + let timeout = match command.as_slice() { + b"BLPOP" | b"BRPOP" | b"BLMOVE" | b"BZPOPMAX" | b"BZPOPMIN" | b"BRPOPLPUSH" => { + get_timeout_from_cmd_arg(cmd, cmd.args_iter().len() - 1, TimeUnit::Seconds) + } + b"BLMPOP" | b"BZMPOP" => get_timeout_from_cmd_arg(cmd, 1, TimeUnit::Seconds), + b"XREAD" | b"XREADGROUP" => cmd + .position(b"BLOCK") + .map(|idx| get_timeout_from_cmd_arg(cmd, idx + 1, TimeUnit::Milliseconds)) + .unwrap_or(Ok(RequestTimeoutOption::ClientConfig)), + _ => Ok(RequestTimeoutOption::ClientConfig), + }?; + + match timeout { + RequestTimeoutOption::NoTimeout => Ok(None), + RequestTimeoutOption::ClientConfig => Ok(Some(default_timeout)), + RequestTimeoutOption::BlockingCommand(blocking_cmd_duration) => { + Ok(Some(blocking_cmd_duration)) + } + } } impl Client { @@ -111,7 +220,13 @@ impl Client { routing: Option, ) -> redis::RedisFuture<'a, Value> { let expected_type = expected_type_for_cmd(cmd); - run_with_timeout(self.request_timeout, async move { + let request_timeout = match get_request_timeout(cmd, self.request_timeout) { + Ok(request_timeout) => request_timeout, + Err(err) => { + return async { Err(err) }.boxed(); + } + }; + run_with_timeout(request_timeout, async move { match self.internal_client { ClientWrapper::Standalone(ref mut client) => client.send_command(cmd).await, @@ -189,7 +304,7 @@ impl Client { ) -> redis::RedisFuture<'a, Value> { let command_count = pipeline.cmd_iter().count(); let offset = command_count + 1; - run_with_timeout(self.request_timeout, async move { + run_with_timeout(Some(self.request_timeout), async move { let values = match self.internal_client { ClientWrapper::Standalone(ref mut client) => { client.send_pipeline(pipeline, offset, 1).await @@ -472,3 +587,153 @@ impl GlideClientForTests for StandaloneClient { self.send_command(cmd).boxed() } } + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use redis::Cmd; + + use crate::client::{ + get_request_timeout, RequestTimeoutOption, TimeUnit, BLOCKING_CMD_TIMEOUT_EXTENSION, + }; + + use super::get_timeout_from_cmd_arg; + + #[test] + fn test_get_timeout_from_cmd_returns_correct_duration_int() { + let mut cmd = Cmd::new(); + cmd.arg("BLPOP").arg("key1").arg("key2").arg("5"); + let result = get_timeout_from_cmd_arg(&cmd, cmd.args_iter().len() - 1, TimeUnit::Seconds); + assert!(result.is_ok()); + assert_eq!( + result.unwrap(), + RequestTimeoutOption::BlockingCommand(Duration::from_secs_f64( + 5.0 + BLOCKING_CMD_TIMEOUT_EXTENSION + )) + ); + } + + #[test] + fn test_get_timeout_from_cmd_returns_correct_duration_float() { + let mut cmd = Cmd::new(); + cmd.arg("BLPOP").arg("key1").arg("key2").arg(0.5); + let result = get_timeout_from_cmd_arg(&cmd, cmd.args_iter().len() - 1, TimeUnit::Seconds); + assert!(result.is_ok()); + assert_eq!( + result.unwrap(), + RequestTimeoutOption::BlockingCommand(Duration::from_secs_f64( + 0.5 + BLOCKING_CMD_TIMEOUT_EXTENSION + )) + ); + } + + #[test] + fn test_get_timeout_from_cmd_returns_correct_duration_milliseconds() { + let mut cmd = Cmd::new(); + cmd.arg("XREAD").arg("BLOCK").arg("500").arg("key"); + let result = get_timeout_from_cmd_arg(&cmd, 2, TimeUnit::Milliseconds); + assert!(result.is_ok()); + assert_eq!( + result.unwrap(), + RequestTimeoutOption::BlockingCommand(Duration::from_secs_f64( + 0.5 + BLOCKING_CMD_TIMEOUT_EXTENSION + )) + ); + } + + #[test] + fn test_get_timeout_from_cmd_returns_err_when_timeout_isnt_passed() { + let mut cmd = Cmd::new(); + cmd.arg("BLPOP").arg("key1").arg("key2").arg("key3"); + let result = get_timeout_from_cmd_arg(&cmd, cmd.args_iter().len() - 1, TimeUnit::Seconds); + assert!(result.is_err()); + let err = result.unwrap_err(); + println!("{:?}", err); + assert!(err.to_string().to_lowercase().contains("index"), "{err}"); + } + + #[test] + fn test_get_timeout_from_cmd_returns_err_when_timeout_is_larger_than_u32_max() { + let mut cmd = Cmd::new(); + cmd.arg("BLPOP") + .arg("key1") + .arg("key2") + .arg(u32::MAX as u64 + 1); + let result = get_timeout_from_cmd_arg(&cmd, cmd.args_iter().len() - 1, TimeUnit::Seconds); + assert!(result.is_err()); + let err = result.unwrap_err(); + println!("{:?}", err); + assert!(err.to_string().to_lowercase().contains("u32"), "{err}"); + } + + #[test] + fn test_get_timeout_from_cmd_returns_err_when_timeout_is_negative() { + let mut cmd = Cmd::new(); + cmd.arg("BLPOP").arg("key1").arg("key2").arg(-1); + let result = get_timeout_from_cmd_arg(&cmd, cmd.args_iter().len() - 1, TimeUnit::Seconds); + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!(err.to_string().to_lowercase().contains("negative"), "{err}"); + } + + #[test] + fn test_get_timeout_from_cmd_returns_no_timeout_when_zero_is_passed() { + let mut cmd = Cmd::new(); + cmd.arg("BLPOP").arg("key1").arg("key2").arg(0); + let result = get_timeout_from_cmd_arg(&cmd, cmd.args_iter().len() - 1, TimeUnit::Seconds); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), RequestTimeoutOption::NoTimeout,); + } + + #[test] + fn test_get_request_timeout_with_blocking_command_returns_cmd_arg_timeout() { + let mut cmd = Cmd::new(); + cmd.arg("BLPOP").arg("key1").arg("key2").arg("500"); + let result = get_request_timeout(&cmd, Duration::from_millis(100)); + assert!(result.is_ok()); + assert_eq!( + result.unwrap(), + Some(Duration::from_secs_f64( + 500.0 + BLOCKING_CMD_TIMEOUT_EXTENSION + )) + ); + + let mut cmd = Cmd::new(); + cmd.arg("XREADGROUP").arg("BLOCK").arg("500").arg("key"); + let result = get_request_timeout(&cmd, Duration::from_millis(100)); + assert!(result.is_ok()); + assert_eq!( + result.unwrap(), + Some(Duration::from_secs_f64( + 0.5 + BLOCKING_CMD_TIMEOUT_EXTENSION + )) + ); + + let mut cmd = Cmd::new(); + cmd.arg("BLMPOP").arg("0.857").arg("key"); + let result = get_request_timeout(&cmd, Duration::from_millis(100)); + assert!(result.is_ok()); + assert_eq!( + result.unwrap(), + Some(Duration::from_secs_f64( + 0.857 + BLOCKING_CMD_TIMEOUT_EXTENSION + )) + ); + } + + #[test] + fn test_get_request_timeout_non_blocking_command_returns_default_timeout() { + let mut cmd = Cmd::new(); + cmd.arg("SET").arg("key").arg("value").arg("PX").arg("500"); + let result = get_request_timeout(&cmd, Duration::from_millis(100)); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), Some(Duration::from_millis(100))); + + let mut cmd = Cmd::new(); + cmd.arg("XREADGROUP").arg("key"); + let result = get_request_timeout(&cmd, Duration::from_millis(100)); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), Some(Duration::from_millis(100))); + } +} diff --git a/glide-core/src/client/reconnecting_connection.rs b/glide-core/src/client/reconnecting_connection.rs index c039d347bd..d79a59c574 100644 --- a/glide-core/src/client/reconnecting_connection.rs +++ b/glide-core/src/client/reconnecting_connection.rs @@ -48,7 +48,7 @@ pub(super) struct ReconnectingConnection { async fn get_multiplexed_connection(client: &redis::Client) -> RedisResult { run_with_timeout( - DEFAULT_CONNECTION_ATTEMPT_TIMEOUT, + Some(DEFAULT_CONNECTION_ATTEMPT_TIMEOUT), client.get_multiplexed_async_connection(), ) .await diff --git a/glide-core/tests/test_client.rs b/glide-core/tests/test_client.rs index 945c9db504..f16f05c536 100644 --- a/glide-core/tests/test_client.rs +++ b/glide-core/tests/test_client.rs @@ -6,7 +6,7 @@ mod utilities; #[cfg(test)] pub(crate) mod shared_client_tests { use super::*; - use glide_core::client::Client; + use glide_core::client::{Client, DEFAULT_RESPONSE_TIMEOUT}; use redis::{ cluster_routing::{MultipleNodeRoutingInfo, RoutingInfo}, FromRedisValue, InfoDict, RedisConnectionInfo, Value, @@ -320,7 +320,44 @@ pub(crate) mod shared_client_tests { let mut test_basics = setup_test_basics( use_cluster, TestConfiguration { - request_timeout: Some(1), + request_timeout: Some(1), // milliseconds + shared_server: false, + ..Default::default() + }, + ) + .await; + let mut cmd = redis::Cmd::new(); + // Create a long running command to ensure we get into timeout + cmd.arg("EVAL") + .arg( + r#" + while (true) + do + redis.call('ping') + end + "#, + ) + .arg("0"); + let result = test_basics.client.send_command(&cmd, None).await; + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!(err.is_timeout(), "{err}"); + }); + } + + #[rstest] + #[timeout(SHORT_CLUSTER_TEST_TIMEOUT)] + fn test_blocking_command_doesnt_raise_timeout_error(#[values(false, true)] use_cluster: bool) { + // We test that the request timeout is based on the value specified in the blocking command argument, + // and not on the one set in the client configuration. To achieve this, we execute a command designed to + // be blocked until it reaches the specified command timeout. We set the client's request timeout to + // a shorter duration than the blocking command's timeout. Subsequently, we confirm that we receive + // a response from the server instead of encountering a timeout error. + block_on_all(async { + let mut test_basics = setup_test_basics( + use_cluster, + TestConfiguration { + request_timeout: Some(1), // milliseconds shared_server: true, ..Default::default() }, @@ -328,11 +365,62 @@ pub(crate) mod shared_client_tests { .await; let mut cmd = redis::Cmd::new(); - cmd.arg("BLPOP").arg("foo").arg(0); // 0 timeout blocks indefinitely + cmd.arg("BLPOP").arg(generate_random_string(10)).arg(0.3); // server should return null after 300 millisecond + let result = test_basics.client.send_command(&cmd, None).await; + assert!(result.is_ok()); + assert_eq!(result.unwrap(), Value::Nil); + }); + } + + #[rstest] + #[timeout(SHORT_CLUSTER_TEST_TIMEOUT)] + fn test_blocking_command_with_negative_timeout_returns_error( + #[values(false, true)] use_cluster: bool, + ) { + // We test that when blocking command is passed with a negative timeout the command will return with an error + block_on_all(async { + let mut test_basics = setup_test_basics( + use_cluster, + TestConfiguration { + request_timeout: Some(1), // milliseconds + shared_server: true, + ..Default::default() + }, + ) + .await; + let mut cmd = redis::Cmd::new(); + cmd.arg("BLPOP").arg(generate_random_string(10)).arg(-1); let result = test_basics.client.send_command(&cmd, None).await; assert!(result.is_err()); let err = result.unwrap_err(); - assert!(err.is_timeout(), "{err}"); + assert_eq!(err.kind(), redis::ErrorKind::ResponseError); + assert!(err.to_string().contains("negative")); + }); + } + + #[rstest] + #[timeout(SHORT_CLUSTER_TEST_TIMEOUT)] + fn test_blocking_command_with_zero_timeout_blocks_indefinitely( + #[values(false, true)] use_cluster: bool, + ) { + // We test that when a blocking command is passed with a timeout duration of 0, it will block the client indefinitely + block_on_all(async { + let config = TestConfiguration { + request_timeout: Some(1), // millisecond + shared_server: true, + ..Default::default() + }; + let mut test_basics = setup_test_basics(use_cluster, config).await; + let key = generate_random_string(10); + let future = async move { + let mut cmd = redis::Cmd::new(); + cmd.arg("BLPOP").arg(key).arg(0); // `0` should block indefinitely + test_basics.client.send_command(&cmd, None).await + }; + // We execute the command with Tokio's timeout wrapper to prevent the test from hanging indefinitely. + let tokio_timeout_result = + tokio::time::timeout(DEFAULT_RESPONSE_TIMEOUT * 2, future).await; + assert!(tokio_timeout_result.is_err()); }); } diff --git a/python/python/tests/test_async_client.py b/python/python/tests/test_async_client.py index 5b6d437c59..caed049a64 100644 --- a/python/python/tests/test_async_client.py +++ b/python/python/tests/test_async_client.py @@ -2156,20 +2156,10 @@ def clean_result(value: TResult): async def test_cluster_fail_routing_by_address_if_no_port_is_provided( self, redis_client: RedisClusterClient ): - with pytest.raises(RequestError) as e: + with pytest.raises(RequestError): await redis_client.info(route=ByAddressRoute("foo")) -@pytest.mark.asyncio -class TestExceptions: - @pytest.mark.parametrize("cluster_mode", [True, False]) - @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) - async def test_timeout_exception_with_blpop(self, redis_client: TRedisClient): - key = get_random_string(10) - with pytest.raises(TimeoutError): - await redis_client.custom_command(["BLPOP", key, "1"]) - - @pytest.mark.asyncio class TestScripts: @pytest.mark.smoke_test