From debb33a0fd7bcd752fcb286bfa04ed7a469ac34c Mon Sep 17 00:00:00 2001 From: Geoffry Song Date: Wed, 20 Jan 2021 12:29:35 -0800 Subject: [PATCH] Let the ? operator work natively in try_stream!. Insteads of desugaring `?` in the macro, we can have the async block itself return `Result<(), E>`, and adjust the supporting code so that `?` just works. The benefit is that this allows `?` operators that are hidden behind macros. --- async-stream-impl/src/lib.rs | 38 +++++------------- async-stream/src/async_stream.rs | 67 ++++++++++++++++++++++++++++++++ async-stream/src/lib.rs | 2 +- async-stream/tests/try_stream.rs | 20 ++++++++++ 4 files changed, 98 insertions(+), 29 deletions(-) diff --git a/async-stream-impl/src/lib.rs b/async-stream-impl/src/lib.rs index b959f65..036bc64 100644 --- a/async-stream-impl/src/lib.rs +++ b/async-stream-impl/src/lib.rs @@ -5,8 +5,6 @@ use syn::parse::Parser; use syn::visit_mut::VisitMut; struct Scrub<'a> { - /// Whether the stream is a try stream. - is_try: bool, /// The unit expression, `()`. unit: Box, has_yielded: bool, @@ -24,9 +22,8 @@ fn parse_input(input: TokenStream) -> syn::Result<(TokenStream2, Vec) } impl<'a> Scrub<'a> { - fn new(is_try: bool, crate_path: &'a TokenStream2) -> Self { + fn new(crate_path: &'a TokenStream2) -> Self { Self { - is_try, unit: syn::parse_quote!(()), has_yielded: false, crate_path, @@ -44,26 +41,7 @@ impl VisitMut for Scrub<'_> { // let ident = &self.yielder; - *i = if self.is_try { - syn::parse_quote! { __yield_tx.send(::core::result::Result::Ok(#value_expr)).await } - } else { - syn::parse_quote! { __yield_tx.send(#value_expr).await } - }; - } - syn::Expr::Try(try_expr) => { - syn::visit_mut::visit_expr_try_mut(self, try_expr); - // let ident = &self.yielder; - let e = &try_expr.expr; - - *i = syn::parse_quote! { - match #e { - ::core::result::Result::Ok(v) => v, - ::core::result::Result::Err(e) => { - __yield_tx.send(::core::result::Result::Err(e.into())).await; - return; - } - } - }; + *i = syn::parse_quote! { __yield_tx.send(#value_expr).await }; } syn::Expr::Closure(_) | syn::Expr::Async(_) => { // Don't transform inner closures or async blocks. @@ -124,7 +102,7 @@ pub fn stream_inner(input: TokenStream) -> TokenStream { Err(e) => return e.to_compile_error().into(), }; - let mut scrub = Scrub::new(false, &crate_path); + let mut scrub = Scrub::new(&crate_path); for mut stmt in &mut stmts { scrub.visit_stmt_mut(&mut stmt); @@ -158,7 +136,7 @@ pub fn try_stream_inner(input: TokenStream) -> TokenStream { Err(e) => return e.to_compile_error().into(), }; - let mut scrub = Scrub::new(true, &crate_path); + let mut scrub = Scrub::new(&crate_path); for mut stmt in &mut stmts { scrub.visit_stmt_mut(&mut stmt); @@ -174,9 +152,13 @@ pub fn try_stream_inner(input: TokenStream) -> TokenStream { quote!({ let (mut __yield_tx, __yield_rx) = #crate_path::yielder::pair(); - #crate_path::AsyncStream::new(__yield_rx, async move { + #crate_path::AsyncTryStream::new(__yield_rx, async move { #dummy_yield - #(#stmts)* + let () = { + #(#stmts)* + }; + #[allow(unreachable_code)] + Ok(()) }) }) .into() diff --git a/async-stream/src/async_stream.rs b/async-stream/src/async_stream.rs index f60c87e..cf6076e 100644 --- a/async-stream/src/async_stream.rs +++ b/async-stream/src/async_stream.rs @@ -75,3 +75,70 @@ where } } } + +#[doc(hidden)] +#[derive(Debug)] +pub struct AsyncTryStream { + rx: Receiver, + done: bool, + generator: U, +} + +impl AsyncTryStream { + #[doc(hidden)] + pub fn new(rx: Receiver, generator: U) -> AsyncTryStream { + AsyncTryStream { + rx, + done: false, + generator, + } + } +} + +impl FusedStream for AsyncTryStream +where + U: Future>, +{ + fn is_terminated(&self) -> bool { + self.done + } +} + +impl Stream for AsyncTryStream +where + U: Future>, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + unsafe { + let me = Pin::get_unchecked_mut(self); + + if me.done { + return Poll::Ready(None); + } + + let mut dst = None; + let res = { + let _enter = me.rx.enter(&mut dst); + Pin::new_unchecked(&mut me.generator).poll(cx) + }; + + me.done = res.is_ready(); + + if let Poll::Ready(Err(e)) = res { + return Poll::Ready(Some(Err(e))); + } + + if let Some(val) = dst.take() { + return Poll::Ready(Some(Ok(val))); + } + + if me.done { + Poll::Ready(None) + } else { + Poll::Pending + } + } + } +} diff --git a/async-stream/src/lib.rs b/async-stream/src/lib.rs index b236351..77975c2 100644 --- a/async-stream/src/lib.rs +++ b/async-stream/src/lib.rs @@ -164,7 +164,7 @@ pub mod yielder; // Used by the macro, but not intended to be accessed publicly. #[doc(hidden)] -pub use crate::async_stream::AsyncStream; +pub use crate::async_stream::{AsyncStream, AsyncTryStream}; #[doc(hidden)] pub use async_stream_impl; diff --git a/async-stream/tests/try_stream.rs b/async-stream/tests/try_stream.rs index 063e37a..84c8ff9 100644 --- a/async-stream/tests/try_stream.rs +++ b/async-stream/tests/try_stream.rs @@ -1,6 +1,7 @@ use async_stream::try_stream; use futures_core::stream::Stream; +use futures_util::pin_mut; use futures_util::stream::StreamExt; #[tokio::test] @@ -78,3 +79,22 @@ async fn multi_try() { values ); } + +macro_rules! try_macro { + ($e:expr) => { + $e? + }; +} + +#[tokio::test] +async fn try_in_macro() { + let s = try_stream! { + yield "hi"; + try_macro!(Err("bye")); + }; + pin_mut!(s); + + assert_eq!(s.next().await, Some(Ok("hi"))); + assert_eq!(s.next().await, Some(Err("bye"))); + assert_eq!(s.next().await, None); +}