Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Use all projection sub-obligations during trait evaluation #86896

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions compiler/rustc_infer/src/traits/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ pub enum ProjectionCacheEntry<'tcx> {
Ambiguous,
Recur,
Error,
NormalizedTy(NormalizedTy<'tcx>),
NormalizedTy { value: NormalizedTy<'tcx>, full_obligations: Vec<PredicateObligation<'tcx>> },
}

impl<'tcx> ProjectionCacheStorage<'tcx> {
Expand Down Expand Up @@ -139,7 +139,12 @@ impl<'tcx> ProjectionCache<'_, 'tcx> {
}

/// Indicates that `key` was normalized to `value`.
pub fn insert_ty(&mut self, key: ProjectionCacheKey<'tcx>, value: NormalizedTy<'tcx>) {
pub fn insert_ty(
&mut self,
key: ProjectionCacheKey<'tcx>,
value: NormalizedTy<'tcx>,
full_obligations: Vec<PredicateObligation<'tcx>>,
) {
debug!(
"ProjectionCacheEntry::insert_ty: adding cache entry: key={:?}, value={:?}",
key, value
Expand All @@ -149,7 +154,8 @@ impl<'tcx> ProjectionCache<'_, 'tcx> {
debug!("Not overwriting Recur");
return;
}
let fresh_key = map.insert(key, ProjectionCacheEntry::NormalizedTy(value));
let fresh_key =
map.insert(key, ProjectionCacheEntry::NormalizedTy { value, full_obligations });
assert!(!fresh_key, "never started projecting `{:?}`", key);
}

Expand All @@ -160,7 +166,7 @@ impl<'tcx> ProjectionCache<'_, 'tcx> {
pub fn complete(&mut self, key: ProjectionCacheKey<'tcx>) {
let mut map = self.map();
let ty = match map.get(&key) {
Some(&ProjectionCacheEntry::NormalizedTy(ref ty)) => {
Some(&ProjectionCacheEntry::NormalizedTy { value: ref ty, .. }) => {
debug!("ProjectionCacheEntry::complete({:?}) - completing {:?}", key, ty);
ty.value
}
Expand All @@ -174,7 +180,10 @@ impl<'tcx> ProjectionCache<'_, 'tcx> {

map.insert(
key,
ProjectionCacheEntry::NormalizedTy(Normalized { value: ty, obligations: vec![] }),
ProjectionCacheEntry::NormalizedTy {
value: Normalized { value: ty, obligations: vec![] },
full_obligations: vec![],
},
);
}

Expand All @@ -186,10 +195,10 @@ impl<'tcx> ProjectionCache<'_, 'tcx> {
if !ty.obligations.is_empty() {
self.map().insert(
key,
ProjectionCacheEntry::NormalizedTy(Normalized {
value: ty.value,
obligations: vec![],
}),
ProjectionCacheEntry::NormalizedTy {
value: Normalized { value: ty.value, obligations: vec![] },
full_obligations: vec![],
},
);
}
}
Expand Down
20 changes: 16 additions & 4 deletions compiler/rustc_trait_selection/src/traits/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ fn opt_normalize_projection_type<'a, 'b, 'tcx>(
Err(ProjectionCacheEntry::Recur) => {
return Err(InProgress);
}
Err(ProjectionCacheEntry::NormalizedTy(ty)) => {
Err(ProjectionCacheEntry::NormalizedTy { value: ty, full_obligations }) => {
// This is the hottest path in this function.
//
// If we find the value in the cache, then return it along
Expand All @@ -529,7 +529,11 @@ fn opt_normalize_projection_type<'a, 'b, 'tcx>(
// evaluation this is not the case, and dropping the trait
// evaluations can causes ICEs (e.g., #43132).
debug!(?ty, "found normalized ty");
obligations.extend(ty.obligations);
if selcx.skip_projection_cache() {
obligations.extend(full_obligations);
} else {
obligations.extend(ty.obligations);
}
return Ok(Some(ty.value));
}
Err(ProjectionCacheEntry::Error) => {
Expand Down Expand Up @@ -571,14 +575,22 @@ fn opt_normalize_projection_type<'a, 'b, 'tcx>(
};

let cache_value = prune_cache_value_obligations(infcx, &result);
infcx.inner.borrow_mut().projection_cache().insert_ty(cache_key, cache_value);
infcx.inner.borrow_mut().projection_cache().insert_ty(
cache_key,
cache_value,
result.obligations.clone(),
);
obligations.extend(result.obligations);
Ok(Some(result.value))
}
Ok(ProjectedTy::NoProgress(projected_ty)) => {
debug!(?projected_ty, "opt_normalize_projection_type: no progress");
let result = Normalized { value: projected_ty, obligations: vec![] };
infcx.inner.borrow_mut().projection_cache().insert_ty(cache_key, result.clone());
infcx.inner.borrow_mut().projection_cache().insert_ty(
cache_key,
result.clone(),
vec![],
);
// No need to extend `obligations`.
Ok(Some(result.value))
}
Expand Down
16 changes: 14 additions & 2 deletions compiler/rustc_trait_selection/src/traits/select/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ pub struct SelectionContext<'cx, 'tcx> {
/// policy. In essence, canonicalized queries need their errors propagated
/// rather than immediately reported because we do not have accurate spans.
query_mode: TraitQueryMode,
skip_projection_cache: bool,
}

// A stack that walks back up the stack frame.
Expand Down Expand Up @@ -221,6 +222,7 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
intercrate_ambiguity_causes: None,
allow_negative_impls: false,
query_mode: TraitQueryMode::Standard,
skip_projection_cache: false,
}
}

Expand All @@ -232,6 +234,7 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
intercrate_ambiguity_causes: None,
allow_negative_impls: false,
query_mode: TraitQueryMode::Standard,
skip_projection_cache: false,
}
}

Expand All @@ -247,6 +250,7 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
intercrate_ambiguity_causes: None,
allow_negative_impls,
query_mode: TraitQueryMode::Standard,
skip_projection_cache: false,
}
}

Expand All @@ -262,9 +266,14 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
intercrate_ambiguity_causes: None,
allow_negative_impls: false,
query_mode,
skip_projection_cache: false,
}
}

pub fn skip_projection_cache(&self) -> bool {
self.skip_projection_cache
}

/// Enables tracking of intercrate ambiguity causes. These are
/// used in coherence to give improved diagnostics. We don't do
/// this until we detect a coherence error because it can lead to
Expand Down Expand Up @@ -379,12 +388,15 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
&mut self,
obligation: &PredicateObligation<'tcx>,
) -> Result<EvaluationResult, OverflowError> {
self.evaluation_probe(|this| {
let old = std::mem::replace(&mut self.skip_projection_cache, true);
let res = self.evaluation_probe(|this| {
this.evaluate_predicate_recursively(
TraitObligationStackList::empty(&ProvisionalEvaluationCache::default()),
obligation.clone(),
)
})
});
self.skip_projection_cache = old;
res
}

fn evaluation_probe(
Expand Down