Skip to content

Commit

Permalink
Rollup merge of rust-lang#109641 - compiler-errors:dont-elaborate-non…
Browse files Browse the repository at this point in the history
…-obl, r=oli-obk

Don't elaborate non-obligations into obligations

It's suspicious to elaborate a `PolyTraitRef` or `Predicate` into an `Obligation`, since the former does not have a param-env associated with it, but the latter does. This is a footgun that, while not being misused *currently* in the compiler, easily could be misused by someone less familiar with the elaborator's inner workings.

This PR just changes the API -- ideally, the elaborator wouldn't even have to deal with obligations if we're not elaborating obligations, but that would require a bit more abstraction than I could be bothered with today.
  • Loading branch information
matthiaskrgr authored Mar 27, 2023
2 parents 2b7dc94 + 1ce4b37 commit 6535e66
Show file tree
Hide file tree
Showing 20 changed files with 112 additions and 131 deletions.
9 changes: 3 additions & 6 deletions compiler/rustc_hir_analysis/src/astconv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1427,13 +1427,10 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
for (base_trait_ref, span, constness) in regular_traits_refs_spans {
assert_eq!(constness, ty::BoundConstness::NotConst);

for obligation in traits::elaborate_trait_ref(tcx, base_trait_ref) {
debug!(
"conv_object_ty_poly_trait_ref: observing object predicate `{:?}`",
obligation.predicate
);
for pred in traits::elaborate_trait_ref(tcx, base_trait_ref) {
debug!("conv_object_ty_poly_trait_ref: observing object predicate `{:?}`", pred);

let bound_predicate = obligation.predicate.kind();
let bound_predicate = pred.kind();
match bound_predicate.skip_binder() {
ty::PredicateKind::Clause(ty::Clause::Trait(pred)) => {
let pred = bound_predicate.rebind(pred);
Expand Down
7 changes: 2 additions & 5 deletions compiler/rustc_hir_analysis/src/check/wfcheck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1912,14 +1912,13 @@ impl<'tcx> WfCheckingCtxt<'_, 'tcx> {
// Check elaborated bounds.
let implied_obligations = traits::elaborate_predicates_with_span(tcx, predicates_with_span);

for obligation in implied_obligations {
for (pred, obligation_span) in implied_obligations {
// We lower empty bounds like `Vec<dyn Copy>:` as
// `WellFormed(Vec<dyn Copy>)`, which will later get checked by
// regular WF checking
if let ty::PredicateKind::WellFormed(..) = obligation.predicate.kind().skip_binder() {
if let ty::PredicateKind::WellFormed(..) = pred.kind().skip_binder() {
continue;
}
let pred = obligation.predicate;
// Match the existing behavior.
if pred.is_global() && !pred.has_late_bound_vars() {
let pred = self.normalize(span, None, pred);
Expand All @@ -1930,8 +1929,6 @@ impl<'tcx> WfCheckingCtxt<'_, 'tcx> {
if let Some(hir::Generics { predicates, .. }) =
hir_node.and_then(|node| node.generics())
{
let obligation_span = obligation.cause.span();

span = predicates
.iter()
// There seems to be no better way to find out which predicate we are in
Expand Down
11 changes: 4 additions & 7 deletions compiler/rustc_hir_analysis/src/collect/item_bounds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,9 @@ pub(super) fn item_bounds(
tcx: TyCtxt<'_>,
def_id: DefId,
) -> ty::EarlyBinder<&'_ ty::List<ty::Predicate<'_>>> {
let bounds = tcx.mk_predicates_from_iter(
util::elaborate_predicates(
tcx,
tcx.explicit_item_bounds(def_id).iter().map(|&(bound, _span)| bound),
)
.map(|obligation| obligation.predicate),
);
let bounds = tcx.mk_predicates_from_iter(util::elaborate_predicates(
tcx,
tcx.explicit_item_bounds(def_id).iter().map(|&(bound, _span)| bound),
));
ty::EarlyBinder(bounds)
}
34 changes: 13 additions & 21 deletions compiler/rustc_hir_analysis/src/impl_wf_check/min_specialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,16 +318,8 @@ fn check_predicates<'tcx>(
span: Span,
) {
let instantiated = tcx.predicates_of(impl1_def_id).instantiate(tcx, impl1_substs);
let impl1_predicates: Vec<_> = traits::elaborate_predicates_with_span(
tcx,
std::iter::zip(
instantiated.predicates,
// Don't drop predicates (unsound!) because `spans` is too short
instantiated.spans.into_iter().chain(std::iter::repeat(span)),
),
)
.map(|obligation| (obligation.predicate, obligation.cause.span))
.collect();
let impl1_predicates: Vec<_> =
traits::elaborate_predicates_with_span(tcx, instantiated.into_iter()).collect();

let mut impl2_predicates = if impl2_node.is_from_trait() {
// Always applicable traits have to be always applicable without any
Expand All @@ -341,7 +333,6 @@ fn check_predicates<'tcx>(
.predicates
.into_iter(),
)
.map(|obligation| obligation.predicate)
.collect()
};
debug!(?impl1_predicates, ?impl2_predicates);
Expand All @@ -361,12 +352,16 @@ fn check_predicates<'tcx>(
// which is sound because we forbid impls like the following
//
// impl<D: Debug> AlwaysApplicable for D { }
let always_applicable_traits = impl1_predicates.iter().copied().filter(|&(predicate, _)| {
matches!(
trait_predicate_kind(tcx, predicate),
Some(TraitSpecializationKind::AlwaysApplicable)
)
});
let always_applicable_traits = impl1_predicates
.iter()
.copied()
.filter(|&(predicate, _)| {
matches!(
trait_predicate_kind(tcx, predicate),
Some(TraitSpecializationKind::AlwaysApplicable)
)
})
.map(|(pred, _span)| pred);

// Include the well-formed predicates of the type parameters of the impl.
for arg in tcx.impl_trait_ref(impl1_def_id).unwrap().subst_identity().substs {
Expand All @@ -380,10 +375,7 @@ fn check_predicates<'tcx>(
traits::elaborate_obligations(tcx, obligations).map(|obligation| obligation.predicate),
)
}
impl2_predicates.extend(
traits::elaborate_predicates_with_span(tcx, always_applicable_traits)
.map(|obligation| obligation.predicate),
);
impl2_predicates.extend(traits::elaborate_predicates(tcx, always_applicable_traits));

for (predicate, span) in impl1_predicates {
if !impl2_predicates.iter().any(|pred2| trait_predicates_eq(tcx, predicate, *pred2, span)) {
Expand Down
10 changes: 5 additions & 5 deletions compiler/rustc_hir_typeck/src/closure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,25 +204,25 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
let mut expected_sig = None;
let mut expected_kind = None;

for obligation in traits::elaborate_predicates_with_span(
for (pred, span) in traits::elaborate_predicates_with_span(
self.tcx,
// Reverse the obligations here, since `elaborate_*` uses a stack,
// and we want to keep inference generally in the same order of
// the registered obligations.
predicates.rev(),
) {
debug!(?obligation.predicate);
let bound_predicate = obligation.predicate.kind();
debug!(?pred);
let bound_predicate = pred.kind();

// Given a Projection predicate, we can potentially infer
// the complete signature.
if expected_sig.is_none()
&& let ty::PredicateKind::Clause(ty::Clause::Projection(proj_predicate)) = bound_predicate.skip_binder()
{
let inferred_sig = self.normalize(
obligation.cause.span,
span,
self.deduce_sig_from_projection(
Some(obligation.cause.span),
Some(span),
bound_predicate.rebind(proj_predicate),
),
);
Expand Down
8 changes: 2 additions & 6 deletions compiler/rustc_hir_typeck/src/method/confirm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -576,17 +576,13 @@ impl<'a, 'tcx> ConfirmContext<'a, 'tcx> {

traits::elaborate_predicates(self.tcx, predicates.predicates.iter().copied())
// We don't care about regions here.
.filter_map(|obligation| match obligation.predicate.kind().skip_binder() {
.filter_map(|pred| match pred.kind().skip_binder() {
ty::PredicateKind::Clause(ty::Clause::Trait(trait_pred))
if trait_pred.def_id() == sized_def_id =>
{
let span = predicates
.iter()
.find_map(
|(p, span)| {
if p == obligation.predicate { Some(span) } else { None }
},
)
.find_map(|(p, span)| if p == pred { Some(span) } else { None })
.unwrap_or(rustc_span::DUMMY_SP);
Some((trait_pred, span))
}
Expand Down
78 changes: 43 additions & 35 deletions compiler/rustc_infer/src/traits/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,44 +74,58 @@ pub struct Elaborator<'tcx> {
pub fn elaborate_trait_ref<'tcx>(
tcx: TyCtxt<'tcx>,
trait_ref: ty::PolyTraitRef<'tcx>,
) -> Elaborator<'tcx> {
) -> impl Iterator<Item = ty::Predicate<'tcx>> {
elaborate_predicates(tcx, std::iter::once(trait_ref.without_const().to_predicate(tcx)))
}

pub fn elaborate_trait_refs<'tcx>(
tcx: TyCtxt<'tcx>,
trait_refs: impl Iterator<Item = ty::PolyTraitRef<'tcx>>,
) -> Elaborator<'tcx> {
let predicates = trait_refs.map(|trait_ref| trait_ref.without_const().to_predicate(tcx));
) -> impl Iterator<Item = ty::Predicate<'tcx>> {
let predicates = trait_refs.map(move |trait_ref| trait_ref.without_const().to_predicate(tcx));
elaborate_predicates(tcx, predicates)
}

pub fn elaborate_predicates<'tcx>(
tcx: TyCtxt<'tcx>,
predicates: impl Iterator<Item = ty::Predicate<'tcx>>,
) -> Elaborator<'tcx> {
let obligations = predicates
.map(|predicate| {
predicate_obligation(predicate, ty::ParamEnv::empty(), ObligationCause::dummy())
})
.collect();
elaborate_obligations(tcx, obligations)
) -> impl Iterator<Item = ty::Predicate<'tcx>> {
elaborate_obligations(
tcx,
predicates
.map(|predicate| {
Obligation::new(
tcx,
// We'll dump the cause/param-env later
ObligationCause::dummy(),
ty::ParamEnv::empty(),
predicate,
)
})
.collect(),
)
.map(|obl| obl.predicate)
}

pub fn elaborate_predicates_with_span<'tcx>(
tcx: TyCtxt<'tcx>,
predicates: impl Iterator<Item = (ty::Predicate<'tcx>, Span)>,
) -> Elaborator<'tcx> {
let obligations = predicates
.map(|(predicate, span)| {
predicate_obligation(
predicate,
ty::ParamEnv::empty(),
ObligationCause::dummy_with_span(span),
)
})
.collect();
elaborate_obligations(tcx, obligations)
) -> impl Iterator<Item = (ty::Predicate<'tcx>, Span)> {
elaborate_obligations(
tcx,
predicates
.map(|(predicate, span)| {
Obligation::new(
tcx,
// We'll dump the cause/param-env later
ObligationCause::dummy_with_span(span),
ty::ParamEnv::empty(),
predicate,
)
})
.collect(),
)
.map(|obl| (obl.predicate, obl.cause.span))
}

pub fn elaborate_obligations<'tcx>(
Expand Down Expand Up @@ -141,10 +155,6 @@ impl<'tcx> Elaborator<'tcx> {
self.stack.extend(obligations.into_iter().filter(|o| self.visited.insert(o.predicate)));
}

pub fn filter_to_traits(self) -> FilterToTraits<Self> {
FilterToTraits::new(self)
}

fn elaborate(&mut self, obligation: &PredicateObligation<'tcx>) {
let tcx = self.visited.tcx;

Expand Down Expand Up @@ -325,20 +335,18 @@ impl<'tcx> Iterator for Elaborator<'tcx> {
// Supertrait iterator
///////////////////////////////////////////////////////////////////////////

pub type Supertraits<'tcx> = FilterToTraits<Elaborator<'tcx>>;

pub fn supertraits<'tcx>(
tcx: TyCtxt<'tcx>,
trait_ref: ty::PolyTraitRef<'tcx>,
) -> Supertraits<'tcx> {
elaborate_trait_ref(tcx, trait_ref).filter_to_traits()
) -> impl Iterator<Item = ty::PolyTraitRef<'tcx>> {
FilterToTraits::new(elaborate_trait_ref(tcx, trait_ref))
}

pub fn transitive_bounds<'tcx>(
tcx: TyCtxt<'tcx>,
bounds: impl Iterator<Item = ty::PolyTraitRef<'tcx>>,
) -> Supertraits<'tcx> {
elaborate_trait_refs(tcx, bounds).filter_to_traits()
trait_refs: impl Iterator<Item = ty::PolyTraitRef<'tcx>>,
) -> impl Iterator<Item = ty::PolyTraitRef<'tcx>> {
FilterToTraits::new(elaborate_trait_refs(tcx, trait_refs))
}

/// A specialized variant of `elaborate_trait_refs` that only elaborates trait references that may
Expand Down Expand Up @@ -393,12 +401,12 @@ impl<I> FilterToTraits<I> {
}
}

impl<'tcx, I: Iterator<Item = PredicateObligation<'tcx>>> Iterator for FilterToTraits<I> {
impl<'tcx, I: Iterator<Item = ty::Predicate<'tcx>>> Iterator for FilterToTraits<I> {
type Item = ty::PolyTraitRef<'tcx>;

fn next(&mut self) -> Option<ty::PolyTraitRef<'tcx>> {
while let Some(obligation) = self.base_iterator.next() {
if let Some(data) = obligation.predicate.to_opt_poly_trait_pred() {
while let Some(pred) = self.base_iterator.next() {
if let Some(data) = pred.to_opt_poly_trait_pred() {
return Some(data.map_bound(|t| t.trait_ref));
}
}
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_lint/src/unused.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,11 +258,11 @@ impl<'tcx> LateLintPass<'tcx> for UnusedResults {
cx.tcx,
cx.tcx.explicit_item_bounds(def).iter().cloned(),
)
.find_map(|obligation| {
.find_map(|(pred, _span)| {
// We only look at the `DefId`, so it is safe to skip the binder here.
if let ty::PredicateKind::Clause(ty::Clause::Trait(
ref poly_trait_predicate,
)) = obligation.predicate.kind().skip_binder()
)) = pred.kind().skip_binder()
{
let def_id = poly_trait_predicate.trait_ref.def_id;

Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_mir_transform/src/const_prop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ impl<'tcx> MirPass<'tcx> for ConstProp {
.filter_map(|(p, _)| if p.is_global() { Some(*p) } else { None });
if traits::impossible_predicates(
tcx,
traits::elaborate_predicates(tcx, predicates).map(|o| o.predicate).collect(),
traits::elaborate_predicates(tcx, predicates).collect(),
) {
trace!("ConstProp skipped for {:?}: found unsatisfiable predicates", def_id);
return;
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_mir_transform/src/const_prop_lint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ impl<'tcx> MirLint<'tcx> for ConstProp {
.filter_map(|(p, _)| if p.is_global() { Some(*p) } else { None });
if traits::impossible_predicates(
tcx,
traits::elaborate_predicates(tcx, predicates).map(|o| o.predicate).collect(),
traits::elaborate_predicates(tcx, predicates).collect(),
) {
trace!("ConstProp skipped for {:?}: found unsatisfiable predicates", def_id);
return;
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_trait_selection/src/solve/assembly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ impl<'tcx> EvalCtxt<'_, 'tcx> {
for assumption in
elaborate_predicates(tcx, bounds.iter().map(|bound| bound.with_self_ty(tcx, self_ty)))
{
match G::consider_object_bound_candidate(self, goal, assumption.predicate) {
match G::consider_object_bound_candidate(self, goal, assumption) {
Ok(result) => {
candidates.push(Candidate { source: CandidateSource::BuiltinImpl, result })
}
Expand Down
3 changes: 1 addition & 2 deletions compiler/rustc_trait_selection/src/traits/auto_trait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -349,8 +349,7 @@ impl<'tcx> AutoTraitFinder<'tcx> {
let normalized_preds = elaborate_predicates(
tcx,
computed_preds.clone().chain(user_computed_preds.iter().cloned()),
)
.map(|o| o.predicate);
);
new_env = ty::ParamEnv::new(
tcx.mk_predicates_from_iter(normalized_preds),
param_env.reveal(),
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_trait_selection/src/traits/coherence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,8 +367,8 @@ fn negative_impl_exists<'tcx>(
}

// Try to prove a negative obligation exists for super predicates
for o in util::elaborate_predicates(infcx.tcx, iter::once(o.predicate)) {
if resolve_negative_obligation(infcx.fork(), &o, body_def_id) {
for pred in util::elaborate_predicates(infcx.tcx, iter::once(o.predicate)) {
if resolve_negative_obligation(infcx.fork(), &o.with(infcx.tcx, pred), body_def_id) {
return true;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,15 @@ pub fn recompute_applicable_impls<'tcx>(

let predicates =
tcx.predicates_of(obligation.cause.body_id.to_def_id()).instantiate_identity(tcx);
for obligation in elaborate_predicates_with_span(tcx, predicates.into_iter()) {
let kind = obligation.predicate.kind();
for (pred, span) in elaborate_predicates_with_span(tcx, predicates.into_iter()) {
let kind = pred.kind();
if let ty::PredicateKind::Clause(ty::Clause::Trait(trait_pred)) = kind.skip_binder()
&& param_env_candidate_may_apply(kind.rebind(trait_pred))
{
if kind.rebind(trait_pred.trait_ref) == ty::TraitRef::identity(tcx, trait_pred.def_id()) {
ambiguities.push(Ambiguity::ParamEnv(tcx.def_span(trait_pred.def_id())))
} else {
ambiguities.push(Ambiguity::ParamEnv(obligation.cause.span))
ambiguities.push(Ambiguity::ParamEnv(span))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1624,8 +1624,8 @@ impl<'tcx> InferCtxtPrivExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
}
};

for obligation in super::elaborate_predicates(self.tcx, std::iter::once(cond)) {
let bound_predicate = obligation.predicate.kind();
for pred in super::elaborate_predicates(self.tcx, std::iter::once(cond)) {
let bound_predicate = pred.kind();
if let ty::PredicateKind::Clause(ty::Clause::Trait(implication)) =
bound_predicate.skip_binder()
{
Expand Down
Loading

0 comments on commit 6535e66

Please sign in to comment.