diff --git a/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/cpp_like.rs b/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/cpp_like.rs index 4792b0798dfb8..4edef14422e5f 100644 --- a/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/cpp_like.rs +++ b/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/cpp_like.rs @@ -683,7 +683,8 @@ fn build_union_fields_for_direct_tag_coroutine<'ll, 'tcx>( _ => unreachable!(), }; - let coroutine_layout = cx.tcx.optimized_mir(coroutine_def_id).coroutine_layout().unwrap(); + let coroutine_layout = + cx.tcx.coroutine_layout(coroutine_def_id, coroutine_args.kind_ty()).unwrap(); let common_upvar_names = cx.tcx.closure_saved_names_of_captured_variables(coroutine_def_id); let variant_range = coroutine_args.variant_range(coroutine_def_id, cx.tcx); diff --git a/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/native.rs b/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/native.rs index 3dbe820b8ff9b..115d5187eafa8 100644 --- a/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/native.rs +++ b/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/native.rs @@ -135,7 +135,7 @@ pub(super) fn build_coroutine_di_node<'ll, 'tcx>( unique_type_id: UniqueTypeId<'tcx>, ) -> DINodeCreationResult<'ll> { let coroutine_type = unique_type_id.expect_ty(); - let &ty::Coroutine(coroutine_def_id, _) = coroutine_type.kind() else { + let &ty::Coroutine(coroutine_def_id, coroutine_args) = coroutine_type.kind() else { bug!("build_coroutine_di_node() called with non-coroutine type: `{:?}`", coroutine_type) }; @@ -158,8 +158,10 @@ pub(super) fn build_coroutine_di_node<'ll, 'tcx>( DIFlags::FlagZero, ), |cx, coroutine_type_di_node| { - let coroutine_layout = - cx.tcx.optimized_mir(coroutine_def_id).coroutine_layout().unwrap(); + let coroutine_layout = cx + .tcx + .coroutine_layout(coroutine_def_id, coroutine_args.as_coroutine().kind_ty()) + .unwrap(); let Variants::Multiple { tag_encoding: TagEncoding::Direct, ref variants, .. } = coroutine_type_and_layout.variants diff --git a/compiler/rustc_const_eval/src/transform/validate.rs b/compiler/rustc_const_eval/src/transform/validate.rs index 08e3e42a82e27..378b168a50c33 100644 --- a/compiler/rustc_const_eval/src/transform/validate.rs +++ b/compiler/rustc_const_eval/src/transform/validate.rs @@ -101,18 +101,17 @@ impl<'tcx> MirPass<'tcx> for Validator { } // Enforce that coroutine-closure layouts are identical. - if let Some(layout) = body.coroutine_layout() + if let Some(layout) = body.coroutine_layout_raw() && let Some(by_move_body) = body.coroutine_by_move_body() - && let Some(by_move_layout) = by_move_body.coroutine_layout() + && let Some(by_move_layout) = by_move_body.coroutine_layout_raw() { - if layout != by_move_layout { - // If this turns out not to be true, please let compiler-errors know. - // It is possible to support, but requires some changes to the layout - // computation code. + // FIXME(async_closures): We could do other validation here? + if layout.variant_fields.len() != by_move_layout.variant_fields.len() { cfg_checker.fail( Location::START, format!( - "Coroutine layout differs from by-move coroutine layout:\n\ + "Coroutine layout has different number of variant fields from \ + by-move coroutine layout:\n\ layout: {layout:#?}\n\ by_move_layout: {by_move_layout:#?}", ), @@ -715,13 +714,14 @@ impl<'a, 'tcx> Visitor<'tcx> for TypeChecker<'a, 'tcx> { // args of the coroutine. Otherwise, we prefer to use this body // since we may be in the process of computing this MIR in the // first place. - let gen_body = if def_id == self.caller_body.source.def_id() { - self.caller_body + let layout = if def_id == self.caller_body.source.def_id() { + // FIXME: This is not right for async closures. + self.caller_body.coroutine_layout_raw() } else { - self.tcx.optimized_mir(def_id) + self.tcx.coroutine_layout(def_id, args.as_coroutine().kind_ty()) }; - let Some(layout) = gen_body.coroutine_layout() else { + let Some(layout) = layout else { self.fail( location, format!("No coroutine layout for {parent_ty:?}"), diff --git a/compiler/rustc_middle/src/mir/mod.rs b/compiler/rustc_middle/src/mir/mod.rs index e4dce2bdc9e80..02af55fbf0e4f 100644 --- a/compiler/rustc_middle/src/mir/mod.rs +++ b/compiler/rustc_middle/src/mir/mod.rs @@ -652,8 +652,9 @@ impl<'tcx> Body<'tcx> { self.coroutine.as_ref().and_then(|coroutine| coroutine.resume_ty) } + /// Prefer going through [`TyCtxt::coroutine_layout`] rather than using this directly. #[inline] - pub fn coroutine_layout(&self) -> Option<&CoroutineLayout<'tcx>> { + pub fn coroutine_layout_raw(&self) -> Option<&CoroutineLayout<'tcx>> { self.coroutine.as_ref().and_then(|coroutine| coroutine.coroutine_layout.as_ref()) } diff --git a/compiler/rustc_middle/src/mir/pretty.rs b/compiler/rustc_middle/src/mir/pretty.rs index f0499cf344fca..41df2e3b5875a 100644 --- a/compiler/rustc_middle/src/mir/pretty.rs +++ b/compiler/rustc_middle/src/mir/pretty.rs @@ -126,7 +126,7 @@ fn dump_matched_mir_node<'tcx, F>( Some(promoted) => write!(file, "::{promoted:?}`")?, } writeln!(file, " {disambiguator} {pass_name}")?; - if let Some(ref layout) = body.coroutine_layout() { + if let Some(ref layout) = body.coroutine_layout_raw() { writeln!(file, "/* coroutine_layout = {layout:#?} */")?; } writeln!(file)?; diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index 6ce53ccc8cd7a..aad2f6a4cf8ac 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -60,6 +60,7 @@ pub use rustc_target::abi::{ReprFlags, ReprOptions}; pub use rustc_type_ir::{DebugWithInfcx, InferCtxtLike, WithInfcx}; pub use vtable::*; +use std::assert_matches::assert_matches; use std::fmt::Debug; use std::hash::{Hash, Hasher}; use std::marker::PhantomData; @@ -1826,8 +1827,40 @@ impl<'tcx> TyCtxt<'tcx> { /// Returns layout of a coroutine. Layout might be unavailable if the /// coroutine is tainted by errors. - pub fn coroutine_layout(self, def_id: DefId) -> Option<&'tcx CoroutineLayout<'tcx>> { - self.optimized_mir(def_id).coroutine_layout() + /// + /// Takes `coroutine_kind` which can be acquired from the `CoroutineArgs::kind_ty`, + /// e.g. `args.as_coroutine().kind_ty()`. + pub fn coroutine_layout( + self, + def_id: DefId, + coroutine_kind_ty: Ty<'tcx>, + ) -> Option<&'tcx CoroutineLayout<'tcx>> { + let mir = self.optimized_mir(def_id); + // Regular coroutine + if coroutine_kind_ty.is_unit() { + mir.coroutine_layout_raw() + } else { + // If we have a `Coroutine` that comes from an coroutine-closure, + // then it may be a by-move or by-ref body. + let ty::Coroutine(_, identity_args) = + *self.type_of(def_id).instantiate_identity().kind() + else { + unreachable!(); + }; + let identity_kind_ty = identity_args.as_coroutine().kind_ty(); + // If the types differ, then we must be getting the by-move body of + // a by-ref coroutine. + if identity_kind_ty == coroutine_kind_ty { + mir.coroutine_layout_raw() + } else { + assert_matches!(coroutine_kind_ty.to_opt_closure_kind(), Some(ClosureKind::FnOnce)); + assert_matches!( + identity_kind_ty.to_opt_closure_kind(), + Some(ClosureKind::Fn | ClosureKind::FnMut) + ); + mir.coroutine_by_move_body().unwrap().coroutine_layout_raw() + } + } } /// Given the `DefId` of an impl, returns the `DefId` of the trait it implements. diff --git a/compiler/rustc_middle/src/ty/sty.rs b/compiler/rustc_middle/src/ty/sty.rs index c85ee140fa4ec..510a4b59520c4 100644 --- a/compiler/rustc_middle/src/ty/sty.rs +++ b/compiler/rustc_middle/src/ty/sty.rs @@ -694,7 +694,8 @@ impl<'tcx> CoroutineArgs<'tcx> { #[inline] pub fn variant_range(&self, def_id: DefId, tcx: TyCtxt<'tcx>) -> Range { // FIXME requires optimized MIR - FIRST_VARIANT..tcx.coroutine_layout(def_id).unwrap().variant_fields.next_index() + FIRST_VARIANT + ..tcx.coroutine_layout(def_id, tcx.types.unit).unwrap().variant_fields.next_index() } /// The discriminant for the given variant. Panics if the `variant_index` is @@ -754,7 +755,7 @@ impl<'tcx> CoroutineArgs<'tcx> { def_id: DefId, tcx: TyCtxt<'tcx>, ) -> impl Iterator> + Captures<'tcx>> { - let layout = tcx.coroutine_layout(def_id).unwrap(); + let layout = tcx.coroutine_layout(def_id, self.kind_ty()).unwrap(); layout.variant_fields.iter().map(move |variant| { variant.iter().map(move |field| { ty::EarlyBinder::bind(layout.field_tys[*field].ty).instantiate(tcx, self.args) diff --git a/compiler/rustc_ty_utils/src/layout.rs b/compiler/rustc_ty_utils/src/layout.rs index 9c3d39307b26f..331970ac36233 100644 --- a/compiler/rustc_ty_utils/src/layout.rs +++ b/compiler/rustc_ty_utils/src/layout.rs @@ -745,7 +745,7 @@ fn coroutine_layout<'tcx>( let tcx = cx.tcx; let instantiate_field = |ty: Ty<'tcx>| EarlyBinder::bind(ty).instantiate(tcx, args); - let Some(info) = tcx.coroutine_layout(def_id) else { + let Some(info) = tcx.coroutine_layout(def_id, args.as_coroutine().kind_ty()) else { return Err(error(cx, LayoutError::Unknown(ty))); }; let (ineligible_locals, assignments) = coroutine_saved_local_eligibility(info); @@ -1072,7 +1072,7 @@ fn variant_info_for_coroutine<'tcx>( return (vec![], None); }; - let coroutine = cx.tcx.optimized_mir(def_id).coroutine_layout().unwrap(); + let coroutine = cx.tcx.coroutine_layout(def_id, args.as_coroutine().kind_ty()).unwrap(); let upvar_names = cx.tcx.closure_saved_names_of_captured_variables(def_id); let mut upvars_size = Size::ZERO;