Skip to content

Commit

Permalink
Auto merge of rust-lang#117703 - compiler-errors:recursive-async, r=lcnr
Browse files Browse the repository at this point in the history
Support async recursive calls (as long as they have indirection)

Before rust-lang#101692, we stored coroutine witness types directly inside of the coroutine. That means that a coroutine could not contain itself (as a witness field) without creating a cycle in the type representation of the coroutine, which we detected with the `OpaqueTypeExpander`, which is used to detect cycles when expanding opaque types after that are inferred to contain themselves.

After `-Zdrop-tracking-mir` was stabilized, we no longer store these generator witness fields directly, but instead behind a def-id based query. That means there is no technical obstacle in the compiler preventing coroutines from containing themselves per se, other than the fact that for a coroutine to have a non-infinite layout, it must contain itself wrapped in a layer of allocation indirection (like a `Box`).

This means that it should be valid for this code to work:

```
async fn async_fibonacci(i: u32) -> u32 {
    if i == 0 || i == 1 {
        i
    } else {
        Box::pin(async_fibonacci(i - 1)).await
          + Box::pin(async_fibonacci(i - 2)).await
    }
}
```

Whereas previously, you'd need to coerce the future to `Pin<Box<dyn Future<Output = ...>>` before `await`ing it, to prevent the async's desugared coroutine from containing itself across as await point.

This PR does two things:
1. Only report an error if an opaque expansion cycle is detected *not* through coroutine witness fields.
    * Instead, if we find an opaque cycle through coroutine witness fields, we compute the layout of the coroutine. If that results in a cycle error, we report it as a recursive async fn.
4. Reworks the way we report layout errors having to do with coroutines, to make up for the diagnostic regressions introduced by (1.). We actually do even better now, pointing out the call sites of the recursion!
  • Loading branch information
bors committed Jan 9, 2024
2 parents 387e7a5 + 9a75603 commit dc64103
Show file tree
Hide file tree
Showing 43 changed files with 394 additions and 243 deletions.
26 changes: 11 additions & 15 deletions compiler/rustc_error_codes/src/error_codes/E0733.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,35 +10,31 @@ async fn foo(n: usize) {
}
```

To perform async recursion, the `async fn` needs to be desugared such that the
`Future` is explicit in the return type:
The recursive invocation can be boxed:

```edition2018,compile_fail,E0720
use std::future::Future;
fn foo_desugared(n: usize) -> impl Future<Output = ()> {
async move {
if n > 0 {
foo_desugared(n - 1).await;
}
```edition2018
async fn foo(n: usize) {
if n > 0 {
Box::pin(foo(n - 1)).await;
}
}
```

Finally, the future is wrapped in a pinned box:
The `Box<...>` ensures that the result is of known size, and the pin is
required to keep it in the same place in memory.

Alternatively, the body can be boxed:

```edition2018
use std::future::Future;
use std::pin::Pin;
fn foo_recursive(n: usize) -> Pin<Box<dyn Future<Output = ()>>> {
fn foo(n: usize) -> Pin<Box<dyn Future<Output = ()>>> {
Box::pin(async move {
if n > 0 {
foo_recursive(n - 1).await;
foo(n - 1).await;
}
})
}
```

The `Box<...>` ensures that the result is of known size, and the pin is
required to keep it in the same place in memory.

[`async`]: https://doc.rust-lang.org/std/keyword.async.html
6 changes: 6 additions & 0 deletions compiler/rustc_hir/src/hir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1361,6 +1361,12 @@ impl CoroutineKind {
}
}

impl CoroutineKind {
pub fn is_fn_like(self) -> bool {
matches!(self, CoroutineKind::Desugared(_, CoroutineSource::Fn))
}
}

impl fmt::Display for CoroutineKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Expand Down
52 changes: 29 additions & 23 deletions compiler/rustc_hir_analysis/src/check/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use rustc_middle::middle::stability::EvalResult;
use rustc_middle::traits::{DefiningAnchor, ObligationCauseCode};
use rustc_middle::ty::fold::BottomUpFolder;
use rustc_middle::ty::layout::{LayoutError, MAX_SIMD_LANES};
use rustc_middle::ty::util::{Discr, IntTypeExt};
use rustc_middle::ty::util::{Discr, InspectCoroutineFields, IntTypeExt};
use rustc_middle::ty::GenericArgKind;
use rustc_middle::ty::{
AdtDef, ParamEnv, RegionKind, TypeSuperVisitable, TypeVisitable, TypeVisitableExt,
Expand Down Expand Up @@ -213,13 +213,12 @@ fn check_opaque(tcx: TyCtxt<'_>, def_id: LocalDefId) {
return;
}

let args = GenericArgs::identity_for_item(tcx, item.owner_id);
let span = tcx.def_span(item.owner_id.def_id);

if tcx.type_of(item.owner_id.def_id).instantiate_identity().references_error() {
return;
}
if check_opaque_for_cycles(tcx, item.owner_id.def_id, args, span, origin).is_err() {
if check_opaque_for_cycles(tcx, item.owner_id.def_id, span).is_err() {
return;
}

Expand All @@ -230,19 +229,36 @@ fn check_opaque(tcx: TyCtxt<'_>, def_id: LocalDefId) {
pub(super) fn check_opaque_for_cycles<'tcx>(
tcx: TyCtxt<'tcx>,
def_id: LocalDefId,
args: GenericArgsRef<'tcx>,
span: Span,
origin: &hir::OpaqueTyOrigin,
) -> Result<(), ErrorGuaranteed> {
if tcx.try_expand_impl_trait_type(def_id.to_def_id(), args).is_err() {
let reported = match origin {
hir::OpaqueTyOrigin::AsyncFn(..) => async_opaque_type_cycle_error(tcx, span),
_ => opaque_type_cycle_error(tcx, def_id, span),
};
Err(reported)
} else {
Ok(())
let args = GenericArgs::identity_for_item(tcx, def_id);

// First, try to look at any opaque expansion cycles, considering coroutine fields
// (even though these aren't necessarily true errors).
if tcx
.try_expand_impl_trait_type(def_id.to_def_id(), args, InspectCoroutineFields::Yes)
.is_err()
{
// Look for true opaque expansion cycles, but ignore coroutines.
// This will give us any true errors. Coroutines are only problematic
// if they cause layout computation errors.
if tcx
.try_expand_impl_trait_type(def_id.to_def_id(), args, InspectCoroutineFields::No)
.is_err()
{
let reported = opaque_type_cycle_error(tcx, def_id, span);
return Err(reported);
}

// And also look for cycle errors in the layout of coroutines.
if let Err(&LayoutError::Cycle(guar)) =
tcx.layout_of(tcx.param_env(def_id).and(Ty::new_opaque(tcx, def_id.to_def_id(), args)))
{
return Err(guar);
}
}

Ok(())
}

/// Check that the concrete type behind `impl Trait` actually implements `Trait`.
Expand Down Expand Up @@ -1300,16 +1316,6 @@ pub(super) fn check_type_params_are_used<'tcx>(
}
}

fn async_opaque_type_cycle_error(tcx: TyCtxt<'_>, span: Span) -> ErrorGuaranteed {
struct_span_err!(tcx.dcx(), span, E0733, "recursion in an `async fn` requires boxing")
.span_label_mv(span, "recursive `async fn`")
.note_mv("a recursive `async fn` must be rewritten to return a boxed `dyn Future`")
.note_mv(
"consider using the `async_recursion` crate: https://crates.io/crates/async_recursion",
)
.emit()
}

/// Emit an error for recursive opaque types.
///
/// If this is a return `impl Trait`, find the item's return expressions and point at them. For
Expand Down
13 changes: 9 additions & 4 deletions compiler/rustc_middle/src/query/keys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pub trait Key: Sized {
None
}

fn ty_adt_id(&self) -> Option<DefId> {
fn ty_def_id(&self) -> Option<DefId> {
None
}
}
Expand Down Expand Up @@ -406,9 +406,10 @@ impl<'tcx> Key for Ty<'tcx> {
DUMMY_SP
}

fn ty_adt_id(&self) -> Option<DefId> {
match self.kind() {
fn ty_def_id(&self) -> Option<DefId> {
match *self.kind() {
ty::Adt(adt, _) => Some(adt.did()),
ty::Coroutine(def_id, ..) => Some(def_id),
_ => None,
}
}
Expand Down Expand Up @@ -452,6 +453,10 @@ impl<'tcx, T: Key> Key for ty::ParamEnvAnd<'tcx, T> {
fn default_span(&self, tcx: TyCtxt<'_>) -> Span {
self.value.default_span(tcx)
}

fn ty_def_id(&self) -> Option<DefId> {
self.value.ty_def_id()
}
}

impl Key for Symbol {
Expand Down Expand Up @@ -550,7 +555,7 @@ impl<'tcx> Key for (ValidityRequirement, ty::ParamEnvAnd<'tcx, Ty<'tcx>>) {
DUMMY_SP
}

fn ty_adt_id(&self) -> Option<DefId> {
fn ty_def_id(&self) -> Option<DefId> {
match self.1.value.kind() {
ty::Adt(adt, _) => Some(adt.did()),
_ => None,
Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_middle/src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1387,6 +1387,8 @@ rustc_queries! {
) -> Result<ty::layout::TyAndLayout<'tcx>, &'tcx ty::layout::LayoutError<'tcx>> {
depth_limit
desc { "computing layout of `{}`", key.value }
// we emit our own error during query cycle handling
cycle_delay_bug
}

/// Compute a `FnAbi` suitable for indirect calls, i.e. to `fn` pointers.
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_middle/src/query/plumbing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ pub struct DynamicQuery<'tcx, C: QueryCache> {
fn(tcx: TyCtxt<'tcx>, key: &C::Key, index: SerializedDepNodeIndex) -> bool,
pub hash_result: HashResult<C::Value>,
pub value_from_cycle_error:
fn(tcx: TyCtxt<'tcx>, cycle: &[QueryInfo], guar: ErrorGuaranteed) -> C::Value,
fn(tcx: TyCtxt<'tcx>, cycle_error: &CycleError, guar: ErrorGuaranteed) -> C::Value,
pub format_value: fn(&C::Value) -> String,
}

Expand Down
63 changes: 51 additions & 12 deletions compiler/rustc_middle/src/ty/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,7 @@ impl<'tcx> TyCtxt<'tcx> {
self,
def_id: DefId,
args: GenericArgsRef<'tcx>,
inspect_coroutine_fields: InspectCoroutineFields,
) -> Result<Ty<'tcx>, Ty<'tcx>> {
let mut visitor = OpaqueTypeExpander {
seen_opaque_tys: FxHashSet::default(),
Expand All @@ -712,6 +713,7 @@ impl<'tcx> TyCtxt<'tcx> {
check_recursion: true,
expand_coroutines: true,
tcx: self,
inspect_coroutine_fields,
};

let expanded_type = visitor.expand_opaque_ty(def_id, args).unwrap();
Expand All @@ -729,16 +731,43 @@ impl<'tcx> TyCtxt<'tcx> {
DefKind::AssocFn if self.associated_item(def_id).fn_has_self_parameter => "method",
DefKind::Closure if let Some(coroutine_kind) = self.coroutine_kind(def_id) => {
match coroutine_kind {
hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Async, _) => {
"async closure"
}
hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::AsyncGen, _) => {
"async gen closure"
}
hir::CoroutineKind::Desugared(
hir::CoroutineDesugaring::Async,
hir::CoroutineSource::Fn,
) => "async fn",
hir::CoroutineKind::Desugared(
hir::CoroutineDesugaring::Async,
hir::CoroutineSource::Block,
) => "async block",
hir::CoroutineKind::Desugared(
hir::CoroutineDesugaring::Async,
hir::CoroutineSource::Closure,
) => "async closure",
hir::CoroutineKind::Desugared(
hir::CoroutineDesugaring::AsyncGen,
hir::CoroutineSource::Fn,
) => "async gen fn",
hir::CoroutineKind::Desugared(
hir::CoroutineDesugaring::AsyncGen,
hir::CoroutineSource::Block,
) => "async gen block",
hir::CoroutineKind::Desugared(
hir::CoroutineDesugaring::AsyncGen,
hir::CoroutineSource::Closure,
) => "async gen closure",
hir::CoroutineKind::Desugared(
hir::CoroutineDesugaring::Gen,
hir::CoroutineSource::Fn,
) => "gen fn",
hir::CoroutineKind::Desugared(
hir::CoroutineDesugaring::Gen,
hir::CoroutineSource::Block,
) => "gen block",
hir::CoroutineKind::Desugared(
hir::CoroutineDesugaring::Gen,
hir::CoroutineSource::Closure,
) => "gen closure",
hir::CoroutineKind::Coroutine(_) => "coroutine",
hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Gen, _) => {
"gen closure"
}
}
}
_ => def_kind.descr(def_id),
Expand Down Expand Up @@ -865,6 +894,13 @@ struct OpaqueTypeExpander<'tcx> {
/// recursion, and 'false' otherwise to avoid unnecessary work.
check_recursion: bool,
tcx: TyCtxt<'tcx>,
inspect_coroutine_fields: InspectCoroutineFields,
}

#[derive(Copy, Clone, PartialEq, Eq, Debug)]
pub enum InspectCoroutineFields {
No,
Yes,
}

impl<'tcx> OpaqueTypeExpander<'tcx> {
Expand Down Expand Up @@ -906,9 +942,11 @@ impl<'tcx> OpaqueTypeExpander<'tcx> {
let expanded_ty = match self.expanded_cache.get(&(def_id, args)) {
Some(expanded_ty) => *expanded_ty,
None => {
for bty in self.tcx.coroutine_hidden_types(def_id) {
let hidden_ty = bty.instantiate(self.tcx, args);
self.fold_ty(hidden_ty);
if matches!(self.inspect_coroutine_fields, InspectCoroutineFields::Yes) {
for bty in self.tcx.coroutine_hidden_types(def_id) {
let hidden_ty = bty.instantiate(self.tcx, args);
self.fold_ty(hidden_ty);
}
}
let expanded_ty = Ty::new_coroutine_witness(self.tcx, def_id, args);
self.expanded_cache.insert((def_id, args), expanded_ty);
Expand Down Expand Up @@ -1486,6 +1524,7 @@ pub fn reveal_opaque_types_in_bounds<'tcx>(
check_recursion: false,
expand_coroutines: false,
tcx,
inspect_coroutine_fields: InspectCoroutineFields::No,
};
val.fold_with(&mut visitor)
}
Expand Down
Loading

0 comments on commit dc64103

Please sign in to comment.