Skip to content

Commit

Permalink
Merge pull request #18774 from Veykril/push-ysppqxpuknnw
Browse files Browse the repository at this point in the history
Implement parameter variance inference
  • Loading branch information
Veykril authored Dec 29, 2024
2 parents 0337e79 + a102ea1 commit d3ebb14
Show file tree
Hide file tree
Showing 23 changed files with 1,335 additions and 129 deletions.
21 changes: 16 additions & 5 deletions src/tools/rust-analyzer/crates/hir-ty/src/chalk_db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -950,22 +950,33 @@ pub(crate) fn fn_def_datum_query(db: &dyn HirDatabase, fn_def_id: FnDefId) -> Ar

pub(crate) fn fn_def_variance_query(db: &dyn HirDatabase, fn_def_id: FnDefId) -> Variances {
let callable_def: CallableDefId = from_chalk(db, fn_def_id);
let generic_params =
generics(db.upcast(), GenericDefId::from_callable(db.upcast(), callable_def));
Variances::from_iter(
Interner,
std::iter::repeat(chalk_ir::Variance::Invariant).take(generic_params.len()),
db.variances_of(GenericDefId::from_callable(db.upcast(), callable_def))
.as_deref()
.unwrap_or_default()
.iter()
.map(|v| match v {
crate::variance::Variance::Covariant => chalk_ir::Variance::Covariant,
crate::variance::Variance::Invariant => chalk_ir::Variance::Invariant,
crate::variance::Variance::Contravariant => chalk_ir::Variance::Contravariant,
crate::variance::Variance::Bivariant => chalk_ir::Variance::Invariant,
}),
)
}

pub(crate) fn adt_variance_query(
db: &dyn HirDatabase,
chalk_ir::AdtId(adt_id): AdtId,
) -> Variances {
let generic_params = generics(db.upcast(), adt_id.into());
Variances::from_iter(
Interner,
std::iter::repeat(chalk_ir::Variance::Invariant).take(generic_params.len()),
db.variances_of(adt_id.into()).as_deref().unwrap_or_default().iter().map(|v| match v {
crate::variance::Variance::Covariant => chalk_ir::Variance::Covariant,
crate::variance::Variance::Invariant => chalk_ir::Variance::Invariant,
crate::variance::Variance::Contravariant => chalk_ir::Variance::Contravariant,
crate::variance::Variance::Bivariant => chalk_ir::Variance::Invariant,
}),
)
}

Expand Down
18 changes: 15 additions & 3 deletions src/tools/rust-analyzer/crates/hir-ty/src/chalk_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -443,13 +443,25 @@ impl ProjectionTyExt for ProjectionTy {
}

pub trait DynTyExt {
fn principal(&self) -> Option<&TraitRef>;
fn principal(&self) -> Option<Binders<Binders<&TraitRef>>>;
fn principal_id(&self) -> Option<chalk_ir::TraitId<Interner>>;
}

impl DynTyExt for DynTy {
fn principal(&self) -> Option<&TraitRef> {
fn principal(&self) -> Option<Binders<Binders<&TraitRef>>> {
self.bounds.as_ref().filter_map(|bounds| {
bounds.interned().first().and_then(|b| {
b.as_ref().filter_map(|b| match b {
crate::WhereClause::Implemented(trait_ref) => Some(trait_ref),
_ => None,
})
})
})
}

fn principal_id(&self) -> Option<chalk_ir::TraitId<Interner>> {
self.bounds.skip_binders().interned().first().and_then(|b| match b.skip_binders() {
crate::WhereClause::Implemented(trait_ref) => Some(trait_ref),
crate::WhereClause::Implemented(trait_ref) => Some(trait_ref.trait_id),
_ => None,
})
}
Expand Down
4 changes: 4 additions & 0 deletions src/tools/rust-analyzer/crates/hir-ty/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,10 @@ pub trait HirDatabase: DefDatabase + Upcast<dyn DefDatabase> {
#[ra_salsa::invoke(chalk_db::adt_variance_query)]
fn adt_variance(&self, adt_id: chalk_db::AdtId) -> chalk_db::Variances;

#[ra_salsa::invoke(crate::variance::variances_of)]
#[ra_salsa::cycle(crate::variance::variances_of_cycle)]
fn variances_of(&self, def: GenericDefId) -> Option<Arc<[crate::variance::Variance]>>;

#[ra_salsa::invoke(chalk_db::associated_ty_value_query)]
fn associated_ty_value(
&self,
Expand Down
8 changes: 4 additions & 4 deletions src/tools/rust-analyzer/crates/hir-ty/src/generics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ use triomphe::Arc;

use crate::{db::HirDatabase, lt_to_placeholder_idx, to_placeholder_idx, Interner, Substitution};

pub(crate) fn generics(db: &dyn DefDatabase, def: GenericDefId) -> Generics {
pub fn generics(db: &dyn DefDatabase, def: GenericDefId) -> Generics {
let parent_generics = parent_generic_def(db, def).map(|def| Box::new(generics(db, def)));
let params = db.generic_params(def);
let has_trait_self_param = params.trait_self_param().is_some();
Generics { def, params, parent_generics, has_trait_self_param }
}
#[derive(Clone, Debug)]
pub(crate) struct Generics {
pub struct Generics {
def: GenericDefId,
params: Arc<GenericParams>,
parent_generics: Option<Box<Generics>>,
Expand Down Expand Up @@ -153,7 +153,7 @@ impl Generics {
(parent_len, self_param, type_params, const_params, impl_trait_params, lifetime_params)
}

pub(crate) fn type_or_const_param_idx(&self, param: TypeOrConstParamId) -> Option<usize> {
pub fn type_or_const_param_idx(&self, param: TypeOrConstParamId) -> Option<usize> {
self.find_type_or_const_param(param)
}

Expand All @@ -174,7 +174,7 @@ impl Generics {
}
}

pub(crate) fn lifetime_idx(&self, lifetime: LifetimeParamId) -> Option<usize> {
pub fn lifetime_idx(&self, lifetime: LifetimeParamId) -> Option<usize> {
self.find_lifetime(lifetime)
}

Expand Down
4 changes: 2 additions & 2 deletions src/tools/rust-analyzer/crates/hir-ty/src/infer/closure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ impl InferenceContext<'_> {
.map(|b| b.into_value_and_skipped_binders().0);
self.deduce_closure_kind_from_predicate_clauses(clauses)
}
TyKind::Dyn(dyn_ty) => dyn_ty.principal().and_then(|trait_ref| {
self.fn_trait_kind_from_trait_id(from_chalk_trait_id(trait_ref.trait_id))
TyKind::Dyn(dyn_ty) => dyn_ty.principal_id().and_then(|trait_id| {
self.fn_trait_kind_from_trait_id(from_chalk_trait_id(trait_id))
}),
TyKind::InferenceVar(ty, chalk_ir::TyVariableKind::General) => {
let clauses = self.clauses_for_self_ty(*ty);
Expand Down
1 change: 1 addition & 0 deletions src/tools/rust-analyzer/crates/hir-ty/src/lang_items.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub fn is_box(db: &dyn HirDatabase, adt: AdtId) -> bool {

pub fn is_unsafe_cell(db: &dyn HirDatabase, adt: AdtId) -> bool {
let AdtId::StructId(id) = adt else { return false };

db.struct_data(id).flags.contains(StructFlags::IS_UNSAFE_CELL)
}

Expand Down
9 changes: 5 additions & 4 deletions src/tools/rust-analyzer/crates/hir-ty/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ extern crate ra_ap_rustc_pattern_analysis as rustc_pattern_analysis;
mod builder;
mod chalk_db;
mod chalk_ext;
mod generics;
mod infer;
mod inhabitedness;
mod interner;
Expand All @@ -39,6 +38,7 @@ pub mod db;
pub mod diagnostics;
pub mod display;
pub mod dyn_compatibility;
pub mod generics;
pub mod lang_items;
pub mod layout;
pub mod method_resolution;
Expand All @@ -50,6 +50,7 @@ pub mod traits;
mod test_db;
#[cfg(test)]
mod tests;
mod variance;

use std::hash::Hash;

Expand Down Expand Up @@ -88,10 +89,9 @@ pub use infer::{
PointerCast,
};
pub use interner::Interner;
pub use lower::diagnostics::*;
pub use lower::{
associated_type_shorthand_candidates, ImplTraitLoweringMode, ParamLoweringMode, TyDefId,
TyLoweringContext, ValueTyDefId,
associated_type_shorthand_candidates, diagnostics::*, ImplTraitLoweringMode, ParamLoweringMode,
TyDefId, TyLoweringContext, ValueTyDefId,
};
pub use mapping::{
from_assoc_type_id, from_chalk_trait_id, from_foreign_def_id, from_placeholder_idx,
Expand All @@ -101,6 +101,7 @@ pub use mapping::{
pub use method_resolution::check_orphan_rules;
pub use traits::TraitEnvironment;
pub use utils::{all_super_traits, direct_super_traits, is_fn_unsafe_to_call};
pub use variance::Variance;

pub use chalk_ir::{
cast::Cast,
Expand Down
13 changes: 6 additions & 7 deletions src/tools/rust-analyzer/crates/hir-ty/src/method_resolution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -805,8 +805,8 @@ fn is_inherent_impl_coherent(
| TyKind::Scalar(_) => def_map.is_rustc_coherence_is_core(),

&TyKind::Adt(AdtId(adt), _) => adt.module(db.upcast()).krate() == def_map.krate(),
TyKind::Dyn(it) => it.principal().map_or(false, |trait_ref| {
from_chalk_trait_id(trait_ref.trait_id).module(db.upcast()).krate() == def_map.krate()
TyKind::Dyn(it) => it.principal_id().map_or(false, |trait_id| {
from_chalk_trait_id(trait_id).module(db.upcast()).krate() == def_map.krate()
}),

_ => true,
Expand Down Expand Up @@ -834,9 +834,8 @@ fn is_inherent_impl_coherent(
.contains(StructFlags::IS_RUSTC_HAS_INCOHERENT_INHERENT_IMPL),
hir_def::AdtId::EnumId(it) => db.enum_data(it).rustc_has_incoherent_inherent_impls,
},
TyKind::Dyn(it) => it.principal().map_or(false, |trait_ref| {
db.trait_data(from_chalk_trait_id(trait_ref.trait_id))
.rustc_has_incoherent_inherent_impls
TyKind::Dyn(it) => it.principal_id().map_or(false, |trait_id| {
db.trait_data(from_chalk_trait_id(trait_id)).rustc_has_incoherent_inherent_impls
}),

_ => false,
Expand Down Expand Up @@ -896,8 +895,8 @@ pub fn check_orphan_rules(db: &dyn HirDatabase, impl_: ImplId) -> bool {
match unwrap_fundamental(ty).kind(Interner) {
&TyKind::Adt(AdtId(id), _) => is_local(id.module(db.upcast()).krate()),
TyKind::Error => true,
TyKind::Dyn(it) => it.principal().map_or(false, |trait_ref| {
is_local(from_chalk_trait_id(trait_ref.trait_id).module(db.upcast()).krate())
TyKind::Dyn(it) => it.principal_id().map_or(false, |trait_id| {
is_local(from_chalk_trait_id(trait_id).module(db.upcast()).krate())
}),
_ => false,
}
Expand Down
78 changes: 50 additions & 28 deletions src/tools/rust-analyzer/crates/hir-ty/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,15 @@ fn check_impl(ra_fixture: &str, allow_none: bool, only_types: bool, display_sour
None => continue,
};
let def_map = module.def_map(&db);
visit_module(&db, &def_map, module.local_id, &mut |it| defs.push(it));
visit_module(&db, &def_map, module.local_id, &mut |it| {
defs.push(match it {
ModuleDefId::FunctionId(it) => it.into(),
ModuleDefId::EnumVariantId(it) => it.into(),
ModuleDefId::ConstId(it) => it.into(),
ModuleDefId::StaticId(it) => it.into(),
_ => return,
})
});
}
defs.sort_by_key(|def| match def {
DefWithBodyId::FunctionId(it) => {
Expand Down Expand Up @@ -375,7 +383,15 @@ fn infer_with_mismatches(content: &str, include_mismatches: bool) -> String {
let def_map = module.def_map(&db);

let mut defs: Vec<DefWithBodyId> = Vec::new();
visit_module(&db, &def_map, module.local_id, &mut |it| defs.push(it));
visit_module(&db, &def_map, module.local_id, &mut |it| {
defs.push(match it {
ModuleDefId::FunctionId(it) => it.into(),
ModuleDefId::EnumVariantId(it) => it.into(),
ModuleDefId::ConstId(it) => it.into(),
ModuleDefId::StaticId(it) => it.into(),
_ => return,
})
});
defs.sort_by_key(|def| match def {
DefWithBodyId::FunctionId(it) => {
let loc = it.lookup(&db);
Expand Down Expand Up @@ -405,30 +421,30 @@ fn infer_with_mismatches(content: &str, include_mismatches: bool) -> String {
buf
}

fn visit_module(
pub(crate) fn visit_module(
db: &TestDB,
crate_def_map: &DefMap,
module_id: LocalModuleId,
cb: &mut dyn FnMut(DefWithBodyId),
cb: &mut dyn FnMut(ModuleDefId),
) {
visit_scope(db, crate_def_map, &crate_def_map[module_id].scope, cb);
for impl_id in crate_def_map[module_id].scope.impls() {
let impl_data = db.impl_data(impl_id);
for &item in impl_data.items.iter() {
match item {
AssocItemId::FunctionId(it) => {
let def = it.into();
cb(def);
let body = db.body(def);
let body = db.body(it.into());
cb(it.into());
visit_body(db, &body, cb);
}
AssocItemId::ConstId(it) => {
let def = it.into();
cb(def);
let body = db.body(def);
let body = db.body(it.into());
cb(it.into());
visit_body(db, &body, cb);
}
AssocItemId::TypeAliasId(_) => (),
AssocItemId::TypeAliasId(it) => {
cb(it.into());
}
}
}
}
Expand All @@ -437,33 +453,27 @@ fn visit_module(
db: &TestDB,
crate_def_map: &DefMap,
scope: &ItemScope,
cb: &mut dyn FnMut(DefWithBodyId),
cb: &mut dyn FnMut(ModuleDefId),
) {
for decl in scope.declarations() {
cb(decl);
match decl {
ModuleDefId::FunctionId(it) => {
let def = it.into();
cb(def);
let body = db.body(def);
let body = db.body(it.into());
visit_body(db, &body, cb);
}
ModuleDefId::ConstId(it) => {
let def = it.into();
cb(def);
let body = db.body(def);
let body = db.body(it.into());
visit_body(db, &body, cb);
}
ModuleDefId::StaticId(it) => {
let def = it.into();
cb(def);
let body = db.body(def);
let body = db.body(it.into());
visit_body(db, &body, cb);
}
ModuleDefId::AdtId(hir_def::AdtId::EnumId(it)) => {
db.enum_data(it).variants.iter().for_each(|&(it, _)| {
let def = it.into();
cb(def);
let body = db.body(def);
let body = db.body(it.into());
cb(it.into());
visit_body(db, &body, cb);
});
}
Expand All @@ -473,7 +483,7 @@ fn visit_module(
match item {
AssocItemId::FunctionId(it) => cb(it.into()),
AssocItemId::ConstId(it) => cb(it.into()),
AssocItemId::TypeAliasId(_) => (),
AssocItemId::TypeAliasId(it) => cb(it.into()),
}
}
}
Expand All @@ -483,7 +493,7 @@ fn visit_module(
}
}

fn visit_body(db: &TestDB, body: &Body, cb: &mut dyn FnMut(DefWithBodyId)) {
fn visit_body(db: &TestDB, body: &Body, cb: &mut dyn FnMut(ModuleDefId)) {
for (_, def_map) in body.blocks(db) {
for (mod_id, _) in def_map.modules() {
visit_module(db, &def_map, mod_id, cb);
Expand Down Expand Up @@ -553,7 +563,13 @@ fn salsa_bug() {
let module = db.module_for_file(pos.file_id);
let crate_def_map = module.def_map(&db);
visit_module(&db, &crate_def_map, module.local_id, &mut |def| {
db.infer(def);
db.infer(match def {
ModuleDefId::FunctionId(it) => it.into(),
ModuleDefId::EnumVariantId(it) => it.into(),
ModuleDefId::ConstId(it) => it.into(),
ModuleDefId::StaticId(it) => it.into(),
_ => return,
});
});

let new_text = "
Expand Down Expand Up @@ -586,6 +602,12 @@ fn salsa_bug() {
let module = db.module_for_file(pos.file_id);
let crate_def_map = module.def_map(&db);
visit_module(&db, &crate_def_map, module.local_id, &mut |def| {
db.infer(def);
db.infer(match def {
ModuleDefId::FunctionId(it) => it.into(),
ModuleDefId::EnumVariantId(it) => it.into(),
ModuleDefId::ConstId(it) => it.into(),
ModuleDefId::StaticId(it) => it.into(),
_ => return,
});
});
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ fn check_closure_captures(ra_fixture: &str, expect: Expect) {

let mut captures_info = Vec::new();
for def in defs {
let def = match def {
hir_def::ModuleDefId::FunctionId(it) => it.into(),
hir_def::ModuleDefId::EnumVariantId(it) => it.into(),
hir_def::ModuleDefId::ConstId(it) => it.into(),
hir_def::ModuleDefId::StaticId(it) => it.into(),
_ => continue,
};
let infer = db.infer(def);
let db = &db;
captures_info.extend(infer.closure_info.iter().flat_map(|(closure_id, (captures, _))| {
Expand Down
Loading

0 comments on commit d3ebb14

Please sign in to comment.