Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't over-constrain projections in generic method signatures #92728

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions compiler/rustc_borrowck/src/type_check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2101,8 +2101,9 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {

let ty_fn_ptr_from = tcx.mk_fn_ptr(fn_sig);

let ty = self.normalize(*ty, location);
if let Err(terr) = self.eq_types(
*ty,
ty,
ty_fn_ptr_from,
location.to_locations(),
ConstraintCategory::Cast,
Expand All @@ -2124,9 +2125,11 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
_ => bug!(),
};
let ty_fn_ptr_from = tcx.mk_fn_ptr(tcx.signature_unclosure(sig, *unsafety));
let ty_fn_ptr_from = self.normalize(ty_fn_ptr_from, location);

let ty = self.normalize(*ty, location);
if let Err(terr) = self.eq_types(
*ty,
ty,
ty_fn_ptr_from,
location.to_locations(),
ConstraintCategory::Cast,
Expand Down Expand Up @@ -2154,8 +2157,9 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {

let ty_fn_ptr_from = tcx.safe_to_unsafe_fn_ty(fn_sig);

let ty = self.normalize(*ty, location);
if let Err(terr) = self.eq_types(
*ty,
ty,
ty_fn_ptr_from,
location.to_locations(),
ConstraintCategory::Cast,
Expand Down
4 changes: 4 additions & 0 deletions compiler/rustc_infer/src/infer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -920,6 +920,10 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
.region_constraints_added_in_snapshot(&snapshot.undo_snapshot)
}

pub fn any_instantiations(&self, snapshot: &CombinedSnapshot<'a, 'tcx>) -> bool {
self.inner.borrow_mut().any_instantiations(&snapshot.undo_snapshot)
}

pub fn add_given(&self, sub: ty::Region<'tcx>, sup: ty::RegionVid) {
self.inner.borrow_mut().unwrap_region_constraints().add_given(sub, sup);
}
Expand Down
31 changes: 31 additions & 0 deletions compiler/rustc_infer/src/infer/undo_log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,23 @@ pub(crate) enum UndoLog<'tcx> {
PushRegionObligation,
}

impl<'tcx> std::fmt::Debug for UndoLog<'tcx> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::TypeVariables(_) => f.debug_tuple("TypeVariables").finish(),
Self::ConstUnificationTable(_) => f.debug_tuple("ConstUnificationTable").finish(),
Self::IntUnificationTable(_) => f.debug_tuple("IntUnificationTable").finish(),
Self::FloatUnificationTable(_) => f.debug_tuple("FloatUnificationTable").finish(),
Self::RegionConstraintCollector(_) => {
f.debug_tuple("RegionConstraintCollector").finish()
}
Self::RegionUnificationTable(_) => f.debug_tuple("RegionUnificationTable").finish(),
Self::ProjectionCache(_) => f.debug_tuple("ProjectionCache").finish(),
Self::PushRegionObligation => write!(f, "PushRegionObligation"),
}
}
}

macro_rules! impl_from {
($($ctor: ident ($ty: ty),)*) => {
$(
Expand Down Expand Up @@ -165,6 +182,10 @@ impl<'tcx> InferCtxtInner<'tcx> {

self.undo_log.num_open_snapshots -= 1;
}

pub fn any_instantiations(&mut self, snapshot: &Snapshot<'tcx>) -> bool {
self.undo_log.instantiations_in_snapshot(snapshot).next().is_some()
}
}

impl<'tcx> InferCtxtUndoLogs<'tcx> {
Expand All @@ -173,6 +194,16 @@ impl<'tcx> InferCtxtUndoLogs<'tcx> {
Snapshot { undo_len: self.logs.len(), _marker: PhantomData }
}

pub(crate) fn instantiations_in_snapshot(
&self,
s: &Snapshot<'tcx>,
) -> impl Iterator<Item = &'_ type_variable::UndoLog<'tcx>> + Clone {
self.logs[s.undo_len..].iter().filter_map(|log| match log {
UndoLog::TypeVariables(log) => Some(log),
_ => None,
})
}

pub(crate) fn region_constraints_in_snapshot(
&self,
s: &Snapshot<'tcx>,
Expand Down
5 changes: 4 additions & 1 deletion compiler/rustc_trait_selection/src/traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@ pub use self::object_safety::is_vtable_safe_method;
pub use self::object_safety::MethodViolationCode;
pub use self::object_safety::ObjectSafetyViolation;
pub use self::on_unimplemented::{OnUnimplementedDirective, OnUnimplementedNote};
pub use self::project::{normalize, normalize_projection_type, normalize_to};
pub use self::project::{
normalize, normalize_projection_type, normalize_to, normalize_with_depth_to,
project_and_unify_type,
};
pub use self::select::{EvaluationCache, SelectionCache, SelectionContext};
pub use self::select::{EvaluationResult, IntercrateAmbiguityCause, OverflowError};
pub use self::specialize::specialization_graph::FutureCompatOverlapError;
Expand Down
154 changes: 92 additions & 62 deletions compiler/rustc_trait_selection/src/traits/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ use super::{Normalized, NormalizedTy, ProjectionCacheEntry, ProjectionCacheKey};
use crate::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKind};
use crate::infer::{InferCtxt, InferOk, LateBoundRegionConversionTime};
use crate::traits::error_reporting::InferCtxtExt as _;
use crate::traits::select::ProjectionMatchesProjection;
use rustc_data_structures::sso::SsoHashSet;
use rustc_data_structures::stack::ensure_sufficient_stack;
use rustc_errors::ErrorGuaranteed;
Expand All @@ -42,7 +41,7 @@ pub type ProjectionObligation<'tcx> = Obligation<'tcx, ty::ProjectionPredicate<'

pub type ProjectionTyObligation<'tcx> = Obligation<'tcx, ty::ProjectionTy<'tcx>>;

pub(super) struct InProgress;
pub struct InProgress;

/// When attempting to resolve `<T as TraitRef>::Name` ...
#[derive(Debug)]
Expand Down Expand Up @@ -188,7 +187,7 @@ pub(super) fn poly_project_and_unify_type<'cx, 'tcx>(
/// If successful, this may result in additional obligations.
///
/// See [poly_project_and_unify_type] for an explanation of the return value.
fn project_and_unify_type<'cx, 'tcx>(
pub fn project_and_unify_type<'cx, 'tcx>(
selcx: &mut SelectionContext<'cx, 'tcx>,
obligation: &ProjectionObligation<'tcx>,
) -> Result<
Expand Down Expand Up @@ -484,7 +483,7 @@ impl<'a, 'b, 'tcx> TypeFolder<'tcx> for AssocTypeNormalizer<'a, 'b, 'tcx> {
// there won't be bound vars there.

let data = data.super_fold_with(self);
let normalized_ty = if self.eager_inference_replacement {
let normalized_ty = if self.selcx.normalization_mode.eager_inference_replacement {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the same as to what is done in #90887, but puts the flag on SelectionContext instead of AssocTypeNormalizer

normalize_projection_type(
self.selcx,
self.param_env,
Expand Down Expand Up @@ -914,7 +913,8 @@ fn opt_normalize_projection_type<'a, 'b, 'tcx>(
// Don't use the projection cache in intercrate mode -
// the `infcx` may be re-used between intercrate in non-intercrate
// mode, which could lead to using incorrect cache results.
let use_cache = !selcx.is_intercrate();
let use_cache =
!selcx.is_intercrate() && selcx.normalization_mode.allow_infer_constraint_during_projection;

let projection_ty = infcx.resolve_vars_if_possible(projection_ty);
let cache_key = ProjectionCacheKey::new(projection_ty);
Expand Down Expand Up @@ -1168,7 +1168,7 @@ fn project<'cx, 'tcx>(

match candidates {
ProjectionCandidateSet::Single(candidate) => {
Ok(Projected::Progress(confirm_candidate(selcx, obligation, candidate)))
Ok(confirm_candidate(selcx, obligation, candidate))
}
ProjectionCandidateSet::None => Ok(Projected::NoProgress(
// FIXME(associated_const_generics): this may need to change in the future?
Expand Down Expand Up @@ -1243,7 +1243,7 @@ fn assemble_candidates_from_trait_def<'cx, 'tcx>(
ProjectionCandidate::TraitDef,
bounds.iter(),
true,
);
)
}

/// In the case of a trait object like
Expand Down Expand Up @@ -1308,35 +1308,28 @@ fn assemble_candidates_from_predicates<'cx, 'tcx>(
let bound_predicate = predicate.kind();
if let ty::PredicateKind::Projection(data) = predicate.kind().skip_binder() {
let data = bound_predicate.rebind(data);
if data.projection_def_id() != obligation.predicate.item_def_id {
continue;
}
let same_def_id = data.projection_def_id() == obligation.predicate.item_def_id;

let is_match = infcx.probe(|_| {
selcx.match_projection_projections(
obligation,
data,
potentially_unnormalized_candidates,
)
});
let is_match = same_def_id
&& infcx.probe(|_| {
selcx.match_projection_projections(
obligation,
data,
potentially_unnormalized_candidates,
)
});

match is_match {
ProjectionMatchesProjection::Yes => {
candidate_set.push_candidate(ctor(data));

if potentially_unnormalized_candidates
&& !obligation.predicate.has_infer_types_or_consts()
{
// HACK: Pick the first trait def candidate for a fully
// inferred predicate. This is to allow duplicates that
// differ only in normalization.
return;
}
}
ProjectionMatchesProjection::Ambiguous => {
candidate_set.mark_ambiguous();
if is_match {
candidate_set.push_candidate(ctor(data));

if potentially_unnormalized_candidates
&& !obligation.predicate.has_infer_types_or_consts()
{
// HACK: Pick the first trait def candidate for a fully
// inferred predicate. This is to allow duplicates that
// differ only in normalization.
return;
}
ProjectionMatchesProjection::No => {}
}
}
}
Expand Down Expand Up @@ -1583,7 +1576,7 @@ fn confirm_candidate<'cx, 'tcx>(
selcx: &mut SelectionContext<'cx, 'tcx>,
obligation: &ProjectionTyObligation<'tcx>,
candidate: ProjectionCandidate<'tcx>,
) -> Progress<'tcx> {
) -> Projected<'tcx> {
debug!(?obligation, ?candidate, "confirm_candidate");
let mut progress = match candidate {
ProjectionCandidate::ParamEnv(poly_projection)
Expand All @@ -1596,7 +1589,7 @@ fn confirm_candidate<'cx, 'tcx>(
}

ProjectionCandidate::Select(impl_source) => {
confirm_select_candidate(selcx, obligation, impl_source)
Projected::Progress(confirm_select_candidate(selcx, obligation, impl_source))
}
};

Expand All @@ -1605,9 +1598,11 @@ fn confirm_candidate<'cx, 'tcx>(
// with new region variables, we need to resolve them to existing variables
// when possible for this to work. See `auto-trait-projection-recursion.rs`
// for a case where this matters.
if progress.term.has_infer_regions() {
progress.term =
progress.term.fold_with(&mut OpportunisticRegionResolver::new(selcx.infcx()));
if let Projected::Progress(progress) = &mut progress {
if progress.term.has_infer_regions() {
progress.term =
progress.term.fold_with(&mut OpportunisticRegionResolver::new(selcx.infcx()));
}
}
progress
}
Expand Down Expand Up @@ -1688,9 +1683,12 @@ fn confirm_generator_candidate<'cx, 'tcx>(
}
});

confirm_param_env_candidate(selcx, obligation, predicate, false)
.with_addl_obligations(impl_source.nested)
.with_addl_obligations(obligations)
let progress = confirm_param_env_candidate(selcx, obligation, predicate, false);
let progress = match progress {
Projected::Progress(progress) => progress,
Projected::NoProgress(_) => bug!(),
};
progress.with_addl_obligations(impl_source.nested).with_addl_obligations(obligations)
}

fn confirm_discriminant_kind_candidate<'cx, 'tcx>(
Expand All @@ -1715,7 +1713,12 @@ fn confirm_discriminant_kind_candidate<'cx, 'tcx>(

// We get here from `poly_project_and_unify_type` which replaces bound vars
// with placeholders, so dummy is okay here.
confirm_param_env_candidate(selcx, obligation, ty::Binder::dummy(predicate), false)
let progress =
confirm_param_env_candidate(selcx, obligation, ty::Binder::dummy(predicate), false);
match progress {
Projected::Progress(progress) => progress,
Projected::NoProgress(_) => bug!(),
}
}

fn confirm_pointee_candidate<'cx, 'tcx>(
Expand Down Expand Up @@ -1746,8 +1749,12 @@ fn confirm_pointee_candidate<'cx, 'tcx>(
term: metadata_ty.into(),
};

confirm_param_env_candidate(selcx, obligation, ty::Binder::dummy(predicate), false)
.with_addl_obligations(obligations)
let progress =
confirm_param_env_candidate(selcx, obligation, ty::Binder::dummy(predicate), false);
match progress {
Projected::Progress(progress) => progress.with_addl_obligations(obligations),
Projected::NoProgress(_) => bug!(),
}
}

fn confirm_fn_pointer_candidate<'cx, 'tcx>(
Expand Down Expand Up @@ -1819,15 +1826,19 @@ fn confirm_callable_candidate<'cx, 'tcx>(
term: ret_type.into(),
});

confirm_param_env_candidate(selcx, obligation, predicate, true)
let progress = confirm_param_env_candidate(selcx, obligation, predicate, true);
match progress {
Projected::Progress(progress) => progress,
Projected::NoProgress(_) => bug!(),
}
}

fn confirm_param_env_candidate<'cx, 'tcx>(
selcx: &mut SelectionContext<'cx, 'tcx>,
obligation: &ProjectionTyObligation<'tcx>,
poly_cache_entry: ty::PolyProjectionPredicate<'tcx>,
potentially_unnormalized_candidate: bool,
) -> Progress<'tcx> {
) -> Projected<'tcx> {
let infcx = selcx.infcx();
let cause = &obligation.cause;
let param_env = obligation.param_env;
Expand Down Expand Up @@ -1868,23 +1879,42 @@ fn confirm_param_env_candidate<'cx, 'tcx>(

debug!(?cache_projection, ?obligation_projection);

match infcx.at(cause, param_env).eq(cache_projection, obligation_projection) {
Ok(InferOk { value: _, obligations }) => {
nested_obligations.extend(obligations);
assoc_ty_own_obligations(selcx, obligation, &mut nested_obligations);
// FIXME(associated_const_equality): Handle consts here as well? Maybe this progress type should just take
// a term instead.
Progress { term: cache_entry.term, obligations: nested_obligations }
}
Err(e) => {
let msg = format!(
"Failed to unify obligation `{:?}` with poly_projection `{:?}`: {:?}",
obligation, poly_cache_entry, e,
);
debug!("confirm_param_env_candidate: {}", msg);
let err = infcx.tcx.ty_error_with_message(obligation.cause.span, &msg);
Progress { term: err.into(), obligations: vec![] }
match infcx.commit_if_ok(|snapshot| {
let progress = match infcx.at(cause, param_env).eq(cache_projection, obligation_projection)
{
Ok(InferOk { value: _, obligations }) => {
nested_obligations.extend(obligations);
assoc_ty_own_obligations(selcx, obligation, &mut nested_obligations);
// FIXME(associated_const_equality): Handle consts here as well? Maybe this progress type should just take
// a term instead.
Progress { term: cache_entry.term, obligations: nested_obligations }
}
Err(e) => {
let msg = format!(
"Failed to unify obligation `{:?}` with poly_projection `{:?}`: {:?}",
obligation, poly_cache_entry, e,
);
debug!("confirm_param_env_candidate: {}", msg);
let err = infcx.tcx.ty_error_with_message(obligation.cause.span, &msg);
Progress { term: err.into(), obligations: vec![] }
}
};

let any_instantiations = infcx.any_instantiations(&snapshot);

if any_instantiations && !selcx.normalization_mode.allow_infer_constraint_during_projection
Comment on lines +1903 to +1905
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that I think about it, this could probably be done during candidate assembly.

{
Err(ty::Term::Ty(
infcx
.tcx
.mk_projection(obligation_projection.item_def_id, obligation_projection.substs),
))
} else {
Ok(progress)
}
}) {
Ok(p) => Projected::Progress(p),
Err(p) => Projected::NoProgress(p),
}
}

Expand Down
Loading