Skip to content

Commit

Permalink
An ugly, awful hack that works
Browse files Browse the repository at this point in the history
  • Loading branch information
jackh726 committed Mar 12, 2022
1 parent 335ffbf commit 717ed47
Show file tree
Hide file tree
Showing 19 changed files with 389 additions and 67 deletions.
10 changes: 7 additions & 3 deletions compiler/rustc_borrowck/src/type_check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2101,8 +2101,9 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {

let ty_fn_ptr_from = tcx.mk_fn_ptr(fn_sig);

let ty = self.normalize(*ty, location);
if let Err(terr) = self.eq_types(
*ty,
ty,
ty_fn_ptr_from,
location.to_locations(),
ConstraintCategory::Cast,
Expand All @@ -2124,9 +2125,11 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
_ => bug!(),
};
let ty_fn_ptr_from = tcx.mk_fn_ptr(tcx.signature_unclosure(sig, *unsafety));
let ty_fn_ptr_from = self.normalize(ty_fn_ptr_from, location);

let ty = self.normalize(*ty, location);
if let Err(terr) = self.eq_types(
*ty,
ty,
ty_fn_ptr_from,
location.to_locations(),
ConstraintCategory::Cast,
Expand Down Expand Up @@ -2154,8 +2157,9 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {

let ty_fn_ptr_from = tcx.safe_to_unsafe_fn_ty(fn_sig);

let ty = self.normalize(*ty, location);
if let Err(terr) = self.eq_types(
*ty,
ty,
ty_fn_ptr_from,
location.to_locations(),
ConstraintCategory::Cast,
Expand Down
4 changes: 4 additions & 0 deletions compiler/rustc_infer/src/infer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -920,6 +920,10 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
.region_constraints_added_in_snapshot(&snapshot.undo_snapshot)
}

pub fn any_instantiations(&self, snapshot: &CombinedSnapshot<'a, 'tcx>) -> bool {
self.inner.borrow_mut().any_instantiations(&snapshot.undo_snapshot)
}

pub fn add_given(&self, sub: ty::Region<'tcx>, sup: ty::RegionVid) {
self.inner.borrow_mut().unwrap_region_constraints().add_given(sub, sup);
}
Expand Down
31 changes: 31 additions & 0 deletions compiler/rustc_infer/src/infer/undo_log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,23 @@ pub(crate) enum UndoLog<'tcx> {
PushRegionObligation,
}

impl<'tcx> std::fmt::Debug for UndoLog<'tcx> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::TypeVariables(_) => f.debug_tuple("TypeVariables").finish(),
Self::ConstUnificationTable(_) => f.debug_tuple("ConstUnificationTable").finish(),
Self::IntUnificationTable(_) => f.debug_tuple("IntUnificationTable").finish(),
Self::FloatUnificationTable(_) => f.debug_tuple("FloatUnificationTable").finish(),
Self::RegionConstraintCollector(_) => {
f.debug_tuple("RegionConstraintCollector").finish()
}
Self::RegionUnificationTable(_) => f.debug_tuple("RegionUnificationTable").finish(),
Self::ProjectionCache(_) => f.debug_tuple("ProjectionCache").finish(),
Self::PushRegionObligation => write!(f, "PushRegionObligation"),
}
}
}

macro_rules! impl_from {
($($ctor: ident ($ty: ty),)*) => {
$(
Expand Down Expand Up @@ -165,6 +182,10 @@ impl<'tcx> InferCtxtInner<'tcx> {

self.undo_log.num_open_snapshots -= 1;
}

pub fn any_instantiations(&mut self, snapshot: &Snapshot<'tcx>) -> bool {
self.undo_log.instantiations_in_snapshot(snapshot).next().is_some()
}
}

impl<'tcx> InferCtxtUndoLogs<'tcx> {
Expand All @@ -173,6 +194,16 @@ impl<'tcx> InferCtxtUndoLogs<'tcx> {
Snapshot { undo_len: self.logs.len(), _marker: PhantomData }
}

pub(crate) fn instantiations_in_snapshot(
&self,
s: &Snapshot<'tcx>,
) -> impl Iterator<Item = &'_ type_variable::UndoLog<'tcx>> + Clone {
self.logs[s.undo_len..].iter().filter_map(|log| match log {
UndoLog::TypeVariables(log) => Some(log),
_ => None,
})
}

pub(crate) fn region_constraints_in_snapshot(
&self,
s: &Snapshot<'tcx>,
Expand Down
5 changes: 4 additions & 1 deletion compiler/rustc_trait_selection/src/traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@ pub use self::object_safety::is_vtable_safe_method;
pub use self::object_safety::MethodViolationCode;
pub use self::object_safety::ObjectSafetyViolation;
pub use self::on_unimplemented::{OnUnimplementedDirective, OnUnimplementedNote};
pub use self::project::{normalize, normalize_projection_type, normalize_to};
pub use self::project::{
normalize, normalize_projection_type, normalize_to, normalize_with_depth_to,
project_and_unify_type,
};
pub use self::select::{EvaluationCache, SelectionCache, SelectionContext};
pub use self::select::{EvaluationResult, IntercrateAmbiguityCause, OverflowError};
pub use self::specialize::specialization_graph::FutureCompatOverlapError;
Expand Down
106 changes: 72 additions & 34 deletions compiler/rustc_trait_selection/src/traits/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ pub type ProjectionObligation<'tcx> = Obligation<'tcx, ty::ProjectionPredicate<'

pub type ProjectionTyObligation<'tcx> = Obligation<'tcx, ty::ProjectionTy<'tcx>>;

pub(super) struct InProgress;
pub struct InProgress;

/// When attempting to resolve `<T as TraitRef>::Name` ...
#[derive(Debug)]
Expand Down Expand Up @@ -188,7 +188,7 @@ pub(super) fn poly_project_and_unify_type<'cx, 'tcx>(
/// If successful, this may result in additional obligations.
///
/// See [poly_project_and_unify_type] for an explanation of the return value.
fn project_and_unify_type<'cx, 'tcx>(
pub fn project_and_unify_type<'cx, 'tcx>(
selcx: &mut SelectionContext<'cx, 'tcx>,
obligation: &ProjectionObligation<'tcx>,
) -> Result<
Expand Down Expand Up @@ -484,7 +484,7 @@ impl<'a, 'b, 'tcx> TypeFolder<'tcx> for AssocTypeNormalizer<'a, 'b, 'tcx> {
// there won't be bound vars there.

let data = data.super_fold_with(self);
let normalized_ty = if self.eager_inference_replacement {
let normalized_ty = if self.selcx.normalization_mode.eager_inference_replacement {
normalize_projection_type(
self.selcx,
self.param_env,
Expand Down Expand Up @@ -914,7 +914,8 @@ fn opt_normalize_projection_type<'a, 'b, 'tcx>(
// Don't use the projection cache in intercrate mode -
// the `infcx` may be re-used between intercrate in non-intercrate
// mode, which could lead to using incorrect cache results.
let use_cache = !selcx.is_intercrate();
let use_cache =
!selcx.is_intercrate() && selcx.normalization_mode.allow_infer_constraint_during_projection;

let projection_ty = infcx.resolve_vars_if_possible(projection_ty);
let cache_key = ProjectionCacheKey::new(projection_ty);
Expand Down Expand Up @@ -1168,7 +1169,7 @@ fn project<'cx, 'tcx>(

match candidates {
ProjectionCandidateSet::Single(candidate) => {
Ok(Projected::Progress(confirm_candidate(selcx, obligation, candidate)))
Ok(confirm_candidate(selcx, obligation, candidate))
}
ProjectionCandidateSet::None => Ok(Projected::NoProgress(
// FIXME(associated_const_generics): this may need to change in the future?
Expand Down Expand Up @@ -1583,7 +1584,7 @@ fn confirm_candidate<'cx, 'tcx>(
selcx: &mut SelectionContext<'cx, 'tcx>,
obligation: &ProjectionTyObligation<'tcx>,
candidate: ProjectionCandidate<'tcx>,
) -> Progress<'tcx> {
) -> Projected<'tcx> {
debug!(?obligation, ?candidate, "confirm_candidate");
let mut progress = match candidate {
ProjectionCandidate::ParamEnv(poly_projection)
Expand All @@ -1596,7 +1597,7 @@ fn confirm_candidate<'cx, 'tcx>(
}

ProjectionCandidate::Select(impl_source) => {
confirm_select_candidate(selcx, obligation, impl_source)
Projected::Progress(confirm_select_candidate(selcx, obligation, impl_source))
}
};

Expand All @@ -1605,9 +1606,11 @@ fn confirm_candidate<'cx, 'tcx>(
// with new region variables, we need to resolve them to existing variables
// when possible for this to work. See `auto-trait-projection-recursion.rs`
// for a case where this matters.
if progress.term.has_infer_regions() {
progress.term =
progress.term.fold_with(&mut OpportunisticRegionResolver::new(selcx.infcx()));
if let Projected::Progress(progress) = &mut progress {
if progress.term.has_infer_regions() {
progress.term =
progress.term.fold_with(&mut OpportunisticRegionResolver::new(selcx.infcx()));
}
}
progress
}
Expand Down Expand Up @@ -1688,9 +1691,12 @@ fn confirm_generator_candidate<'cx, 'tcx>(
}
});

confirm_param_env_candidate(selcx, obligation, predicate, false)
.with_addl_obligations(impl_source.nested)
.with_addl_obligations(obligations)
let progress = confirm_param_env_candidate(selcx, obligation, predicate, false);
let progress = match progress {
Projected::Progress(progress) => progress,
Projected::NoProgress(_) => bug!(),
};
progress.with_addl_obligations(impl_source.nested).with_addl_obligations(obligations)
}

fn confirm_discriminant_kind_candidate<'cx, 'tcx>(
Expand All @@ -1715,7 +1721,12 @@ fn confirm_discriminant_kind_candidate<'cx, 'tcx>(

// We get here from `poly_project_and_unify_type` which replaces bound vars
// with placeholders, so dummy is okay here.
confirm_param_env_candidate(selcx, obligation, ty::Binder::dummy(predicate), false)
let progress =
confirm_param_env_candidate(selcx, obligation, ty::Binder::dummy(predicate), false);
match progress {
Projected::Progress(progress) => progress,
Projected::NoProgress(_) => bug!(),
}
}

fn confirm_pointee_candidate<'cx, 'tcx>(
Expand Down Expand Up @@ -1746,8 +1757,12 @@ fn confirm_pointee_candidate<'cx, 'tcx>(
term: metadata_ty.into(),
};

confirm_param_env_candidate(selcx, obligation, ty::Binder::dummy(predicate), false)
.with_addl_obligations(obligations)
let progress =
confirm_param_env_candidate(selcx, obligation, ty::Binder::dummy(predicate), false);
match progress {
Projected::Progress(progress) => progress.with_addl_obligations(obligations),
Projected::NoProgress(_) => bug!(),
}
}

fn confirm_fn_pointer_candidate<'cx, 'tcx>(
Expand Down Expand Up @@ -1819,15 +1834,19 @@ fn confirm_callable_candidate<'cx, 'tcx>(
term: ret_type.into(),
});

confirm_param_env_candidate(selcx, obligation, predicate, true)
let progress = confirm_param_env_candidate(selcx, obligation, predicate, true);
match progress {
Projected::Progress(progress) => progress,
Projected::NoProgress(_) => bug!(),
}
}

fn confirm_param_env_candidate<'cx, 'tcx>(
selcx: &mut SelectionContext<'cx, 'tcx>,
obligation: &ProjectionTyObligation<'tcx>,
poly_cache_entry: ty::PolyProjectionPredicate<'tcx>,
potentially_unnormalized_candidate: bool,
) -> Progress<'tcx> {
) -> Projected<'tcx> {
let infcx = selcx.infcx();
let cause = &obligation.cause;
let param_env = obligation.param_env;
Expand Down Expand Up @@ -1868,23 +1887,42 @@ fn confirm_param_env_candidate<'cx, 'tcx>(

debug!(?cache_projection, ?obligation_projection);

match infcx.at(cause, param_env).eq(cache_projection, obligation_projection) {
Ok(InferOk { value: _, obligations }) => {
nested_obligations.extend(obligations);
assoc_ty_own_obligations(selcx, obligation, &mut nested_obligations);
// FIXME(associated_const_equality): Handle consts here as well? Maybe this progress type should just take
// a term instead.
Progress { term: cache_entry.term, obligations: nested_obligations }
}
Err(e) => {
let msg = format!(
"Failed to unify obligation `{:?}` with poly_projection `{:?}`: {:?}",
obligation, poly_cache_entry, e,
);
debug!("confirm_param_env_candidate: {}", msg);
let err = infcx.tcx.ty_error_with_message(obligation.cause.span, &msg);
Progress { term: err.into(), obligations: vec![] }
match infcx.commit_if_ok(|snapshot| {
let progress = match infcx.at(cause, param_env).eq(cache_projection, obligation_projection)
{
Ok(InferOk { value: _, obligations }) => {
nested_obligations.extend(obligations);
assoc_ty_own_obligations(selcx, obligation, &mut nested_obligations);
// FIXME(associated_const_equality): Handle consts here as well? Maybe this progress type should just take
// a term instead.
Progress { term: cache_entry.term, obligations: nested_obligations }
}
Err(e) => {
let msg = format!(
"Failed to unify obligation `{:?}` with poly_projection `{:?}`: {:?}",
obligation, poly_cache_entry, e,
);
debug!("confirm_param_env_candidate: {}", msg);
let err = infcx.tcx.ty_error_with_message(obligation.cause.span, &msg);
Progress { term: err.into(), obligations: vec![] }
}
};

let any_instantiations = infcx.any_instantiations(&snapshot);

if any_instantiations && !selcx.normalization_mode.allow_infer_constraint_during_projection
{
Err(ty::Term::Ty(
infcx
.tcx
.mk_projection(obligation_projection.item_def_id, obligation_projection.substs),
))
} else {
Ok(progress)
}
}) {
Ok(p) => Projected::Progress(p),
Err(p) => Projected::NoProgress(p),
}
}

Expand Down
31 changes: 31 additions & 0 deletions compiler/rustc_trait_selection/src/traits/select/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,17 @@ pub struct SelectionContext<'cx, 'tcx> {
/// policy. In essence, canonicalized queries need their errors propagated
/// rather than immediately reported because we do not have accurate spans.
query_mode: TraitQueryMode,

pub normalization_mode: NormalizationMode,
}

#[derive(Copy, Clone)]
pub struct NormalizationMode {
pub allow_infer_constraint_during_projection: bool,
/// If true, when a projection is unable to be completed, an inference
/// variable will be created and an obligation registered to project to that
/// inference variable. Also, constants will be eagerly evaluated.
pub eager_inference_replacement: bool,
}

// A stack that walks back up the stack frame.
Expand Down Expand Up @@ -221,6 +232,10 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
intercrate_ambiguity_causes: None,
allow_negative_impls: false,
query_mode: TraitQueryMode::Standard,
normalization_mode: NormalizationMode {
allow_infer_constraint_during_projection: true,
eager_inference_replacement: true,
},
}
}

Expand All @@ -232,6 +247,10 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
intercrate_ambiguity_causes: None,
allow_negative_impls: false,
query_mode: TraitQueryMode::Standard,
normalization_mode: NormalizationMode {
allow_infer_constraint_during_projection: true,
eager_inference_replacement: true,
},
}
}

Expand All @@ -247,6 +266,10 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
intercrate_ambiguity_causes: None,
allow_negative_impls,
query_mode: TraitQueryMode::Standard,
normalization_mode: NormalizationMode {
allow_infer_constraint_during_projection: true,
eager_inference_replacement: true,
},
}
}

Expand All @@ -262,6 +285,10 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
intercrate_ambiguity_causes: None,
allow_negative_impls: false,
query_mode,
normalization_mode: NormalizationMode {
allow_infer_constraint_during_projection: true,
eager_inference_replacement: true,
},
}
}

Expand Down Expand Up @@ -297,6 +324,10 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
self.intercrate
}

pub fn normalization_mode(&self) -> NormalizationMode {
self.normalization_mode
}

///////////////////////////////////////////////////////////////////////////
// Selection
//
Expand Down
Loading

0 comments on commit 717ed47

Please sign in to comment.