Skip to content

Commit

Permalink
pattern lowering: make sure we never call user-defined PartialEq inst…
Browse files Browse the repository at this point in the history
…ances
  • Loading branch information
RalfJung committed Jul 13, 2024
1 parent cf94bfb commit f1df521
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 37 deletions.
7 changes: 2 additions & 5 deletions compiler/rustc_middle/src/thir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -783,16 +783,13 @@ pub enum PatKind<'tcx> {
},

/// One of the following:
/// * `&str` (represented as a valtree), which will be handled as a string pattern and thus
/// * `&str`/`&[u8]` (represented as a valtree), which will be handled as a string pattern and thus
/// exhaustiveness checking will detect if you use the same string twice in different
/// patterns.
/// * integer, bool, char or float (represented as a valtree), which will be handled by
/// exhaustiveness to cover exactly its own value, similar to `&str`, but these values are
/// much simpler.
/// * Opaque constants (represented as `mir::ConstValue`), that must not be matched
/// structurally. So anything that does not derive `PartialEq` and `Eq`.
///
/// These are always compared with the matched place using (the semantics of) `PartialEq`.
/// * `String`, if `string_deref_patterns` is enabled.
Constant {
value: mir::Const<'tcx>,
},
Expand Down
45 changes: 17 additions & 28 deletions compiler/rustc_mir_build/src/build/matches/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
&& tcx.is_lang_item(def.did(), LangItem::String)
{
if !tcx.features().string_deref_patterns {
bug!(
span_bug!(
test.span,
"matching on `String` went through without enabling string_deref_patterns"
);
}
Expand Down Expand Up @@ -432,40 +433,28 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
}
}

match *ty.kind() {
ty::Ref(_, deref_ty, _) => ty = deref_ty,
_ => {
// non_scalar_compare called on non-reference type
let temp = self.temp(ty, source_info.span);
self.cfg.push_assign(block, source_info, temp, Rvalue::Use(expect));
let ref_ty = Ty::new_imm_ref(self.tcx, self.tcx.lifetimes.re_erased, ty);
let ref_temp = self.temp(ref_ty, source_info.span);

self.cfg.push_assign(
block,
source_info,
ref_temp,
Rvalue::Ref(self.tcx.lifetimes.re_erased, BorrowKind::Shared, temp),
);
expect = Operand::Move(ref_temp);

let ref_temp = self.temp(ref_ty, source_info.span);
self.cfg.push_assign(
block,
source_info,
ref_temp,
Rvalue::Ref(self.tcx.lifetimes.re_erased, BorrowKind::Shared, val),
);
val = ref_temp;
// Figure out the type on which we are calling `PartialEq`. This involves an extra wrapping
// reference: we can only compare two `&T`, and then compare_ty will be `T`.
// Make sure that we do *not* call any user-defined code here.
// The only types that can end up here are string and byte literals,
// which have their comparison defined in `core`.
// (Interestingly this means that exhaustiveness analysis relies, for soundness,
// on the `PartialEq` impls for `str` and `[u8]` to b correct!)
let compare_ty = match *ty.kind() {
ty::Ref(_, deref_ty, _)
if deref_ty == self.tcx.types.str_ || deref_ty != self.tcx.types.u8 =>
{
deref_ty
}
}
_ => span_bug!(source_info.span, "invalid type for non-scalar compare: {}", ty),
};

let eq_def_id = self.tcx.require_lang_item(LangItem::PartialEq, Some(source_info.span));
let method = trait_method(
self.tcx,
eq_def_id,
sym::eq,
self.tcx.with_opt_host_effect_param(self.def_id, eq_def_id, [ty, ty]),
self.tcx.with_opt_host_effect_param(self.def_id, eq_def_id, [compare_ty, compare_ty]),
);

let bool_ty = self.tcx.types.bool;
Expand Down
18 changes: 14 additions & 4 deletions compiler/rustc_pattern_analysis/src/rustc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,12 @@ impl<'p, 'tcx: 'p> RustcPatCtxt<'p, 'tcx> {
// This is a box pattern.
ty::Adt(adt, ..) if adt.is_box() => Struct,
ty::Ref(..) => Ref,
_ => bug!("pattern has unexpected type: pat: {:?}, ty: {:?}", pat, ty),
_ => span_bug!(
pat.span,
"pattern has unexpected type: pat: {:?}, ty: {:?}",
pat.kind,
ty.inner()
),
};
}
PatKind::DerefPattern { .. } => {
Expand Down Expand Up @@ -518,7 +523,12 @@ impl<'p, 'tcx: 'p> RustcPatCtxt<'p, 'tcx> {
.map(|ipat| self.lower_pat(&ipat.pattern).at_index(ipat.field.index()))
.collect();
}
_ => bug!("pattern has unexpected type: pat: {:?}, ty: {:?}", pat, ty),
_ => span_bug!(
pat.span,
"pattern has unexpected type: pat: {:?}, ty: {}",
pat.kind,
ty.inner()
),
}
}
PatKind::Constant { value } => {
Expand Down Expand Up @@ -663,7 +673,7 @@ impl<'p, 'tcx: 'p> RustcPatCtxt<'p, 'tcx> {
}
}
}
_ => bug!("invalid type for range pattern: {}", ty.inner()),
_ => span_bug!(pat.span, "invalid type for range pattern: {}", ty.inner()),
};
fields = vec![];
arity = 0;
Expand All @@ -674,7 +684,7 @@ impl<'p, 'tcx: 'p> RustcPatCtxt<'p, 'tcx> {
Some(length.eval_target_usize(cx.tcx, cx.param_env) as usize)
}
ty::Slice(_) => None,
_ => span_bug!(pat.span, "bad ty {:?} for slice pattern", ty),
_ => span_bug!(pat.span, "bad ty {} for slice pattern", ty.inner()),
};
let kind = if slice.is_some() {
SliceKind::VarLen(prefix.len(), suffix.len())
Expand Down

0 comments on commit f1df521

Please sign in to comment.