From 5805866f27ad5ab47ac70e2956ff47d83f2da8ca Mon Sep 17 00:00:00 2001 From: Matt Alonso Date: Mon, 25 Mar 2024 11:34:56 -0500 Subject: [PATCH] Fix #167: Notify waiters when dropping a bad connection from the pool --- bb8/src/inner.rs | 1 + bb8/tests/test.rs | 55 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+) diff --git a/bb8/src/inner.rs b/bb8/src/inner.rs index 8ba6f53..4e73ddd 100644 --- a/bb8/src/inner.rs +++ b/bb8/src/inner.rs @@ -141,6 +141,7 @@ where (_, _) => { let approvals = locked.dropped(1, &self.inner.statics); self.spawn_replenishing_approvals(approvals); + self.inner.notify.notify_waiters(); } } } diff --git a/bb8/tests/test.rs b/bb8/tests/test.rs index 1c4658c..82f4339 100644 --- a/bb8/tests/test.rs +++ b/bb8/tests/test.rs @@ -831,3 +831,58 @@ async fn test_customize_connection_acquire() { let connection_1_or_2 = pool.get().await.unwrap(); assert!(connection_1_or_2.custom_field == 1 || connection_1_or_2.custom_field == 2); } + +#[tokio::test] +async fn test_broken_connections_dont_starve_pool() { + use std::sync::RwLock; + use std::{convert::Infallible, time::Duration}; + + #[derive(Default)] + struct ConnectionManager { + counter: RwLock, + } + #[derive(Debug)] + struct Connection; + + #[async_trait::async_trait] + impl bb8::ManageConnection for ConnectionManager { + type Connection = Connection; + type Error = Infallible; + + async fn connect(&self) -> Result { + Ok(Connection) + } + + async fn is_valid(&self, _: &mut Self::Connection) -> Result<(), Self::Error> { + Ok(()) + } + + fn has_broken(&self, _: &mut Self::Connection) -> bool { + let mut counter = self.counter.write().unwrap(); + let res = *counter < 5; + *counter += 1; + res + } + } + + let pool = bb8::Pool::builder() + .max_size(5) + .connection_timeout(Duration::from_secs(10)) + .build(ConnectionManager::default()) + .await + .unwrap(); + + let mut futures = Vec::new(); + + for _ in 0..10 { + let pool = pool.clone(); + futures.push(tokio::spawn(async move { + let conn = pool.get().await.unwrap(); + drop(conn); + })); + } + + for future in futures { + future.await.unwrap(); + } +}