Skip to content

Commit

Permalink
relocate upvars onto Unresumed state
Browse files Browse the repository at this point in the history
  • Loading branch information
dingxiangfei2009 committed Jan 18, 2024
1 parent 6ae4cfb commit 4b04bb1
Show file tree
Hide file tree
Showing 32 changed files with 671 additions and 352 deletions.
3 changes: 1 addition & 2 deletions compiler/rustc_abi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1503,8 +1503,7 @@ pub struct LayoutS<FieldIdx: Idx, VariantIdx: Idx> {

/// Encodes information about multi-variant layouts.
/// Even with `Multiple` variants, a layout still has its own fields! Those are then
/// shared between all variants. One of them will be the discriminant,
/// but e.g. coroutines can have more.
/// shared between all variants. One of them will be the discriminant.
///
/// To access all fields of this layout, both `fields` and the fields of the active variant
/// must be taken into account.
Expand Down
22 changes: 9 additions & 13 deletions compiler/rustc_borrowck/src/type_check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -801,14 +801,11 @@ impl<'a, 'b, 'tcx> TypeVerifier<'a, 'b, 'tcx> {
}),
};
}
ty::Coroutine(_, args) => {
// Only prefix fields (upvars and current state) are
// accessible without a variant index.
return match args.as_coroutine().prefix_tys().get(field.index()) {
Some(ty) => Ok(*ty),
None => Err(FieldAccessError::OutOfRange {
field_count: args.as_coroutine().prefix_tys().len(),
}),
ty::Coroutine(_def_id, args) => {
let upvar_tys = args.as_coroutine().upvar_tys();
return match upvar_tys.get(field.index()) {
Some(&ty) => Ok(ty),
None => Err(FieldAccessError::OutOfRange { field_count: upvar_tys.len() }),
};
}
ty::Tuple(tys) => {
Expand Down Expand Up @@ -1821,11 +1818,10 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
// It doesn't make sense to look at a field beyond the prefix;
// these require a variant index, and are not initialized in
// aggregate rvalues.
match args.as_coroutine().prefix_tys().get(field_index.as_usize()) {
let upvar_tys = args.as_coroutine().upvar_tys();
match upvar_tys.get(field_index.as_usize()) {
Some(ty) => Ok(*ty),
None => Err(FieldAccessError::OutOfRange {
field_count: args.as_coroutine().prefix_tys().len(),
}),
None => Err(FieldAccessError::OutOfRange { field_count: upvar_tys.len() }),
}
}
AggregateKind::Array(ty) => Ok(ty),
Expand Down Expand Up @@ -2442,7 +2438,7 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {

self.prove_aggregate_predicates(aggregate_kind, location);

if *aggregate_kind == AggregateKind::Tuple {
if matches!(aggregate_kind, AggregateKind::Tuple) {
// tuple rvalue field type is always the type of the op. Nothing to check here.
return;
}
Expand Down
3 changes: 3 additions & 0 deletions compiler/rustc_codegen_cranelift/src/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,9 @@ fn codegen_stmt<'tcx>(
let variant_dest = lval.downcast_variant(fx, variant_index);
(variant_index, variant_dest, active_field_index)
}
mir::AggregateKind::Coroutine(_, _) => {
(FIRST_VARIANT, lval.project_downcast(fx, FIRST_VARIANT), None)
}
_ => (FIRST_VARIANT, lval, None),
};
if active_field_index.is_some() {
Expand Down
3 changes: 2 additions & 1 deletion compiler/rustc_codegen_llvm/src/debuginfo/metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use rustc_hir::def::CtorKind;
use rustc_hir::def_id::{DefId, LOCAL_CRATE};
use rustc_middle::bug;
use rustc_middle::ty::layout::{LayoutOf, TyAndLayout};
use rustc_middle::ty::List;
use rustc_middle::ty::{
self, AdtKind, Instance, ParamEnv, PolyExistentialTraitRef, Ty, TyCtxt, Visibility,
};
Expand Down Expand Up @@ -1066,7 +1067,7 @@ fn build_upvar_field_di_nodes<'ll, 'tcx>(
closure_or_coroutine_di_node: &'ll DIType,
) -> SmallVec<&'ll DIType> {
let (&def_id, up_var_tys) = match closure_or_coroutine_ty.kind() {
ty::Coroutine(def_id, args) => (def_id, args.as_coroutine().prefix_tys()),
ty::Coroutine(def_id, _args) => (def_id, List::empty()),
ty::Closure(def_id, args) => (def_id, args.as_closure().upvar_tys()),
_ => {
bug!(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,6 @@ fn build_union_fields_for_direct_tag_coroutine<'ll, 'tcx>(

let coroutine_layout = cx.tcx.optimized_mir(coroutine_def_id).coroutine_layout().unwrap();

let common_upvar_names = cx.tcx.closure_saved_names_of_captured_variables(coroutine_def_id);
let variant_range = coroutine_args.variant_range(coroutine_def_id, cx.tcx);
let variant_count = (variant_range.start.as_u32()..variant_range.end.as_u32()).len();

Expand Down Expand Up @@ -720,7 +719,6 @@ fn build_union_fields_for_direct_tag_coroutine<'ll, 'tcx>(
coroutine_type_and_layout,
coroutine_type_di_node,
coroutine_layout,
common_upvar_names,
);

let span = coroutine_layout.variant_source_info[variant_index].span;
Expand Down
33 changes: 2 additions & 31 deletions compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use rustc_codegen_ssa::debuginfo::{
wants_c_like_enum_debuginfo,
};
use rustc_hir::def::CtorKind;
use rustc_index::IndexSlice;
use rustc_middle::{
bug,
mir::CoroutineLayout,
Expand All @@ -13,7 +12,6 @@ use rustc_middle::{
AdtDef, CoroutineArgs, Ty, VariantDef,
},
};
use rustc_span::Symbol;
use rustc_target::abi::{
FieldIdx, HasDataLayout, Integer, Primitive, TagEncoding, VariantIdx, Variants,
};
Expand Down Expand Up @@ -324,7 +322,6 @@ pub fn build_coroutine_variant_struct_type_di_node<'ll, 'tcx>(
coroutine_type_and_layout: TyAndLayout<'tcx>,
coroutine_type_di_node: &'ll DIType,
coroutine_layout: &CoroutineLayout<'tcx>,
common_upvar_names: &IndexSlice<FieldIdx, Symbol>,
) -> &'ll DIType {
let variant_name = CoroutineArgs::variant_name(variant_index);
let unique_type_id = UniqueTypeId::for_enum_variant_struct_type(
Expand All @@ -335,11 +332,6 @@ pub fn build_coroutine_variant_struct_type_di_node<'ll, 'tcx>(

let variant_layout = coroutine_type_and_layout.for_variant(cx, variant_index);

let coroutine_args = match coroutine_type_and_layout.ty.kind() {
ty::Coroutine(_, args) => args.as_coroutine(),
_ => unreachable!(),
};

type_map::build_type_with_children(
cx,
type_map::stub(
Expand All @@ -353,7 +345,7 @@ pub fn build_coroutine_variant_struct_type_di_node<'ll, 'tcx>(
),
|cx, variant_struct_type_di_node| {
// Fields that just belong to this variant/state
let state_specific_fields: SmallVec<_> = (0..variant_layout.fields.count())
(0..variant_layout.fields.count())
.map(|field_index| {
let coroutine_saved_local = coroutine_layout.variant_fields[variant_index]
[FieldIdx::from_usize(field_index)];
Expand All @@ -375,28 +367,7 @@ pub fn build_coroutine_variant_struct_type_di_node<'ll, 'tcx>(
type_di_node(cx, field_type),
)
})
.collect();

// Fields that are common to all states
let common_fields: SmallVec<_> = coroutine_args
.prefix_tys()
.iter()
.zip(common_upvar_names)
.enumerate()
.map(|(index, (upvar_ty, upvar_name))| {
build_field_di_node(
cx,
variant_struct_type_di_node,
upvar_name.as_str(),
cx.size_and_align_of(upvar_ty),
coroutine_type_and_layout.fields.offset(index),
DIFlags::FlagZero,
type_di_node(cx, upvar_ty),
)
})
.collect();

state_specific_fields.into_iter().chain(common_fields).collect()
.collect()
},
|cx| build_generic_type_param_di_nodes(cx, coroutine_type_and_layout.ty),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,6 @@ pub(super) fn build_coroutine_di_node<'ll, 'tcx>(
)
};

let common_upvar_names =
cx.tcx.closure_saved_names_of_captured_variables(coroutine_def_id);

// Build variant struct types
let variant_struct_type_di_nodes: SmallVec<_> = variants
.indices()
Expand Down Expand Up @@ -200,7 +197,6 @@ pub(super) fn build_coroutine_di_node<'ll, 'tcx>(
coroutine_type_and_layout,
coroutine_type_di_node,
coroutine_layout,
common_upvar_names,
),
source_info,
}
Expand Down
3 changes: 3 additions & 0 deletions compiler/rustc_codegen_ssa/src/mir/rvalue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
let variant_dest = dest.project_downcast(bx, variant_index);
(variant_index, variant_dest, active_field_index)
}
mir::AggregateKind::Coroutine(_, _) => {
(FIRST_VARIANT, dest.project_downcast(bx, FIRST_VARIANT), None)
}
_ => (FIRST_VARIANT, dest, None),
};
if active_field_index.is_some() {
Expand Down
3 changes: 3 additions & 0 deletions compiler/rustc_const_eval/src/interpret/step.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,9 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
let variant_dest = self.project_downcast(dest, variant_index)?;
(variant_index, variant_dest, active_field_index)
}
mir::AggregateKind::Coroutine(_def_id, _args) => {
(FIRST_VARIANT, self.project_downcast(dest, FIRST_VARIANT)?, None)
}
_ => (FIRST_VARIANT, dest.clone(), None),
};
if active_field_index.is_some() {
Expand Down
12 changes: 5 additions & 7 deletions compiler/rustc_const_eval/src/transform/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -695,14 +695,12 @@ impl<'a, 'tcx> Visitor<'tcx> for TypeChecker<'a, 'tcx> {
};

ty::EarlyBinder::bind(f_ty.ty).instantiate(self.tcx, args)
} else if let Some(&ty) = args.as_coroutine().upvar_tys().get(f.as_usize())
{
ty
} else {
let Some(&f_ty) = args.as_coroutine().prefix_tys().get(f.index())
else {
fail_out_of_bounds(self, location);
return;
};

f_ty
fail_out_of_bounds(self, location);
return;
};

check_equal(self, location, f_ty);
Expand Down
6 changes: 5 additions & 1 deletion compiler/rustc_middle/src/mir/tcx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,11 @@ impl<'tcx> PlaceTy<'tcx> {
T: ::std::fmt::Debug + Copy,
{
if self.variant_index.is_some() && !matches!(elem, ProjectionElem::Field(..)) {
bug!("cannot use non field projection on downcasted place")
bug!(
"cannot use non field projection on downcasted place from {:?} (variant {:?}), got {elem:?}",
self.ty,
self.variant_index
)
}
let answer = match *elem {
ProjectionElem::Deref => {
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_middle/src/ty/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -919,7 +919,7 @@ where
if i == tag_field {
return TyMaybeWithLayout::TyAndLayout(tag_layout(tag));
}
TyMaybeWithLayout::Ty(args.as_coroutine().prefix_tys()[i])
bug!("coroutine has no prefix field");
}
},

Expand Down
7 changes: 0 additions & 7 deletions compiler/rustc_middle/src/ty/sty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -549,13 +549,6 @@ impl<'tcx> CoroutineArgs<'tcx> {
})
})
}

/// This is the types of the fields of a coroutine which are not stored in a
/// variant.
#[inline]
pub fn prefix_tys(self) -> &'tcx List<Ty<'tcx>> {
self.upvar_tys()
}
}

#[derive(Debug, Copy, Clone, HashStable)]
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_mir_dataflow/src/framework/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ where
}

None if dump_enabled(tcx, A::NAME, def_id) => {
create_dump_file(tcx, ".dot", false, A::NAME, &pass_name.unwrap_or("-----"), body)?
create_dump_file(tcx, ".dot", true, A::NAME, &pass_name.unwrap_or("-----"), body)?
}

_ => return (Ok(()), results),
Expand Down
66 changes: 58 additions & 8 deletions compiler/rustc_mir_dataflow/src/impls/borrowed_locals.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use rustc_index::bit_set::BitSet;
use rustc_middle::mir::visit::Visitor;
use rustc_middle::mir::*;
use rustc_middle::{mir::visit::Visitor, ty};

use crate::{AnalysisDomain, GenKill, GenKillAnalysis};

Expand All @@ -12,11 +12,28 @@ use crate::{AnalysisDomain, GenKill, GenKillAnalysis};
/// `MaybeBorrowedLocals` is used to compute which locals are live during a yield expression for
/// immovable coroutines.
#[derive(Clone, Copy)]
pub struct MaybeBorrowedLocals;
pub struct MaybeBorrowedLocals {
upvar_start: Option<(Local, usize)>,
}

impl MaybeBorrowedLocals {
/// `upvar_start` is to signal that upvars are treated as locals,
/// and locals greater than this value refers to upvars accessed
/// through the tuple `ty::CAPTURE_STRUCT_LOCAL`, aka. _1.
pub fn new(upvar_start: Option<(Local, usize)>) -> Self {
Self { upvar_start }
}

pub(super) fn transfer_function<'a, T>(&'a self, trans: &'a mut T) -> TransferFunction<'a, T> {
TransferFunction { trans }
TransferFunction { trans, upvar_start: self.upvar_start }
}

pub fn domain_size(&self, body: &Body<'_>) -> usize {
if let Some((start, len)) = self.upvar_start {
start.as_usize() + len
} else {
body.local_decls.len()
}
}
}

Expand All @@ -26,7 +43,7 @@ impl<'tcx> AnalysisDomain<'tcx> for MaybeBorrowedLocals {

fn bottom_value(&self, body: &Body<'tcx>) -> Self::Domain {
// bottom = unborrowed
BitSet::new_empty(body.local_decls().len())
BitSet::new_empty(self.domain_size(body))
}

fn initialize_start_block(&self, _: &Body<'tcx>, _: &mut Self::Domain) {
Expand All @@ -38,7 +55,7 @@ impl<'tcx> GenKillAnalysis<'tcx> for MaybeBorrowedLocals {
type Idx = Local;

fn domain_size(&self, body: &Body<'tcx>) -> usize {
body.local_decls.len()
self.domain_size(body)
}

fn statement_effect(
Expand Down Expand Up @@ -72,6 +89,7 @@ impl<'tcx> GenKillAnalysis<'tcx> for MaybeBorrowedLocals {
/// A `Visitor` that defines the transfer function for `MaybeBorrowedLocals`.
pub(super) struct TransferFunction<'a, T> {
trans: &'a mut T,
upvar_start: Option<(Local, usize)>,
}

impl<'tcx, T> Visitor<'tcx> for TransferFunction<'_, T>
Expand All @@ -97,7 +115,20 @@ where
Rvalue::AddressOf(_, borrowed_place)
| Rvalue::Ref(_, BorrowKind::Mut { .. } | BorrowKind::Shared, borrowed_place) => {
if !borrowed_place.is_indirect() {
self.trans.gen(borrowed_place.local);
if borrowed_place.local == ty::CAPTURE_STRUCT_LOCAL
&& let Some((upvar_start, nr_upvars)) = self.upvar_start
{
match **borrowed_place.projection {
[ProjectionElem::Field(field, _), ..]
if field.as_usize() < nr_upvars =>
{
self.trans.gen(upvar_start + field.as_usize())
}
_ => bug!("unexpected upvar access"),
}
} else {
self.trans.gen(borrowed_place.local);
}
}
}

Expand Down Expand Up @@ -132,7 +163,26 @@ where
//
// [#61069]: https://github.com/rust-lang/rust/pull/61069
if !dropped_place.is_indirect() {
self.trans.gen(dropped_place.local);
if dropped_place.local == ty::CAPTURE_STRUCT_LOCAL
&& let Some((upvar_start, nr_upvars)) = self.upvar_start
{
match **dropped_place.projection {
[] => {
for field in 0..nr_upvars {
self.trans.gen(upvar_start + field)
}
self.trans.gen(dropped_place.local)
}
[ProjectionElem::Field(field, _), ..]
if field.as_usize() < nr_upvars =>
{
self.trans.gen(upvar_start + field.as_usize())
}
_ => bug!("unexpected upvar access"),
}
} else {
self.trans.gen(dropped_place.local);
}
}
}

Expand Down Expand Up @@ -169,6 +219,6 @@ pub fn borrowed_locals(body: &Body<'_>) -> BitSet<Local> {
}

let mut borrowed = Borrowed(BitSet::new_empty(body.local_decls.len()));
TransferFunction { trans: &mut borrowed }.visit_body(body);
TransferFunction { trans: &mut borrowed, upvar_start: None }.visit_body(body);
borrowed.0
}
Loading

0 comments on commit 4b04bb1

Please sign in to comment.