From d98b99af56e1260f520102a93f198ffe47793722 Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Sat, 11 Jan 2025 19:31:28 +0000 Subject: [PATCH] More assertions, tests, and miri coverage --- compiler/rustc_codegen_ssa/src/meth.rs | 2 + .../rustc_const_eval/src/interpret/call.rs | 45 ++++++++------- .../rustc_const_eval/src/interpret/cast.rs | 54 +++++++++--------- .../src/traits/vtable.rs | 57 ++++++++----------- src/tools/miri/tests/pass/dyn-upcast.rs | 52 +++++++++++++++++ src/tools/miri/tests/pass/dyn-upcast.stdout | 2 + ...ltiple-supertraits-modulo-binder-vtable.rs | 2 + ...le-supertraits-modulo-binder-vtable.stderr | 4 +- .../multiple-supertraits-modulo-binder.rs | 2 + ...supertraits-modulo-normalization-vtable.rs | 3 +- ...rtraits-modulo-normalization-vtable.stderr | 4 +- ...ltiple-supertraits-modulo-normalization.rs | 2 + .../supertraits-modulo-inner-binder.rs | 30 ++++++++++ 13 files changed, 173 insertions(+), 86 deletions(-) create mode 100644 tests/ui/traits/trait-upcasting/supertraits-modulo-inner-binder.rs diff --git a/compiler/rustc_codegen_ssa/src/meth.rs b/compiler/rustc_codegen_ssa/src/meth.rs index 886951855401e..399c592432aca 100644 --- a/compiler/rustc_codegen_ssa/src/meth.rs +++ b/compiler/rustc_codegen_ssa/src/meth.rs @@ -80,6 +80,8 @@ fn dyn_trait_in_self<'tcx>( if let GenericArgKind::Type(ty) = arg.unpack() && let ty::Dynamic(data, _, _) = ty.kind() { + // FIXME(arbitrary_self_types): This is likely broken for receivers which + // have a "non-self" trait objects as a generic argument. return data .principal() .map(|principal| tcx.instantiate_bound_regions_with_erased(principal)); diff --git a/compiler/rustc_const_eval/src/interpret/call.rs b/compiler/rustc_const_eval/src/interpret/call.rs index e6a34193c9d69..e2e6e16d8a71a 100644 --- a/compiler/rustc_const_eval/src/interpret/call.rs +++ b/compiler/rustc_const_eval/src/interpret/call.rs @@ -5,6 +5,7 @@ use std::borrow::Cow; use either::{Left, Right}; use rustc_abi::{self as abi, ExternAbi, FieldIdx, Integer, VariantIdx}; +use rustc_hir::def_id::DefId; use rustc_middle::ty::layout::{FnAbiOf, IntegerExt, LayoutOf, TyAndLayout}; use rustc_middle::ty::{self, AdtDef, Instance, Ty, VariantDef}; use rustc_middle::{bug, mir, span_bug}; @@ -693,25 +694,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { trace!("Virtual call dispatches to {fn_inst:#?}"); // We can also do the lookup based on `def_id` and `dyn_ty`, and check that that // produces the same result. - if cfg!(debug_assertions) { - let tcx = *self.tcx; - - let trait_def_id = tcx.trait_of_item(def_id).unwrap(); - let virtual_trait_ref = - ty::TraitRef::from_method(tcx, trait_def_id, instance.args); - let existential_trait_ref = - ty::ExistentialTraitRef::erase_self_ty(tcx, virtual_trait_ref); - let concrete_trait_ref = existential_trait_ref.with_self_ty(tcx, dyn_ty); - - let concrete_method = Instance::expect_resolve_for_vtable( - tcx, - self.typing_env, - def_id, - instance.args.rebase_onto(tcx, trait_def_id, concrete_trait_ref.args), - self.cur_span(), - ); - assert_eq!(fn_inst, concrete_method); - } + self.assert_virtual_instance_matches_concrete(dyn_ty, def_id, instance, fn_inst); // Adjust receiver argument. Layout can be any (thin) ptr. let receiver_ty = Ty::new_mut_ptr(self.tcx.tcx, dyn_ty); @@ -744,6 +727,30 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { } } + fn assert_virtual_instance_matches_concrete( + &self, + dyn_ty: Ty<'tcx>, + def_id: DefId, + virtual_instance: ty::Instance<'tcx>, + concrete_instance: ty::Instance<'tcx>, + ) { + let tcx = *self.tcx; + + let trait_def_id = tcx.trait_of_item(def_id).unwrap(); + let virtual_trait_ref = ty::TraitRef::from_method(tcx, trait_def_id, virtual_instance.args); + let existential_trait_ref = ty::ExistentialTraitRef::erase_self_ty(tcx, virtual_trait_ref); + let concrete_trait_ref = existential_trait_ref.with_self_ty(tcx, dyn_ty); + + let concrete_method = Instance::expect_resolve_for_vtable( + tcx, + self.typing_env, + def_id, + virtual_instance.args.rebase_onto(tcx, trait_def_id, concrete_trait_ref.args), + self.cur_span(), + ); + assert_eq!(concrete_instance, concrete_method); + } + /// Initiate a tail call to this function -- popping the current stack frame, pushing the new /// stack frame and initializing the arguments. pub(super) fn init_fn_tail_call( diff --git a/compiler/rustc_const_eval/src/interpret/cast.rs b/compiler/rustc_const_eval/src/interpret/cast.rs index 52bc2af928dce..e110c155da089 100644 --- a/compiler/rustc_const_eval/src/interpret/cast.rs +++ b/compiler/rustc_const_eval/src/interpret/cast.rs @@ -414,35 +414,33 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { // Sanity-check that `supertrait_vtable_slot` in this type's vtable indeed produces // our destination trait. - if cfg!(debug_assertions) { - let vptr_entry_idx = - self.tcx.supertrait_vtable_slot((src_pointee_ty, dest_pointee_ty)); - let vtable_entries = self.vtable_entries(data_a.principal(), ty); - if let Some(entry_idx) = vptr_entry_idx { - let Some(&ty::VtblEntry::TraitVPtr(upcast_trait_ref)) = - vtable_entries.get(entry_idx) - else { - span_bug!( - self.cur_span(), - "invalid vtable entry index in {} -> {} upcast", - src_pointee_ty, - dest_pointee_ty - ); - }; - let erased_trait_ref = - ty::ExistentialTraitRef::erase_self_ty(*self.tcx, upcast_trait_ref); - assert!(data_b.principal().is_some_and(|b| self.eq_in_param_env( - erased_trait_ref, - self.tcx.instantiate_bound_regions_with_erased(b) - ))); - } else { - // In this case codegen would keep using the old vtable. We don't want to do - // that as it has the wrong trait. The reason codegen can do this is that - // one vtable is a prefix of the other, so we double-check that. - let vtable_entries_b = self.vtable_entries(data_b.principal(), ty); - assert!(&vtable_entries[..vtable_entries_b.len()] == vtable_entries_b); + let vptr_entry_idx = + self.tcx.supertrait_vtable_slot((src_pointee_ty, dest_pointee_ty)); + let vtable_entries = self.vtable_entries(data_a.principal(), ty); + if let Some(entry_idx) = vptr_entry_idx { + let Some(&ty::VtblEntry::TraitVPtr(upcast_trait_ref)) = + vtable_entries.get(entry_idx) + else { + span_bug!( + self.cur_span(), + "invalid vtable entry index in {} -> {} upcast", + src_pointee_ty, + dest_pointee_ty + ); }; - } + let erased_trait_ref = + ty::ExistentialTraitRef::erase_self_ty(*self.tcx, upcast_trait_ref); + assert!(data_b.principal().is_some_and(|b| self.eq_in_param_env( + erased_trait_ref, + self.tcx.instantiate_bound_regions_with_erased(b) + ))); + } else { + // In this case codegen would keep using the old vtable. We don't want to do + // that as it has the wrong trait. The reason codegen can do this is that + // one vtable is a prefix of the other, so we double-check that. + let vtable_entries_b = self.vtable_entries(data_b.principal(), ty); + assert!(&vtable_entries[..vtable_entries_b.len()] == vtable_entries_b); + }; // Get the destination trait vtable and return that. let new_vptr = self.get_vtable_ptr(ty, data_b)?; diff --git a/compiler/rustc_trait_selection/src/traits/vtable.rs b/compiler/rustc_trait_selection/src/traits/vtable.rs index abdf5df6f7282..000e6a765d35a 100644 --- a/compiler/rustc_trait_selection/src/traits/vtable.rs +++ b/compiler/rustc_trait_selection/src/traits/vtable.rs @@ -2,8 +2,6 @@ use std::fmt::Debug; use std::ops::ControlFlow; use rustc_hir::def_id::DefId; -use rustc_infer::infer::TyCtxtInferExt; -use rustc_infer::traits::ObligationCause; use rustc_infer::traits::util::PredicateSet; use rustc_middle::bug; use rustc_middle::query::Providers; @@ -14,7 +12,7 @@ use rustc_span::DUMMY_SP; use smallvec::{SmallVec, smallvec}; use tracing::debug; -use crate::traits::{ObligationCtxt, impossible_predicates, is_vtable_safe_method}; +use crate::traits::{impossible_predicates, is_vtable_safe_method}; #[derive(Clone, Debug)] pub enum VtblSegment<'tcx> { @@ -228,6 +226,11 @@ fn vtable_entries<'tcx>( trait_ref: ty::TraitRef<'tcx>, ) -> &'tcx [VtblEntry<'tcx>] { debug_assert!(!trait_ref.has_non_region_infer() && !trait_ref.has_non_region_param()); + debug_assert_eq!( + tcx.normalize_erasing_regions(ty::TypingEnv::fully_monomorphized(), trait_ref), + trait_ref, + "vtable trait ref should be normalized" + ); debug!("vtable_entries({:?})", trait_ref); @@ -305,6 +308,11 @@ fn vtable_entries<'tcx>( // for `Supertrait`'s methods in the vtable of `Subtrait`. pub(crate) fn first_method_vtable_slot<'tcx>(tcx: TyCtxt<'tcx>, key: ty::TraitRef<'tcx>) -> usize { debug_assert!(!key.has_non_region_infer() && !key.has_non_region_param()); + debug_assert_eq!( + tcx.normalize_erasing_regions(ty::TypingEnv::fully_monomorphized(), key), + key, + "vtable trait ref should be normalized" + ); let ty::Dynamic(source, _, _) = *key.self_ty().kind() else { bug!(); @@ -323,11 +331,9 @@ pub(crate) fn first_method_vtable_slot<'tcx>(tcx: TyCtxt<'tcx>, key: ty::TraitRe vptr_offset += TyCtxt::COMMON_VTABLE_ENTRIES.len(); } VtblSegment::TraitOwnEntries { trait_ref: vtable_principal, emit_vptr } => { - if trait_refs_are_compatible( - tcx, - ty::ExistentialTraitRef::erase_self_ty(tcx, vtable_principal), - target_principal, - ) { + if ty::ExistentialTraitRef::erase_self_ty(tcx, vtable_principal) + == target_principal + { return ControlFlow::Break(vptr_offset); } @@ -358,6 +364,12 @@ pub(crate) fn supertrait_vtable_slot<'tcx>( ), ) -> Option { debug_assert!(!key.has_non_region_infer() && !key.has_non_region_param()); + debug_assert_eq!( + tcx.normalize_erasing_regions(ty::TypingEnv::fully_monomorphized(), key), + key, + "upcasting trait refs should be normalized" + ); + let (source, target) = key; // If the target principal is `None`, we can just return `None`. @@ -384,11 +396,9 @@ pub(crate) fn supertrait_vtable_slot<'tcx>( VtblSegment::TraitOwnEntries { trait_ref: vtable_principal, emit_vptr } => { vptr_offset += tcx.own_existential_vtable_entries(vtable_principal.def_id).len(); - if trait_refs_are_compatible( - tcx, - ty::ExistentialTraitRef::erase_self_ty(tcx, vtable_principal), - target_principal, - ) { + if ty::ExistentialTraitRef::erase_self_ty(tcx, vtable_principal) + == target_principal + { if emit_vptr { return ControlFlow::Break(Some(vptr_offset)); } else { @@ -408,27 +418,6 @@ pub(crate) fn supertrait_vtable_slot<'tcx>( prepare_vtable_segments(tcx, source_principal, vtable_segment_callback).unwrap() } -fn trait_refs_are_compatible<'tcx>( - tcx: TyCtxt<'tcx>, - vtable_principal: ty::ExistentialTraitRef<'tcx>, - target_principal: ty::ExistentialTraitRef<'tcx>, -) -> bool { - if vtable_principal.def_id != target_principal.def_id { - return false; - } - - let (infcx, param_env) = - tcx.infer_ctxt().build_with_typing_env(ty::TypingEnv::fully_monomorphized()); - let ocx = ObligationCtxt::new(&infcx); - let source_principal = ocx.normalize(&ObligationCause::dummy(), param_env, vtable_principal); - let target_principal = ocx.normalize(&ObligationCause::dummy(), param_env, target_principal); - let Ok(()) = ocx.eq(&ObligationCause::dummy(), param_env, target_principal, source_principal) - else { - return false; - }; - ocx.select_all_or_error().is_empty() -} - pub(super) fn provide(providers: &mut Providers) { *providers = Providers { own_existential_vtable_entries, diff --git a/src/tools/miri/tests/pass/dyn-upcast.rs b/src/tools/miri/tests/pass/dyn-upcast.rs index 61410f7c4e0b5..f100c4d6a869e 100644 --- a/src/tools/miri/tests/pass/dyn-upcast.rs +++ b/src/tools/miri/tests/pass/dyn-upcast.rs @@ -10,6 +10,8 @@ fn main() { replace_vptr(); vtable_nop_cast(); drop_principal(); + modulo_binder(); + modulo_assoc(); } fn vtable_nop_cast() { @@ -482,3 +484,53 @@ fn drop_principal() { println!("before"); drop(y); } + +// Test for . +fn modulo_binder() { + trait Supertrait { + fn _print_numbers(&self, mem: &[usize; 100]) { + println!("{mem:?}"); + } + } + impl Supertrait for () {} + + trait Trait: Supertrait + Supertrait { + fn say_hello(&self, _: &usize) { + println!("Hello!"); + } + } + impl Trait for () {} + + (&() as &'static dyn for<'a> Trait<&'static (), &'a ()> + as &'static dyn Trait<&'static (), &'static ()>) + .say_hello(&0); +} + +// Test for . +fn modulo_assoc() { + trait Supertrait { + fn _print_numbers(&self, mem: &[usize; 100]) { + println!("{mem:?}"); + } + } + impl Supertrait for () {} + + trait Identity { + type Selff; + } + impl Identity for Selff { + type Selff = Selff; + } + + trait Middle: Supertrait<()> + Supertrait { + fn say_hello(&self, _: &usize) { + println!("Hello!"); + } + } + impl Middle for () {} + + trait Trait: Middle<<() as Identity>::Selff> {} + impl Trait for () {} + + (&() as &dyn Trait as &dyn Middle<()>).say_hello(&0); +} diff --git a/src/tools/miri/tests/pass/dyn-upcast.stdout b/src/tools/miri/tests/pass/dyn-upcast.stdout index edd99a114a112..379600db3d91f 100644 --- a/src/tools/miri/tests/pass/dyn-upcast.stdout +++ b/src/tools/miri/tests/pass/dyn-upcast.stdout @@ -2,3 +2,5 @@ before goodbye before goodbye +Hello! +Hello! diff --git a/tests/ui/traits/trait-upcasting/multiple-supertraits-modulo-binder-vtable.rs b/tests/ui/traits/trait-upcasting/multiple-supertraits-modulo-binder-vtable.rs index b41cf2e9f2679..796ddec46ac7a 100644 --- a/tests/ui/traits/trait-upcasting/multiple-supertraits-modulo-binder-vtable.rs +++ b/tests/ui/traits/trait-upcasting/multiple-supertraits-modulo-binder-vtable.rs @@ -1,5 +1,7 @@ #![feature(rustc_attrs)] +// Test for . + trait Supertrait { fn _print_numbers(&self, mem: &[usize; 100]) { } diff --git a/tests/ui/traits/trait-upcasting/multiple-supertraits-modulo-binder-vtable.stderr b/tests/ui/traits/trait-upcasting/multiple-supertraits-modulo-binder-vtable.stderr index 4bb7e1bdb6ad2..24fa1650ca14c 100644 --- a/tests/ui/traits/trait-upcasting/multiple-supertraits-modulo-binder-vtable.stderr +++ b/tests/ui/traits/trait-upcasting/multiple-supertraits-modulo-binder-vtable.stderr @@ -5,7 +5,7 @@ error: vtable entries: [ Method( Trait<&(), &'a ()> as Supertrait<&()>>::_print_numbers - shim(reify)), Method( Trait<&(), &'a ()> as Trait<&(), &()>>::say_hello - shim(reify)), ] - --> $DIR/multiple-supertraits-modulo-binder-vtable.rs:18:1 + --> $DIR/multiple-supertraits-modulo-binder-vtable.rs:20:1 | LL | type First = dyn for<'a> Trait<&'static (), &'a ()>; | ^^^^^^^^^^ @@ -17,7 +17,7 @@ error: vtable entries: [ Method( as Supertrait<&()>>::_print_numbers - shim(reify)), Method( as Trait<&(), &()>>::say_hello - shim(reify)), ] - --> $DIR/multiple-supertraits-modulo-binder-vtable.rs:22:1 + --> $DIR/multiple-supertraits-modulo-binder-vtable.rs:24:1 | LL | type Second = dyn Trait<&'static (), &'static ()>; | ^^^^^^^^^^^ diff --git a/tests/ui/traits/trait-upcasting/multiple-supertraits-modulo-binder.rs b/tests/ui/traits/trait-upcasting/multiple-supertraits-modulo-binder.rs index 7bc069f4f4407..510a1471af293 100644 --- a/tests/ui/traits/trait-upcasting/multiple-supertraits-modulo-binder.rs +++ b/tests/ui/traits/trait-upcasting/multiple-supertraits-modulo-binder.rs @@ -1,6 +1,8 @@ //@ run-pass //@ check-run-results +// Test for . + #![feature(trait_upcasting)] trait Supertrait { diff --git a/tests/ui/traits/trait-upcasting/multiple-supertraits-modulo-normalization-vtable.rs b/tests/ui/traits/trait-upcasting/multiple-supertraits-modulo-normalization-vtable.rs index 22d6bf94d8706..69a71859a5cc7 100644 --- a/tests/ui/traits/trait-upcasting/multiple-supertraits-modulo-normalization-vtable.rs +++ b/tests/ui/traits/trait-upcasting/multiple-supertraits-modulo-normalization-vtable.rs @@ -1,7 +1,8 @@ #![feature(rustc_attrs)] - #![feature(trait_upcasting)] +// Test for . + trait Supertrait { fn _print_numbers(&self, mem: &[usize; 100]) { println!("{mem:?}"); diff --git a/tests/ui/traits/trait-upcasting/multiple-supertraits-modulo-normalization-vtable.stderr b/tests/ui/traits/trait-upcasting/multiple-supertraits-modulo-normalization-vtable.stderr index 04b1afae7beca..757e2dc69390c 100644 --- a/tests/ui/traits/trait-upcasting/multiple-supertraits-modulo-normalization-vtable.stderr +++ b/tests/ui/traits/trait-upcasting/multiple-supertraits-modulo-normalization-vtable.stderr @@ -5,7 +5,7 @@ error: vtable entries: [ Method(<() as Supertrait<()>>::_print_numbers), Method(<() as Middle<()>>::say_hello), ] - --> $DIR/multiple-supertraits-modulo-normalization-vtable.rs:29:1 + --> $DIR/multiple-supertraits-modulo-normalization-vtable.rs:30:1 | LL | impl Trait for () {} | ^^^^^^^^^^^^^^^^^ @@ -17,7 +17,7 @@ error: vtable entries: [ Method( as Supertrait<()>>::_print_numbers - shim(reify)), Method( as Middle<()>>::say_hello - shim(reify)), ] - --> $DIR/multiple-supertraits-modulo-normalization-vtable.rs:33:1 + --> $DIR/multiple-supertraits-modulo-normalization-vtable.rs:34:1 | LL | type Virtual = dyn Middle<()>; | ^^^^^^^^^^^^ diff --git a/tests/ui/traits/trait-upcasting/multiple-supertraits-modulo-normalization.rs b/tests/ui/traits/trait-upcasting/multiple-supertraits-modulo-normalization.rs index fd0f62b4255a1..c744e6e64f575 100644 --- a/tests/ui/traits/trait-upcasting/multiple-supertraits-modulo-normalization.rs +++ b/tests/ui/traits/trait-upcasting/multiple-supertraits-modulo-normalization.rs @@ -3,6 +3,8 @@ #![feature(trait_upcasting)] +// Test for . + trait Supertrait { fn _print_numbers(&self, mem: &[usize; 100]) { println!("{mem:?}"); diff --git a/tests/ui/traits/trait-upcasting/supertraits-modulo-inner-binder.rs b/tests/ui/traits/trait-upcasting/supertraits-modulo-inner-binder.rs new file mode 100644 index 0000000000000..6cd74b6c7f75f --- /dev/null +++ b/tests/ui/traits/trait-upcasting/supertraits-modulo-inner-binder.rs @@ -0,0 +1,30 @@ +//@ run-pass + +#![feature(trait_upcasting)] + +trait Super { + fn call(&self) + where + U: HigherRanked, + { + } +} + +impl Super for () {} + +trait HigherRanked {} +impl HigherRanked for for<'a> fn(&'a ()) {} + +trait Unimplemented {} +impl HigherRanked for T {} + +trait Sub: Super + Super fn(&'a ())> {} +impl Sub for () {} + +fn main() { + let a: &dyn Sub = &(); + // `Super` and `Super fn(&'a ())>` have different + // vtables and we need to upcast to the latter! + let b: &dyn Super fn(&'a ())> = a; + b.call(); +}