Skip to content

Commit

Permalink
Merge pull request #887 from Nadrieril/fix-binders3
Browse files Browse the repository at this point in the history
Initial cleanup of binders-related code
  • Loading branch information
Nadrieril authored Sep 17, 2024
2 parents cc29a3f + ca85aa9 commit 2d49e8c
Show file tree
Hide file tree
Showing 9 changed files with 110 additions and 168 deletions.
2 changes: 1 addition & 1 deletion engine/lib/import_thir.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1073,7 +1073,7 @@ end) : EXPR = struct
List.fold ~init ~f:browse_path path
| Dyn -> Dyn
| SelfImpl { path; _ } -> List.fold ~init:Self ~f:browse_path path
| Builtin { trait } -> Builtin (c_trait_ref span trait)
| Builtin { trait } -> Builtin (c_trait_ref span trait.value)
| Todo str -> failwith @@ "impl_expr_atom: Todo " ^ str

and c_generic_value (span : Thir.span) (ty : Thir.generic_arg) : generic_value
Expand Down
7 changes: 1 addition & 6 deletions frontend/exporter/src/deterministic_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,9 @@
use core::hash::Hasher;

/// Wrapper around any hasher to make it deterministic.
#[derive(Default)]
pub struct DeterministicHasher<T: Hasher>(T);

impl<T: Hasher> DeterministicHasher<T> {
pub fn new(inner: T) -> Self {
Self(inner)
}
}

/// Implementation of hasher that forces all bytes written to be platform agnostic.
impl<T: Hasher> core::hash::Hasher for DeterministicHasher<T> {
fn finish(&self) -> u64 {
Expand Down
15 changes: 2 additions & 13 deletions frontend/exporter/src/rustc_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,7 @@ impl<'tcx, T: ty::TypeFoldable<ty::TyCtxt<'tcx>>> ty::Binder<'tcx, T> {
tcx: ty::TyCtxt<'tcx>,
generics: &[ty::GenericArg<'tcx>],
) -> ty::Binder<'tcx, T> {
self.rebind(ty::EarlyBinder::bind(self.clone().skip_binder()).instantiate(tcx, generics))
}
}

#[extension_traits::extension(pub trait PredicateToPolyTraitPredicate)]
impl<'tcx> ty::Binder<'tcx, ty::PredicateKind<'tcx>> {
fn as_poly_trait_predicate(self) -> Option<ty::PolyTraitPredicate<'tcx>> {
self.try_map_bound(|kind| match kind {
ty::PredicateKind::Clause(ty::ClauseKind::Trait(trait_pred)) => Ok(trait_pred),
_ => Err(()),
})
.ok()
ty::EarlyBinder::bind(self).instantiate(tcx, generics)
}
}

Expand Down Expand Up @@ -65,7 +54,7 @@ impl<'tcx> ty::TyCtxt<'tcx> {
// An opaque type (e.g. `impl Trait`) provides
// predicates by itself: we need to account for them.
self.explicit_item_bounds(did)
.skip_binder()
.skip_binder() // Skips an `EarlyBinder`, likely for GATs
.iter()
.copied()
.collect()
Expand Down
70 changes: 31 additions & 39 deletions frontend/exporter/src/traits.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
use crate::prelude::*;

#[derive(AdtInto)]
#[args(<'tcx, S: UnderOwnerState<'tcx> >, from: search_clause::PathChunk<'tcx>, state: S as tcx)]
#[args(<'tcx, S: UnderOwnerState<'tcx> >, from: search_clause::PathChunk<'tcx>, state: S as s)]
#[derive_group(Serializers)]
#[derive(Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, JsonSchema)]
pub enum ImplExprPathChunk {
AssocItem {
item: AssocItem,
predicate: Binder<TraitPredicate>,
#[value(predicate.predicate_id(tcx))]
#[value(<_ as SInto<_, Clause>>::sinto(predicate, s).id)]
predicate_id: PredicateId,
index: usize,
},
Parent {
predicate: Binder<TraitPredicate>,
#[value(predicate.predicate_id(tcx))]
#[value(<_ as SInto<_, Clause>>::sinto(predicate, s).id)]
predicate_id: PredicateId,
index: usize,
},
Expand Down Expand Up @@ -49,7 +49,7 @@ pub enum ImplExprAtom {
/// built-in implementation.
Dyn,
/// A built-in trait whose implementation is computed by the compiler, such as `Sync`.
Builtin { r#trait: TraitRef },
Builtin { r#trait: Binder<TraitRef> },
/// Anything else. Currently used for trait upcasting and trait aliases.
Todo(String),
}
Expand All @@ -62,7 +62,7 @@ pub enum ImplExprAtom {
#[derive(Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, JsonSchema)]
pub struct ImplExpr {
/// The trait this is an impl for.
pub r#trait: TraitRef,
pub r#trait: Binder<TraitRef>,
/// The kind of implemention of the root of the tree.
pub r#impl: ImplExprAtom,
/// A list of `ImplExpr`s required to fully specify the trait references in `impl`.
Expand All @@ -78,16 +78,6 @@ pub mod rustc {
use crate::rustc_utils::*;
use rustc_middle::ty::*;

fn predicates_to_poly_trait_predicates<'tcx>(
tcx: TyCtxt<'tcx>,
predicates: impl Iterator<Item = Predicate<'tcx>>,
generics: GenericArgsRef<'tcx>,
) -> impl Iterator<Item = PolyTraitPredicate<'tcx>> {
predicates
.map(move |pred| pred.kind().subst(tcx, generics))
.filter_map(|pred| pred.as_poly_trait_predicate())
}

#[derive(Clone, Debug)]
pub enum PathChunk<'tcx> {
AssocItem {
Expand Down Expand Up @@ -157,20 +147,28 @@ pub mod rustc {

#[extension_traits::extension(pub trait TraitPredicateExt)]
impl<'tcx, S: UnderOwnerState<'tcx>> PolyTraitPredicate<'tcx> {
fn predicates_to_poly_trait_predicates(
self,
s: &S,
predicates: impl Iterator<Item = Predicate<'tcx>>,
) -> impl Iterator<Item = PolyTraitPredicate<'tcx>> {
let tcx = s.base().tcx;
let generics = self.skip_binder().trait_ref.args;
predicates
.filter_map(|pred| pred.as_trait_clause())
.map(move |clause| clause.subst(tcx, generics))
}

#[tracing::instrument(level = "trace", skip(s))]
fn parents_trait_predicates(self, s: &S) -> Vec<(usize, PolyTraitPredicate<'tcx>)> {
let tcx = s.base().tcx;
let predicates = tcx
.predicates_defined_on_or_above(self.def_id())
.into_iter()
.map(|apred| apred.predicate);
predicates_to_poly_trait_predicates(
tcx,
predicates,
self.skip_binder().trait_ref.args,
)
.enumerate()
.collect()
self.predicates_to_poly_trait_predicates(s, predicates)
.enumerate()
.collect()
}
#[tracing::instrument(level = "trace", skip(s))]
fn associated_items_trait_predicates(
Expand All @@ -187,10 +185,9 @@ pub mod rustc {
.copied()
.map(|item| {
let bounds = tcx.item_bounds(item.def_id).map_bound(|clauses| {
predicates_to_poly_trait_predicates(
tcx,
self.predicates_to_poly_trait_predicates(
s,
clauses.into_iter().map(|clause| clause.as_predicate()),
self.skip_binder().trait_ref.args,
)
.enumerate()
.collect()
Expand Down Expand Up @@ -285,7 +282,7 @@ pub mod rustc {
}

impl ImplExprAtom {
fn with_args(self, args: Vec<ImplExpr>, r#trait: TraitRef) -> ImplExpr {
fn with_args(self, args: Vec<ImplExpr>, r#trait: Binder<TraitRef>) -> ImplExpr {
ImplExpr {
r#impl: self,
args,
Expand All @@ -304,15 +301,11 @@ pub mod rustc {
obligations
.into_iter()
.flat_map(|obligation| {
obligation
.predicate
.kind()
.as_poly_trait_predicate()
.map(|trait_ref| {
trait_ref
.map_bound(|p| p.trait_ref)
.impl_expr(s, obligation.param_env)
})
obligation.predicate.as_trait_clause().map(|trait_ref| {
trait_ref
.map_bound(|p| p.trait_ref)
.impl_expr(s, obligation.param_env)
})
})
.collect()
}
Expand Down Expand Up @@ -344,7 +337,6 @@ pub mod rustc {
) -> ImplExpr {
use rustc_trait_selection::traits::*;
let trait_ref: Binder<TraitRef> = self.sinto(s);
let trait_ref = trait_ref.value;
match select_trait_candidate(s, param_env, *self) {
ImplSource::UserDefined(ImplSourceUserDefinedData {
impl_def_id,
Expand Down Expand Up @@ -379,7 +371,7 @@ pub mod rustc {
.with_args(impl_exprs(s, &nested), trait_ref)
} else {
ImplExprAtom::LocalBound {
predicate_id: apred.predicate.predicate_id(s),
predicate_id: apred.predicate.sinto(s).id,
r#trait,
path,
}
Expand All @@ -394,7 +386,7 @@ pub mod rustc {
let atom = match source {
BuiltinImplSource::Object { .. } => ImplExprAtom::Dyn,
_ => ImplExprAtom::Builtin {
r#trait: self.skip_binder().sinto(s),
r#trait: trait_ref.clone(),
},
};
atom.with_args(vec![], trait_ref)
Expand All @@ -420,7 +412,7 @@ pub mod rustc {
// We don't want the id of the substituted clause id, but the
// original clause id (with, i.e., `Self`)
let s = &with_owner_id(s.base(), (), (), impl_trait_ref.def_id());
clause.predicate_id(s)
clause.sinto(s).id
};
let new_clause = clause.instantiate_supertrait(tcx, impl_trait_ref);
let impl_expr = new_clause
Expand Down
85 changes: 49 additions & 36 deletions frontend/exporter/src/types/copied.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2932,25 +2932,13 @@ pub enum Constness {
#[derive(Clone, Debug, JsonSchema)]
pub struct Generics<Body: IsBody> {
pub params: Vec<GenericParam<Body>>,
pub predicates: Vec<WherePredicate<Body>>,
#[value(region_bounds_at_current_owner(tcx))]
pub bounds: GenericBounds,
pub has_where_clause_predicates: bool,
pub where_clause_span: Span,
pub span: Span,
}

/// Reflects [`rustc_hir::WherePredicate`]
#[derive(AdtInto)]
#[args(<'tcx, S: UnderOwnerState<'tcx> >, from: rustc_hir::WherePredicate<'tcx>, state: S as tcx)]
#[derive_group(Serializers)]
#[derive(Clone, Debug, JsonSchema)]
pub enum WherePredicate<Body: IsBody> {
BoundPredicate(WhereBoundPredicate<Body>),
RegionPredicate(WhereRegionPredicate),
EqPredicate(WhereEqPredicate),
}

#[cfg(feature = "rustc")]
impl<'tcx, S: UnderOwnerState<'tcx>, Body: IsBody> SInto<S, ImplItem<Body>>
for rustc_hir::ImplItemRef
Expand Down Expand Up @@ -3095,7 +3083,7 @@ pub enum ImplItemKind<Body: IsBody> {
let assoc_item = tcx.opt_associated_item(owner_id).unwrap();
let impl_did = assoc_item.impl_container(tcx).unwrap();
tcx.explicit_item_bounds(assoc_item.trait_item_def_id.unwrap())
.skip_binder()
.skip_binder() // Skips an `EarlyBinder`, likely for GATs
.iter()
.copied()
.filter_map(|(clause, span)| super_clause_to_clause_and_impl_expr(s, impl_did, clause, span))
Expand Down Expand Up @@ -3703,10 +3691,27 @@ pub struct Clause {
#[cfg(feature = "rustc")]
impl<'tcx, S: UnderOwnerState<'tcx>> SInto<S, Clause> for rustc_middle::ty::Clause<'tcx> {
fn sinto(&self, s: &S) -> Clause {
Clause {
kind: self.kind().sinto(s),
id: self.predicate_id(s),
}
let kind = self.kind().sinto(s);
let id = kind
.clone()
.map(|clause_kind| PredicateKind::Clause(clause_kind))
.predicate_id();
Clause { kind, id }
}
}

#[cfg(feature = "rustc")]
impl<'tcx, S: UnderOwnerState<'tcx>> SInto<S, Clause>
for rustc_middle::ty::PolyTraitPredicate<'tcx>
{
fn sinto(&self, s: &S) -> Clause {
let kind: Binder<_> = self.sinto(s);
let kind: Binder<ClauseKind> = kind.map(|x| ClauseKind::Trait(x));
let id = kind
.clone()
.map(|clause_kind| PredicateKind::Clause(clause_kind))
.predicate_id();
Clause { kind, id }
}
}

Expand All @@ -3721,10 +3726,9 @@ pub struct Predicate {
#[cfg(feature = "rustc")]
impl<'tcx, S: UnderOwnerState<'tcx>> SInto<S, Predicate> for rustc_middle::ty::Predicate<'tcx> {
fn sinto(&self, s: &S) -> Predicate {
Predicate {
kind: self.kind().sinto(s),
id: self.predicate_id(s),
}
let kind = self.kind().sinto(s);
let id = kind.predicate_id();
Predicate { kind, id }
}
}

Expand All @@ -3747,6 +3751,30 @@ pub struct Binder<T> {
pub bound_vars: Vec<BoundVariableKind>,
}

impl<T> Binder<T> {
pub fn as_ref(&self) -> Binder<&T> {
Binder {
value: &self.value,
bound_vars: self.bound_vars.clone(),
}
}

pub fn hax_skip_binder(self) -> T {
self.value
}

pub fn map<U>(self, f: impl FnOnce(T) -> U) -> Binder<U> {
Binder {
value: f(self.value),
bound_vars: self.bound_vars,
}
}

pub fn inner_mut(&mut self) -> &mut T {
&mut self.value
}
}

/// Reflects [`rustc_middle::ty::GenericPredicates`]
#[derive(AdtInto)]
#[args(<'tcx, S: UnderOwnerState<'tcx>>, from: rustc_middle::ty::GenericPredicates<'tcx>, state: S as s)]
Expand Down Expand Up @@ -3992,21 +4020,6 @@ impl<'tcx, S: UnderOwnerState<'tcx>> SInto<S, Ident> for rustc_span::symbol::Ide
}
}

/// Reflects [`rustc_hir::WhereBoundPredicate`]
#[derive(AdtInto)]
#[args(<'tcx, S: UnderOwnerState<'tcx> >, from: rustc_hir::WhereBoundPredicate<'tcx>, state: S as tcx)]
#[derive_group(Serializers)]
#[derive(Clone, Debug, JsonSchema)]
pub struct WhereBoundPredicate<Body: IsBody> {
pub hir_id: HirId,
pub span: Span,
pub origin: PredicateOrigin,
pub bound_generic_params: Vec<GenericParam<Body>>,
pub bounded_ty: Ty,
// TODO: What to do with WhereBoundPredicate?
// pub bounds: GenericBounds,
}

/// Reflects [`rustc_hir::PredicateOrigin`]
#[derive(AdtInto)]
#[args(<S>, from: rustc_hir::PredicateOrigin, state: S as _s)]
Expand Down
13 changes: 1 addition & 12 deletions frontend/exporter/src/types/mir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -856,16 +856,6 @@ impl<'tcx, S: UnderOwnerState<'tcx> + HasMir<'tcx>> SInto<S, Place>
}
}

// TODO: we need this function because sometimes, Rust doesn't infer the proper
// typeclass instance.
#[cfg(feature = "rustc")]
pub(crate) fn poly_fn_sig_to_mir_poly_fn_sig<'tcx, S: BaseState<'tcx> + HasOwnerId>(
sig: &rustc_middle::ty::PolyFnSig<'tcx>,
s: &S,
) -> PolyFnSig {
sig.sinto(s)
}

#[derive_group(Serializers)]
#[derive(AdtInto, Clone, Debug, JsonSchema)]
#[args(<'tcx, S: UnderOwnerState<'tcx> + HasMir<'tcx>>, from: rustc_middle::mir::AggregateKind<'tcx>, state: S as s)]
Expand Down Expand Up @@ -901,8 +891,7 @@ pub enum AggregateKind {
// type, regions, etc. variables, which means we can treat the local
// closure like any top-level function.
let closure = generics.as_closure();
let sig = closure.sig();
let sig = poly_fn_sig_to_mir_poly_fn_sig(&sig, s);
let sig = closure.sig().sinto(s);
// Solve the trait obligations. Note that we solve the parent
let tcx = s.base().tcx;
Expand Down
Loading

0 comments on commit 2d49e8c

Please sign in to comment.