diff --git a/Cargo.lock b/Cargo.lock index a4dcfbbd90..8c80a2cf09 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -530,7 +530,7 @@ dependencies = [ "http 0.2.1", "indexmap", "slab", - "tokio", + "tokio 0.2.23", "tokio-util", "tracing", ] @@ -649,7 +649,7 @@ dependencies = [ "pin-project", "socket2", "time", - "tokio", + "tokio 0.2.23", "tower-service", "tracing", "want", @@ -663,8 +663,8 @@ dependencies = [ "http 0.2.1", "hyper", "pin-project", - "tokio", - "tokio-test", + "tokio 0.2.23", + "tokio-test 0.2.1", "tower", ] @@ -823,7 +823,7 @@ dependencies = [ "linkerd2-error", "linkerd2-opencensus", "regex 1.3.9", - "tokio", + "tokio 0.2.23", "tonic", "tower", "tracing", @@ -883,7 +883,7 @@ dependencies = [ "procinfo", "prost-types", "regex 1.3.9", - "tokio", + "tokio 0.2.23", "tokio-timer", "tonic", "tower", @@ -902,8 +902,8 @@ dependencies = [ "linkerd2-app-core", "linkerd2-app-inbound", "linkerd2-app-outbound", - "tokio", - "tokio-test", + "tokio 0.2.23", + "tokio-test 0.2.1", "tower", "tower-test", "tracing", @@ -918,7 +918,7 @@ dependencies = [ "http 0.2.1", "indexmap", "linkerd2-app-core", - "tokio", + "tokio 0.2.23", "tower", "tracing", ] @@ -942,7 +942,7 @@ dependencies = [ "regex 0.1.80", "rustls", "socket2", - "tokio", + "tokio 0.2.23", "tokio-rustls", "tonic", "tower", @@ -968,7 +968,7 @@ dependencies = [ "linkerd2-io", "linkerd2-retry", "pin-project", - "tokio", + "tokio 0.2.23", "tower", "tracing", "tracing-futures", @@ -979,7 +979,7 @@ name = "linkerd2-app-profiling" version = "0.1.0" dependencies = [ "linkerd2-app-integration", - "tokio", + "tokio 0.2.23", ] [[package]] @@ -993,8 +993,8 @@ dependencies = [ "hyper", "linkerd2-app-core", "regex 0.1.80", - "tokio", - "tokio-test", + "tokio 0.2.23", + "tokio-test 0.2.1", "tower", "tracing", "tracing-futures", @@ -1006,10 +1006,11 @@ name = "linkerd2-buffer" version = "0.1.0" dependencies = [ "futures 0.3.5", + "linkerd2-channel", "linkerd2-error", "pin-project", - "tokio", - "tokio-test", + "tokio 0.2.23", + "tokio-test 0.2.1", "tower", "tower-test", "tracing", @@ -1024,18 +1025,27 @@ dependencies = [ "linkerd2-error", "linkerd2-stack", "parking_lot", - "tokio", + "tokio 0.2.23", "tower", "tracing", ] +[[package]] +name = "linkerd2-channel" +version = "0.1.0" +dependencies = [ + "futures 0.3.5", + "tokio 0.3.5", + "tokio-test 0.3.0", +] + [[package]] name = "linkerd2-concurrency-limit" version = "0.1.0" dependencies = [ "futures 0.3.5", "pin-project", - "tokio", + "tokio 0.2.23", "tower", "tracing", ] @@ -1052,7 +1062,7 @@ dependencies = [ "linkerd2-dns-name", "linkerd2-error", "pin-project", - "tokio", + "tokio 0.2.23", "tracing", "trust-dns-resolver", ] @@ -1072,8 +1082,8 @@ dependencies = [ "futures 0.3.5", "linkerd2-error", "pin-project", - "tokio", - "tokio-test", + "tokio 0.2.23", + "tokio-test 0.2.1", ] [[package]] @@ -1083,7 +1093,7 @@ dependencies = [ "bytes 0.5.4", "futures 0.3.5", "pin-project", - "tokio", + "tokio 0.2.23", "tracing", ] @@ -1127,7 +1137,7 @@ dependencies = [ "pin-project", "quickcheck", "rand 0.7.2", - "tokio", + "tokio 0.2.23", ] [[package]] @@ -1193,9 +1203,9 @@ dependencies = [ "futures 0.3.5", "linkerd2-errno", "pin-project", - "tokio", + "tokio 0.2.23", "tokio-rustls", - "tokio-test", + "tokio-test 0.2.1", ] [[package]] @@ -1210,7 +1220,7 @@ dependencies = [ "indexmap", "parking_lot", "quickcheck", - "tokio", + "tokio 0.2.23", "tracing", ] @@ -1225,7 +1235,7 @@ dependencies = [ "linkerd2-metrics", "opencensus-proto", "pin-project", - "tokio", + "tokio 0.2.23", "tonic", "tower", "tracing", @@ -1240,7 +1250,7 @@ dependencies = [ "linkerd2-signal", "mimalloc", "num_cpus", - "tokio", + "tokio 0.2.23", "tracing", ] @@ -1295,11 +1305,12 @@ dependencies = [ "async-stream", "futures 0.3.5", "indexmap", + "linkerd2-channel", "linkerd2-error", "linkerd2-proxy-core", "linkerd2-stack", "pin-project", - "tokio", + "tokio 0.2.23", "tower", "tracing", "tracing-futures", @@ -1314,7 +1325,7 @@ dependencies = [ "linkerd2-dns", "linkerd2-error", "linkerd2-proxy-core", - "tokio", + "tokio 0.2.23", "tower", "tracing", "tracing-futures", @@ -1343,7 +1354,7 @@ dependencies = [ "linkerd2-timeout", "pin-project", "rand 0.7.2", - "tokio", + "tokio 0.2.23", "tower", "tracing", "tracing-futures", @@ -1361,7 +1372,7 @@ dependencies = [ "linkerd2-proxy-api", "linkerd2-proxy-transport", "pin-project", - "tokio", + "tokio 0.2.23", "tonic", "tracing", ] @@ -1398,7 +1409,7 @@ dependencies = [ "pin-project", "prost-types", "rand 0.7.2", - "tokio", + "tokio 0.2.23", "tonic", "tower", "tracing", @@ -1415,7 +1426,7 @@ dependencies = [ "linkerd2-stack", "pin-project", "rand 0.7.2", - "tokio", + "tokio 0.2.23", "tower", ] @@ -1438,7 +1449,7 @@ dependencies = [ "linkerd2-stack", "pin-project", "rustls", - "tokio", + "tokio 0.2.23", "tokio-rustls", "tokio-util", "tower", @@ -1492,7 +1503,7 @@ dependencies = [ "quickcheck", "rand 0.7.2", "regex 1.3.9", - "tokio", + "tokio 0.2.23", "tonic", "tower", "tracing", @@ -1503,7 +1514,7 @@ dependencies = [ name = "linkerd2-signal" version = "0.1.0" dependencies = [ - "tokio", + "tokio 0.2.23", "tracing", ] @@ -1515,8 +1526,8 @@ dependencies = [ "futures 0.3.5", "linkerd2-error", "pin-project", - "tokio", - "tokio-test", + "tokio 0.2.23", + "tokio-test 0.2.1", "tower", "tower-test", "tracing", @@ -1553,9 +1564,9 @@ dependencies = [ "linkerd2-error", "linkerd2-stack", "pin-project", - "tokio", + "tokio 0.2.23", "tokio-connect", - "tokio-test", + "tokio-test 0.2.1", "tower", "tower-test", "tracing", @@ -1573,7 +1584,7 @@ dependencies = [ "linkerd2-error", "linkerd2-stack", "rand 0.7.2", - "tokio", + "tokio 0.2.23", "tower", "tracing", ] @@ -1941,6 +1952,12 @@ version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "237844750cfbb86f67afe27eee600dfbbcb6188d734139b534cbfbf4f96792ae" +[[package]] +name = "pin-project-lite" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b063f57ec186e6140e2b8b6921e5f1bd89c7356dda5b33acc5401203ca6131c" + [[package]] name = "pin-utils" version = "0.1.0" @@ -2585,14 +2602,27 @@ dependencies = [ "mio-uds", "num_cpus", "parking_lot", - "pin-project-lite", + "pin-project-lite 0.1.4", "signal-hook-registry", "slab", - "tokio-macros", + "tokio-macros 0.2.6", "tracing", "winapi 0.3.8", ] +[[package]] +name = "tokio" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a12a3eb39ee2c231be64487f1fcbe726c8f2514876a55480a5ab8559fc374252" +dependencies = [ + "autocfg 1.0.0", + "futures-core", + "pin-project-lite 0.2.0", + "slab", + "tokio-macros 0.3.1", +] + [[package]] name = "tokio-connect" version = "0.1.0" @@ -2634,6 +2664,17 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio-macros" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21d30fdbb5dc2d8f91049691aa1a9d4d4ae422a21c334ce8936e5886d30c5c45" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "tokio-rustls" version = "0.14.1" @@ -2642,7 +2683,7 @@ checksum = "e12831b255bcfa39dc0436b01e19fea231a37db570686c06ee72c423479f889a" dependencies = [ "futures-core", "rustls", - "tokio", + "tokio 0.2.23", "webpki", ] @@ -2654,7 +2695,18 @@ checksum = "ed0049c119b6d505c4447f5c64873636c7af6c75ab0d45fd9f618d82acb8016d" dependencies = [ "bytes 0.5.4", "futures-core", - "tokio", + "tokio 0.2.23", +] + +[[package]] +name = "tokio-test" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7edc5a2b13a713b9134debc9aa1265b0a1ccf6a81fa05af2ba26e4aee56cdf1a" +dependencies = [ + "bytes 0.5.4", + "futures-core", + "tokio 0.3.5", ] [[package]] @@ -2676,7 +2728,7 @@ source = "git+https://github.com/hawkw/tokio-trace?rev=a8240c5cbb4ff981def84920d dependencies = [ "num_cpus", "serde", - "tokio", + "tokio 0.2.23", "tracing-core", "tracing-subscriber", ] @@ -2692,8 +2744,8 @@ dependencies = [ "futures-io", "futures-sink", "log", - "pin-project-lite", - "tokio", + "pin-project-lite 0.1.4", + "tokio 0.2.23", ] [[package]] @@ -2743,7 +2795,7 @@ dependencies = [ "pin-project", "rand 0.7.2", "slab", - "tokio", + "tokio 0.2.23", "tower-layer 0.3.0 (git+https://github.com/tower-rs/tower?rev=ad348d8)", "tower-service", "tracing", @@ -2766,7 +2818,7 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ce50370d644a0364bf4877ffd4f76404156a248d104e2cc234cd391ea5cdc965" dependencies = [ - "tokio", + "tokio 0.2.23", "tower-service", ] @@ -2793,8 +2845,8 @@ checksum = "9ba4bbc2c1e4a8543c30d4c13a4c8314ed72d6e07581910f665aa13fde0153c8" dependencies = [ "futures-util", "pin-project", - "tokio", - "tokio-test", + "tokio 0.2.23", + "tokio-test 0.2.1", "tower-layer 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)", "tower-service", ] @@ -2902,7 +2954,7 @@ dependencies = [ "rand 0.7.2", "smallvec", "thiserror", - "tokio", + "tokio 0.2.23", "url", ] @@ -2921,7 +2973,7 @@ dependencies = [ "resolv-conf", "smallvec", "thiserror", - "tokio", + "tokio 0.2.23", "trust-dns-proto", ] diff --git a/Cargo.toml b/Cargo.toml index e625d717e1..17c7d2d405 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,8 +10,9 @@ members = [ "linkerd/app/profiling", "linkerd/app/test", "linkerd/app", - "linkerd/cache", "linkerd/buffer", + "linkerd/cache", + "linkerd/channel", "linkerd/concurrency-limit", "linkerd/conditional", "linkerd/dns/name", diff --git a/linkerd/buffer/Cargo.toml b/linkerd/buffer/Cargo.toml index 804ade0348..453c3ceff9 100644 --- a/linkerd/buffer/Cargo.toml +++ b/linkerd/buffer/Cargo.toml @@ -7,6 +7,7 @@ publish = false [dependencies] futures = "0.3" +linkerd2-channel = { path = "../channel" } linkerd2-error = { path = "../error" } tokio = { version = "0.2", features = ["sync", "stream", "time", "macros"] } tower = { version = "0.3", default_features = false, features = ["util"] } diff --git a/linkerd/buffer/src/dispatch.rs b/linkerd/buffer/src/dispatch.rs index 26ec113ff0..59ab3dd60b 100644 --- a/linkerd/buffer/src/dispatch.rs +++ b/linkerd/buffer/src/dispatch.rs @@ -1,9 +1,9 @@ use crate::error::{IdleError, ServiceError}; use crate::InFlight; use futures::{prelude::*, select_biased}; +use linkerd2_channel as mpsc; use linkerd2_error::Error; use std::sync::Arc; -use tokio::sync::mpsc; use tower::util::ServiceExt; use tracing::trace; @@ -54,7 +54,7 @@ pub(crate) async fn run( e = idle().fuse() => { let error = ServiceError(Arc::new(e.into())); trace!(%error, "Idling out inner service"); - return; + break; } } } @@ -64,7 +64,7 @@ pub(crate) async fn run( mod test { use super::*; use std::time::Duration; - use tokio::sync::{mpsc, oneshot}; + use tokio::sync::oneshot; use tokio::time::delay_for; use tokio_test::{assert_pending, assert_ready, task}; use tower_test::mock; @@ -101,12 +101,13 @@ mod test { delay_for(max_idle).await; // Send a request after the deadline has fired but before the - // dispatch future is polled. Ensure that the request is admitted, resetting idleness. - tx.try_send({ + // dispatch future is polled. Ensure that the request is admitted, + // resetting idleness. + tx.send({ let (tx, _rx) = oneshot::channel(); super::InFlight { request: (), tx } }) - .ok() + .await .expect("request not sent"); assert_pending!(dispatch.poll()); diff --git a/linkerd/buffer/src/lib.rs b/linkerd/buffer/src/lib.rs index 2fe62c797a..1bc1a41b6d 100644 --- a/linkerd/buffer/src/lib.rs +++ b/linkerd/buffer/src/lib.rs @@ -1,8 +1,9 @@ #![recursion_limit = "256"] +use linkerd2_channel as mpsc; use linkerd2_error::Error; -use std::{future::Future, pin::Pin, time::Duration}; -use tokio::sync::{mpsc, oneshot}; +use std::{fmt, future::Future, pin::Pin, time::Duration}; +use tokio::sync::oneshot; mod dispatch; pub mod error; @@ -43,3 +44,13 @@ where let dispatch = dispatch::run(inner, rx, idle); (Buffer::new(tx), dispatch) } + +// Required so that `TrySendError`/`SendError` can be `expect`ed. +impl fmt::Debug for InFlight { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("InFlight") + .field("request_type", &std::any::type_name::()) + .field("response_type", &std::any::type_name::()) + .finish() + } +} diff --git a/linkerd/buffer/src/service.rs b/linkerd/buffer/src/service.rs index a6b9df64c7..f72dd09392 100644 --- a/linkerd/buffer/src/service.rs +++ b/linkerd/buffer/src/service.rs @@ -1,9 +1,10 @@ use crate::error::Closed; use crate::InFlight; +use linkerd2_channel as mpsc; use linkerd2_error::Error; use std::task::{Context, Poll}; use std::{future::Future, pin::Pin}; -use tokio::sync::{mpsc, oneshot}; +use tokio::sync::oneshot; pub struct Buffer { /// The queue on which in-flight requests are sent to the inner service. @@ -27,14 +28,13 @@ where type Future = Pin> + Send + 'static>>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.tx.poll_ready(cx).map_err(|_| Closed(()).into()) + self.tx.poll_ready(cx).map_err(Into::into) } fn call(&mut self, request: Req) -> Self::Future { let (tx, rx) = oneshot::channel(); self.tx .try_send(InFlight { request, tx }) - .ok() .expect("poll_ready must be called"); Box::pin(async move { rx.await.map_err(|_| Closed(()))??.await }) } diff --git a/linkerd/channel/Cargo.toml b/linkerd/channel/Cargo.toml new file mode 100644 index 0000000000..13f8563020 --- /dev/null +++ b/linkerd/channel/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "linkerd2-channel" +version = "0.1.0" +authors = ["Linkerd Developers "] +edition = "2018" +publish = false +description = """ +A bounded MPSC channel where senders expose a `poll_ready` method. +""" + +[dependencies] +tokio = { version = "0.3", features = ["sync", "stream"] } +futures = "0.3" + +[dev-dependencies] +tokio = { version = "0.3", features = ["sync", "stream", "macros"] } +tokio-test = "0.3" diff --git a/linkerd/channel/src/lib.rs b/linkerd/channel/src/lib.rs new file mode 100644 index 0000000000..0cc3c1807c --- /dev/null +++ b/linkerd/channel/src/lib.rs @@ -0,0 +1,191 @@ +use futures::{future, ready, Stream}; +use std::sync::{Arc, Weak}; +use std::task::{Context, Poll}; +use std::{fmt, future::Future, mem, pin::Pin}; +use tokio::sync::{mpsc, OwnedSemaphorePermit as Permit, Semaphore}; + +use self::error::{SendError, TrySendError}; +pub use tokio::sync::mpsc::error; + +/// Returns a new pollable, bounded MPSC channel. +/// +/// Unlike `tokio::sync`'s `MPSC` channel, this channel exposes a `poll_ready` +/// function, at the cost of an allocation when driving it to readiness. +pub fn channel(buffer: usize) -> (Sender, Receiver) { + assert!(buffer > 0, "mpsc bounded channel requires buffer > 0"); + let semaphore = Arc::new(Semaphore::new(buffer)); + let (tx, rx) = mpsc::unbounded_channel(); + let rx = Receiver { + rx, + semaphore: Arc::downgrade(&semaphore), + buffer, + }; + let tx = Sender { + tx, + semaphore, + state: State::Empty, + }; + (tx, rx) +} + +/// A bounded, pollable MPSC sender. +/// +/// This is similar to Tokio's bounded MPSC channel's `Sender` type, except that +/// it exposes a `poll_ready` function, at the cost of an allocation when +/// driving it to readiness. +pub struct Sender { + tx: mpsc::UnboundedSender<(T, Permit)>, + semaphore: Arc, + state: State, +} + +/// A bounded MPSC receiver. +/// +/// This is similar to Tokio's bounded MPSC channel's `Receiver` type. +pub struct Receiver { + rx: mpsc::UnboundedReceiver<(T, Permit)>, + semaphore: Weak, + buffer: usize, +} + +enum State { + Waiting(Pin + Send + Sync>>), + Acquired(Permit), + Empty, +} + +impl Sender { + pub fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll>> { + loop { + self.state = match self.state { + State::Empty => State::Waiting(Box::pin(self.semaphore.clone().acquire_owned())), + State::Waiting(ref mut f) => State::Acquired(ready!(Pin::new(f).poll(cx))), + State::Acquired(_) if self.tx.is_closed() => { + return Poll::Ready(Err(SendError(()))) + } + State::Acquired(_) => return Poll::Ready(Ok(())), + } + } + } + + pub async fn ready(&mut self) -> Result<(), SendError<()>> { + future::poll_fn(|cx| self.poll_ready(cx)).await + } + + pub fn try_send(&mut self, value: T) -> Result<(), TrySendError> { + if self.tx.is_closed() { + return Err(TrySendError::Closed(value)); + } + self.state = match mem::replace(&mut self.state, State::Empty) { + // Have we previously acquired a permit? + State::Acquired(permit) => { + self.send2(value, permit); + return Ok(()); + } + // Okay, can we acquire a permit now? + State::Empty => { + if let Ok(permit) = self.semaphore.clone().try_acquire_owned() { + self.send2(value, permit); + return Ok(()); + } + State::Empty + } + state => state, + }; + Err(TrySendError::Full(value)) + } + + pub async fn send(&mut self, value: T) -> Result<(), SendError> { + if let Err(_) = self.ready().await { + return Err(SendError(value)); + } + match mem::replace(&mut self.state, State::Empty) { + State::Acquired(permit) => { + self.send2(value, permit); + Ok(()) + } + state => panic!("unexpected state after poll_ready: {:?}", state), + } + } + + fn send2(&mut self, value: T, permit: Permit) { + self.tx.send((value, permit)).ok().expect("was not closed"); + } +} + +impl Clone for Sender { + fn clone(&self) -> Self { + Self { + tx: self.tx.clone(), + semaphore: self.semaphore.clone(), + state: State::Empty, + } + } +} + +impl fmt::Debug for Sender { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Sender") + .field("message_type", &std::any::type_name::()) + .field("state", &self.state) + .field("semaphore", &self.semaphore) + .finish() + } +} + +// === impl Receiver === + +impl Receiver { + pub async fn recv(&mut self) -> Option { + self.rx.recv().await.map(|(t, _)| t) + } + + pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll> { + let res = ready!(Pin::new(&mut self.rx).poll_next(cx)); + Poll::Ready(res.map(|(t, _)| t)) + } +} + +impl Stream for Receiver { + type Item = T; + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let res = ready!(Pin::new(&mut self.as_mut().rx).poll_next(cx)); + Poll::Ready(res.map(|(t, _)| t)) + } +} + +impl Drop for Receiver { + fn drop(&mut self) { + if let Some(semaphore) = self.semaphore.upgrade() { + // Close the buffer by releasing any senders waiting on channel capacity. + // If more than `usize::MAX >> 3` permits are added to the semaphore, it + // will panic. + const MAX: usize = std::usize::MAX >> 4; + semaphore.add_permits(MAX - self.buffer - semaphore.available_permits()); + } + } +} + +impl fmt::Debug for Receiver { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Receiver") + .field("message_type", &std::any::type_name::()) + .field("semaphore", &self.semaphore) + .finish() + } +} + +// === impl State === + +impl fmt::Debug for State { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt( + match self { + State::Acquired(_) => "State::Acquired(..)", + State::Waiting(_) => "State::Waiting(..)", + State::Empty => "State::Empty", + }, + f, + ) + } +} diff --git a/linkerd/channel/tests/channel.rs b/linkerd/channel/tests/channel.rs new file mode 100644 index 0000000000..ab28302673 --- /dev/null +++ b/linkerd/channel/tests/channel.rs @@ -0,0 +1,270 @@ +#![warn(rust_2018_idioms)] +use linkerd2_channel::{channel, error::TrySendError, Receiver, Sender}; + +use std::sync::Arc; +use tokio_test::task; +use tokio_test::{assert_err, assert_ok, assert_pending, assert_ready, assert_ready_ok}; + +trait AssertSend: Send {} +impl AssertSend for Sender {} +impl AssertSend for Receiver {} + +#[tokio::test] +async fn send_recv_with_buffer() { + let (mut tx, mut rx) = channel::(16); + + // Using poll_ready / try_send + assert_ready_ok!(task::spawn(tx.ready()).poll()); + tx.try_send(1).unwrap(); + + // Without poll_ready + tx.send(2).await.unwrap(); + + drop(tx); + + let val = rx.recv().await; + assert_eq!(val, Some(1)); + + let val = rx.recv().await; + assert_eq!(val, Some(2)); + + let val = rx.recv().await; + assert!(val.is_none()); +} + +#[tokio::test] +async fn ready_disarm() { + let (tx, mut rx) = channel::(2); + let mut tx1 = tx.clone(); + let mut tx2 = tx.clone(); + let mut tx3 = tx.clone(); + let mut tx4 = tx; + + // We should be able to `poll_ready` two handles without problem + let _ = assert_ok!(tx1.ready().await); + let _ = assert_ok!(tx2.ready().await); + + // But a third should not be ready + let mut r3 = task::spawn(tx3.ready()); + assert_pending!(r3.poll()); + + let mut r4 = task::spawn(tx4.ready()); + assert_pending!(r4.poll()); + + // Using one of the readyd slots should allow a new handle to become ready + tx1.send(1).await.unwrap(); + + // We also need to receive for the slot to be free + assert!(!r3.is_woken()); + rx.recv().await.unwrap(); + // Now there's a free slot! + assert!(r3.is_woken()); + assert!(!r4.is_woken()); + + // Dropping a permit should also open up a slot + drop(tx2); + assert!(r4.is_woken()); + + let mut r1 = task::spawn(tx1.ready()); + assert_pending!(r1.poll()); +} + +#[tokio::test] +async fn send_recv_stream_with_buffer() { + use tokio::stream::StreamExt; + + let (mut tx, mut rx) = channel::(16); + + tokio::spawn(async move { + assert_ok!(tx.send(1).await); + assert_ok!(tx.send(2).await); + }); + + assert_eq!(Some(1), rx.next().await); + assert_eq!(Some(2), rx.next().await); + assert_eq!(None, rx.next().await); +} + +#[tokio::test] +async fn async_send_recv_with_buffer() { + let (mut tx, mut rx) = channel(16); + + tokio::spawn(async move { + assert_ok!(tx.send(1).await); + assert_ok!(tx.send(2).await); + }); + + assert_eq!(Some(1), rx.recv().await); + assert_eq!(Some(2), rx.recv().await); + assert_eq!(None, rx.recv().await); +} + +#[tokio::test] +async fn start_send_past_cap() { + let (mut tx1, mut rx) = channel(1); + let mut tx2 = tx1.clone(); + + assert_ok!(tx1.try_send(())); + + let mut r1 = task::spawn(tx1.ready()); + assert_pending!(r1.poll()); + + { + let mut r2 = task::spawn(tx2.ready()); + assert_pending!(r2.poll()); + + drop(r1); + drop(tx1); + + assert!(rx.recv().await.is_some()); + + assert!(r2.is_woken()); + } + + drop(tx2); + + assert!(rx.recv().await.is_none()); +} + +#[test] +#[should_panic] +fn buffer_gteq_one() { + channel::(0); +} + +#[tokio::test] +async fn no_t_bounds_buffer() { + struct NoImpls; + + let (tx, mut rx) = channel(100); + + // sender should be Debug even though T isn't Debug + println!("{:?}", tx); + // same with Receiver + println!("{:?}", rx); + // and sender should be Clone even though T isn't Clone + assert!(tx.clone().send(NoImpls).await.is_ok()); + + assert!(rx.recv().await.is_some()); +} + +#[tokio::test] +async fn send_recv_buffer_limited() { + let (mut tx, mut rx) = channel::(1); + + // ready capacity + assert_ok!(tx.ready().await); + + // Send first message + tx.try_send(1).unwrap(); + + // Not ready + let mut p2 = task::spawn(tx.ready()); + assert_pending!(p2.poll()); + + // Take the value + assert!(rx.recv().await.is_some()); + + // Notified + assert!(p2.is_woken()); + + // Send second + assert_ready_ok!(p2.poll()); + drop(p2); + tx.try_send(2).unwrap(); + + assert!(rx.recv().await.is_some()); +} + +#[tokio::test] +async fn tx_close_gets_none() { + let (_, mut rx) = channel::(10); + assert!(rx.recv().await.is_none()); +} + +#[tokio::test] +async fn try_send_fail() { + let (mut tx, mut rx) = channel(1); + + tx.ready().await.unwrap(); + tx.try_send("hello").unwrap(); + + // This should fail + match assert_err!(tx.try_send("fail")) { + TrySendError::Full(..) => {} + _ => panic!(), + } + + assert_eq!(rx.recv().await, Some("hello")); + + assert_ok!(tx.try_send("goodbye")); + drop(tx); + + assert_eq!(rx.recv().await, Some("goodbye")); + assert!(rx.recv().await.is_none()); +} + +#[tokio::test] +async fn drop_tx_releases_permit() { + // ready reserves a permit capacity, ensure that the capacity is + // released if tx is dropped w/o sending a value. + let (mut tx1, _rx) = channel::(1); + let mut tx2 = tx1.clone(); + + assert_ok!(tx1.ready().await); + + let mut ready2 = task::spawn(tx2.ready()); + assert_pending!(ready2.poll()); + + drop(tx1); + + assert!(ready2.is_woken()); + assert_ready_ok!(ready2.poll()); +} + +#[tokio::test] +async fn dropping_rx_closes_channel() { + let (mut tx, rx) = channel(100); + + let msg = Arc::new(()); + assert_ok!(tx.try_send(msg.clone())); + + drop(rx); + assert_err!(tx.ready().await); + assert_eq!(1, Arc::strong_count(&msg)); +} + +#[test] +fn dropping_rx_closes_channel_for_try() { + let (mut tx, rx) = channel(100); + + let msg = Arc::new(()); + tx.try_send(msg.clone()).unwrap(); + + drop(rx); + + { + let err = assert_err!(tx.try_send(msg.clone())); + match err { + TrySendError::Closed(..) => {} + _ => panic!(), + } + } + + assert_eq!(1, Arc::strong_count(&msg)); +} + +#[test] +fn unconsumed_messages_are_dropped() { + let msg = Arc::new(()); + + let (mut tx, rx) = channel(100); + + tx.try_send(msg.clone()).unwrap(); + + assert_eq!(2, Arc::strong_count(&msg)); + + drop((tx, rx)); + + assert_eq!(1, Arc::strong_count(&msg)); +} diff --git a/linkerd/proxy/discover/Cargo.toml b/linkerd/proxy/discover/Cargo.toml index 3c0941bee5..8c673fe481 100644 --- a/linkerd/proxy/discover/Cargo.toml +++ b/linkerd/proxy/discover/Cargo.toml @@ -11,6 +11,7 @@ Utilities to implement a Discover with the core Resolve type [dependencies] futures = "0.3" +linkerd2-channel = { path = "../../channel" } linkerd2-error = { path = "../../error" } linkerd2-proxy-core = { path = "../core" } linkerd2-stack = { path = "../../stack" } diff --git a/linkerd/proxy/discover/src/buffer.rs b/linkerd/proxy/discover/src/buffer.rs index 0dcc860912..dc6d49c197 100644 --- a/linkerd/proxy/discover/src/buffer.rs +++ b/linkerd/proxy/discover/src/buffer.rs @@ -1,11 +1,12 @@ use futures::{ready, Stream, TryFuture}; +use linkerd2_channel as mpsc; use linkerd2_error::{Error, Never}; use pin_project::pin_project; use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; use std::time::Duration; -use tokio::sync::{mpsc, oneshot}; +use tokio::sync::oneshot; use tokio::time::{self, Delay}; use tower::discover; use tracing::warn;