-
Notifications
You must be signed in to change notification settings - Fork 195
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add: naive implementation of a minimum-throughput body with tests #1627
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
use aws_smithy_async::future::fn_stream::FnStream; | ||
use aws_smithy_async::rt::sleep::AsyncSleep; | ||
use aws_smithy_http::body::SdkBody; | ||
use aws_smithy_http::byte_stream::ByteStream; | ||
use aws_smithy_http::minimum_throughput::MinimumThroughputBody; | ||
use std::convert::Infallible; | ||
use std::time::Duration; | ||
|
||
// This test is flaky. The error message will end with something like "0.999 B/s was observed" | ||
#[should_panic = "minimum throughput was specified at 2 B/s, but throughput of"] | ||
#[tokio::test] | ||
async fn test_throughput_timeout_happens_for_slow_stream() { | ||
tracing_subscriber::fmt::init(); | ||
tracing::info!("tracing is working"); | ||
// let test_start_time = std::time::Instant::now(); | ||
// Have to return results b/c `hyper::body::Body::wrap_stream` expects them | ||
let stream: FnStream<Result<String, Infallible>, _> = FnStream::new(|tx| { | ||
let async_sleep = aws_smithy_async::rt::sleep::TokioSleep::new(); | ||
Box::pin(async move { | ||
for i in 0..10 { | ||
// Will send slightly less that 1 byte per second because ASCII digits have a size | ||
// of 1 byte and we sleep for 1 second after every digit we send. | ||
tx.send(Ok(format!("{}", i))).await.expect("failed to send"); | ||
async_sleep.sleep(Duration::from_secs(1)).await; | ||
} | ||
}) | ||
}); | ||
let body = ByteStream::new(SdkBody::from(hyper::body::Body::wrap_stream(stream))); | ||
let body = body.map(|body| { | ||
// Throw an error if the stream sends less than 2 bytes per second at any point | ||
let minimum_throughput = (2u64, Duration::from_secs(1)); | ||
SdkBody::from_dyn(aws_smithy_http::body::BoxBody::new( | ||
MinimumThroughputBody::new(body, minimum_throughput), | ||
)) | ||
}); | ||
let res = body.collect().await; | ||
|
||
match res { | ||
Ok(_) => panic!("Expected an error due to slow stream but no error occurred."), | ||
Err(e) => panic!("{}", e), | ||
} | ||
} | ||
|
||
#[tokio::test] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm feeling mathy today, bear with me... I think a good test to add may be a shrinking sine wave. Something like With that function, it should start off happy despite some significant drop offs in speed periodically, but eventually (after 20 seconds or so) get to a point where it needs to timeout. Another good test is a straight line going up starting at 1, something like And then, another good one would be a full minute of good throughput, then 0.5 seconds of zero throughput, followed by good throughput again. |
||
async fn test_throughput_timeout_doesnt_happen_for_fast_stream() { | ||
// Have to return results b/c `hyper::body::Body::wrap_stream` expects them | ||
let stream: FnStream<Result<String, Infallible>, _> = FnStream::new(|tx| { | ||
let async_sleep = aws_smithy_async::rt::sleep::TokioSleep::new(); | ||
Box::pin(async move { | ||
for i in 0..10 { | ||
// Will send slightly less that 1 byte per millisecond because ASCII digits have a | ||
// size of 1 byte and we sleep for 1 millisecond after every digit we send. | ||
tx.send(Ok(format!("{}", i))).await.expect("failed to send"); | ||
async_sleep.sleep(Duration::from_millis(1)).await; | ||
} | ||
}) | ||
}); | ||
let body = ByteStream::new(SdkBody::from(hyper::body::Body::wrap_stream(stream))); | ||
let body = body.map(|body| { | ||
// Throw an error if the stream sends less than 1 bytes per 5ms at any point | ||
let minimum_throughput = (1u64, Duration::from_millis(5)); | ||
|
||
SdkBody::from_dyn(aws_smithy_http::body::BoxBody::new( | ||
MinimumThroughputBody::new(body, minimum_throughput), | ||
)) | ||
}); | ||
let _res = body.collect().await.unwrap(); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
use bytes::Buf; | ||
use http::HeaderMap; | ||
use std::pin::Pin; | ||
use std::task::{Context, Poll}; | ||
use std::time::{Duration, Instant}; | ||
|
||
pin_project_lite::pin_project! { | ||
/// A body-wrapper that will ensure that the wrapped body is emitting bytes faster than some | ||
/// `minimum_throughput`. | ||
pub struct MinimumThroughputBody<InnerBody> { | ||
#[pin] | ||
inner: InnerBody, | ||
// A record of when and how much data was read | ||
throughput_logs: Vec<(Instant, u64)>, | ||
// The minimum acceptable throughput. If the amount of data per unit of time returned is | ||
// less that this, an error will be returned instead. | ||
minimum_throughput: (u64, Duration), | ||
} | ||
} | ||
|
||
impl<T: http_body::Body> MinimumThroughputBody<T> { | ||
pub fn new(body: T, minimum_throughput: (u64, Duration)) -> Self { | ||
Self { | ||
inner: body, | ||
throughput_logs: Vec::new(), | ||
minimum_throughput, | ||
} | ||
} | ||
} | ||
|
||
impl<T> http_body::Body for MinimumThroughputBody<T> | ||
where | ||
T: http_body::Body<Data = bytes::Bytes, Error = Box<dyn std::error::Error + Send + Sync>>, | ||
{ | ||
type Data = T::Data; | ||
type Error = T::Error; | ||
|
||
fn poll_data( | ||
self: Pin<&mut Self>, | ||
cx: &mut Context<'_>, | ||
) -> Poll<Option<Result<Self::Data, Self::Error>>> { | ||
let this = self.project(); | ||
|
||
let poll_res = this.inner.poll_data(cx); | ||
|
||
if let Poll::Ready(Some(Ok(ref data))) = poll_res { | ||
this.throughput_logs | ||
.push((Instant::now(), data.remaining() as u64)); | ||
}; | ||
|
||
let mut logs = this.throughput_logs.iter(); | ||
if let Some((first_instant, first_bytes)) = logs.next() { | ||
let time_elapsed_since_first_poll = first_instant.elapsed(); | ||
let mut total_bytes_read = *first_bytes; | ||
|
||
while let Some((_, bytes_read)) = logs.next() { | ||
total_bytes_read += bytes_read; | ||
} | ||
|
||
let minimum_bytes_per_second = | ||
this.minimum_throughput.0 as f64 / this.minimum_throughput.1.as_secs_f64(); | ||
let actual_bytes_per_second = | ||
total_bytes_read as f64 / time_elapsed_since_first_poll.as_secs_f64(); | ||
Comment on lines
+62
to
+63
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think to be accurate for longer downloads, and to avoid consuming excess memory, you'll need to have a sliding window over the Additionally, we probably shouldn't do this comparison at all unless the configured duration has elapsed already. Otherwise, there's a chance that a slow-to-start download would always result in failure since it would only have one data point. |
||
|
||
// oh no, too slow! | ||
if actual_bytes_per_second < minimum_bytes_per_second { | ||
return Poll::Ready(Some(Err(Box::new(Error::ThroughputBelowMinimum { | ||
expected: this.minimum_throughput.clone(), | ||
actual: (total_bytes_read, time_elapsed_since_first_poll), | ||
})))); | ||
} | ||
}; | ||
|
||
poll_res | ||
} | ||
|
||
fn poll_trailers( | ||
self: Pin<&mut Self>, | ||
cx: &mut Context<'_>, | ||
) -> Poll<Result<Option<HeaderMap>, Self::Error>> { | ||
self.project().inner.poll_trailers(cx) | ||
} | ||
} | ||
|
||
#[derive(Debug)] | ||
enum Error { | ||
ThroughputBelowMinimum { | ||
expected: (u64, Duration), | ||
actual: (u64, Duration), | ||
}, | ||
} | ||
|
||
impl std::fmt::Display for Error { | ||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||
match self { | ||
Self::ThroughputBelowMinimum { expected, actual } => { | ||
let expected = format_throughput(expected); | ||
let actual = format_throughput(actual); | ||
write!( | ||
f, | ||
"minimum throughput was specified at {}, but throughput of {} was observed", | ||
expected, actual | ||
) | ||
} | ||
} | ||
} | ||
} | ||
|
||
impl std::error::Error for Error {} | ||
|
||
/// Format a given throughput as human-readable bytes per second | ||
fn format_throughput(throughput: &(u64, Duration)) -> String { | ||
let b = throughput.0 as f64; | ||
let d = throughput.1.as_secs_f64(); | ||
// The default float formatting behavior will ensure the a number like 2.000 is rendered as 2 | ||
// while a number like 0.9982107441748642 will be rendered as 0.9982107441748642. This | ||
// multiplication and division will truncate a float to have a precision of no greater than 3. | ||
// For example, 0.9982107441748642 would become 0.999. This will fail for very large floats | ||
// but should suffice for the numbers we're dealing with. | ||
let bytes_per_second = ((b / d) * 1000.0).round() / 1000.0; | ||
|
||
format!("{bytes_per_second} B/s") | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally, we should be able to fake time keeping in the
MinimumThroughputBody
so that we don't need to sleep in tests. Essentially, be able to replace the source ofInstant::now()
inside ofMinimumThroughputBody
.