Skip to content

Commit

Permalink
make infcx optional in canonicalizer
Browse files Browse the repository at this point in the history
This doesn't change behavior.
It should prevent unintentional resolution of inference variables
during canonicalization, which previously caused a soundness bug.
See PR description for more.
  • Loading branch information
aliemjay committed Dec 14, 2023
1 parent fafe66d commit 3b55869
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 47 deletions.
80 changes: 35 additions & 45 deletions compiler/rustc_infer/src/infer/canonical/canonicalizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ impl<'tcx> InferCtxt<'tcx> {
{
Canonicalizer::canonicalize(
value,
self,
Some(self),
self.tcx,
&CanonicalizeAllFreeRegionsPreservingUniverses,
query_state,
Expand Down Expand Up @@ -99,7 +99,7 @@ impl<'tcx> InferCtxt<'tcx> {
let mut query_state = OriginalQueryValues::default();
Canonicalizer::canonicalize(
value,
self,
Some(self),
self.tcx,
&CanonicalizeQueryResponse,
&mut query_state,
Expand All @@ -113,7 +113,7 @@ impl<'tcx> InferCtxt<'tcx> {
let mut query_state = OriginalQueryValues::default();
Canonicalizer::canonicalize(
value,
self,
Some(self),
self.tcx,
&CanonicalizeUserTypeAnnotation,
&mut query_state,
Expand Down Expand Up @@ -153,11 +153,11 @@ impl<'tcx> InferCtxt<'tcx> {
self.tcx,
param_env,
query_state,
|query_state| {
|tcx, param_env, query_state| {
Canonicalizer::canonicalize(
param_env,
self,
self.tcx,
None,
tcx,
&CanonicalizeFreeRegionsOtherThanStatic,
query_state,
)
Expand All @@ -167,7 +167,7 @@ impl<'tcx> InferCtxt<'tcx> {
Canonicalizer::canonicalize_with_base(
base,
value,
self,
Some(self),
self.tcx,
canonicalize_region_mode,
query_state,
Expand Down Expand Up @@ -204,9 +204,10 @@ impl CanonicalizeMode for CanonicalizeQueryResponse {
canonicalizer: &mut Canonicalizer<'_, 'tcx>,
mut r: ty::Region<'tcx>,
) -> ty::Region<'tcx> {
let infcx = canonicalizer.infcx.unwrap();

if let ty::ReVar(vid) = *r {
r = canonicalizer
.infcx
r = infcx
.inner
.borrow_mut()
.unwrap_region_constraints()
Expand All @@ -226,7 +227,8 @@ impl CanonicalizeMode for CanonicalizeQueryResponse {
),

ty::ReVar(vid) => {
let universe = canonicalizer.region_var_universe(vid);
let universe =
infcx.inner.borrow_mut().unwrap_region_constraints().var_universe(vid);
canonicalizer.canonical_var_for_region(
CanonicalVarInfo { kind: CanonicalVarKind::Region(universe) },
r,
Expand Down Expand Up @@ -319,7 +321,7 @@ impl CanonicalizeMode for CanonicalizeAllFreeRegionsPreservingUniverses {
canonicalizer: &mut Canonicalizer<'_, 'tcx>,
r: ty::Region<'tcx>,
) -> ty::Region<'tcx> {
let universe = canonicalizer.infcx.universe_of_region(r);
let universe = canonicalizer.infcx.unwrap().universe_of_region(r);
canonicalizer.canonical_var_for_region(
CanonicalVarInfo { kind: CanonicalVarKind::Region(universe) },
r,
Expand Down Expand Up @@ -356,7 +358,8 @@ impl CanonicalizeMode for CanonicalizeFreeRegionsOtherThanStatic {
}

struct Canonicalizer<'cx, 'tcx> {
infcx: &'cx InferCtxt<'tcx>,
/// Set to `None` to disable the resolution of inference variables.
infcx: Option<&'cx InferCtxt<'tcx>>,
tcx: TyCtxt<'tcx>,
variables: SmallVec<[CanonicalVarInfo<'tcx>; 8]>,
query_state: &'cx mut OriginalQueryValues<'tcx>,
Expand Down Expand Up @@ -410,14 +413,14 @@ impl<'cx, 'tcx> TypeFolder<TyCtxt<'tcx>> for Canonicalizer<'cx, 'tcx> {
// We need to canonicalize the *root* of our ty var.
// This is so that our canonical response correctly reflects
// any equated inference vars correctly!
let root_vid = self.infcx.root_var(vid);
let root_vid = self.infcx.unwrap().root_var(vid);
if root_vid != vid {
t = Ty::new_var(self.infcx.tcx, root_vid);
t = Ty::new_var(self.tcx, root_vid);
vid = root_vid;
}

debug!("canonical: type var found with vid {:?}", vid);
match self.infcx.probe_ty_var(vid) {
match self.infcx.unwrap().probe_ty_var(vid) {
// `t` could be a float / int variable; canonicalize that instead.
Ok(t) => {
debug!("(resolved to {:?})", t);
Expand All @@ -442,7 +445,7 @@ impl<'cx, 'tcx> TypeFolder<TyCtxt<'tcx>> for Canonicalizer<'cx, 'tcx> {
}

ty::Infer(ty::IntVar(vid)) => {
let nt = self.infcx.opportunistic_resolve_int_var(vid);
let nt = self.infcx.unwrap().opportunistic_resolve_int_var(vid);
if nt != t {
return self.fold_ty(nt);
} else {
Expand All @@ -453,7 +456,7 @@ impl<'cx, 'tcx> TypeFolder<TyCtxt<'tcx>> for Canonicalizer<'cx, 'tcx> {
}
}
ty::Infer(ty::FloatVar(vid)) => {
let nt = self.infcx.opportunistic_resolve_float_var(vid);
let nt = self.infcx.unwrap().opportunistic_resolve_float_var(vid);
if nt != t {
return self.fold_ty(nt);
} else {
Expand Down Expand Up @@ -524,14 +527,14 @@ impl<'cx, 'tcx> TypeFolder<TyCtxt<'tcx>> for Canonicalizer<'cx, 'tcx> {
// We need to canonicalize the *root* of our const var.
// This is so that our canonical response correctly reflects
// any equated inference vars correctly!
let root_vid = self.infcx.root_const_var(vid);
let root_vid = self.infcx.unwrap().root_const_var(vid);
if root_vid != vid {
ct = ty::Const::new_var(self.infcx.tcx, root_vid, ct.ty());
ct = ty::Const::new_var(self.tcx, root_vid, ct.ty());
vid = root_vid;
}

debug!("canonical: const var found with vid {:?}", vid);
match self.infcx.probe_const_var(vid) {
match self.infcx.unwrap().probe_const_var(vid) {
Ok(c) => {
debug!("(resolved to {:?})", c);
return self.fold_const(c);
Expand All @@ -552,8 +555,8 @@ impl<'cx, 'tcx> TypeFolder<TyCtxt<'tcx>> for Canonicalizer<'cx, 'tcx> {
}
}
ty::ConstKind::Infer(InferConst::EffectVar(vid)) => {
match self.infcx.probe_effect_var(vid) {
Some(value) => return self.fold_const(value.as_const(self.infcx.tcx)),
match self.infcx.unwrap().probe_effect_var(vid) {
Some(value) => return self.fold_const(value.as_const(self.tcx)),
None => {
return self.canonicalize_const_var(
CanonicalVarInfo { kind: CanonicalVarKind::Effect },
Expand Down Expand Up @@ -596,7 +599,7 @@ impl<'cx, 'tcx> Canonicalizer<'cx, 'tcx> {
/// `canonicalize_query` and `canonicalize_response`.
fn canonicalize<V>(
value: V,
infcx: &InferCtxt<'tcx>,
infcx: Option<&InferCtxt<'tcx>>,
tcx: TyCtxt<'tcx>,
canonicalize_region_mode: &dyn CanonicalizeMode,
query_state: &mut OriginalQueryValues<'tcx>,
Expand All @@ -623,7 +626,7 @@ impl<'cx, 'tcx> Canonicalizer<'cx, 'tcx> {
fn canonicalize_with_base<U, V>(
base: Canonical<'tcx, U>,
value: V,
infcx: &InferCtxt<'tcx>,
infcx: Option<&InferCtxt<'tcx>>,
tcx: TyCtxt<'tcx>,
canonicalize_region_mode: &dyn CanonicalizeMode,
query_state: &mut OriginalQueryValues<'tcx>,
Expand Down Expand Up @@ -826,11 +829,6 @@ impl<'cx, 'tcx> Canonicalizer<'cx, 'tcx> {
)
}

/// Returns the universe in which `vid` is defined.
fn region_var_universe(&self, vid: ty::RegionVid) -> ty::UniverseIndex {
self.infcx.inner.borrow_mut().unwrap_region_constraints().var_universe(vid)
}

/// Creates a canonical variable (with the given `info`)
/// representing the region `r`; return a region referencing it.
fn canonical_var_for_region(
Expand All @@ -848,14 +846,9 @@ impl<'cx, 'tcx> Canonicalizer<'cx, 'tcx> {
/// *that*. Otherwise, create a new canonical variable for
/// `ty_var`.
fn canonicalize_ty_var(&mut self, info: CanonicalVarInfo<'tcx>, ty_var: Ty<'tcx>) -> Ty<'tcx> {
let infcx = self.infcx;
let bound_to = infcx.shallow_resolve(ty_var);
if bound_to != ty_var {
self.fold_ty(bound_to)
} else {
let var = self.canonical_var(info, ty_var.into());
Ty::new_bound(self.tcx, self.binder_index, var.into())
}
debug_assert!(!self.infcx.is_some_and(|infcx| ty_var != infcx.shallow_resolve(ty_var)));
let var = self.canonical_var(info, ty_var.into());
Ty::new_bound(self.tcx, self.binder_index, var.into())
}

/// Given a type variable `const_var` of the given kind, first check
Expand All @@ -867,13 +860,10 @@ impl<'cx, 'tcx> Canonicalizer<'cx, 'tcx> {
info: CanonicalVarInfo<'tcx>,
const_var: ty::Const<'tcx>,
) -> ty::Const<'tcx> {
let infcx = self.infcx;
let bound_to = infcx.shallow_resolve(const_var);
if bound_to != const_var {
self.fold_const(bound_to)
} else {
let var = self.canonical_var(info, const_var.into());
ty::Const::new_bound(self.tcx, self.binder_index, var, self.fold_ty(const_var.ty()))
}
debug_assert!(
!self.infcx.is_some_and(|infcx| const_var != infcx.shallow_resolve(const_var))
);
let var = self.canonical_var(info, const_var.into());
ty::Const::new_bound(self.tcx, self.binder_index, var, self.fold_ty(const_var.ty()))
}
}
12 changes: 10 additions & 2 deletions compiler/rustc_middle/src/infer/canonical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -306,12 +306,20 @@ pub struct CanonicalParamEnvCache<'tcx> {
}

impl<'tcx> CanonicalParamEnvCache<'tcx> {
/// Gets the cached canonical form of `key` or executes
/// `canonicalize_op` and caches the result if not present.
///
/// `canonicalize_op` is intentionally not allowed to be a closure to
/// statically prevent it from capturing `InferCtxt` and resolving
/// inference variables, which invalidates the cache.
pub fn get_or_insert(
&self,
tcx: TyCtxt<'tcx>,
key: ty::ParamEnv<'tcx>,
state: &mut OriginalQueryValues<'tcx>,
canonicalize_op: impl FnOnce(
canonicalize_op: fn(
TyCtxt<'tcx>,
ty::ParamEnv<'tcx>,
&mut OriginalQueryValues<'tcx>,
) -> Canonical<'tcx, ty::ParamEnv<'tcx>>,
) -> Canonical<'tcx, ty::ParamEnv<'tcx>> {
Expand All @@ -336,7 +344,7 @@ impl<'tcx> CanonicalParamEnvCache<'tcx> {
canonical.clone()
}
Entry::Vacant(e) => {
let canonical = canonicalize_op(state);
let canonical = canonicalize_op(tcx, key, state);
let OriginalQueryValues { var_values, universe_map } = state;
assert_eq!(universe_map.len(), 1);
e.insert((canonical.clone(), tcx.arena.alloc_slice(var_values)));
Expand Down

0 comments on commit 3b55869

Please sign in to comment.