Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

buffer: wake tasks waiting for channel capacity when terminating #480

Merged
merged 4 commits into from
Oct 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tower-test/tests/mock.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use tokio_test::{assert_pending, assert_ready};
use tower_test::{assert_request_eq, mock};

#[tokio::test]
#[tokio::test(flavor = "current_thread")]
async fn single_request_ready() {
let (mut service, mut handle) = mock::spawn();

Expand All @@ -16,7 +16,7 @@ async fn single_request_ready() {
assert_eq!(response.await.unwrap(), "world");
}

#[tokio::test]
#[tokio::test(flavor = "current_thread")]
#[should_panic]
async fn backpressure() {
let (mut service, mut handle) = mock::spawn::<_, ()>();
Expand Down
4 changes: 2 additions & 2 deletions tower/src/buffer/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ where
Request: Send + 'static,
{
let (tx, rx) = mpsc::unbounded_channel();
let (handle, worker) = Worker::new(service, rx);
let semaphore = Semaphore::new(bound);
let (semaphore, wake_waiters) = Semaphore::new_with_close(bound);
let (handle, worker) = Worker::new(service, rx, wake_waiters);
(
Buffer {
tx,
Expand Down
23 changes: 22 additions & 1 deletion tower/src/buffer/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use tower_service::Service;
/// as part of the public API. This is the "sealed" pattern to include "private"
/// types in public traits that are not meant for consumers of the library to
/// implement (only call).
#[pin_project]
#[pin_project(PinnedDrop)]
#[derive(Debug)]
pub struct Worker<T, Request>
where
Expand All @@ -33,6 +33,7 @@ where
finish: bool,
failed: Option<ServiceError>,
handle: Handle,
close: Option<crate::semaphore::Close>,
}

/// Get the error out
Expand All @@ -49,6 +50,7 @@ where
pub(crate) fn new(
service: T,
rx: mpsc::UnboundedReceiver<Message<Request, T::Future>>,
close: crate::semaphore::Close,
) -> (Handle, Worker<T, Request>) {
let handle = Handle {
inner: Arc::new(Mutex::new(None)),
Expand All @@ -61,6 +63,7 @@ where
rx,
service,
handle: handle.clone(),
close: Some(close),
};

(handle, worker)
Expand Down Expand Up @@ -195,6 +198,11 @@ where
.as_ref()
.expect("Worker::failed did not set self.failed?")
.clone()));
// Wake any tasks waiting on channel capacity.
if let Some(close) = self.close.take() {
tracing::debug!("waking pending tasks");
close.close();
}
}
}
}
Expand All @@ -208,6 +216,19 @@ where
}
}

#[pin_project::pinned_drop]
impl<T, Request> PinnedDrop for Worker<T, Request>
where
T: Service<Request>,
T::Error: Into<crate::BoxError>,
{
fn drop(mut self: Pin<&mut Self>) {
if let Some(close) = self.as_mut().close.take() {
close.close();
}
}
}

impl Handle {
pub(crate) fn get_error_on_closed(&self) -> crate::BoxError {
self.inner
Expand Down
41 changes: 40 additions & 1 deletion tower/src/semaphore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::{
future::Future,
mem,
pin::Pin,
sync::Arc,
sync::{Arc, Weak},
task::{Context, Poll},
};
use tokio::sync;
Expand All @@ -16,13 +16,32 @@ pub(crate) struct Semaphore {
state: State,
}

#[derive(Debug)]
pub(crate) struct Close {
semaphore: Weak<sync::Semaphore>,
permits: usize,
}

enum State {
Waiting(Pin<Box<dyn Future<Output = Permit> + Send + 'static>>),
Ready(Permit),
Empty,
}

impl Semaphore {
pub(crate) fn new_with_close(permits: usize) -> (Self, Close) {
let semaphore = Arc::new(sync::Semaphore::new(permits));
let close = Close {
semaphore: Arc::downgrade(&semaphore),
permits,
};
let semaphore = Self {
semaphore,
state: State::Empty,
};
(semaphore, close)
}

pub(crate) fn new(permits: usize) -> Self {
Self {
semaphore: Arc::new(sync::Semaphore::new(permits)),
Expand Down Expand Up @@ -72,3 +91,23 @@ impl fmt::Debug for State {
}
}
}

impl Close {
/// Close the semaphore, waking any remaining tasks currently awaiting a permit.
pub(crate) fn close(self) {
// The maximum number of permits that a `tokio::sync::Semaphore`
// can hold is usize::MAX >> 3. If we attempt to add more than that
// number of permits, the semaphore will panic.
// XXX(eliza): another shift is kinda janky but if we add (usize::MAX
// > 3 - initial permits) the semaphore impl panics (I think due to a
// bug in tokio?).
// TODO(eliza): Tokio should _really_ just expose `Semaphore::close`
// publicly so we don't have to do this nonsense...
const MAX: usize = std::usize::MAX >> 4;
if let Some(semaphore) = self.semaphore.upgrade() {
// If we added `MAX - available_permits`, any tasks that are
// currently holding permits could drop them, overflowing the max.
semaphore.add_permits(MAX - self.permits);
}
}
}
3 changes: 3 additions & 0 deletions tower/tests/balance/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#![cfg(feature = "balance")]
#[path = "../support.rs"]
mod support;

use std::future::Future;
use std::task::{Context, Poll};
Expand Down Expand Up @@ -32,6 +34,7 @@ impl tower::load::Load for Mock {

#[test]
fn stress() {
let _t = support::trace_init();
let mut task = task::spawn(());
let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<Result<_, &'static str>>();
let mut cache = Balance::<_, Req>::new(rx);
Expand Down
149 changes: 141 additions & 8 deletions tower/tests/buffer/main.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
#![cfg(feature = "buffer")]

#[path = "../support.rs"]
mod support;
use std::thread;
use tokio_test::{assert_pending, assert_ready, assert_ready_err, assert_ready_ok, task};
use tower::buffer::{error, Buffer};
use tower::{util::ServiceExt, Service};
use tower_test::{assert_request_eq, mock};

fn let_worker_work() {
// Allow the Buffer's executor to do work
thread::sleep(::std::time::Duration::from_millis(100));
}

#[tokio::test]
#[tokio::test(flavor = "current_thread")]
async fn req_and_res() {
let _t = support::trace_init();

let (mut service, mut handle) = new_service();

assert_ready_ok!(service.poll_ready());
Expand All @@ -23,8 +27,10 @@ async fn req_and_res() {
assert_eq!(assert_ready_ok!(response.poll()), "world");
}

#[tokio::test]
#[tokio::test(flavor = "current_thread")]
async fn clears_canceled_requests() {
let _t = support::trace_init();

let (mut service, mut handle) = new_service();

handle.allow(1);
Expand Down Expand Up @@ -59,8 +65,10 @@ async fn clears_canceled_requests() {
assert_eq!(assert_ready_ok!(res3.poll()), "world3");
}

#[tokio::test]
#[tokio::test(flavor = "current_thread")]
async fn when_inner_is_not_ready() {
let _t = support::trace_init();

let (mut service, mut handle) = new_service();

// Make the service NotReady
Expand All @@ -81,9 +89,10 @@ async fn when_inner_is_not_ready() {
assert_eq!(assert_ready_ok!(res1.poll()), "world");
}

#[tokio::test]
#[tokio::test(flavor = "current_thread")]
async fn when_inner_fails() {
use std::error::Error as StdError;
let _t = support::trace_init();

let (mut service, mut handle) = new_service();

Expand All @@ -105,8 +114,10 @@ async fn when_inner_fails() {
}
}

#[tokio::test]
#[tokio::test(flavor = "current_thread")]
async fn poll_ready_when_worker_is_dropped_early() {
let _t = support::trace_init();

let (service, _handle) = mock::pair::<(), ()>();

let (service, worker) = Buffer::pair(service, 1);
Expand All @@ -120,8 +131,10 @@ async fn poll_ready_when_worker_is_dropped_early() {
assert!(err.is::<error::Closed>(), "should be a Closed: {:?}", err);
}

#[tokio::test]
#[tokio::test(flavor = "current_thread")]
async fn response_future_when_worker_is_dropped_early() {
let _t = support::trace_init();

let (service, mut handle) = mock::pair::<_, ()>();

let (service, worker) = Buffer::pair(service, 1);
Expand All @@ -140,8 +153,10 @@ async fn response_future_when_worker_is_dropped_early() {
assert!(err.is::<error::Closed>(), "should be a Closed: {:?}", err);
}

#[tokio::test]
#[tokio::test(flavor = "current_thread")]
async fn waits_for_channel_capacity() {
let _t = support::trace_init();

let (service, mut handle) = mock::pair::<&'static str, &'static str>();

let (service, worker) = Buffer::pair(service, 3);
Expand Down Expand Up @@ -213,6 +228,124 @@ async fn waits_for_channel_capacity() {
assert_ready_ok!(response4.poll());
}

#[tokio::test(flavor = "current_thread")]
async fn wakes_pending_waiters_on_close() {
let _t = support::trace_init();

let (service, mut handle) = mock::pair::<_, ()>();

let (mut service, worker) = Buffer::pair(service, 1);
let mut worker = task::spawn(worker);

// keep the request in the worker
handle.allow(0);
let service1 = service.ready_and().await.unwrap();
assert_pending!(worker.poll());
let mut response = task::spawn(service1.call("hello"));

let mut service1 = service.clone();
let mut ready_and1 = task::spawn(service1.ready_and());
assert_pending!(worker.poll());
assert_pending!(ready_and1.poll(), "no capacity");

let mut service1 = service.clone();
let mut ready_and2 = task::spawn(service1.ready_and());
assert_pending!(worker.poll());
assert_pending!(ready_and2.poll(), "no capacity");

// kill the worker task
drop(worker);

let err = assert_ready_err!(response.poll());
assert!(
err.is::<error::Closed>(),
"response should fail with a Closed, got: {:?}",
err
);

assert!(
ready_and1.is_woken(),
"dropping worker should wake ready_and task 1"
);
let err = assert_ready_err!(ready_and1.poll());
assert!(
err.is::<error::Closed>(),
"ready_and 1 should fail with a Closed, got: {:?}",
err
);

assert!(
ready_and2.is_woken(),
"dropping worker should wake ready_and task 2"
);
let err = assert_ready_err!(ready_and1.poll());
assert!(
err.is::<error::Closed>(),
"ready_and 2 should fail with a Closed, got: {:?}",
err
);
}

#[tokio::test(flavor = "current_thread")]
async fn wakes_pending_waiters_on_failure() {
let _t = support::trace_init();

let (service, mut handle) = mock::pair::<_, ()>();

let (mut service, worker) = Buffer::pair(service, 1);
let mut worker = task::spawn(worker);

// keep the request in the worker
handle.allow(0);
let service1 = service.ready_and().await.unwrap();
assert_pending!(worker.poll());
let mut response = task::spawn(service1.call("hello"));

let mut service1 = service.clone();
let mut ready_and1 = task::spawn(service1.ready_and());
assert_pending!(worker.poll());
assert_pending!(ready_and1.poll(), "no capacity");

let mut service1 = service.clone();
let mut ready_and2 = task::spawn(service1.ready_and());
assert_pending!(worker.poll());
assert_pending!(ready_and2.poll(), "no capacity");

// fail the inner service
handle.send_error("foobar");
// worker task terminates
assert_ready!(worker.poll());

let err = assert_ready_err!(response.poll());
assert!(
err.is::<error::ServiceError>(),
"response should fail with a ServiceError, got: {:?}",
err
);

assert!(
ready_and1.is_woken(),
"dropping worker should wake ready_and task 1"
);
let err = assert_ready_err!(ready_and1.poll());
assert!(
err.is::<error::ServiceError>(),
"ready_and 1 should fail with a ServiceError, got: {:?}",
err
);

assert!(
ready_and2.is_woken(),
"dropping worker should wake ready_and task 2"
);
let err = assert_ready_err!(ready_and1.poll());
assert!(
err.is::<error::ServiceError>(),
"ready_and 2 should fail with a ServiceError, got: {:?}",
err
);
}

type Mock = mock::Mock<&'static str, &'static str>;
type Handle = mock::Handle<&'static str, &'static str>;

Expand Down
Loading