Skip to content

Commit

Permalink
More assertions, tests, and miri coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
compiler-errors committed Jan 30, 2025
1 parent 9dc41a0 commit d98b99a
Show file tree
Hide file tree
Showing 13 changed files with 173 additions and 86 deletions.
2 changes: 2 additions & 0 deletions compiler/rustc_codegen_ssa/src/meth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ fn dyn_trait_in_self<'tcx>(
if let GenericArgKind::Type(ty) = arg.unpack()
&& let ty::Dynamic(data, _, _) = ty.kind()
{
// FIXME(arbitrary_self_types): This is likely broken for receivers which
// have a "non-self" trait objects as a generic argument.
return data
.principal()
.map(|principal| tcx.instantiate_bound_regions_with_erased(principal));
Expand Down
45 changes: 26 additions & 19 deletions compiler/rustc_const_eval/src/interpret/call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::borrow::Cow;

use either::{Left, Right};
use rustc_abi::{self as abi, ExternAbi, FieldIdx, Integer, VariantIdx};
use rustc_hir::def_id::DefId;
use rustc_middle::ty::layout::{FnAbiOf, IntegerExt, LayoutOf, TyAndLayout};
use rustc_middle::ty::{self, AdtDef, Instance, Ty, VariantDef};
use rustc_middle::{bug, mir, span_bug};
Expand Down Expand Up @@ -693,25 +694,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
trace!("Virtual call dispatches to {fn_inst:#?}");
// We can also do the lookup based on `def_id` and `dyn_ty`, and check that that
// produces the same result.
if cfg!(debug_assertions) {
let tcx = *self.tcx;

let trait_def_id = tcx.trait_of_item(def_id).unwrap();
let virtual_trait_ref =
ty::TraitRef::from_method(tcx, trait_def_id, instance.args);
let existential_trait_ref =
ty::ExistentialTraitRef::erase_self_ty(tcx, virtual_trait_ref);
let concrete_trait_ref = existential_trait_ref.with_self_ty(tcx, dyn_ty);

let concrete_method = Instance::expect_resolve_for_vtable(
tcx,
self.typing_env,
def_id,
instance.args.rebase_onto(tcx, trait_def_id, concrete_trait_ref.args),
self.cur_span(),
);
assert_eq!(fn_inst, concrete_method);
}
self.assert_virtual_instance_matches_concrete(dyn_ty, def_id, instance, fn_inst);

// Adjust receiver argument. Layout can be any (thin) ptr.
let receiver_ty = Ty::new_mut_ptr(self.tcx.tcx, dyn_ty);
Expand Down Expand Up @@ -744,6 +727,30 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
}
}

fn assert_virtual_instance_matches_concrete(
&self,
dyn_ty: Ty<'tcx>,
def_id: DefId,
virtual_instance: ty::Instance<'tcx>,
concrete_instance: ty::Instance<'tcx>,
) {
let tcx = *self.tcx;

let trait_def_id = tcx.trait_of_item(def_id).unwrap();
let virtual_trait_ref = ty::TraitRef::from_method(tcx, trait_def_id, virtual_instance.args);
let existential_trait_ref = ty::ExistentialTraitRef::erase_self_ty(tcx, virtual_trait_ref);
let concrete_trait_ref = existential_trait_ref.with_self_ty(tcx, dyn_ty);

let concrete_method = Instance::expect_resolve_for_vtable(
tcx,
self.typing_env,
def_id,
virtual_instance.args.rebase_onto(tcx, trait_def_id, concrete_trait_ref.args),
self.cur_span(),
);
assert_eq!(concrete_instance, concrete_method);
}

/// Initiate a tail call to this function -- popping the current stack frame, pushing the new
/// stack frame and initializing the arguments.
pub(super) fn init_fn_tail_call(
Expand Down
54 changes: 26 additions & 28 deletions compiler/rustc_const_eval/src/interpret/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -414,35 +414,33 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {

// Sanity-check that `supertrait_vtable_slot` in this type's vtable indeed produces
// our destination trait.
if cfg!(debug_assertions) {
let vptr_entry_idx =
self.tcx.supertrait_vtable_slot((src_pointee_ty, dest_pointee_ty));
let vtable_entries = self.vtable_entries(data_a.principal(), ty);
if let Some(entry_idx) = vptr_entry_idx {
let Some(&ty::VtblEntry::TraitVPtr(upcast_trait_ref)) =
vtable_entries.get(entry_idx)
else {
span_bug!(
self.cur_span(),
"invalid vtable entry index in {} -> {} upcast",
src_pointee_ty,
dest_pointee_ty
);
};
let erased_trait_ref =
ty::ExistentialTraitRef::erase_self_ty(*self.tcx, upcast_trait_ref);
assert!(data_b.principal().is_some_and(|b| self.eq_in_param_env(
erased_trait_ref,
self.tcx.instantiate_bound_regions_with_erased(b)
)));
} else {
// In this case codegen would keep using the old vtable. We don't want to do
// that as it has the wrong trait. The reason codegen can do this is that
// one vtable is a prefix of the other, so we double-check that.
let vtable_entries_b = self.vtable_entries(data_b.principal(), ty);
assert!(&vtable_entries[..vtable_entries_b.len()] == vtable_entries_b);
let vptr_entry_idx =
self.tcx.supertrait_vtable_slot((src_pointee_ty, dest_pointee_ty));
let vtable_entries = self.vtable_entries(data_a.principal(), ty);
if let Some(entry_idx) = vptr_entry_idx {
let Some(&ty::VtblEntry::TraitVPtr(upcast_trait_ref)) =
vtable_entries.get(entry_idx)
else {
span_bug!(
self.cur_span(),
"invalid vtable entry index in {} -> {} upcast",
src_pointee_ty,
dest_pointee_ty
);
};
}
let erased_trait_ref =
ty::ExistentialTraitRef::erase_self_ty(*self.tcx, upcast_trait_ref);
assert!(data_b.principal().is_some_and(|b| self.eq_in_param_env(
erased_trait_ref,
self.tcx.instantiate_bound_regions_with_erased(b)
)));
} else {
// In this case codegen would keep using the old vtable. We don't want to do
// that as it has the wrong trait. The reason codegen can do this is that
// one vtable is a prefix of the other, so we double-check that.
let vtable_entries_b = self.vtable_entries(data_b.principal(), ty);
assert!(&vtable_entries[..vtable_entries_b.len()] == vtable_entries_b);
};

// Get the destination trait vtable and return that.
let new_vptr = self.get_vtable_ptr(ty, data_b)?;
Expand Down
57 changes: 23 additions & 34 deletions compiler/rustc_trait_selection/src/traits/vtable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ use std::fmt::Debug;
use std::ops::ControlFlow;

use rustc_hir::def_id::DefId;
use rustc_infer::infer::TyCtxtInferExt;
use rustc_infer::traits::ObligationCause;
use rustc_infer::traits::util::PredicateSet;
use rustc_middle::bug;
use rustc_middle::query::Providers;
Expand All @@ -14,7 +12,7 @@ use rustc_span::DUMMY_SP;
use smallvec::{SmallVec, smallvec};
use tracing::debug;

use crate::traits::{ObligationCtxt, impossible_predicates, is_vtable_safe_method};
use crate::traits::{impossible_predicates, is_vtable_safe_method};

#[derive(Clone, Debug)]
pub enum VtblSegment<'tcx> {
Expand Down Expand Up @@ -228,6 +226,11 @@ fn vtable_entries<'tcx>(
trait_ref: ty::TraitRef<'tcx>,
) -> &'tcx [VtblEntry<'tcx>] {
debug_assert!(!trait_ref.has_non_region_infer() && !trait_ref.has_non_region_param());
debug_assert_eq!(
tcx.normalize_erasing_regions(ty::TypingEnv::fully_monomorphized(), trait_ref),
trait_ref,
"vtable trait ref should be normalized"
);

debug!("vtable_entries({:?})", trait_ref);

Expand Down Expand Up @@ -305,6 +308,11 @@ fn vtable_entries<'tcx>(
// for `Supertrait`'s methods in the vtable of `Subtrait`.
pub(crate) fn first_method_vtable_slot<'tcx>(tcx: TyCtxt<'tcx>, key: ty::TraitRef<'tcx>) -> usize {
debug_assert!(!key.has_non_region_infer() && !key.has_non_region_param());
debug_assert_eq!(
tcx.normalize_erasing_regions(ty::TypingEnv::fully_monomorphized(), key),
key,
"vtable trait ref should be normalized"
);

let ty::Dynamic(source, _, _) = *key.self_ty().kind() else {
bug!();
Expand All @@ -323,11 +331,9 @@ pub(crate) fn first_method_vtable_slot<'tcx>(tcx: TyCtxt<'tcx>, key: ty::TraitRe
vptr_offset += TyCtxt::COMMON_VTABLE_ENTRIES.len();
}
VtblSegment::TraitOwnEntries { trait_ref: vtable_principal, emit_vptr } => {
if trait_refs_are_compatible(
tcx,
ty::ExistentialTraitRef::erase_self_ty(tcx, vtable_principal),
target_principal,
) {
if ty::ExistentialTraitRef::erase_self_ty(tcx, vtable_principal)
== target_principal
{
return ControlFlow::Break(vptr_offset);
}

Expand Down Expand Up @@ -358,6 +364,12 @@ pub(crate) fn supertrait_vtable_slot<'tcx>(
),
) -> Option<usize> {
debug_assert!(!key.has_non_region_infer() && !key.has_non_region_param());
debug_assert_eq!(
tcx.normalize_erasing_regions(ty::TypingEnv::fully_monomorphized(), key),
key,
"upcasting trait refs should be normalized"
);

let (source, target) = key;

// If the target principal is `None`, we can just return `None`.
Expand All @@ -384,11 +396,9 @@ pub(crate) fn supertrait_vtable_slot<'tcx>(
VtblSegment::TraitOwnEntries { trait_ref: vtable_principal, emit_vptr } => {
vptr_offset +=
tcx.own_existential_vtable_entries(vtable_principal.def_id).len();
if trait_refs_are_compatible(
tcx,
ty::ExistentialTraitRef::erase_self_ty(tcx, vtable_principal),
target_principal,
) {
if ty::ExistentialTraitRef::erase_self_ty(tcx, vtable_principal)
== target_principal
{
if emit_vptr {
return ControlFlow::Break(Some(vptr_offset));
} else {
Expand All @@ -408,27 +418,6 @@ pub(crate) fn supertrait_vtable_slot<'tcx>(
prepare_vtable_segments(tcx, source_principal, vtable_segment_callback).unwrap()
}

fn trait_refs_are_compatible<'tcx>(
tcx: TyCtxt<'tcx>,
vtable_principal: ty::ExistentialTraitRef<'tcx>,
target_principal: ty::ExistentialTraitRef<'tcx>,
) -> bool {
if vtable_principal.def_id != target_principal.def_id {
return false;
}

let (infcx, param_env) =
tcx.infer_ctxt().build_with_typing_env(ty::TypingEnv::fully_monomorphized());
let ocx = ObligationCtxt::new(&infcx);
let source_principal = ocx.normalize(&ObligationCause::dummy(), param_env, vtable_principal);
let target_principal = ocx.normalize(&ObligationCause::dummy(), param_env, target_principal);
let Ok(()) = ocx.eq(&ObligationCause::dummy(), param_env, target_principal, source_principal)
else {
return false;
};
ocx.select_all_or_error().is_empty()
}

pub(super) fn provide(providers: &mut Providers) {
*providers = Providers {
own_existential_vtable_entries,
Expand Down
52 changes: 52 additions & 0 deletions src/tools/miri/tests/pass/dyn-upcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ fn main() {
replace_vptr();
vtable_nop_cast();
drop_principal();
modulo_binder();
modulo_assoc();
}

fn vtable_nop_cast() {
Expand Down Expand Up @@ -482,3 +484,53 @@ fn drop_principal() {
println!("before");
drop(y);
}

// Test for <https://github.com/rust-lang/rust/issues/135316>.
fn modulo_binder() {
trait Supertrait<T> {
fn _print_numbers(&self, mem: &[usize; 100]) {
println!("{mem:?}");
}
}
impl<T> Supertrait<T> for () {}

trait Trait<T, U>: Supertrait<T> + Supertrait<U> {
fn say_hello(&self, _: &usize) {
println!("Hello!");
}
}
impl<T, U> Trait<T, U> for () {}

(&() as &'static dyn for<'a> Trait<&'static (), &'a ()>
as &'static dyn Trait<&'static (), &'static ()>)
.say_hello(&0);
}

// Test for <https://github.com/rust-lang/rust/issues/135315>.
fn modulo_assoc() {
trait Supertrait<T> {
fn _print_numbers(&self, mem: &[usize; 100]) {
println!("{mem:?}");
}
}
impl<T> Supertrait<T> for () {}

trait Identity {
type Selff;
}
impl<Selff> Identity for Selff {
type Selff = Selff;
}

trait Middle<T>: Supertrait<()> + Supertrait<T> {
fn say_hello(&self, _: &usize) {
println!("Hello!");
}
}
impl<T> Middle<T> for () {}

trait Trait: Middle<<() as Identity>::Selff> {}
impl Trait for () {}

(&() as &dyn Trait as &dyn Middle<()>).say_hello(&0);
}
2 changes: 2 additions & 0 deletions src/tools/miri/tests/pass/dyn-upcast.stdout
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@ before
goodbye
before
goodbye
Hello!
Hello!
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#![feature(rustc_attrs)]

// Test for <https://github.com/rust-lang/rust/issues/135316>.

trait Supertrait<T> {
fn _print_numbers(&self, mem: &[usize; 100]) {
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ error: vtable entries: [
Method(<dyn for<'a> Trait<&(), &'a ()> as Supertrait<&()>>::_print_numbers - shim(reify)),
Method(<dyn for<'a> Trait<&(), &'a ()> as Trait<&(), &()>>::say_hello - shim(reify)),
]
--> $DIR/multiple-supertraits-modulo-binder-vtable.rs:18:1
--> $DIR/multiple-supertraits-modulo-binder-vtable.rs:20:1
|
LL | type First = dyn for<'a> Trait<&'static (), &'a ()>;
| ^^^^^^^^^^
Expand All @@ -17,7 +17,7 @@ error: vtable entries: [
Method(<dyn Trait<&(), &()> as Supertrait<&()>>::_print_numbers - shim(reify)),
Method(<dyn Trait<&(), &()> as Trait<&(), &()>>::say_hello - shim(reify)),
]
--> $DIR/multiple-supertraits-modulo-binder-vtable.rs:22:1
--> $DIR/multiple-supertraits-modulo-binder-vtable.rs:24:1
|
LL | type Second = dyn Trait<&'static (), &'static ()>;
| ^^^^^^^^^^^
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
//@ run-pass
//@ check-run-results

// Test for <https://github.com/rust-lang/rust/issues/135316>.

#![feature(trait_upcasting)]

trait Supertrait<T> {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#![feature(rustc_attrs)]

#![feature(trait_upcasting)]

// Test for <https://github.com/rust-lang/rust/issues/135315>.

trait Supertrait<T> {
fn _print_numbers(&self, mem: &[usize; 100]) {
println!("{mem:?}");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ error: vtable entries: [
Method(<() as Supertrait<()>>::_print_numbers),
Method(<() as Middle<()>>::say_hello),
]
--> $DIR/multiple-supertraits-modulo-normalization-vtable.rs:29:1
--> $DIR/multiple-supertraits-modulo-normalization-vtable.rs:30:1
|
LL | impl Trait for () {}
| ^^^^^^^^^^^^^^^^^
Expand All @@ -17,7 +17,7 @@ error: vtable entries: [
Method(<dyn Middle<()> as Supertrait<()>>::_print_numbers - shim(reify)),
Method(<dyn Middle<()> as Middle<()>>::say_hello - shim(reify)),
]
--> $DIR/multiple-supertraits-modulo-normalization-vtable.rs:33:1
--> $DIR/multiple-supertraits-modulo-normalization-vtable.rs:34:1
|
LL | type Virtual = dyn Middle<()>;
| ^^^^^^^^^^^^
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

#![feature(trait_upcasting)]

// Test for <https://github.com/rust-lang/rust/issues/135315>.

trait Supertrait<T> {
fn _print_numbers(&self, mem: &[usize; 100]) {
println!("{mem:?}");
Expand Down
Loading

0 comments on commit d98b99a

Please sign in to comment.