diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 6c0cc45b7d..699033cf1a 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -115,6 +115,7 @@ jobs: working-directory: ./python run: | source .env/bin/activate + pip install -r dev_requirements.txt cd python/tests/ pytest --asyncio-mode=auto --html=pytest_report.html --self-contained-html @@ -177,6 +178,7 @@ jobs: working-directory: ./python run: | source .env/bin/activate + pip install -r dev_requirements.txt cd python/tests/ pytest --asyncio-mode=auto -k test_pubsub --html=pytest_report.html --self-contained-html diff --git a/CHANGELOG.md b/CHANGELOG.md index 0bf932eb5d..e6d5968117 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,5 @@ #### Changes +* Node, Python: Adding support for replacing connection configured password ([#2651](https://github.com/valkey-io/valkey-glide/pull/2651)) * Node: Add FT._ALIASLIST command([#2652](https://github.com/valkey-io/valkey-glide/pull/2652)) * Python: Python: `FT._ALIASLIST` command added([#2638](https://github.com/valkey-io/valkey-glide/pull/2638)) * Node: alias commands added: FT.ALIASADD, FT.ALIADDEL, FT.ALIASUPDATE([#2596](https://github.com/valkey-io/valkey-glide/pull/2596)) diff --git a/glide-core/redis-rs/redis/Cargo.toml b/glide-core/redis-rs/redis/Cargo.toml index 3320ba8ec7..579c1da799 100644 --- a/glide-core/redis-rs/redis/Cargo.toml +++ b/glide-core/redis-rs/redis/Cargo.toml @@ -8,7 +8,7 @@ repository = "https://github.com/redis-rs/redis-rs" documentation = "https://docs.rs/redis" license = "BSD-3-Clause" edition = "2021" -rust-version = "1.65" +rust-version = "1.67" readme = "../README.md" [package.metadata.docs.rs] @@ -47,7 +47,6 @@ pin-project-lite = { version = "0.2", optional = true } tokio-util = { version = "0.7", optional = true } tokio = { version = "1", features = ["rt", "net", "time", "sync"] } socket2 = { version = "0.5", features = ["all"], optional = true } -fast-math = { version = "0.1.1", optional = true } dispose = { version = "0.5.0", optional = true } # Only needed for the connection manager @@ -67,7 +66,7 @@ dashmap = { version = "6.0", optional = true } async-trait = { version = "0.1.24", optional = true } # Only needed for tokio support -tokio-retry2 = {version = "0.5", features = ["jitter"], optional = true} +tokio-retry2 = { version = "0.5", features = ["jitter"], optional = true } # Only needed for native tls native-tls = { version = "0.2", optional = true } @@ -125,7 +124,6 @@ aio = [ "tokio-util/codec", "combine/tokio", "async-trait", - "fast-math", "dispose", ] geospatial = [] diff --git a/glide-core/redis-rs/redis/src/aio/multiplexed_connection.rs b/glide-core/redis-rs/redis/src/aio/multiplexed_connection.rs index 15df4e9aa8..b31c817817 100644 --- a/glide-core/redis-rs/redis/src/aio/multiplexed_connection.rs +++ b/glide-core/redis-rs/redis/src/aio/multiplexed_connection.rs @@ -31,6 +31,9 @@ use std::time::Duration; #[cfg(feature = "tokio-comp")] use tokio_util::codec::Decoder; +// Default connection timeout in ms +const DEFAULT_CONNECTION_ATTEMPT_TIMEOUT: Duration = Duration::from_millis(250); + // Senders which the result of a single request are sent through type PipelineOutput = oneshot::Sender>; @@ -76,7 +79,7 @@ struct PipelineMessage { /// interface provided by `Pipeline` an easy interface of request to response, hiding the `Stream` /// and `Sink`. #[derive(Clone)] -struct Pipeline { +pub(crate) struct Pipeline { sender: mpsc::Sender>, push_manager: Arc>, is_stream_closed: Arc, @@ -399,6 +402,7 @@ where self.push_manager.store(Arc::new(push_manager)); } + /// Checks if the pipeline is closed. pub fn is_closed(&self) -> bool { self.is_stream_closed.load(Ordering::Relaxed) } @@ -413,6 +417,7 @@ pub struct MultiplexedConnection { response_timeout: Duration, protocol: ProtocolVersion, push_manager: PushManager, + password: Option, } impl Debug for MultiplexedConnection { @@ -455,35 +460,28 @@ impl MultiplexedConnection { where C: Unpin + AsyncRead + AsyncWrite + Send + 'static, { - fn boxed( - f: impl Future + Send + 'static, - ) -> Pin + Send>> { - Box::pin(f) - } - - #[cfg(not(feature = "tokio-comp"))] - compile_error!("tokio-comp feature is required for aio feature"); - - let redis_connection_info = &connection_info.redis; let codec = ValueCodec::default() .framed(stream) .and_then(|msg| async move { msg }); let (mut pipeline, driver) = Pipeline::new(codec, glide_connection_options.disconnect_notifier); - let driver = boxed(driver); + let driver = Box::pin(driver); let pm = PushManager::default(); if let Some(sender) = glide_connection_options.push_sender { pm.replace_sender(sender); } pipeline.set_push_manager(pm.clone()).await; - let mut con = MultiplexedConnection { - pipeline, - db: connection_info.redis.db, - response_timeout, - push_manager: pm, - protocol: redis_connection_info.protocol, - }; + + let mut con = MultiplexedConnection::builder(pipeline) + .with_db(connection_info.redis.db) + .with_response_timeout(response_timeout) + .with_push_manager(pm) + .with_protocol(connection_info.redis.protocol) + .with_password(connection_info.redis.password.clone()) + .build() + .await?; + let driver = { let auth = setup_connection(&connection_info.redis, &mut con); @@ -502,6 +500,7 @@ impl MultiplexedConnection { } } }; + Ok((con, driver)) } @@ -575,6 +574,97 @@ impl MultiplexedConnection { self.push_manager = push_manager.clone(); self.pipeline.set_push_manager(push_manager).await; } + + /// Replace the password used to authenticate with the server. + /// If `None` is provided, the password will be removed. + pub async fn update_connection_password( + &mut self, + password: Option, + ) -> RedisResult { + self.password = password; + Ok(Value::Okay) + } + + /// Creates a new `MultiplexedConnectionBuilder` for constructing a `MultiplexedConnection`. + pub(crate) fn builder(pipeline: Pipeline>) -> MultiplexedConnectionBuilder { + MultiplexedConnectionBuilder::new(pipeline) + } +} + +/// A builder for creating `MultiplexedConnection` instances. +pub struct MultiplexedConnectionBuilder { + pipeline: Pipeline>, + db: Option, + response_timeout: Option, + push_manager: Option, + protocol: Option, + password: Option, +} + +impl MultiplexedConnectionBuilder { + /// Creates a new builder with the required pipeline + pub(crate) fn new(pipeline: Pipeline>) -> Self { + Self { + pipeline, + db: None, + response_timeout: None, + push_manager: None, + protocol: None, + password: None, + } + } + + /// Sets the database index for the `MultiplexedConnectionBuilder`. + pub fn with_db(mut self, db: i64) -> Self { + self.db = Some(db); + self + } + + /// Sets the response timeout for the `MultiplexedConnectionBuilder`. + pub fn with_response_timeout(mut self, timeout: Duration) -> Self { + self.response_timeout = Some(timeout); + self + } + + /// Sets the push manager for the `MultiplexedConnectionBuilder`. + pub fn with_push_manager(mut self, push_manager: PushManager) -> Self { + self.push_manager = Some(push_manager); + self + } + + /// Sets the protocol version for the `MultiplexedConnectionBuilder`. + pub fn with_protocol(mut self, protocol: ProtocolVersion) -> Self { + self.protocol = Some(protocol); + self + } + + /// Sets the password for the `MultiplexedConnectionBuilder`. + pub fn with_password(mut self, password: Option) -> Self { + self.password = password; + self + } + + /// Builds and returns a new `MultiplexedConnection` instance using the configured settings. + pub async fn build(self) -> RedisResult { + let db = self.db.unwrap_or_default(); + let response_timeout = self + .response_timeout + .unwrap_or(DEFAULT_CONNECTION_ATTEMPT_TIMEOUT); + let push_manager = self.push_manager.unwrap_or_default(); + let protocol = self.protocol.unwrap_or_default(); + let password = self.password; + + let con = MultiplexedConnection { + pipeline: self.pipeline, + db, + response_timeout, + push_manager, + protocol, + password, + }; + + Ok(con) + } } impl ConnectionLike for MultiplexedConnection { diff --git a/glide-core/redis-rs/redis/src/cluster.rs b/glide-core/redis-rs/redis/src/cluster.rs index 2846fc1137..1107965bf3 100644 --- a/glide-core/redis-rs/redis/src/cluster.rs +++ b/glide-core/redis-rs/redis/src/cluster.rs @@ -306,7 +306,7 @@ where /// Returns the connection status. /// - /// The connection is open until any `read_response` call recieved an + /// The connection is open until any `read_response` call received an /// invalid response from the server (most likely a closed or dropped /// connection, otherwise a Redis protocol error). When using unix /// sockets the connection is open until writing a command failed with a @@ -808,7 +808,7 @@ where self.refresh_slots()?; // Given that there are commands that need to be retried, it means something in the cluster - // topology changed. Execute each command seperately to take advantage of the existing + // topology changed. Execute each command separately to take advantage of the existing // retry logic that handles these cases. for retry_idx in to_retry { let cmd = &cmds[retry_idx]; diff --git a/glide-core/redis-rs/redis/src/cluster_async/mod.rs b/glide-core/redis-rs/redis/src/cluster_async/mod.rs index 426601ca02..35997b2282 100644 --- a/glide-core/redis-rs/redis/src/cluster_async/mod.rs +++ b/glide-core/redis-rs/redis/src/cluster_async/mod.rs @@ -55,6 +55,7 @@ use std::{ task::{self, Poll}, time::SystemTime, }; +use strum_macros::Display; #[cfg(feature = "tokio-comp")] use tokio::task::JoinHandle; @@ -106,10 +107,10 @@ use self::{ connections_container::{ConnectionAndAddress, ConnectionType, ConnectionsMap}, connections_logic::connect_and_check, }; +use crate::types::RetryMethod; -pub(crate) const MUTEX_WRITE_ERR: &str = "Failed to obtain write lock for mutex. Poisoned mutex"; -pub(crate) const MUTEX_READ_ERR: &str = "Failed to obtain read lock for mutex. Poisoned mutex"; - +pub(crate) const MUTEX_READ_ERR: &str = "Failed to obtain read lock. Poisoned mutex?"; +const MUTEX_WRITE_ERR: &str = "Failed to obtain write lock. Poisoned mutex?"; /// This represents an async Redis Cluster connection. It stores the /// underlying connections maintained for each node in the cluster, as well /// as common parameters for connecting to nodes and executing commands. @@ -294,8 +295,7 @@ where }) .map(|response| match response { Response::ClusterScanResult(new_scan_state_ref, key) => (new_scan_state_ref, key), - Response::Single(_) => unreachable!(), - Response::Multiple(_) => unreachable!(), + Response::Single(_) | Response::Multiple(_) => unreachable!(), }) } @@ -332,8 +332,7 @@ where }) .map(|response| match response { Response::Single(value) => value, - Response::Multiple(_) => unreachable!(), - Response::ClusterScanResult(_, _) => unreachable!(), + Response::ClusterScanResult(..) | Response::Multiple(_) => unreachable!(), }) } @@ -357,15 +356,57 @@ where sender, }) .await - .map_err(|_| RedisError::from(io::Error::from(io::ErrorKind::BrokenPipe)))?; + .map_err(|err| { + RedisError::from(io::Error::new(io::ErrorKind::BrokenPipe, err.to_string())) + })?; receiver .await - .unwrap_or_else(|_| Err(RedisError::from(io::Error::from(io::ErrorKind::BrokenPipe)))) + .unwrap_or_else(|err| { + Err(RedisError::from(io::Error::new( + io::ErrorKind::BrokenPipe, + err.to_string(), + ))) + }) .map(|response| match response { Response::Multiple(values) => values, - Response::Single(_) => unreachable!(), - Response::ClusterScanResult(_, _) => unreachable!(), + Response::ClusterScanResult(..) | Response::Single(_) => unreachable!(), + }) + } + /// Update the password used to authenticate with all cluster servers + pub async fn update_connection_password( + &mut self, + password: Option, + ) -> RedisResult { + self.route_operation_request(Operation::UpdateConnectionPassword(password)) + .await + } + + /// Routes an operation request to the appropriate handler. + async fn route_operation_request( + &mut self, + operation_request: Operation, + ) -> RedisResult { + let (sender, receiver) = oneshot::channel(); + self.0 + .send(Message { + cmd: CmdArg::OperationRequest(operation_request), + sender, + }) + .await + .map_err(|_| RedisError::from(io::Error::from(io::ErrorKind::BrokenPipe)))?; + + receiver + .await + .unwrap_or_else(|err| { + Err(RedisError::from(io::Error::new( + io::ErrorKind::BrokenPipe, + err.to_string(), + ))) + }) + .map(|response| match response { + Response::Single(values) => values, + Response::ClusterScanResult(..) | Response::Multiple(_) => unreachable!(), }) } } @@ -410,7 +451,7 @@ type ConnectionsContainer = pub(crate) struct InnerCore { pub(crate) conn_lock: StdRwLock>, - cluster_params: ClusterParams, + cluster_params: StdRwLock, pending_requests: Mutex>>, slot_refresh_state: SlotRefreshState, initial_nodes: Vec, @@ -425,6 +466,29 @@ impl InnerCore where C: ConnectionLike + Connect + Clone + Send + Sync + 'static, { + fn get_cluster_param(&self, f: F) -> Result + where + F: FnOnce(&ClusterParams) -> T, + T: Clone, + { + self.cluster_params + .read() + .map(|guard| f(&guard).clone()) + .map_err(|_| RedisError::from((ErrorKind::ClientError, MUTEX_READ_ERR))) + } + + fn set_cluster_param(&self, f: F) -> Result<(), RedisError> + where + F: FnOnce(&mut ClusterParams), + { + self.cluster_params + .write() + .map(|mut params| { + f(&mut params); + }) + .map_err(|_| RedisError::from((ErrorKind::ClientError, MUTEX_WRITE_ERR))) + } + // return address of node for slot pub(crate) async fn get_address_from_slot( &self, @@ -615,6 +679,14 @@ enum CmdArg { // struct containing the arguments for the cluster scan command - scan state cursor, match pattern, count and object type. cluster_scan_args: ClusterScanArgs, }, + // Operational requests which are connected to the internal state of the connection and not send as a command to the server. + OperationRequest(Operation), +} + +// Operation requests which are connected to the internal state of the connection and not send as a command to the server. +#[derive(Clone)] +enum Operation { + UpdateConnectionPassword(Option), } fn route_for_pipeline(pipeline: &crate::Pipeline) -> RedisResult> { @@ -656,16 +728,17 @@ fn route_for_pipeline(pipeline: &crate::Pipeline) -> RedisResult> } fn boxed_sleep(duration: Duration) -> BoxFuture<'static, ()> { - #[cfg(feature = "tokio-comp")] - return Box::pin(tokio::time::sleep(duration)); + Box::pin(tokio::time::sleep(duration)) } +#[derive(Debug, Display)] pub(crate) enum Response { Single(Value), ClusterScanResult(ScanStateRC, Vec), Multiple(Vec), } +#[derive(Debug)] pub(crate) enum OperationTarget { Node { address: String }, FanOut, @@ -679,7 +752,7 @@ impl From for OperationTarget { } } -struct Message { +struct Message { cmd: CmdArg, sender: oneshot::Sender>, } @@ -740,6 +813,10 @@ impl RequestInfo { CmdArg::ClusterScan { .. } => { unreachable!() } + // Operation requests are not routed. + CmdArg::OperationRequest(_) => { + unreachable!() + } } } } @@ -774,6 +851,10 @@ impl RequestInfo { CmdArg::ClusterScan { .. } => { unreachable!() } + // Operation requests are not routed. + CmdArg::OperationRequest { .. } => { + unreachable!() + } } } } @@ -867,7 +948,7 @@ impl Future for Request { let retry_method = err.retry_method(); let next = if err.kind() == ErrorKind::AllConnectionsUnavailable { Next::ReconnectToInitialNodes { request: None }.into() - } else if matches!(err.retry_method(), crate::types::RetryMethod::MovedRedirect) + } else if matches!(err.retry_method(), RetryMethod::MovedRedirect) || matches!(target, OperationTarget::NotFound) { Next::RefreshSlots { @@ -875,8 +956,8 @@ impl Future for Request { sleep_duration: None, } .into() - } else if matches!(retry_method, crate::types::RetryMethod::Reconnect) - || matches!(retry_method, crate::types::RetryMethod::ReconnectAndRetry) + } else if matches!(retry_method, RetryMethod::Reconnect) + || matches!(retry_method, RetryMethod::ReconnectAndRetry) { if let OperationTarget::Node { address } = target { Next::Reconnect { @@ -928,7 +1009,7 @@ impl Future for Request { warn!("Received request error {} on node {:?}.", err, address); match err.retry_method() { - crate::types::RetryMethod::AskRedirect => { + RetryMethod::AskRedirect => { let mut request = this.request.take().unwrap(); request.info.set_redirect( err.redirect_node() @@ -936,7 +1017,7 @@ impl Future for Request { ); Next::Retry { request }.into() } - crate::types::RetryMethod::MovedRedirect => { + RetryMethod::MovedRedirect => { let mut request = this.request.take().unwrap(); request.info.set_redirect( err.redirect_node() @@ -948,7 +1029,7 @@ impl Future for Request { } .into() } - crate::types::RetryMethod::WaitAndRetry => { + RetryMethod::WaitAndRetry => { let sleep_duration = this.retry_params.wait_time_for_retry(request.retry); // Sleep and retry. this.future.set(RequestState::Sleep { @@ -956,34 +1037,31 @@ impl Future for Request { }); self.poll(cx) } - crate::types::RetryMethod::Reconnect - | crate::types::RetryMethod::ReconnectAndRetry => { + RetryMethod::Reconnect | RetryMethod::ReconnectAndRetry => { let mut request = this.request.take().unwrap(); // TODO should we reset the redirect here? request.info.reset_routing(); warn!("disconnected from {:?}", address); - let should_retry = matches!( - err.retry_method(), - crate::types::RetryMethod::ReconnectAndRetry - ); + let should_retry = + matches!(err.retry_method(), RetryMethod::ReconnectAndRetry); Next::Reconnect { request: should_retry.then_some(request), target: address, } .into() } - crate::types::RetryMethod::WaitAndRetryOnPrimaryRedirectOnReplica => { + RetryMethod::WaitAndRetryOnPrimaryRedirectOnReplica => { Next::RetryBusyLoadingError { request: this.request.take().unwrap(), address, } .into() } - crate::types::RetryMethod::RetryImmediately => Next::Retry { + RetryMethod::RetryImmediately => Next::Retry { request: this.request.take().unwrap(), } .into(), - crate::types::RetryMethod::NoRetry => { + RetryMethod::NoRetry => { self.respond(Err(err)); Next::Done.into() } @@ -1051,7 +1129,7 @@ where cluster_params.read_from_replicas, 0, )), - cluster_params: cluster_params.clone(), + cluster_params: StdRwLock::new(cluster_params.clone()), pending_requests: Mutex::new(Vec::new()), slot_refresh_state: SlotRefreshState::new(slots_refresh_rate_limiter), initial_nodes: initial_nodes.to_vec(), @@ -1211,10 +1289,17 @@ where // Being used when all cluster connections are unavailable. fn reconnect_to_initial_nodes(inner: Arc>) -> impl Future { let inner = inner.clone(); - async move { + let cluster_params = match inner.get_cluster_param(|params| params.clone()) { + Ok(params) => params, + Err(err) => { + warn!("Failed to get cluster params: {}", err); + return async {}.boxed(); + } + }; + Box::pin(async move { let connection_map = match Self::create_initial_connections( &inner.initial_nodes, - &inner.cluster_params, + &cluster_params, inner.glide_connection_options.clone(), ) .await @@ -1238,7 +1323,7 @@ where { warn!("Can't refresh slots with initial nodes: `{err}`"); }; - } + }) } // Validate all existing user connections and try to reconnect if necessary. @@ -1329,7 +1414,7 @@ where }; // Override subscriptions for this connection - let mut cluster_params = inner.cluster_params.clone(); + let mut cluster_params = inner.cluster_params.read().expect(MUTEX_READ_ERR).clone(); let subs_guard = inner.subscriptions_by_address.read().await; cluster_params.pubsub_subscriptions = subs_guard.get(&address).cloned(); drop(subs_guard); @@ -1591,7 +1676,9 @@ where } async fn refresh_pubsub_subscriptions(inner: Arc>) { - if inner.cluster_params.protocol != crate::types::ProtocolVersion::RESP3 { + if inner.cluster_params.read().expect(MUTEX_READ_ERR).protocol + != crate::types::ProtocolVersion::RESP3 + { return; } @@ -1679,10 +1766,7 @@ where /// Returns true if change was detected, otherwise false. async fn check_for_topology_diff(inner: Arc>) -> bool { let num_of_nodes = inner.conn_lock.read().expect(MUTEX_READ_ERR).len(); - // TODO: Starting from Rust V1.67, integers has logarithms support. - // When we no longer need to support Rust versions < 1.67, remove fast_math and transition to the ilog2 function. - let num_of_nodes_to_query = - std::cmp::max(fast_math::log2_raw(num_of_nodes as f32) as usize, 1); + let num_of_nodes_to_query = std::cmp::max(num_of_nodes.ilog2() as usize, 1); let (res, failed_connections) = calculate_topology_from_random_nodes( &inner, num_of_nodes_to_query, @@ -1790,7 +1874,9 @@ where .fold( ConnectionsMap(DashMap::with_capacity(nodes_len)), |connections, (addr, node)| async { - let mut cluster_params = inner.cluster_params.clone(); + let mut cluster_params = inner + .get_cluster_param(|params| params.clone()) + .expect(MUTEX_READ_ERR); let subs_guard = inner.subscriptions_by_address.read().await; cluster_params.pubsub_subscriptions = subs_guard.get(addr).cloned(); drop(subs_guard); @@ -1811,12 +1897,15 @@ where .await; info!("refresh_slots found nodes:\n{new_connections}"); - // Replace the current slot map and connection vector with the new ones + // Reset the current slot map and connection vector with the new ones let mut write_guard = inner.conn_lock.write().expect(MUTEX_WRITE_ERR); + let read_from_replicas = inner + .get_cluster_param(|params| params.read_from_replicas) + .expect(MUTEX_READ_ERR); *write_guard = ConnectionsContainer::new( new_slots, new_connections, - inner.cluster_params.read_from_replicas, + read_from_replicas, topology_hash, ); Ok(()) @@ -2005,6 +2094,13 @@ where Err(err) => Err((OperationTarget::FanOut, err)), } } + CmdArg::OperationRequest(operation_request) => match operation_request { + Operation::UpdateConnectionPassword(password) => { + core.set_cluster_param(|params| params.password = password) + .expect(MUTEX_WRITE_ERR); + Ok(Response::Single(Value::Okay)) + } + }, } } @@ -2099,7 +2195,7 @@ where let (address, mut conn) = match conn_check { ConnectionCheck::Found((address, connection)) => (address, connection.await), ConnectionCheck::OnlyAddress(addr) => { - let mut this_conn_params = core.cluster_params.clone(); + let mut this_conn_params = core.get_cluster_param(|params| params.clone())?; let subs_guard = core.subscriptions_by_address.read().await; this_conn_params.pubsub_subscriptions = subs_guard.get(addr.as_str()).cloned(); drop(subs_guard); @@ -2197,6 +2293,7 @@ where info: RequestInfo, address: String, retry: u32, + retry_params: RetryParams, ) -> OperationResult { let is_primary = core .conn_lock @@ -2213,7 +2310,7 @@ where .remove_node(&address); } else { // If the connection is primary, just sleep and retry - let sleep_duration = core.cluster_params.retry_params.wait_time_for_retry(retry); + let sleep_duration = retry_params.wait_time_for_retry(retry); boxed_sleep(sleep_duration).await; } @@ -2221,8 +2318,11 @@ where } fn poll_complete(&mut self, cx: &mut task::Context<'_>) -> Poll { + let retry_params = self + .inner + .get_cluster_param(|params| params.retry_params.clone()) + .expect(MUTEX_READ_ERR); let mut poll_flush_action = PollFlushAction::None; - let mut pending_requests_guard = self.inner.pending_requests.lock().unwrap(); if !pending_requests_guard.is_empty() { let mut pending_requests = mem::take(&mut *pending_requests_guard); @@ -2236,7 +2336,7 @@ where let future = Self::try_request(request.info.clone(), self.inner.clone()).boxed(); self.in_flight_requests.push(Box::pin(Request { - retry_params: self.inner.cluster_params.retry_params.clone(), + retry_params: retry_params.clone(), request: Some(request), future: RequestState::Future { future }, })); @@ -2246,6 +2346,7 @@ where drop(pending_requests_guard); loop { + let retry_params = retry_params.clone(); let result = match Pin::new(&mut self.in_flight_requests).poll_next(cx) { Poll::Ready(Some(result)) => result, Poll::Ready(None) | Poll::Pending => break, @@ -2255,7 +2356,7 @@ where Next::Retry { request } => { let future = Self::try_request(request.info.clone(), self.inner.clone()); self.in_flight_requests.push(Box::pin(Request { - retry_params: self.inner.cluster_params.retry_params.clone(), + retry_params: retry_params.clone(), request: Some(request), future: RequestState::Future { future: Box::pin(future), @@ -2269,9 +2370,10 @@ where request.info.clone(), address, request.retry, + retry_params.clone(), ); self.in_flight_requests.push(Box::pin(Request { - retry_params: self.inner.cluster_params.retry_params.clone(), + retry_params: retry_params.clone(), request: Some(request), future: RequestState::Future { future: Box::pin(future), @@ -2299,7 +2401,7 @@ where }, }; self.in_flight_requests.push(Box::pin(Request { - retry_params: self.inner.cluster_params.retry_params.clone(), + retry_params, request: Some(request), future, })); @@ -2524,13 +2626,20 @@ where .ok() .and_then(|value| get_host_and_port_from_addr(addr).map(|(host, _)| (host, value))) }); + let tls_mode = inner + .get_cluster_param(|params| params.tls) + .expect(MUTEX_READ_ERR); + + let read_from_replicas = inner + .get_cluster_param(|params| params.read_from_replicas) + .expect(MUTEX_READ_ERR); ( calculate_topology( topology_values, curr_retry, - inner.cluster_params.tls, + tls_mode, num_of_nodes_to_query, - inner.cluster_params.read_from_replicas, + read_from_replicas, ), failed_addresses, ) diff --git a/glide-core/redis-rs/redis/src/cmd.rs b/glide-core/redis-rs/redis/src/cmd.rs index 979bc7987b..3e248dad6f 100644 --- a/glide-core/redis-rs/redis/src/cmd.rs +++ b/glide-core/redis-rs/redis/src/cmd.rs @@ -57,8 +57,8 @@ impl<'a, T: FromRedisValue> Iterator for Iter<'a, T> { return None; } - let pcmd = self.cmd.get_packed_command_with_cursor(self.cursor)?; - let rv = self.con.req_packed_command(&pcmd).ok()?; + let packed_cmd = self.cmd.get_packed_command_with_cursor(self.cursor)?; + let rv = self.con.req_packed_command(&packed_cmd).ok()?; let (cur, batch): (u64, Vec) = from_owned_redis_value(rv).ok()?; self.cursor = cur; @@ -204,14 +204,14 @@ fn args_len<'a, I>(args: I, cursor: u64) -> usize where I: IntoIterator> + ExactSizeIterator, { - let mut totlen = 1 + countdigits(args.len()) + 2; + let mut total_len = countdigits(args.len()).saturating_add(3); for item in args { - totlen += bulklen(match item { + total_len += bulklen(match item { Arg::Cursor => countdigits(cursor as usize), Arg::Simple(val) => val.len(), }); } - totlen + total_len } pub(crate) fn cmd_len(cmd: &Cmd) -> usize { @@ -231,9 +231,9 @@ fn write_command_to_vec<'a, I>(cmd: &mut Vec, args: I, cursor: u64) where I: IntoIterator> + Clone + ExactSizeIterator, { - let totlen = args_len(args.clone(), cursor); + let total_len = args_len(args.clone(), cursor); - cmd.reserve(totlen); + cmd.reserve(total_len); write_command(cmd, args, cursor).unwrap() } @@ -287,7 +287,7 @@ impl Default for Cmd { } /// A command acts as a builder interface to creating encoded redis -/// requests. This allows you to easiy assemble a packed command +/// requests. This allows you to easily assemble a packed command /// by chaining arguments together. /// /// Basic example: @@ -324,7 +324,7 @@ impl Cmd { } } - /// Creates a new empty command, with at least the requested capcity. + /// Creates a new empty command, with at least the requested capacity. pub fn with_capacity(arg_count: usize, size_of_data: usize) -> Cmd { Cmd { data: Vec::with_capacity(size_of_data), @@ -448,7 +448,7 @@ impl Cmd { /// /// This is useful for commands such as `SSCAN`, `SCAN` and others. /// - /// One speciality of this function is that it will check if the response + /// One specialty of this function is that it will check if the response /// looks like a cursor or not and always just looks at the payload. /// This way you can use the function the same for responses in the /// format of `KEYS` (just a list) as well as `SSCAN` (which returns a @@ -481,7 +481,7 @@ impl Cmd { /// /// This is useful for commands such as `SSCAN`, `SCAN` and others in async contexts. /// - /// One speciality of this function is that it will check if the response + /// One specialty of this function is that it will check if the response /// looks like a cursor or not and always just looks at the payload. /// This way you can use the function the same for responses in the /// format of `KEYS` (just a list) as well as `SSCAN` (which returns a diff --git a/glide-core/redis-rs/redis/src/parser.rs b/glide-core/redis-rs/redis/src/parser.rs index 96e0bcd8f1..1f42a774f1 100644 --- a/glide-core/redis-rs/redis/src/parser.rs +++ b/glide-core/redis-rs/redis/src/parser.rs @@ -633,11 +633,11 @@ mod tests { #[test] fn decode_resp3_push() { - let val = parse_redis_value(b">3\r\n+message\r\n+somechannel\r\n+this is the message\r\n") + let val = parse_redis_value(b">3\r\n+message\r\n+some_channel\r\n+this is the message\r\n") .unwrap(); if let Value::Push { ref kind, ref data } = val { assert_eq!(&PushKind::Message, kind); - assert_eq!(Value::SimpleString("somechannel".to_string()), data[0]); + assert_eq!(Value::SimpleString("some_channel".to_string()), data[0]); assert_eq!( Value::SimpleString("this is the message".to_string()), data[1] diff --git a/glide-core/redis-rs/redis/src/types.rs b/glide-core/redis-rs/redis/src/types.rs index 6fd564b203..4b6cdbb150 100644 --- a/glide-core/redis-rs/redis/src/types.rs +++ b/glide-core/redis-rs/redis/src/types.rs @@ -8,10 +8,7 @@ use std::io; use std::str::{from_utf8, Utf8Error}; use std::string::FromUtf8Error; -#[cfg(feature = "ahash")] -pub(crate) use ahash::{AHashMap as HashMap, AHashSet as HashSet}; use num_bigint::BigInt; -#[cfg(not(feature = "ahash"))] pub(crate) use std::collections::{HashMap, HashSet}; use std::ops::Deref; @@ -139,7 +136,7 @@ pub enum ErrorKind { NoValidReplicasFoundBySentinel, /// At least one sentinel connection info is required EmptySentinelList, - /// Attempted to kill a script/function while they werent' executing + /// Attempted to kill a script/function while they weren't executing NotBusy, /// Used when no valid node connections remain in the cluster connection AllConnectionsUnavailable, @@ -983,7 +980,6 @@ impl RedisError { match self.retry_method() { RetryMethod::Reconnect => true, RetryMethod::ReconnectAndRetry => true, - RetryMethod::NoRetry => false, RetryMethod::RetryImmediately => false, RetryMethod::WaitAndRetry => false, @@ -1149,9 +1145,9 @@ impl InfoDict { /// the INFO command. Each line is a key, value pair with the /// key and value separated by a colon (`:`). Lines starting with a /// hash (`#`) are ignored. - pub fn new(kvpairs: &str) -> InfoDict { + pub fn new(key_val_pairs: &str) -> InfoDict { let mut map = HashMap::new(); - for line in kvpairs.lines() { + for line in key_val_pairs.lines() { if line.is_empty() || line.starts_with('#') { continue; } @@ -1179,7 +1175,7 @@ impl InfoDict { self.map.get(*key) } - /// Checks if a key is contained in the info dicf. + /// Checks if a key is contained in the info dict. pub fn contains_key(&self, key: &&str) -> bool { self.find(key).is_some() } @@ -1255,7 +1251,7 @@ pub trait ToRedisArgs: Sized { NumericBehavior::NonNumeric } - /// Returns an indiciation if the value contained is exactly one + /// Returns an indication if the value contained is exactly one /// argument. It returns false if it's zero or more than one. This /// is used in some high level functions to intelligently switch /// between `GET` and `MGET` variants. @@ -1401,7 +1397,7 @@ ryu_based_to_redis_impl!(f64, NumericBehavior::NumberIsFloat); feature = "bigdecimal", feature = "num-bigint" ))] -macro_rules! bignum_to_redis_impl { +macro_rules! big_num_to_redis_impl { ($t:ty) => { impl ToRedisArgs for $t { fn write_redis_args(&self, out: &mut W) @@ -1415,13 +1411,13 @@ macro_rules! bignum_to_redis_impl { } #[cfg(feature = "rust_decimal")] -bignum_to_redis_impl!(rust_decimal::Decimal); +big_num_to_redis_impl!(rust_decimal::Decimal); #[cfg(feature = "bigdecimal")] -bignum_to_redis_impl!(bigdecimal::BigDecimal); +big_num_to_redis_impl!(bigdecimal::BigDecimal); #[cfg(feature = "num-bigint")] -bignum_to_redis_impl!(num_bigint::BigInt); +big_num_to_redis_impl!(num_bigint::BigInt); #[cfg(feature = "num-bigint")] -bignum_to_redis_impl!(num_bigint::BigUint); +big_num_to_redis_impl!(num_bigint::BigUint); impl ToRedisArgs for bool { fn write_redis_args(&self, out: &mut W) @@ -1969,7 +1965,7 @@ impl FromRedisValue for String { /// Implement `FromRedisValue` for `$Type` (which should use the generic parameter `$T`). /// /// The implementation parses the value into a vec, and then passes the value through `$convert`. -/// If `$convert` is ommited, it defaults to `Into::into`. +/// If `$convert` is omitted, it defaults to `Into::into`. macro_rules! from_vec_from_redis_value { (<$T:ident> $Type:ty) => { from_vec_from_redis_value!(<$T> $Type; Into::into); @@ -2169,16 +2165,16 @@ where { fn from_redis_value(v: &Value) -> RedisResult> { let v = get_inner_value(v); - let items = v - .as_sequence() - .ok_or_else(|| invalid_type_error_inner!(v, "Response type not btreeset compatible"))?; + let items = v.as_sequence().ok_or_else(|| { + invalid_type_error_inner!(v, "Response type not btree_set compatible") + })?; items.iter().map(|item| from_redis_value(item)).collect() } fn from_owned_redis_value(v: Value) -> RedisResult> { let v = get_owned_inner_value(v); let items = v .into_sequence() - .map_err(|v| invalid_type_error_inner!(v, "Response type not btreeset compatible"))?; + .map_err(|v| invalid_type_error_inner!(v, "Response type not btree_set compatible"))?; items .into_iter() .map(|item| from_owned_redis_value(item)) diff --git a/glide-core/redis-rs/redis/tests/auth.rs b/glide-core/redis-rs/redis/tests/auth.rs new file mode 100644 index 0000000000..2bf9d250a2 --- /dev/null +++ b/glide-core/redis-rs/redis/tests/auth.rs @@ -0,0 +1,304 @@ +mod support; + +#[cfg(test)] +mod auth { + use crate::support::*; + use redis::{ + aio::MultiplexedConnection, + cluster::ClusterClientBuilder, + cluster_async::ClusterConnection, + cluster_routing::{MultipleNodeRoutingInfo, ResponsePolicy, RoutingInfo}, + cmd, ConnectionInfo, GlideConnectionOptions, RedisConnectionInfo, RedisResult, Value, + }; + + const ALL_SUCCESS_ROUTE: RoutingInfo = RoutingInfo::MultiNode(( + MultipleNodeRoutingInfo::AllNodes, + Some(ResponsePolicy::AllSucceeded), + )); + + const PASSWORD: &str = "password"; + const NEW_PASSWORD: &str = "new_password"; + + enum ConnectionType { + Cluster, + Standalone, + } + + enum Connection { + Cluster(ClusterConnection), + Standalone(MultiplexedConnection), + } + + async fn create_connection( + password: Option, + connection_type: ConnectionType, + cluster_context: Option<&TestClusterContext>, + standalone_context: Option<&TestContext>, + ) -> RedisResult { + match connection_type { + ConnectionType::Cluster => { + let cluster_context = + cluster_context.expect("ClusterContext is required for Cluster connection"); + let builder = get_builder(cluster_context, password); + let connection = builder.build()?.get_async_connection(None).await?; + Ok(Connection::Cluster(connection)) + } + ConnectionType::Standalone => { + let standalone_context = + standalone_context.expect("TestContext is required for Standalone connection"); + let info = get_connection_info(standalone_context, password); + let client = redis::Client::open(info)?; + let connection = client + .get_multiplexed_tokio_connection(GlideConnectionOptions::default()) + .await?; + Ok(Connection::Standalone(connection)) + } + } + } + + fn get_connection_info(cluster: &TestContext, password: Option) -> ConnectionInfo { + let addr = cluster.server.connection_info().addr.clone(); + ConnectionInfo { + addr, + redis: RedisConnectionInfo { + password, + ..Default::default() + }, + } + } + + fn get_builder(cluster: &TestClusterContext, password: Option) -> ClusterClientBuilder { + let mut builder = ClusterClientBuilder::new(cluster.nodes.clone()); + if let Some(password) = password { + builder = builder.password(password); + } + builder + } + + async fn set_password(password: &str, conn: &mut Connection) -> RedisResult<()> { + let mut set_auth_cmd = cmd("config"); + set_auth_cmd.arg("set").arg("requirepass").arg(password); + match conn { + Connection::Cluster(cluster_conn) => cluster_conn + .route_command(&set_auth_cmd, ALL_SUCCESS_ROUTE) + .await + .map(|_| ()), + Connection::Standalone(standalone_conn) => set_auth_cmd + .query_async::<_, ()>(standalone_conn) + .await + .map(|_| ()), + } + } + + async fn kill_non_management_connections(con: &mut Connection) { + let mut kill_cmd = cmd("client"); + kill_cmd.arg("kill").arg("type").arg("normal"); + match con { + Connection::Cluster(cluster_conn) => { + cluster_conn + .route_command(&kill_cmd, ALL_SUCCESS_ROUTE) + .await + .unwrap(); + } + Connection::Standalone(standalone_conn) => { + kill_cmd.arg("skipme").arg("no"); + kill_cmd + .query_async::<_, ()>(standalone_conn) + .await + .unwrap(); + } + } + } + + #[tokio::test] + #[serial_test::serial] + async fn test_replace_password_cluster() { + let cluster_context = TestClusterContext::new(3, 0); + + // Create a management connection to set the password + let management_connection = + match create_connection(None, ConnectionType::Cluster, Some(&cluster_context), None) + .await + .unwrap() + { + Connection::Cluster(conn) => conn, + _ => panic!("Expected ClusterConnection"), + }; + + // Set the password using the unified function + let mut management_conn = Connection::Cluster(management_connection.clone()); + set_password(PASSWORD, &mut management_conn).await.unwrap(); + + // Test that we can't connect without password + let connection_should_fail = + create_connection(None, ConnectionType::Cluster, Some(&cluster_context), None).await; + assert!(connection_should_fail.is_err()); + let err = connection_should_fail.err().unwrap(); + println!("{}", err.to_string()); + assert!(err.to_string().contains("Authentication required.")); + + // Test that we can connect with password + let mut connection_should_succeed = match create_connection( + Some(PASSWORD.to_string()), + ConnectionType::Cluster, + Some(&cluster_context), + None, + ) + .await + .unwrap() + { + Connection::Cluster(conn) => conn, + _ => panic!("Expected ClusterConnection"), + }; + + let res: RedisResult = cmd("set") + .arg("foo") + .arg("bar") + .query_async(&mut connection_should_succeed) + .await; + assert_eq!(res.unwrap(), Value::Okay); + + // Verify that we can retrieve the set value + let res: RedisResult = cmd("get") + .arg("foo") + .query_async(&mut connection_should_succeed) + .await; + assert_eq!(res.unwrap(), Value::BulkString(b"bar".to_vec())); + + // Kill the connection to force reconnection + kill_non_management_connections(&mut Connection::Cluster(management_connection.clone())) + .await; + + // Attempt to get the value again to ensure reconnection works + let should_be_ok: RedisResult = cmd("get") + .arg("foo") + .query_async(&mut connection_should_succeed) + .await; + assert_eq!(should_be_ok.unwrap(), Value::BulkString(b"bar".to_vec())); + + // Update the password in the connection + connection_should_succeed + .update_connection_password(Some(NEW_PASSWORD.to_string())) + .await + .unwrap(); + + // Update the password on the server + let mut management_conn = Connection::Cluster(management_connection.clone()); + set_password(NEW_PASSWORD, &mut management_conn) + .await + .unwrap(); + + // Test that we can't connect with the old password + let connection_should_fail = create_connection( + Some(PASSWORD.to_string()), + ConnectionType::Cluster, + Some(&cluster_context), + None, + ) + .await; + assert!(connection_should_fail.is_err()); + let err = connection_should_fail.err().unwrap(); + assert!(err + .to_string() + .contains("Password authentication failed- AuthenticationFailed")); + + // Kill the connection to force reconnection + let mut management_conn = Connection::Cluster(management_connection); + kill_non_management_connections(&mut management_conn).await; + + // Verify that the connection with new password still works + let result_should_succeed: RedisResult = cmd("get") + .arg("foo") + .query_async(&mut connection_should_succeed) + .await; + assert!(result_should_succeed.is_ok()); + assert_eq!( + result_should_succeed.unwrap(), + Value::BulkString(b"bar".to_vec()) + ); + } + + #[tokio::test] + #[serial_test::serial] + async fn test_replace_password_standalone() { + let standalone_context = TestContext::new(); + + // Create a management connection to set the password + let management_connection = match create_connection( + None, + ConnectionType::Standalone, + None, + Some(&standalone_context), + ) + .await + .unwrap() + { + Connection::Standalone(conn) => conn, + _ => panic!("Expected Standalone connection"), + }; + + // Set the password using the unified function + let mut management_conn = Connection::Standalone(management_connection.clone()); + set_password(PASSWORD, &mut management_conn).await.unwrap(); + + // Test that we can't send commands with new connection without password + let connection_should_fail = create_connection( + None, + ConnectionType::Standalone, + None, + Some(&standalone_context), + ) + .await; + let res_should_fail: RedisResult = match connection_should_fail.unwrap() { + Connection::Cluster(mut conn) => cmd("get").arg("foo").query_async(&mut conn).await, + Connection::Standalone(mut conn) => cmd("get").arg("foo").query_async(&mut conn).await, + }; + assert!(res_should_fail.is_err()); + + // Test that we can connect with password + let mut connection_should_succeed = match create_connection( + Some(PASSWORD.to_string()), + ConnectionType::Standalone, + None, + Some(&standalone_context), + ) + .await + .unwrap() + { + Connection::Standalone(conn) => conn, + _ => panic!("Expected Standalone connection"), + }; + + let res: RedisResult = cmd("set") + .arg("foo") + .arg("bar") + .query_async(&mut connection_should_succeed) + .await; + assert_eq!(res.unwrap(), Value::Okay); + + // Update the password in the connection + connection_should_succeed + .update_connection_password(Some(NEW_PASSWORD.to_string())) + .await + .unwrap(); + + // Update the password on the server + let mut management_conn = Connection::Standalone(management_connection.clone()); + set_password(NEW_PASSWORD, &mut management_conn) + .await + .unwrap(); + + // Reset the management connection + kill_non_management_connections(&mut management_conn).await; + + // Test that we can't connect with the old password + let connection_should_fail = create_connection( + Some(PASSWORD.to_string()), + ConnectionType::Standalone, + None, + Some(&standalone_context), + ) + .await; + assert!(connection_should_fail.is_err()); + } +} diff --git a/glide-core/src/client/mod.rs b/glide-core/src/client/mod.rs index 0ed66c4217..ffbdc60d4e 100644 --- a/glide-core/src/client/mod.rs +++ b/glide-core/src/client/mod.rs @@ -9,7 +9,9 @@ use futures::FutureExt; use logger_core::{log_info, log_warn}; use redis::aio::ConnectionLike; use redis::cluster_async::ClusterConnection; -use redis::cluster_routing::{Routable, RoutingInfo, SingleNodeRoutingInfo}; +use redis::cluster_routing::{ + MultipleNodeRoutingInfo, ResponsePolicy, Routable, RoutingInfo, SingleNodeRoutingInfo, +}; use redis::{Cmd, ErrorKind, ObjectType, PushInfo, RedisError, RedisResult, ScanStateRC, Value}; pub use standalone_client::StandaloneClient; use std::io; @@ -259,9 +261,9 @@ impl Client { if let Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) = routing { - let cmdname = cmd.command().unwrap_or_default(); - let cmdname = String::from_utf8_lossy(&cmdname); - if redis::cluster_routing::is_readonly_cmd(cmdname.as_bytes()) { + let cmd_name = cmd.command().unwrap_or_default(); + let cmd_name = String::from_utf8_lossy(&cmd_name); + if redis::cluster_routing::is_readonly_cmd(cmd_name.as_bytes()) { // A read-only command, go ahead and send it to a random node RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random) } else { @@ -270,7 +272,7 @@ impl Client { log_warn( "send_command", format!( - "User provided 'Random' routing which is not suitable for the writeable command '{cmdname}'. Changing it to 'RandomPrimary'" + "User provided 'Random' routing which is not suitable for the writeable command '{cmd_name}'. Changing it to 'RandomPrimary'" ), ); RoutingInfo::SingleNode(SingleNodeRoutingInfo::RandomPrimary) @@ -474,6 +476,34 @@ impl Client { self.inflight_requests_allowed .fetch_add(1, Ordering::SeqCst) } + + /// Update the password used to authenticate with the servers. + /// If None is passed, the password will be removed. + /// If `re_auth` is true, the new password will be used to re-authenticate with all of the nodes. + pub async fn update_connection_password( + &mut self, + password: Option, + re_auth: bool, + ) -> RedisResult { + if re_auth { + let routing = RoutingInfo::MultiNode(( + MultipleNodeRoutingInfo::AllNodes, + Some(ResponsePolicy::AllSucceeded), + )); + let mut cmd = redis::cmd("AUTH"); + cmd.arg(&password); + self.send_command(&cmd, Some(routing)).await?; + } + + match self.internal_client { + ClientWrapper::Standalone(ref mut client) => { + client.update_connection_password(password).await + } + ClientWrapper::Cluster { ref mut client } => { + client.update_connection_password(password).await + } + } + } } fn load_cmd(code: &[u8]) -> Cmd { diff --git a/glide-core/src/client/standalone_client.rs b/glide-core/src/client/standalone_client.rs index 961f67e516..2a6dbd0e77 100644 --- a/glide-core/src/client/standalone_client.rs +++ b/glide-core/src/client/standalone_client.rs @@ -470,6 +470,19 @@ impl StandaloneClient { } }); } + + /// Update the password used to authenticate with the servers. + /// If the password is `None`, the password will be removed. + pub async fn update_connection_password( + &mut self, + password: Option, + ) -> RedisResult { + self.get_connection(false) + .get_connection() + .await? + .update_connection_password(password.clone()) + .await + } } async fn get_connection_and_replication_info( diff --git a/glide-core/src/protobuf/command_request.proto b/glide-core/src/protobuf/command_request.proto index 5b2b826acc..e50cdc8b3c 100644 --- a/glide-core/src/protobuf/command_request.proto +++ b/glide-core/src/protobuf/command_request.proto @@ -508,6 +508,11 @@ message ClusterScan { optional string object_type = 4; } +message UpdateConnectionPassword { + optional string password = 1; + bool re_auth = 2; +} + message CommandRequest { uint32 callback_idx = 1; @@ -517,6 +522,7 @@ message CommandRequest { ScriptInvocation script_invocation = 4; ScriptInvocationPointers script_invocation_pointers = 5; ClusterScan cluster_scan = 6; + UpdateConnectionPassword update_connection_password = 7; } - Routes route = 7; + Routes route = 8; } diff --git a/glide-core/src/socket_listener.rs b/glide-core/src/socket_listener.rs index f823f908e5..b7f967e0bd 100644 --- a/glide-core/src/socket_listener.rs +++ b/glide-core/src/socket_listener.rs @@ -375,7 +375,7 @@ async fn invoke_script( async fn send_transaction( request: Transaction, - mut client: Client, + client: &mut Client, routing: Option, ) -> ClientUsageResult { let mut pipeline = redis::Pipeline::with_capacity(request.commands.capacity()); @@ -461,7 +461,7 @@ fn get_route( } } -fn handle_request(request: CommandRequest, client: Client, writer: Rc) { +fn handle_request(request: CommandRequest, mut client: Client, writer: Rc) { task::spawn_local(async move { let mut updated_inflight_counter = true; let client_clone = client.clone(); @@ -489,7 +489,7 @@ fn handle_request(request: CommandRequest, client: Client, writer: Rc) { } command_request::Command::Transaction(transaction) => { match get_route(request.route.0, None) { - Ok(routes) => send_transaction(transaction, client, routes).await, + Ok(routes) => send_transaction(transaction, &mut client, routes).await, Err(e) => Err(e), } } @@ -522,6 +522,17 @@ fn handle_request(request: CommandRequest, client: Client, writer: Rc) { Err(e) => Err(e), } } + command_request::Command::UpdateConnectionPassword( + update_connection_password_command, + ) => client + .update_connection_password( + update_connection_password_command + .password + .map(|chars| chars.to_string()), + update_connection_password_command.re_auth, + ) + .await + .map_err(|err| err.into()), }, None => { log_debug( diff --git a/node/src/BaseClient.ts b/node/src/BaseClient.ts index 36cee35b72..665acb4fae 100644 --- a/node/src/BaseClient.ts +++ b/node/src/BaseClient.ts @@ -915,7 +915,8 @@ export class BaseClient { | command_request.Command | command_request.Command[] | command_request.ScriptInvocation - | command_request.ClusterScan, + | command_request.ClusterScan + | command_request.UpdateConnectionPassword, options: WritePromiseOptions = {}, ): Promise { const route = toProtobufRoute(options?.route); @@ -985,7 +986,8 @@ export class BaseClient { | command_request.Command | command_request.Command[] | command_request.ScriptInvocation - | command_request.ClusterScan, + | command_request.ClusterScan + | command_request.UpdateConnectionPassword, route?: command_request.Routes, ) { const message = Array.isArray(command) @@ -1005,10 +1007,15 @@ export class BaseClient { callbackIdx, clusterScan: command, }) - : command_request.CommandRequest.create({ - callbackIdx, - scriptInvocation: command, - }); + : command instanceof command_request.UpdateConnectionPassword + ? command_request.CommandRequest.create({ + callbackIdx, + updateConnectionPassword: command, + }) + : command_request.CommandRequest.create({ + callbackIdx, + scriptInvocation: command, + }); message.route = route; this.writeOrBufferRequest( @@ -7672,4 +7679,46 @@ export class BaseClient { throw err; } } + + /** + * Update the current connection with a new password. + * + * This method is useful in scenarios where the server password has changed or when utilizing short-lived passwords for enhanced security. + * It allows the client to update its password to reconnect upon disconnection without the need to recreate the client instance. + * This ensures that the internal reconnection mechanism can handle reconnection seamlessly, preventing the loss of in-flight commands. + * + * This method updates the client's internal password configuration and does not perform password rotation on the server side. + * + * @param password - The new password to update the current password, or `null` to remove the current password. + * @param reAuth - If `true`, the client will re-authenticate immediately with the new password. If `false`, the new password will be used for the next connection attempt. + * @returns Always `"OK"`. + * + * @example + * ```typescript + * await client.updateConnectionPassword("newPassword", true) // "OK" + * ``` + */ + async updateConnectionPassword(password: string | null, reAuth: boolean) { + const updateConnectionPassword = + command_request.UpdateConnectionPassword.create({ + password: password, + reAuth, + }); + + const response = await this.createWritePromise( + updateConnectionPassword, + ); + + if (response === "OK" && !this.config?.credentials) { + this.config = { + ...this.config!, + credentials: { + ...this.config!.credentials, + password: password ? password : "", + }, + }; + } + + return response; + } } diff --git a/node/tests/SharedTests.ts b/node/tests/SharedTests.ts index 6cada7b66f..1a6ddaa5cf 100644 --- a/node/tests/SharedTests.ts +++ b/node/tests/SharedTests.ts @@ -12240,6 +12240,161 @@ export function runBaseTests(config: { }, config.timeout, ); + + describe.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( + "update_connection_password_%p", + (protocol) => { + const NEW_PASSWORD = "new_password"; + const WRONG_PASSWORD = "wrong_password"; + /** + * Test replacing the connection password without immediate re-authentication. + * Verifies that: + * 1. The client can update its internal password + * 2. The client remains connected with current auth + * 3. The client can reconnect using the new password after server password change + * Currently, this test is only supported for cluster mode, + * since standalone mode dont have retry mechanism. + */ + it("test_update_connection_password", async () => { + await runTest(async (client: BaseClient) => { + try { + if (client instanceof GlideClient) { + return; + } + + const result = await client.updateConnectionPassword( + NEW_PASSWORD, + false, + ); + expect(result).toEqual("OK"); + + await client.set("test_key", "test_value"); + const value = await client.get("test_key"); + expect(value).toEqual("test_value"); + await expect( + client.configSet({ + requirepass: NEW_PASSWORD, + }), + ).resolves.toBe("OK"); + await client.customCommand([ + "CLIENT", + "KILL", + "TYPE", + "normal", + "skipme", + "no", + ]); + await client.set("test_key2", "test_value2"); + const value2 = await client.get("test_key2"); + expect(value2).toEqual("test_value2"); + await client.configSet({ + requirepass: "", + }); + } finally { + client?.close(); + } + }, protocol); + }); + + /** + * Test that immediate re-authentication fails when no server password is set. + * This verifies proper error handling when trying to re-authenticate with a + * password when the server has no password set. + */ + it("test_update_connection_password_no_server_auth", async () => { + await runTest(async (client: BaseClient) => { + try { + await expect( + client.updateConnectionPassword(NEW_PASSWORD, true), + ).rejects.toThrow(RequestError); + } finally { + client?.close(); + } + }, protocol); + }); + + /** + * Test replacing connection password with a long password string. + * Verifies that the client can handle long passwords (1000 characters). + */ + it("test_update_connection_password_long", async () => { + await runTest(async (client: BaseClient) => { + try { + const longPassword = "p".repeat(1000); + await expect( + client.updateConnectionPassword( + longPassword, + false, + ), + ).resolves.toBe("OK"); + await client.configSet({ + requirepass: "", + }); + } finally { + client?.close(); + } + }, protocol); + }); + + /** + * Test that re-authentication fails when using wrong password. + * Verifies proper error handling when immediate re-authentication is attempted + * with a password that doesn't match the server's password. + */ + it("test_replace_password_reauth_wrong_password", async () => { + await runTest(async (client: BaseClient) => { + try { + await client.configSet({ + requirepass: NEW_PASSWORD, + }); + + await expect( + client.updateConnectionPassword( + WRONG_PASSWORD, + true, + ), + ).rejects.toThrow(RequestError); + await client.updateConnectionPassword( + NEW_PASSWORD, + true, + ); + await client.configSet({ + requirepass: "", + }); + } finally { + client?.close(); + } + }, protocol); + }); + /** + * Test replacing connection password with immediate re-authentication. + * Verifies that: + * 1. The client can update its password and re-authenticate immediately + * 2. The client remains operational after re-authentication + */ + it("test_update_connection_password_with_reauth", async () => { + await runTest(async (client: BaseClient) => { + try { + await client.configSet({ + requirepass: NEW_PASSWORD, + }); + + await expect( + client.updateConnectionPassword(NEW_PASSWORD, true), + ).resolves.toBe("OK"); + await client.set("test_key", "test_value"); + const value = await client.get("test_key"); + expect(value).toEqual("test_value"); + await client.configSet({ + requirepass: "", + }); + } finally { + client?.close(); + } + }, protocol); + }); + }, + ); } export function runCommonTests(config: { diff --git a/package.json b/package.json index fa682d7107..2f59fcc5a8 100644 --- a/package.json +++ b/package.json @@ -7,6 +7,6 @@ "eslint-config-prettier": "^9.1.0", "prettier": "^3.3.3", "typescript": "^5.6.2", - "typescript-eslint": "^8.5.0" + "typescript-eslint": "^8.13" } } diff --git a/python/dev_requirements.txt b/python/dev_requirements.txt index 36f3438740..02e9c4fd53 100644 --- a/python/dev_requirements.txt +++ b/python/dev_requirements.txt @@ -4,3 +4,4 @@ isort == 5.10 mypy == 1.2 mypy-protobuf == 3.5 packaging >= 22.0 +pyrsistent diff --git a/python/python/glide/__init__.py b/python/python/glide/__init__.py index 7d92c38b63..6f77cac760 100644 --- a/python/python/glide/__init__.py +++ b/python/python/glide/__init__.py @@ -260,7 +260,7 @@ "TrimByMaxLen", "TrimByMinId", "UpdateOptions", - "ClusterScanCursor" + "ClusterScanCursor", # PubSub "PubSubMsg", # Json diff --git a/python/python/glide/async_commands/core.py b/python/python/glide/async_commands/core.py index edb5178c0a..4c29fbc3c9 100644 --- a/python/python/glide/async_commands/core.py +++ b/python/python/glide/async_commands/core.py @@ -392,6 +392,41 @@ async def _cluster_scan( type: Optional[ObjectType] = ..., ) -> TResult: ... + async def _update_connection_password( + self, password: Optional[str], re_auth: bool + ) -> TResult: ... + + async def update_connection_password( + self, password: Optional[str], re_auth: bool + ) -> TOK: + """ + Update the current connection password with a new password. + + **Note:** This method updates the client's internal password configuration and does + not perform password rotation on the server side. + + This method is useful in scenarios where the server password has changed or when + utilizing short-lived passwords for enhanced security. It allows the client to + update its password to reconnect upon disconnection without the need to recreate + the client instance. This ensures that the internal reconnection mechanism can + handle reconnection seamlessly, preventing the loss of in-flight commands. + + Args: + password (Optional[str]): The new password to use for the connection, + if `None` the password will be removed. + re_auth (bool): + - `True`: The client will re-authenticate immediately with the new password. + - `False`: The new password will be used for the next connection attempt. + + Returns: + TOK: A simple OK response. + + Example: + >>> await client.update_connection_password("new_password", re_auth=True) + 'OK' + """ + return cast(TOK, await self._update_connection_password(password, re_auth)) + async def set( self, key: TEncodable, diff --git a/python/python/glide/glide_client.py b/python/python/glide/glide_client.py index f53644fa3d..2838ae288e 100644 --- a/python/python/glide/glide_client.py +++ b/python/python/glide/glide_client.py @@ -9,7 +9,7 @@ from glide.async_commands.command_args import ObjectType from glide.async_commands.core import CoreCommands from glide.async_commands.standalone_commands import StandaloneCommands -from glide.config import BaseClientConfiguration +from glide.config import BaseClientConfiguration, ServerCredentials from glide.constants import DEFAULT_READ_BYTES_SIZE, OK, TEncodable, TRequest, TResult from glide.exceptions import ( ClosingError, @@ -26,6 +26,7 @@ from glide.protobuf.response_pb2 import RequestErrorType, Response from glide.protobuf_codec import PartialMessageException, ProtobufCodec from glide.routes import Route, set_protobuf_route +from pyrsistent import optional from .glide import ( DEFAULT_TIMEOUT_IN_MILLISECONDS, @@ -532,6 +533,21 @@ async def _reader_loop(self) -> None: else: await self._process_response(response=response) + async def _update_connection_password( + self, password: Optional[str], re_auth: bool + ) -> TResult: + request = CommandRequest() + request.callback_idx = self._get_callback_index() + request.update_connection_password.password = password + request.update_connection_password.re_auth = re_auth + response = await self._write_request_await_response(request) + # Update the client binding side password if managed to change core configuration password + if response is OK: + if self.config.credentials is None: + self.config.credentials = ServerCredentials(password=password or "") + self.config.credentials.password = password or "" + return response + class GlideClusterClient(BaseClient, ClusterCommands): """ diff --git a/python/python/tests/test_auth.py b/python/python/tests/test_auth.py new file mode 100644 index 0000000000..694c8c345b --- /dev/null +++ b/python/python/tests/test_auth.py @@ -0,0 +1,153 @@ +import pytest +from glide.config import ProtocolVersion +from glide.constants import OK +from glide.exceptions import RequestError +from glide.glide_client import GlideClient, GlideClusterClient, TGlideClient +from glide.routes import AllNodes + +NEW_PASSWORD = "new_secure_password" +WRONG_PASSWORD = "wrong_password" + + +async def auth_client(client: TGlideClient, password): + """ + Authenticates the given TGlideClient server connected. + """ + if isinstance(client, GlideClient): + await client.custom_command(["AUTH", password]) + if isinstance(client, GlideClusterClient): + await client.custom_command(["AUTH", password], route=AllNodes()) + + +async def config_set_new_password(client: TGlideClient, password): + """ + Sets a new password for the given TGlideClient server connected. + This function updates the server to require a new password. + """ + if isinstance(client, GlideClient): + await client.config_set({"requirepass": password}) + if isinstance(client, GlideClusterClient): + await client.config_set({"requirepass": password}, route=AllNodes()) + + +async def kill_connections(client: TGlideClient): + """ + Kills all connections to the given TGlideClient server connected. + """ + if isinstance(client, GlideClient): + await client.custom_command( + ["CLIENT", "KILL", "TYPE", "normal", "skipme", "no"] + ) + if isinstance(client, GlideClusterClient): + await client.custom_command( + ["CLIENT", "KILL", "TYPE", "normal", "skipme", "no"], route=AllNodes() + ) + + +@pytest.mark.asyncio +class TestAuthCommands: + """Test cases for password authentication and management""" + + @pytest.fixture(autouse=True) + async def setup(self, glide_client: TGlideClient): + """ + Teardown the test environment, make sure that theres no password set on the server side + """ + try: + await auth_client(glide_client, NEW_PASSWORD) + await config_set_new_password(glide_client, "") + except RequestError: + pass + yield + try: + await auth_client(glide_client, NEW_PASSWORD) + await config_set_new_password(glide_client, "") + except RequestError: + pass + + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_update_connection_password(self, glide_client: TGlideClient): + """ + Test replacing the connection password without immediate re-authentication. + Verifies that: + 1. The client can update its internal password + 2. The client remains connected with current auth + 3. The client can reconnect using the new password after server password change + Currently, this test is only supported for cluster mode, + since standalone mode dont have retry mechanism. + """ + result = await glide_client.update_connection_password( + NEW_PASSWORD, re_auth=False + ) + assert result == OK + # Verify that the client is still authenticated + assert await glide_client.set("test_key", "test_value") == OK + value = await glide_client.get("test_key") + assert value == b"test_value" + await config_set_new_password(glide_client, NEW_PASSWORD) + await kill_connections(glide_client) + # Verify that the client is able to reconnect with the new password + value = await glide_client.get("test_key") + assert value == b"test_value" + + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_update_connection_password_no_server_auth( + self, glide_client: TGlideClient + ): + """ + Test that immediate re-authentication fails when no server password is set. + This verifies proper error handling when trying to re-authenticate with a + password when the server has no password set. + """ + with pytest.raises(RequestError): + await glide_client.update_connection_password(WRONG_PASSWORD, re_auth=True) + + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_update_connection_password_long(self, glide_client: TGlideClient): + """ + Test replacing connection password with a long password string. + Verifies that the client can handle long passwords (1000 characters). + """ + long_password = "p" * 1000 + result = await glide_client.update_connection_password( + long_password, re_auth=False + ) + assert result == OK + + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_replace_password_reauth_wrong_password( + self, glide_client: TGlideClient + ): + """ + Test that re-authentication fails when using wrong password. + Verifies proper error handling when immediate re-authentication is attempted + with a password that doesn't match the server's password. + """ + await config_set_new_password(glide_client, NEW_PASSWORD) + with pytest.raises(RequestError): + await glide_client.update_connection_password(WRONG_PASSWORD, re_auth=True) + + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_update_connection_password_with_reauth( + self, glide_client: TGlideClient + ): + """ + Test replacing connection password with immediate re-authentication. + Verifies that: + 1. The client can update its password and re-authenticate immediately + 2. The client remains operational after re-authentication + """ + await config_set_new_password(glide_client, NEW_PASSWORD) + result = await glide_client.update_connection_password( + NEW_PASSWORD, re_auth=True + ) + assert result == OK + # Verify that the client is still authenticated + assert await glide_client.set("test_key", "test_value") == OK + value = await glide_client.get("test_key") + assert value == b"test_value" diff --git a/python/requirements.txt b/python/requirements.txt index 63b2be3603..93c90b2cac 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -1,7 +1,7 @@ async-timeout==4.0.2;python_version<"3.11" maturin==0.13.0 protobuf==3.20.* -pytest==7.1.2 -pytest-asyncio==0.19.0 +pytest +pytest-asyncio typing_extensions==4.8.0;python_version<"3.11" pytest-html