From 73ae6a0ba8bb00138f83d89b66dc6b73c4fa66e8 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Sat, 28 Dec 2024 15:08:26 +0100 Subject: [PATCH 1/8] Downgrade salsa log levels --- .../ra-salsa-macros/src/query_group.rs | 2 +- .../crates/ra-salsa/src/derived/slot.rs | 45 ++++++++++------- .../crates/ra-salsa/src/derived_lru/slot.rs | 49 ++++++++++++------- .../crates/ra-salsa/src/input.rs | 14 +++--- .../rust-analyzer/crates/ra-salsa/src/lib.rs | 4 +- .../rust-analyzer/crates/ra-salsa/src/lru.rs | 24 ++++----- .../crates/ra-salsa/src/runtime.rs | 12 ++--- .../ra-salsa/src/runtime/local_state.rs | 4 +- 8 files changed, 86 insertions(+), 68 deletions(-) diff --git a/src/tools/rust-analyzer/crates/ra-salsa/ra-salsa-macros/src/query_group.rs b/src/tools/rust-analyzer/crates/ra-salsa/ra-salsa-macros/src/query_group.rs index 88db6093ee0e7..d761a5e798e89 100644 --- a/src/tools/rust-analyzer/crates/ra-salsa/ra-salsa-macros/src/query_group.rs +++ b/src/tools/rust-analyzer/crates/ra-salsa/ra-salsa-macros/src/query_group.rs @@ -242,7 +242,7 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream let tracing = if let QueryStorage::Memoized | QueryStorage::LruMemoized = query.storage { let s = format!("{trait_name}::{fn_name}"); Some(quote! { - let _p = tracing::debug_span!(#s, #(#key_names = tracing::field::debug(&#key_names)),*).entered(); + let _p = tracing::trace_span!(#s, #(#key_names = tracing::field::debug(&#key_names)),*).entered(); }) } else { None diff --git a/src/tools/rust-analyzer/crates/ra-salsa/src/derived/slot.rs b/src/tools/rust-analyzer/crates/ra-salsa/src/derived/slot.rs index 6c5ccba173b99..cfe2c48f411f1 100644 --- a/src/tools/rust-analyzer/crates/ra-salsa/src/derived/slot.rs +++ b/src/tools/rust-analyzer/crates/ra-salsa/src/derived/slot.rs @@ -13,7 +13,7 @@ use crate::{Database, DatabaseKeyIndex, Event, EventKind, QueryDb}; use parking_lot::{RawRwLock, RwLock}; use std::ops::Deref; use std::sync::atomic::{AtomicBool, Ordering}; -use tracing::{debug, info}; +use tracing::trace; pub(super) struct Slot where @@ -126,7 +126,7 @@ where // doing any `set` invocations while the query function runs. let revision_now = runtime.current_revision(); - info!("{:?}: invoked at {:?}", self, revision_now,); + trace!("{:?}: invoked at {:?}", self, revision_now,); // First, do a check with a read-lock. loop { @@ -152,7 +152,7 @@ where ) -> StampedValue { let runtime = db.salsa_runtime(); - debug!("{:?}: read_upgrade(revision_now={:?})", self, revision_now,); + trace!("{:?}: read_upgrade(revision_now={:?})", self, revision_now,); // Check with an upgradable read to see if there is a value // already. (This permits other readers but prevents anyone @@ -184,7 +184,7 @@ where // inputs and check whether they are out of date. if let Some(memo) = &mut old_memo { if let Some(value) = memo.verify_value(db.ops_database(), revision_now, &active_query) { - info!("{:?}: validated old memoized value", self,); + trace!("{:?}: validated old memoized value", self,); db.salsa_event(Event { runtime_id: runtime.id(), @@ -212,7 +212,7 @@ where old_memo: Option>, key: &Q::Key, ) -> StampedValue { - tracing::info!("{:?}: executing query", self.database_key_index().debug(db)); + tracing::trace!("{:?}: executing query", self.database_key_index().debug(db)); db.salsa_event(Event { runtime_id: db.salsa_runtime().id(), @@ -224,7 +224,7 @@ where let value = match Cycle::catch(|| Q::execute(db, key.clone())) { Ok(v) => v, Err(cycle) => { - tracing::debug!( + tracing::trace!( "{:?}: caught cycle {:?}, have strategy {:?}", self.database_key_index().debug(db), cycle, @@ -272,9 +272,10 @@ where // consumers must be aware of. Becoming *more* durable // is not. See the test `constant_to_non_constant`. if revisions.durability >= old_memo.revisions.durability && old_memo.value == value { - debug!( + trace!( "read_upgrade({:?}): value is equal, back-dating to {:?}", - self, old_memo.revisions.changed_at, + self, + old_memo.revisions.changed_at, ); assert!(old_memo.revisions.changed_at <= revisions.changed_at); @@ -290,7 +291,7 @@ where let memo_value = new_value.value.clone(); - debug!("read_upgrade({:?}): result.revisions = {:#?}", self, revisions,); + trace!("read_upgrade({:?}): result.revisions = {:#?}", self, revisions,); panic_guard.proceed(Some(Memo { value: memo_value, verified_at: revision_now, revisions })); @@ -339,9 +340,11 @@ where } QueryState::Memoized(memo) => { - debug!( + trace!( "{:?}: found memoized value, verified_at={:?}, changed_at={:?}", - self, memo.verified_at, memo.revisions.changed_at, + self, + memo.verified_at, + memo.revisions.changed_at, ); if memo.verified_at < revision_now { @@ -355,7 +358,7 @@ where value: value.clone(), }; - info!("{:?}: returning memoized value changed at {:?}", self, value.changed_at); + trace!("{:?}: returning memoized value changed at {:?}", self, value.changed_at); ProbeState::UpToDate(value) } @@ -387,7 +390,7 @@ where } pub(super) fn invalidate(&self, new_revision: Revision) -> Option { - tracing::debug!("Slot::invalidate(new_revision = {:?})", new_revision); + tracing::trace!("Slot::invalidate(new_revision = {:?})", new_revision); match &mut *self.state.write() { QueryState::Memoized(memo) => { memo.revisions.untracked = true; @@ -411,9 +414,11 @@ where db.unwind_if_cancelled(); - debug!( + trace!( "maybe_changed_after({:?}) called with revision={:?}, revision_now={:?}", - self, revision, revision_now, + self, + revision, + revision_now, ); // Do an initial probe with just the read-lock. @@ -680,9 +685,11 @@ where assert!(self.verified_at != revision_now); let verified_at = self.verified_at; - debug!( + trace!( "verify_revisions: verified_at={:?}, revision_now={:?}, inputs={:#?}", - verified_at, revision_now, self.revisions.inputs + verified_at, + revision_now, + self.revisions.inputs ); if self.check_durability(db.salsa_runtime()) { @@ -708,7 +715,7 @@ where let changed_input = inputs.slice.iter().find(|&&input| db.maybe_changed_after(input, verified_at)); if let Some(input) = changed_input { - debug!("validate_memoized_value: `{:?}` may have changed", input); + trace!("validate_memoized_value: `{:?}` may have changed", input); return false; } @@ -721,7 +728,7 @@ where /// True if this memo is known not to have changed based on its durability. fn check_durability(&self, runtime: &Runtime) -> bool { let last_changed = runtime.last_changed_revision(self.revisions.durability); - debug!( + trace!( "check_durability(last_changed={:?} <= verified_at={:?}) = {:?}", last_changed, self.verified_at, diff --git a/src/tools/rust-analyzer/crates/ra-salsa/src/derived_lru/slot.rs b/src/tools/rust-analyzer/crates/ra-salsa/src/derived_lru/slot.rs index ff9cc4eade2cf..73a5e07aa05ab 100644 --- a/src/tools/rust-analyzer/crates/ra-salsa/src/derived_lru/slot.rs +++ b/src/tools/rust-analyzer/crates/ra-salsa/src/derived_lru/slot.rs @@ -17,7 +17,7 @@ use parking_lot::{RawRwLock, RwLock}; use std::marker::PhantomData; use std::ops::Deref; use std::sync::atomic::{AtomicBool, Ordering}; -use tracing::{debug, info}; +use tracing::trace; pub(super) struct Slot where @@ -140,7 +140,7 @@ where // doing any `set` invocations while the query function runs. let revision_now = runtime.current_revision(); - info!("{:?}: invoked at {:?}", self, revision_now,); + trace!("{:?}: invoked at {:?}", self, revision_now,); // First, do a check with a read-lock. loop { @@ -168,7 +168,7 @@ where ) -> StampedValue { let runtime = db.salsa_runtime(); - debug!("{:?}: read_upgrade(revision_now={:?})", self, revision_now,); + trace!("{:?}: read_upgrade(revision_now={:?})", self, revision_now,); // Check with an upgradable read to see if there is a value // already. (This permits other readers but prevents anyone @@ -202,7 +202,7 @@ where // inputs and check whether they are out of date. if let Some(memo) = &mut old_memo { if let Some(value) = memo.verify_value(db.ops_database(), revision_now, &active_query) { - info!("{:?}: validated old memoized value", self,); + trace!("{:?}: validated old memoized value", self,); db.salsa_event(Event { runtime_id: runtime.id(), @@ -230,7 +230,7 @@ where old_memo: Option>, key: &Q::Key, ) -> StampedValue { - tracing::info!("{:?}: executing query", self.database_key_index().debug(db)); + tracing::trace!("{:?}: executing query", self.database_key_index().debug(db)); db.salsa_event(Event { runtime_id: db.salsa_runtime().id(), @@ -242,7 +242,7 @@ where let value = match Cycle::catch(|| Q::execute(db, key.clone())) { Ok(v) => v, Err(cycle) => { - tracing::debug!( + tracing::trace!( "{:?}: caught cycle {:?}, have strategy {:?}", self.database_key_index().debug(db), cycle, @@ -293,9 +293,10 @@ where if revisions.durability >= old_memo.revisions.durability && MP::memoized_value_eq(old_value, &value) { - debug!( + trace!( "read_upgrade({:?}): value is equal, back-dating to {:?}", - self, old_memo.revisions.changed_at, + self, + old_memo.revisions.changed_at, ); assert!(old_memo.revisions.changed_at <= revisions.changed_at); @@ -313,7 +314,7 @@ where let memo_value = if self.should_memoize_value(key) { Some(new_value.value.clone()) } else { None }; - debug!("read_upgrade({:?}): result.revisions = {:#?}", self, revisions,); + trace!("read_upgrade({:?}): result.revisions = {:#?}", self, revisions,); panic_guard.proceed(Some(Memo { value: memo_value, verified_at: revision_now, revisions })); @@ -362,9 +363,11 @@ where } QueryState::Memoized(memo) => { - debug!( + trace!( "{:?}: found memoized value, verified_at={:?}, changed_at={:?}", - self, memo.verified_at, memo.revisions.changed_at, + self, + memo.verified_at, + memo.revisions.changed_at, ); if memo.verified_at < revision_now { @@ -378,7 +381,11 @@ where value: value.clone(), }; - info!("{:?}: returning memoized value changed at {:?}", self, value.changed_at); + trace!( + "{:?}: returning memoized value changed at {:?}", + self, + value.changed_at + ); ProbeState::UpToDate(value) } else { @@ -426,7 +433,7 @@ where } pub(super) fn invalidate(&self, new_revision: Revision) -> Option { - tracing::debug!("Slot::invalidate(new_revision = {:?})", new_revision); + tracing::trace!("Slot::invalidate(new_revision = {:?})", new_revision); match &mut *self.state.write() { QueryState::Memoized(memo) => { memo.revisions.untracked = true; @@ -450,9 +457,11 @@ where db.unwind_if_cancelled(); - debug!( + trace!( "maybe_changed_after({:?}) called with revision={:?}, revision_now={:?}", - self, revision, revision_now, + self, + revision, + revision_now, ); // Do an initial probe with just the read-lock. @@ -734,9 +743,11 @@ where assert!(self.verified_at != revision_now); let verified_at = self.verified_at; - debug!( + trace!( "verify_revisions: verified_at={:?}, revision_now={:?}, inputs={:#?}", - verified_at, revision_now, self.revisions.inputs + verified_at, + revision_now, + self.revisions.inputs ); if self.check_durability(db.salsa_runtime()) { @@ -762,7 +773,7 @@ where let changed_input = inputs.slice.iter().find(|&&input| db.maybe_changed_after(input, verified_at)); if let Some(input) = changed_input { - debug!("validate_memoized_value: `{:?}` may have changed", input); + trace!("validate_memoized_value: `{:?}` may have changed", input); return false; } @@ -775,7 +786,7 @@ where /// True if this memo is known not to have changed based on its durability. fn check_durability(&self, runtime: &Runtime) -> bool { let last_changed = runtime.last_changed_revision(self.revisions.durability); - debug!( + trace!( "check_durability(last_changed={:?} <= verified_at={:?}) = {:?}", last_changed, self.verified_at, diff --git a/src/tools/rust-analyzer/crates/ra-salsa/src/input.rs b/src/tools/rust-analyzer/crates/ra-salsa/src/input.rs index f04f48e3bab85..4992a0c7271cc 100644 --- a/src/tools/rust-analyzer/crates/ra-salsa/src/input.rs +++ b/src/tools/rust-analyzer/crates/ra-salsa/src/input.rs @@ -14,7 +14,7 @@ use crate::{DatabaseKeyIndex, QueryDb}; use indexmap::map::Entry; use parking_lot::RwLock; use std::iter; -use tracing::debug; +use tracing::trace; /// Input queries store the result plus a list of the other queries /// that they invoked. This means we can avoid recomputing them when @@ -73,11 +73,11 @@ where return true; }; - debug!("maybe_changed_after(slot={:?}, revision={:?})", Q::default(), revision,); + trace!("maybe_changed_after(slot={:?}, revision={:?})", Q::default(), revision,); let changed_at = slot.stamped_value.read().changed_at; - debug!("maybe_changed_after: changed_at = {:?}", changed_at); + trace!("maybe_changed_after: changed_at = {:?}", changed_at); changed_at > revision } @@ -140,7 +140,7 @@ where Q: Query, { fn set(&self, runtime: &mut Runtime, key: &Q::Key, value: Q::Value, durability: Durability) { - tracing::debug!("{:?}({:?}) = {:?} ({:?})", Q::default(), key, value, durability); + tracing::trace!("{:?}({:?}) = {:?} ({:?})", Q::default(), key, value, durability); // The value is changing, so we need a new revision (*). We also // need to update the 'last changed' revision by invoking @@ -234,14 +234,14 @@ where ) -> bool { debug_assert!(revision < db.salsa_runtime().current_revision()); - debug!("maybe_changed_after(slot={:?}, revision={:?})", Q::default(), revision,); + trace!("maybe_changed_after(slot={:?}, revision={:?})", Q::default(), revision,); let Some(value) = &*self.slot.stamped_value.read() else { return true; }; let changed_at = value.changed_at; - debug!("maybe_changed_after: changed_at = {:?}", changed_at); + trace!("maybe_changed_after: changed_at = {:?}", changed_at); changed_at > revision } @@ -298,7 +298,7 @@ where Q: Query, { fn set(&self, runtime: &mut Runtime, (): &Q::Key, value: Q::Value, durability: Durability) { - tracing::debug!("{:?} = {:?} ({:?})", Q::default(), value, durability); + tracing::trace!("{:?} = {:?} ({:?})", Q::default(), value, durability); // The value is changing, so we need a new revision (*). We also // need to update the 'last changed' revision by invoking diff --git a/src/tools/rust-analyzer/crates/ra-salsa/src/lib.rs b/src/tools/rust-analyzer/crates/ra-salsa/src/lib.rs index 8530521d9157b..843b6d31f0c33 100644 --- a/src/tools/rust-analyzer/crates/ra-salsa/src/lib.rs +++ b/src/tools/rust-analyzer/crates/ra-salsa/src/lib.rs @@ -79,7 +79,7 @@ pub trait Database: plumbing::DatabaseOps { let current_revision = runtime.current_revision(); let pending_revision = runtime.pending_revision(); - tracing::debug!( + tracing::trace!( "unwind_if_cancelled: current_revision={:?}, pending_revision={:?}", current_revision, pending_revision @@ -684,7 +684,7 @@ impl Cycle { } pub(crate) fn throw(self) -> ! { - tracing::debug!("throwing cycle {:?}", self); + tracing::trace!("throwing cycle {:?}", self); std::panic::resume_unwind(Box::new(self)) } diff --git a/src/tools/rust-analyzer/crates/ra-salsa/src/lru.rs b/src/tools/rust-analyzer/crates/ra-salsa/src/lru.rs index a6f96beeab11a..7fbd42f92627a 100644 --- a/src/tools/rust-analyzer/crates/ra-salsa/src/lru.rs +++ b/src/tools/rust-analyzer/crates/ra-salsa/src/lru.rs @@ -103,11 +103,11 @@ where /// Records that `node` was used. This may displace an old node (if the LRU limits are pub(crate) fn record_use(&self, node: &Arc) -> Option> { - tracing::debug!("record_use(node={:?})", node); + tracing::trace!("record_use(node={:?})", node); // Load green zone length and check if the LRU cache is even enabled. let green_zone = self.green_zone.load(Ordering::Acquire); - tracing::debug!("record_use: green_zone={}", green_zone); + tracing::trace!("record_use: green_zone={}", green_zone); if green_zone == 0 { return None; } @@ -115,7 +115,7 @@ where // Find current index of list (if any) and the current length // of our green zone. let index = node.lru_index().load(); - tracing::debug!("record_use: index={}", index); + tracing::trace!("record_use: index={}", index); // Already a member of the list, and in the green zone -- nothing to do! if index < green_zone { @@ -162,9 +162,9 @@ where let entries = std::mem::replace(&mut self.entries, Vec::with_capacity(self.end_red_zone as usize)); - tracing::debug!("green_zone = {:?}", self.green_zone()); - tracing::debug!("yellow_zone = {:?}", self.yellow_zone()); - tracing::debug!("red_zone = {:?}", self.red_zone()); + tracing::trace!("green_zone = {:?}", self.green_zone()); + tracing::trace!("yellow_zone = {:?}", self.yellow_zone()); + tracing::trace!("red_zone = {:?}", self.red_zone()); // We expect to resize when the LRU cache is basically empty. // So just forget all the old LRU indices to start. @@ -180,7 +180,7 @@ where /// list may displace an old member of the red zone, in which case /// that is returned. fn record_use(&mut self, node: &Arc) -> Option> { - tracing::debug!("record_use(node={:?})", node); + tracing::trace!("record_use(node={:?})", node); // NB: When this is invoked, we have typically already loaded // the LRU index (to check if it is in green zone). But that @@ -212,7 +212,7 @@ where if len < self.end_red_zone { self.entries.push(node.clone()); node.lru_index().store(len); - tracing::debug!("inserted node {:?} at {}", node, len); + tracing::trace!("inserted node {:?} at {}", node, len); return self.record_use(node); } @@ -220,7 +220,7 @@ where // zone and then promoting. let victim_index = self.pick_index(self.red_zone()); let victim_node = std::mem::replace(&mut self.entries[victim_index as usize], node.clone()); - tracing::debug!("evicting red node {:?} from {}", victim_node, victim_index); + tracing::trace!("evicting red node {:?} from {}", victim_node, victim_index); victim_node.lru_index().clear(); self.promote_red_to_green(node, victim_index); Some(victim_node) @@ -241,7 +241,7 @@ where // going to invoke `self.promote_yellow` next, and it will get // updated then. let yellow_index = self.pick_index(self.yellow_zone()); - tracing::debug!( + tracing::trace!( "demoting yellow node {:?} from {} to red at {}", self.entries[yellow_index as usize], yellow_index, @@ -265,7 +265,7 @@ where // Pick a yellow at random and switch places with it. let green_index = self.pick_index(self.green_zone()); - tracing::debug!( + tracing::trace!( "demoting green node {:?} from {} to yellow at {}", self.entries[green_index as usize], green_index, @@ -275,7 +275,7 @@ where self.entries[yellow_index as usize].lru_index().store(yellow_index); node.lru_index().store(green_index); - tracing::debug!("promoted {:?} to green index {}", node, green_index); + tracing::trace!("promoted {:?} to green index {}", node, green_index); } fn pick_index(&mut self, zone: std::ops::Range) -> u16 { diff --git a/src/tools/rust-analyzer/crates/ra-salsa/src/runtime.rs b/src/tools/rust-analyzer/crates/ra-salsa/src/runtime.rs index 5fe5f4b46d3e8..cb16ba0044dfd 100644 --- a/src/tools/rust-analyzer/crates/ra-salsa/src/runtime.rs +++ b/src/tools/rust-analyzer/crates/ra-salsa/src/runtime.rs @@ -9,7 +9,7 @@ use parking_lot::{Mutex, RwLock}; use std::hash::Hash; use std::panic::panic_any; use std::sync::atomic::{AtomicU32, Ordering}; -use tracing::debug; +use tracing::trace; use triomphe::{Arc, ThinArc}; mod dependency_graph; @@ -177,7 +177,7 @@ impl Runtime { where F: FnOnce(Revision) -> Option, { - tracing::debug!("increment_revision()"); + tracing::trace!("increment_revision()"); if !self.permits_increment() { panic!("increment_revision invoked during a query computation"); @@ -196,7 +196,7 @@ impl Runtime { let new_revision = current_revision.next(); - debug!("increment_revision: incremented to {:?}", new_revision); + trace!("increment_revision: incremented to {:?}", new_revision); if let Some(d) = op(new_revision) { for rev in &self.shared_state.revisions[1..=d.index()] { @@ -267,7 +267,7 @@ impl Runtime { database_key_index: DatabaseKeyIndex, to_id: RuntimeId, ) { - debug!("unblock_cycle_and_maybe_throw(database_key={:?})", database_key_index); + trace!("unblock_cycle_and_maybe_throw(database_key={:?})", database_key_index); let mut from_stack = self.local_state.take_query_stack(); let from_id = self.id(); @@ -305,7 +305,7 @@ impl Runtime { Cycle::new(Arc::new(v)) }; - debug!("cycle {:?}, cycle_query {:#?}", cycle.debug(db), cycle_query,); + trace!("cycle {:?}, cycle_query {:#?}", cycle.debug(db), cycle_query,); // We can remove the cycle participants from the list of dependencies; // they are a strongly connected component (SCC) and we only care about @@ -323,7 +323,7 @@ impl Runtime { CycleRecoveryStrategy::Fallback => false, }) .for_each(|aq| { - debug!("marking {:?} for fallback", aq.database_key_index.debug(db)); + trace!("marking {:?} for fallback", aq.database_key_index.debug(db)); aq.take_inputs_from(&cycle_query); assert!(aq.cycle.is_none()); aq.cycle = Some(cycle.clone()); diff --git a/src/tools/rust-analyzer/crates/ra-salsa/src/runtime/local_state.rs b/src/tools/rust-analyzer/crates/ra-salsa/src/runtime/local_state.rs index 738696718868f..4ab4bad0cc508 100644 --- a/src/tools/rust-analyzer/crates/ra-salsa/src/runtime/local_state.rs +++ b/src/tools/rust-analyzer/crates/ra-salsa/src/runtime/local_state.rs @@ -1,4 +1,4 @@ -use tracing::debug; +use tracing::trace; use triomphe::ThinArc; use crate::durability::Durability; @@ -78,7 +78,7 @@ impl LocalState { durability: Durability, changed_at: Revision, ) { - debug!( + trace!( "report_query_read_and_unwind_if_cycle_resulted(input={:?}, durability={:?}, changed_at={:?})", input, durability, changed_at ); From 873cf255def56e5187aceff89e27146192a8366d Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Sat, 28 Dec 2024 15:08:26 +0100 Subject: [PATCH 2/8] Add DynTyExt::principal_id --- .../crates/hir-ty/src/chalk_ext.rs | 18 +++++++++++++++--- .../crates/hir-ty/src/infer/closure.rs | 4 ++-- .../crates/hir-ty/src/lang_items.rs | 1 + .../crates/hir-ty/src/method_resolution.rs | 13 ++++++------- 4 files changed, 24 insertions(+), 12 deletions(-) diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/chalk_ext.rs b/src/tools/rust-analyzer/crates/hir-ty/src/chalk_ext.rs index 302558162ac96..51c178b90d72b 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/chalk_ext.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/chalk_ext.rs @@ -443,13 +443,25 @@ impl ProjectionTyExt for ProjectionTy { } pub trait DynTyExt { - fn principal(&self) -> Option<&TraitRef>; + fn principal(&self) -> Option>>; + fn principal_id(&self) -> Option>; } impl DynTyExt for DynTy { - fn principal(&self) -> Option<&TraitRef> { + fn principal(&self) -> Option>> { + self.bounds.as_ref().filter_map(|bounds| { + bounds.interned().first().and_then(|b| { + b.as_ref().filter_map(|b| match b { + crate::WhereClause::Implemented(trait_ref) => Some(trait_ref), + _ => None, + }) + }) + }) + } + + fn principal_id(&self) -> Option> { self.bounds.skip_binders().interned().first().and_then(|b| match b.skip_binders() { - crate::WhereClause::Implemented(trait_ref) => Some(trait_ref), + crate::WhereClause::Implemented(trait_ref) => Some(trait_ref.trait_id), _ => None, }) } diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/infer/closure.rs b/src/tools/rust-analyzer/crates/hir-ty/src/infer/closure.rs index 5a251683b962a..c59013beafbae 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/infer/closure.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/infer/closure.rs @@ -96,8 +96,8 @@ impl InferenceContext<'_> { .map(|b| b.into_value_and_skipped_binders().0); self.deduce_closure_kind_from_predicate_clauses(clauses) } - TyKind::Dyn(dyn_ty) => dyn_ty.principal().and_then(|trait_ref| { - self.fn_trait_kind_from_trait_id(from_chalk_trait_id(trait_ref.trait_id)) + TyKind::Dyn(dyn_ty) => dyn_ty.principal_id().and_then(|trait_id| { + self.fn_trait_kind_from_trait_id(from_chalk_trait_id(trait_id)) }), TyKind::InferenceVar(ty, chalk_ir::TyVariableKind::General) => { let clauses = self.clauses_for_self_ty(*ty); diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/lang_items.rs b/src/tools/rust-analyzer/crates/hir-ty/src/lang_items.rs index f704b59d303e5..ff9c52fbb6c17 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/lang_items.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/lang_items.rs @@ -13,6 +13,7 @@ pub fn is_box(db: &dyn HirDatabase, adt: AdtId) -> bool { pub fn is_unsafe_cell(db: &dyn HirDatabase, adt: AdtId) -> bool { let AdtId::StructId(id) = adt else { return false }; + db.struct_data(id).flags.contains(StructFlags::IS_UNSAFE_CELL) } diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/method_resolution.rs b/src/tools/rust-analyzer/crates/hir-ty/src/method_resolution.rs index 5a72b97653dbf..952580c3b70db 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/method_resolution.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/method_resolution.rs @@ -805,8 +805,8 @@ fn is_inherent_impl_coherent( | TyKind::Scalar(_) => def_map.is_rustc_coherence_is_core(), &TyKind::Adt(AdtId(adt), _) => adt.module(db.upcast()).krate() == def_map.krate(), - TyKind::Dyn(it) => it.principal().map_or(false, |trait_ref| { - from_chalk_trait_id(trait_ref.trait_id).module(db.upcast()).krate() == def_map.krate() + TyKind::Dyn(it) => it.principal_id().map_or(false, |trait_id| { + from_chalk_trait_id(trait_id).module(db.upcast()).krate() == def_map.krate() }), _ => true, @@ -834,9 +834,8 @@ fn is_inherent_impl_coherent( .contains(StructFlags::IS_RUSTC_HAS_INCOHERENT_INHERENT_IMPL), hir_def::AdtId::EnumId(it) => db.enum_data(it).rustc_has_incoherent_inherent_impls, }, - TyKind::Dyn(it) => it.principal().map_or(false, |trait_ref| { - db.trait_data(from_chalk_trait_id(trait_ref.trait_id)) - .rustc_has_incoherent_inherent_impls + TyKind::Dyn(it) => it.principal_id().map_or(false, |trait_id| { + db.trait_data(from_chalk_trait_id(trait_id)).rustc_has_incoherent_inherent_impls }), _ => false, @@ -896,8 +895,8 @@ pub fn check_orphan_rules(db: &dyn HirDatabase, impl_: ImplId) -> bool { match unwrap_fundamental(ty).kind(Interner) { &TyKind::Adt(AdtId(id), _) => is_local(id.module(db.upcast()).krate()), TyKind::Error => true, - TyKind::Dyn(it) => it.principal().map_or(false, |trait_ref| { - is_local(from_chalk_trait_id(trait_ref.trait_id).module(db.upcast()).krate()) + TyKind::Dyn(it) => it.principal_id().map_or(false, |trait_id| { + is_local(from_chalk_trait_id(trait_id).module(db.upcast()).krate()) }), _ => false, } From 4ea29d619cc435e6087ea07f9605180a1dd22d92 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Sat, 28 Dec 2024 15:08:26 +0100 Subject: [PATCH 3/8] Implement parameter variance inference --- .../crates/hir-ty/src/chalk_db.rs | 21 +- .../rust-analyzer/crates/hir-ty/src/db.rs | 4 + .../crates/hir-ty/src/generics.rs | 8 + .../rust-analyzer/crates/hir-ty/src/lib.rs | 1 + .../rust-analyzer/crates/hir-ty/src/tests.rs | 78 +- .../hir-ty/src/tests/closure_captures.rs | 7 + .../crates/hir-ty/src/tests/incremental.rs | 17 +- .../crates/hir-ty/src/variance.rs | 1172 +++++++++++++++++ 8 files changed, 1271 insertions(+), 37 deletions(-) create mode 100644 src/tools/rust-analyzer/crates/hir-ty/src/variance.rs diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/chalk_db.rs b/src/tools/rust-analyzer/crates/hir-ty/src/chalk_db.rs index 55d0edd5e0c91..0a2612219a467 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/chalk_db.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/chalk_db.rs @@ -950,11 +950,18 @@ pub(crate) fn fn_def_datum_query(db: &dyn HirDatabase, fn_def_id: FnDefId) -> Ar pub(crate) fn fn_def_variance_query(db: &dyn HirDatabase, fn_def_id: FnDefId) -> Variances { let callable_def: CallableDefId = from_chalk(db, fn_def_id); - let generic_params = - generics(db.upcast(), GenericDefId::from_callable(db.upcast(), callable_def)); Variances::from_iter( Interner, - std::iter::repeat(chalk_ir::Variance::Invariant).take(generic_params.len()), + db.variances_of(GenericDefId::from_callable(db.upcast(), callable_def)) + .as_deref() + .unwrap_or_default() + .iter() + .map(|v| match v { + crate::variance::Variance::Covariant => chalk_ir::Variance::Covariant, + crate::variance::Variance::Invariant => chalk_ir::Variance::Invariant, + crate::variance::Variance::Contravariant => chalk_ir::Variance::Contravariant, + crate::variance::Variance::Bivariant => chalk_ir::Variance::Invariant, + }), ) } @@ -962,10 +969,14 @@ pub(crate) fn adt_variance_query( db: &dyn HirDatabase, chalk_ir::AdtId(adt_id): AdtId, ) -> Variances { - let generic_params = generics(db.upcast(), adt_id.into()); Variances::from_iter( Interner, - std::iter::repeat(chalk_ir::Variance::Invariant).take(generic_params.len()), + db.variances_of(adt_id.into()).as_deref().unwrap_or_default().iter().map(|v| match v { + crate::variance::Variance::Covariant => chalk_ir::Variance::Covariant, + crate::variance::Variance::Invariant => chalk_ir::Variance::Invariant, + crate::variance::Variance::Contravariant => chalk_ir::Variance::Contravariant, + crate::variance::Variance::Bivariant => chalk_ir::Variance::Invariant, + }), ) } diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/db.rs b/src/tools/rust-analyzer/crates/hir-ty/src/db.rs index 6856eaa3e02f0..6b05682667087 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/db.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/db.rs @@ -271,6 +271,10 @@ pub trait HirDatabase: DefDatabase + Upcast { #[ra_salsa::invoke(chalk_db::adt_variance_query)] fn adt_variance(&self, adt_id: chalk_db::AdtId) -> chalk_db::Variances; + #[ra_salsa::invoke(crate::variance::variances_of)] + #[ra_salsa::cycle(crate::variance::variances_of_cycle)] + fn variances_of(&self, def: GenericDefId) -> Option>; + #[ra_salsa::invoke(chalk_db::associated_ty_value_query)] fn associated_ty_value( &self, diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/generics.rs b/src/tools/rust-analyzer/crates/hir-ty/src/generics.rs index fe7541d237478..e7a2721afee5c 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/generics.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/generics.rs @@ -132,6 +132,14 @@ impl Generics { self.params.len() } + pub(crate) fn len_self_lifetimes(&self) -> usize { + self.params.len_lifetimes() + } + + pub(crate) fn has_trait_self(&self) -> bool { + self.params.trait_self_param().is_some() + } + /// (parent total, self param, type params, const params, impl trait list, lifetimes) pub(crate) fn provenance_split(&self) -> (usize, bool, usize, usize, usize, usize) { let mut self_param = false; diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/lib.rs b/src/tools/rust-analyzer/crates/hir-ty/src/lib.rs index 224fcf313a4b5..88134f564c018 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/lib.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/lib.rs @@ -50,6 +50,7 @@ pub mod traits; mod test_db; #[cfg(test)] mod tests; +mod variance; use std::hash::Hash; diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/tests.rs b/src/tools/rust-analyzer/crates/hir-ty/src/tests.rs index cabeeea2bd86d..b7607b5f6396b 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/tests.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/tests.rs @@ -127,7 +127,15 @@ fn check_impl(ra_fixture: &str, allow_none: bool, only_types: bool, display_sour None => continue, }; let def_map = module.def_map(&db); - visit_module(&db, &def_map, module.local_id, &mut |it| defs.push(it)); + visit_module(&db, &def_map, module.local_id, &mut |it| { + defs.push(match it { + ModuleDefId::FunctionId(it) => it.into(), + ModuleDefId::EnumVariantId(it) => it.into(), + ModuleDefId::ConstId(it) => it.into(), + ModuleDefId::StaticId(it) => it.into(), + _ => return, + }) + }); } defs.sort_by_key(|def| match def { DefWithBodyId::FunctionId(it) => { @@ -375,7 +383,15 @@ fn infer_with_mismatches(content: &str, include_mismatches: bool) -> String { let def_map = module.def_map(&db); let mut defs: Vec = Vec::new(); - visit_module(&db, &def_map, module.local_id, &mut |it| defs.push(it)); + visit_module(&db, &def_map, module.local_id, &mut |it| { + defs.push(match it { + ModuleDefId::FunctionId(it) => it.into(), + ModuleDefId::EnumVariantId(it) => it.into(), + ModuleDefId::ConstId(it) => it.into(), + ModuleDefId::StaticId(it) => it.into(), + _ => return, + }) + }); defs.sort_by_key(|def| match def { DefWithBodyId::FunctionId(it) => { let loc = it.lookup(&db); @@ -405,11 +421,11 @@ fn infer_with_mismatches(content: &str, include_mismatches: bool) -> String { buf } -fn visit_module( +pub(crate) fn visit_module( db: &TestDB, crate_def_map: &DefMap, module_id: LocalModuleId, - cb: &mut dyn FnMut(DefWithBodyId), + cb: &mut dyn FnMut(ModuleDefId), ) { visit_scope(db, crate_def_map, &crate_def_map[module_id].scope, cb); for impl_id in crate_def_map[module_id].scope.impls() { @@ -417,18 +433,18 @@ fn visit_module( for &item in impl_data.items.iter() { match item { AssocItemId::FunctionId(it) => { - let def = it.into(); - cb(def); - let body = db.body(def); + let body = db.body(it.into()); + cb(it.into()); visit_body(db, &body, cb); } AssocItemId::ConstId(it) => { - let def = it.into(); - cb(def); - let body = db.body(def); + let body = db.body(it.into()); + cb(it.into()); visit_body(db, &body, cb); } - AssocItemId::TypeAliasId(_) => (), + AssocItemId::TypeAliasId(it) => { + cb(it.into()); + } } } } @@ -437,33 +453,27 @@ fn visit_module( db: &TestDB, crate_def_map: &DefMap, scope: &ItemScope, - cb: &mut dyn FnMut(DefWithBodyId), + cb: &mut dyn FnMut(ModuleDefId), ) { for decl in scope.declarations() { + cb(decl); match decl { ModuleDefId::FunctionId(it) => { - let def = it.into(); - cb(def); - let body = db.body(def); + let body = db.body(it.into()); visit_body(db, &body, cb); } ModuleDefId::ConstId(it) => { - let def = it.into(); - cb(def); - let body = db.body(def); + let body = db.body(it.into()); visit_body(db, &body, cb); } ModuleDefId::StaticId(it) => { - let def = it.into(); - cb(def); - let body = db.body(def); + let body = db.body(it.into()); visit_body(db, &body, cb); } ModuleDefId::AdtId(hir_def::AdtId::EnumId(it)) => { db.enum_data(it).variants.iter().for_each(|&(it, _)| { - let def = it.into(); - cb(def); - let body = db.body(def); + let body = db.body(it.into()); + cb(it.into()); visit_body(db, &body, cb); }); } @@ -473,7 +483,7 @@ fn visit_module( match item { AssocItemId::FunctionId(it) => cb(it.into()), AssocItemId::ConstId(it) => cb(it.into()), - AssocItemId::TypeAliasId(_) => (), + AssocItemId::TypeAliasId(it) => cb(it.into()), } } } @@ -483,7 +493,7 @@ fn visit_module( } } - fn visit_body(db: &TestDB, body: &Body, cb: &mut dyn FnMut(DefWithBodyId)) { + fn visit_body(db: &TestDB, body: &Body, cb: &mut dyn FnMut(ModuleDefId)) { for (_, def_map) in body.blocks(db) { for (mod_id, _) in def_map.modules() { visit_module(db, &def_map, mod_id, cb); @@ -553,7 +563,13 @@ fn salsa_bug() { let module = db.module_for_file(pos.file_id); let crate_def_map = module.def_map(&db); visit_module(&db, &crate_def_map, module.local_id, &mut |def| { - db.infer(def); + db.infer(match def { + ModuleDefId::FunctionId(it) => it.into(), + ModuleDefId::EnumVariantId(it) => it.into(), + ModuleDefId::ConstId(it) => it.into(), + ModuleDefId::StaticId(it) => it.into(), + _ => return, + }); }); let new_text = " @@ -586,6 +602,12 @@ fn salsa_bug() { let module = db.module_for_file(pos.file_id); let crate_def_map = module.def_map(&db); visit_module(&db, &crate_def_map, module.local_id, &mut |def| { - db.infer(def); + db.infer(match def { + ModuleDefId::FunctionId(it) => it.into(), + ModuleDefId::EnumVariantId(it) => it.into(), + ModuleDefId::ConstId(it) => it.into(), + ModuleDefId::StaticId(it) => it.into(), + _ => return, + }); }); } diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/tests/closure_captures.rs b/src/tools/rust-analyzer/crates/hir-ty/src/tests/closure_captures.rs index b63d632dd26ca..7de92d6b16078 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/tests/closure_captures.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/tests/closure_captures.rs @@ -24,6 +24,13 @@ fn check_closure_captures(ra_fixture: &str, expect: Expect) { let mut captures_info = Vec::new(); for def in defs { + let def = match def { + hir_def::ModuleDefId::FunctionId(it) => it.into(), + hir_def::ModuleDefId::EnumVariantId(it) => it.into(), + hir_def::ModuleDefId::ConstId(it) => it.into(), + hir_def::ModuleDefId::StaticId(it) => it.into(), + _ => continue, + }; let infer = db.infer(def); let db = &db; captures_info.extend(infer.closure_info.iter().flat_map(|(closure_id, (captures, _))| { diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/tests/incremental.rs b/src/tools/rust-analyzer/crates/hir-ty/src/tests/incremental.rs index 0a24eeb1fe82d..3757d722ac83b 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/tests/incremental.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/tests/incremental.rs @@ -1,4 +1,5 @@ use base_db::SourceDatabaseFileInputExt as _; +use hir_def::ModuleDefId; use test_fixture::WithFixture; use crate::{db::HirDatabase, test_db::TestDB}; @@ -19,7 +20,9 @@ fn foo() -> i32 { let module = db.module_for_file(pos.file_id.file_id()); let crate_def_map = module.def_map(&db); visit_module(&db, &crate_def_map, module.local_id, &mut |def| { - db.infer(def); + if let ModuleDefId::FunctionId(it) = def { + db.infer(it.into()); + } }); }); assert!(format!("{events:?}").contains("infer")) @@ -39,7 +42,9 @@ fn foo() -> i32 { let module = db.module_for_file(pos.file_id.file_id()); let crate_def_map = module.def_map(&db); visit_module(&db, &crate_def_map, module.local_id, &mut |def| { - db.infer(def); + if let ModuleDefId::FunctionId(it) = def { + db.infer(it.into()); + } }); }); assert!(!format!("{events:?}").contains("infer"), "{events:#?}") @@ -66,7 +71,9 @@ fn baz() -> i32 { let module = db.module_for_file(pos.file_id.file_id()); let crate_def_map = module.def_map(&db); visit_module(&db, &crate_def_map, module.local_id, &mut |def| { - db.infer(def); + if let ModuleDefId::FunctionId(it) = def { + db.infer(it.into()); + } }); }); assert!(format!("{events:?}").contains("infer")) @@ -91,7 +98,9 @@ fn baz() -> i32 { let module = db.module_for_file(pos.file_id.file_id()); let crate_def_map = module.def_map(&db); visit_module(&db, &crate_def_map, module.local_id, &mut |def| { - db.infer(def); + if let ModuleDefId::FunctionId(it) = def { + db.infer(it.into()); + } }); }); assert!(format!("{events:?}").matches("infer").count() == 1, "{events:#?}") diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/variance.rs b/src/tools/rust-analyzer/crates/hir-ty/src/variance.rs new file mode 100644 index 0000000000000..2339e37fa1409 --- /dev/null +++ b/src/tools/rust-analyzer/crates/hir-ty/src/variance.rs @@ -0,0 +1,1172 @@ +//! Module for inferring the variance of type and lifetime parameters. See the [rustc dev guide] +//! chapter for more info. +//! +//! [rustc dev guide]: https://rustc-dev-guide.rust-lang.org/variance.html + +use crate::db::HirDatabase; +use crate::generics::{generics, Generics}; +use crate::{ + AliasTy, Const, ConstScalar, DynTyExt, FnPointer, GenericArg, GenericArgData, Interner, + Lifetime, LifetimeData, Ty, TyKind, +}; +use base_db::ra_salsa::Cycle; +use chalk_ir::Mutability; +use hir_def::data::adt::StructFlags; +use hir_def::{AdtId, GenericDefId, GenericParamId, VariantId}; +use std::fmt; +use std::ops::Not; +use triomphe::Arc; + +pub(crate) fn variances_of(db: &dyn HirDatabase, def: GenericDefId) -> Option> { + tracing::debug!("variances_of(def={:?})", def); + match def { + GenericDefId::FunctionId(_) => (), + GenericDefId::AdtId(adt) => { + if let AdtId::StructId(id) = adt { + let flags = &db.struct_data(id).flags; + if flags.contains(StructFlags::IS_UNSAFE_CELL) { + return Some(Arc::from_iter(vec![Variance::Invariant; 1])); + } else if flags.contains(StructFlags::IS_PHANTOM_DATA) { + return Some(Arc::from_iter(vec![Variance::Covariant; 1])); + } + } + } + _ => return None, + } + + let generics = generics(db.upcast(), def); + let count = generics.len(); + if count == 0 { + return None; + } + let mut ctxt = Context { + def, + has_trait_self: generics.parent_generics().map_or(false, |it| it.has_trait_self()), + len_self: generics.len_self(), + len_self_lifetimes: generics.len_self_lifetimes(), + generics, + constraints: Vec::new(), + db, + }; + + ctxt.build_constraints_for_item(); + let res = ctxt.solve(); + res.is_empty().not().then(|| Arc::from_iter(res)) +} + +pub(crate) fn variances_of_cycle( + db: &dyn HirDatabase, + _cycle: &Cycle, + def: &GenericDefId, +) -> Option> { + let generics = generics(db.upcast(), *def); + let count = generics.len(); + + if count == 0 { + return None; + } + Some(Arc::from(vec![Variance::Bivariant; count])) +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub enum Variance { + Covariant, // T <: T iff A <: B -- e.g., function return type + Invariant, // T <: T iff B == A -- e.g., type of mutable cell + Contravariant, // T <: T iff B <: A -- e.g., function param type + Bivariant, // T <: T -- e.g., unused type parameter +} + +impl fmt::Display for Variance { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Variance::Covariant => write!(f, "covariant"), + Variance::Invariant => write!(f, "invariant"), + Variance::Contravariant => write!(f, "contravariant"), + Variance::Bivariant => write!(f, "bivariant"), + } + } +} + +impl Variance { + /// `a.xform(b)` combines the variance of a context with the + /// variance of a type with the following meaning. If we are in a + /// context with variance `a`, and we encounter a type argument in + /// a position with variance `b`, then `a.xform(b)` is the new + /// variance with which the argument appears. + /// + /// Example 1: + /// ```ignore (illustrative) + /// *mut Vec + /// ``` + /// Here, the "ambient" variance starts as covariant. `*mut T` is + /// invariant with respect to `T`, so the variance in which the + /// `Vec` appears is `Covariant.xform(Invariant)`, which + /// yields `Invariant`. Now, the type `Vec` is covariant with + /// respect to its type argument `T`, and hence the variance of + /// the `i32` here is `Invariant.xform(Covariant)`, which results + /// (again) in `Invariant`. + /// + /// Example 2: + /// ```ignore (illustrative) + /// fn(*const Vec, *mut Vec` appears is + /// `Contravariant.xform(Covariant)` or `Contravariant`. The same + /// is true for its `i32` argument. In the `*mut T` case, the + /// variance of `Vec` is `Contravariant.xform(Invariant)`, + /// and hence the outermost type is `Invariant` with respect to + /// `Vec` (and its `i32` argument). + /// + /// Source: Figure 1 of "Taming the Wildcards: + /// Combining Definition- and Use-Site Variance" published in PLDI'11. + fn xform(self, v: Variance) -> Variance { + match (self, v) { + // Figure 1, column 1. + (Variance::Covariant, Variance::Covariant) => Variance::Covariant, + (Variance::Covariant, Variance::Contravariant) => Variance::Contravariant, + (Variance::Covariant, Variance::Invariant) => Variance::Invariant, + (Variance::Covariant, Variance::Bivariant) => Variance::Bivariant, + + // Figure 1, column 2. + (Variance::Contravariant, Variance::Covariant) => Variance::Contravariant, + (Variance::Contravariant, Variance::Contravariant) => Variance::Covariant, + (Variance::Contravariant, Variance::Invariant) => Variance::Invariant, + (Variance::Contravariant, Variance::Bivariant) => Variance::Bivariant, + + // Figure 1, column 3. + (Variance::Invariant, _) => Variance::Invariant, + + // Figure 1, column 4. + (Variance::Bivariant, _) => Variance::Bivariant, + } + } + + fn glb(self, v: Variance) -> Variance { + // Greatest lower bound of the variance lattice as + // defined in The Paper: + // + // * + // - + + // o + match (self, v) { + (Variance::Invariant, _) | (_, Variance::Invariant) => Variance::Invariant, + + (Variance::Covariant, Variance::Contravariant) => Variance::Invariant, + (Variance::Contravariant, Variance::Covariant) => Variance::Invariant, + + (Variance::Covariant, Variance::Covariant) => Variance::Covariant, + + (Variance::Contravariant, Variance::Contravariant) => Variance::Contravariant, + + (x, Variance::Bivariant) | (Variance::Bivariant, x) => x, + } + } +} +#[derive(Copy, Clone, Debug)] +struct InferredIndex(usize); + +#[derive(Clone)] +enum VarianceTerm { + ConstantTerm(Variance), + TransformTerm(Box, Box), +} + +impl fmt::Debug for VarianceTerm { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + VarianceTerm::ConstantTerm(c1) => write!(f, "{c1:?}"), + VarianceTerm::TransformTerm(v1, v2) => write!(f, "({v1:?} \u{00D7} {v2:?})"), + } + } +} + +struct Context<'db> { + db: &'db dyn HirDatabase, + def: GenericDefId, + has_trait_self: bool, + len_self: usize, + len_self_lifetimes: usize, + generics: Generics, + constraints: Vec, +} + +/// Declares that the variable `decl_id` appears in a location with +/// variance `variance`. +#[derive(Clone)] +struct Constraint { + inferred: InferredIndex, + variance: VarianceTerm, +} + +impl Context<'_> { + fn build_constraints_for_item(&mut self) { + match self.def { + GenericDefId::AdtId(adt) => { + let db = self.db; + let mut add_constraints_from_variant = |variant| { + let subst = self.generics.placeholder_subst(db); + for (_, field) in db.field_types(variant).iter() { + self.add_constraints_from_ty( + &field.clone().substitute(Interner, &subst), + &VarianceTerm::ConstantTerm(Variance::Covariant), + ); + } + }; + match adt { + AdtId::StructId(s) => add_constraints_from_variant(VariantId::StructId(s)), + AdtId::UnionId(u) => add_constraints_from_variant(VariantId::UnionId(u)), + AdtId::EnumId(e) => { + db.enum_data(e).variants.iter().for_each(|&(variant, _)| { + add_constraints_from_variant(VariantId::EnumVariantId(variant)) + }); + } + } + } + GenericDefId::FunctionId(f) => { + let subst = self.generics.placeholder_subst(self.db); + self.add_constraints_from_sig2( + &self + .db + .callable_item_signature(f.into()) + .substitute(Interner, &subst) + .params_and_return, + &VarianceTerm::ConstantTerm(Variance::Covariant), + ); + } + _ => {} + } + } + + fn contravariant(&mut self, variance: &VarianceTerm) -> VarianceTerm { + self.xform(variance, &VarianceTerm::ConstantTerm(Variance::Contravariant)) + } + + fn invariant(&mut self, variance: &VarianceTerm) -> VarianceTerm { + self.xform(variance, &VarianceTerm::ConstantTerm(Variance::Invariant)) + } + + fn xform(&mut self, v1: &VarianceTerm, v2: &VarianceTerm) -> VarianceTerm { + match (v1, v2) { + // Applying a "covariant" transform is always a no-op + (_, VarianceTerm::ConstantTerm(Variance::Covariant)) => v1.clone(), + (VarianceTerm::ConstantTerm(c1), VarianceTerm::ConstantTerm(c2)) => { + VarianceTerm::ConstantTerm(c1.xform(*c2)) + } + _ => VarianceTerm::TransformTerm(Box::new(v1.clone()), Box::new(v2.clone())), + } + } + + fn add_constraints_from_invariant_args( + &mut self, + args: &[GenericArg], + variance: &VarianceTerm, + ) { + tracing::debug!( + "add_constraints_from_invariant_args(args={:?}, variance={:?})", + args, + variance + ); + let variance_i = self.invariant(variance); + + for k in args { + match k.data(Interner) { + GenericArgData::Lifetime(lt) => self.add_constraints_from_region(lt, &variance_i), + GenericArgData::Ty(ty) => self.add_constraints_from_ty(ty, &variance_i), + GenericArgData::Const(val) => self.add_constraints_from_const(val, &variance_i), + } + } + } + + /// Adds constraints appropriate for an instance of `ty` appearing + /// in a context with the generics defined in `generics` and + /// ambient variance `variance` + fn add_constraints_from_ty(&mut self, ty: &Ty, variance: &VarianceTerm) { + tracing::debug!("add_constraints_from_ty(ty={:?}, variance={:?})", ty, variance); + match ty.kind(Interner) { + TyKind::Scalar(_) | TyKind::Never | TyKind::Str | TyKind::Foreign(..) => { + // leaf type -- noop + } + + TyKind::FnDef(..) | TyKind::Coroutine(..) | TyKind::Closure(..) => { + panic!("Unexpected unnameable type in variance computation: {ty:?}"); + } + + TyKind::Ref(mutbl, lifetime, ty) => { + self.add_constraints_from_region(lifetime, variance); + self.add_constraints_from_mt(ty, *mutbl, variance); + } + + TyKind::Array(typ, len) => { + self.add_constraints_from_const(len, variance); + self.add_constraints_from_ty(typ, variance); + } + + TyKind::Slice(typ) => { + self.add_constraints_from_ty(typ, variance); + } + + TyKind::Raw(mutbl, ty) => { + self.add_constraints_from_mt(ty, *mutbl, variance); + } + + TyKind::Tuple(_, subtys) => { + for subty in subtys.type_parameters(Interner) { + self.add_constraints_from_ty(&subty, variance); + } + } + + TyKind::Adt(def, args) => { + self.add_constraints_from_args(def.0.into(), args.as_slice(Interner), variance); + } + + TyKind::Alias(AliasTy::Opaque(opaque)) => { + self.add_constraints_from_invariant_args( + opaque.substitution.as_slice(Interner), + variance, + ); + } + TyKind::Alias(AliasTy::Projection(proj)) => { + self.add_constraints_from_invariant_args( + proj.substitution.as_slice(Interner), + variance, + ); + } + // FIXME: check this + TyKind::AssociatedType(_, subst) => { + self.add_constraints_from_invariant_args(subst.as_slice(Interner), variance); + } + // FIXME: check this + TyKind::OpaqueType(_, subst) => { + self.add_constraints_from_invariant_args(subst.as_slice(Interner), variance); + } + + TyKind::Dyn(it) => { + // The type `dyn Trait +'a` is covariant w/r/t `'a`: + self.add_constraints_from_region(&it.lifetime, variance); + + if let Some(trait_ref) = it.principal() { + // Trait are always invariant so we can take advantage of that. + self.add_constraints_from_invariant_args( + trait_ref + .map(|it| it.map(|it| it.substitution.clone())) + .substitute( + Interner, + &[GenericArg::new( + Interner, + chalk_ir::GenericArgData::Ty(TyKind::Error.intern(Interner)), + )], + ) + .skip_binders() + .as_slice(Interner), + variance, + ); + } + + // FIXME + // for projection in data.projection_bounds() { + // match projection.skip_binder().term.unpack() { + // TyKind::TermKind::Ty(ty) => { + // self.add_constraints_from_ty( ty, self.invariant); + // } + // TyKind::TermKind::Const(c) => { + // self.add_constraints_from_const( c, self.invariant) + // } + // } + // } + } + + // Chalk has no params, so use placeholders for now? + TyKind::Placeholder(index) => { + let idx = crate::from_placeholder_idx(self.db, *index); + let index = idx.local_id.into_raw().into_u32() as usize + self.len_self_lifetimes; + let inferred = if idx.parent == self.def { + InferredIndex(self.has_trait_self as usize + index) + } else { + InferredIndex(self.len_self + index) + }; + tracing::debug!("add_constraint(index={:?}, variance={:?})", inferred, variance); + self.constraints.push(Constraint { inferred, variance: variance.clone() }); + } + TyKind::Function(f) => { + self.add_constraints_from_sig(f, variance); + } + + TyKind::Error => { + // we encounter this when walking the trait references for object + // types, where we use Error as the Self type + } + + TyKind::CoroutineWitness(..) | TyKind::BoundVar(..) | TyKind::InferenceVar(..) => { + panic!("unexpected type encountered in variance inference: {:?}", ty); + } + } + } + + /// Adds constraints appropriate for a nominal type (enum, struct, + /// object, etc) appearing in a context with ambient variance `variance` + fn add_constraints_from_args( + &mut self, + def_id: GenericDefId, + args: &[GenericArg], + variance: &VarianceTerm, + ) { + tracing::debug!( + "add_constraints_from_args(def_id={:?}, args={:?}, variance={:?})", + def_id, + args, + variance + ); + + // We don't record `inferred_starts` entries for empty generics. + if args.is_empty() { + return; + } + if def_id == self.def { + // HACK: Workaround for the trivial cycle salsa case (see + // recursive_one_bivariant_more_non_bivariant_params test) + let variance_i = self.xform(variance, &VarianceTerm::ConstantTerm(Variance::Bivariant)); + for k in args { + match k.data(Interner) { + GenericArgData::Lifetime(lt) => { + self.add_constraints_from_region(lt, &variance_i) + } + GenericArgData::Ty(ty) => self.add_constraints_from_ty(ty, &variance_i), + GenericArgData::Const(val) => self.add_constraints_from_const(val, variance), + } + } + } else { + let Some(variances) = self.db.variances_of(def_id) else { + return; + }; + + for (i, k) in args.iter().enumerate() { + let variance_decl = &VarianceTerm::ConstantTerm(variances[i]); + let variance_i = self.xform(variance, variance_decl); + match k.data(Interner) { + GenericArgData::Lifetime(lt) => { + self.add_constraints_from_region(lt, &variance_i) + } + GenericArgData::Ty(ty) => self.add_constraints_from_ty(ty, &variance_i), + GenericArgData::Const(val) => self.add_constraints_from_const(val, variance), + } + } + } + } + + /// Adds constraints appropriate for a const expression `val` + /// in a context with ambient variance `variance` + fn add_constraints_from_const(&mut self, c: &Const, variance: &VarianceTerm) { + match &c.data(Interner).value { + chalk_ir::ConstValue::Concrete(c) => { + if let ConstScalar::UnevaluatedConst(_, subst) = &c.interned { + self.add_constraints_from_invariant_args(subst.as_slice(Interner), variance); + } + } + _ => {} + } + } + + /// Adds constraints appropriate for a function with signature + /// `sig` appearing in a context with ambient variance `variance` + fn add_constraints_from_sig(&mut self, sig: &FnPointer, variance: &VarianceTerm) { + let contra = self.contravariant(variance); + let mut tys = sig.substitution.0.iter(Interner).filter_map(move |p| p.ty(Interner)); + self.add_constraints_from_ty(tys.next_back().unwrap(), variance); + for input in tys { + self.add_constraints_from_ty(input, &contra); + } + } + + fn add_constraints_from_sig2(&mut self, sig: &[Ty], variance: &VarianceTerm) { + let contra = self.contravariant(variance); + let mut tys = sig.iter(); + self.add_constraints_from_ty(tys.next_back().unwrap(), variance); + for input in tys { + self.add_constraints_from_ty(input, &contra); + } + } + + /// Adds constraints appropriate for a region appearing in a + /// context with ambient variance `variance` + fn add_constraints_from_region(&mut self, region: &Lifetime, variance: &VarianceTerm) { + match region.data(Interner) { + // FIXME: chalk has no params? + LifetimeData::Placeholder(index) => { + let idx = crate::lt_from_placeholder_idx(self.db, *index); + let index = idx.local_id.into_raw().into_u32() as usize; + let inferred = if idx.parent == self.def { + InferredIndex(index) + } else { + InferredIndex(self.has_trait_self as usize + self.len_self + index) + }; + tracing::debug!("add_constraint(index={:?}, variance={:?})", inferred, variance); + self.constraints.push(Constraint { inferred, variance: variance.clone() }); + } + LifetimeData::Static => {} + + LifetimeData::BoundVar(..) => { + // Either a higher-ranked region inside of a type or a + // late-bound function parameter. + // + // We do not compute constraints for either of these. + } + + LifetimeData::Error => {} + + LifetimeData::Phantom(..) | LifetimeData::InferenceVar(..) | LifetimeData::Erased => { + // We don't expect to see anything but 'static or bound + // regions when visiting member types or method types. + panic!( + "unexpected region encountered in variance \ + inference: {:?}", + region + ); + } + } + } + + /// Adds constraints appropriate for a mutability-type pair + /// appearing in a context with ambient variance `variance` + fn add_constraints_from_mt(&mut self, ty: &Ty, mt: Mutability, variance: &VarianceTerm) { + match mt { + Mutability::Mut => { + let invar = self.invariant(variance); + self.add_constraints_from_ty(ty, &invar); + } + + Mutability::Not => { + self.add_constraints_from_ty(ty, variance); + } + } + } +} + +impl Context<'_> { + fn solve(self) -> Vec { + let mut solutions = vec![Variance::Bivariant; self.generics.len()]; + // Propagate constraints until a fixed point is reached. Note + // that the maximum number of iterations is 2C where C is the + // number of constraints (each variable can change values at most + // twice). Since number of constraints is linear in size of the + // input, so is the inference process. + let mut changed = true; + while changed { + changed = false; + + for constraint in &self.constraints { + let Constraint { inferred, variance: term } = constraint; + let InferredIndex(inferred) = inferred; + let variance = Self::evaluate(term); + let old_value = solutions[*inferred]; + let new_value = variance.glb(old_value); + if old_value != new_value { + solutions[*inferred] = new_value; + changed = true; + } + } + } + + // Const parameters are always invariant. + // Make all const parameters invariant. + for (idx, param) in self.generics.iter_id().enumerate() { + if let GenericParamId::ConstParamId(_) = param { + solutions[idx] = Variance::Invariant; + } + } + + // Functions are permitted to have unused generic parameters: make those invariant. + if let GenericDefId::FunctionId(_) = self.def { + for variance in &mut solutions { + if *variance == Variance::Bivariant { + *variance = Variance::Invariant; + } + } + } + + solutions + } + + fn evaluate(term: &VarianceTerm) -> Variance { + match term { + VarianceTerm::ConstantTerm(v) => *v, + VarianceTerm::TransformTerm(t1, t2) => { + let v1 = Self::evaluate(t1); + let v2 = Self::evaluate(t2); + v1.xform(v2) + } + } + } +} + +#[cfg(test)] +mod tests { + use expect_test::{expect, Expect}; + use hir_def::{ + generics::GenericParamDataRef, src::HasSource, AdtId, GenericDefId, ModuleDefId, + }; + use itertools::Itertools; + use stdx::format_to; + use syntax::{ast::HasName, AstNode}; + use test_fixture::WithFixture; + + use hir_def::Lookup; + + use crate::{db::HirDatabase, test_db::TestDB, variance::generics}; + + #[test] + fn phantom_data() { + check( + r#" +//- minicore: phantom_data + +struct Covariant { + t: core::marker::PhantomData +} +"#, + expect![[r#" + Covariant[A: covariant] + "#]], + ); + } + + #[test] + fn rustc_test_variance_types() { + check( + r#" +//- minicore: cell + +use core::cell::UnsafeCell; + +struct InvariantMut<'a,A:'a,B:'a> { //~ ERROR ['a: +, A: o, B: o] + t: &'a mut (A,B) +} + +struct InvariantCell { //~ ERROR [A: o] + t: UnsafeCell +} + +struct InvariantIndirect { //~ ERROR [A: o] + t: InvariantCell +} + +struct Covariant { //~ ERROR [A: +] + t: A, u: fn() -> A +} + +struct Contravariant { //~ ERROR [A: -] + t: fn(A) +} + +enum Enum { //~ ERROR [A: +, B: -, C: o] + Foo(Covariant), + Bar(Contravariant),` + Zed(Covariant,Contravariant) +} +"#, + expect![[r#" + InvariantMut['a: covariant, A: invariant, B: invariant] + InvariantCell[A: invariant] + InvariantIndirect[A: invariant] + Covariant[A: covariant] + Contravariant[A: contravariant] + Enum[A: covariant, B: contravariant, C: invariant] + "#]], + ); + } + + #[test] + fn type_resolve_error_two_structs_deep() { + check( + r#" +struct Hello<'a> { + missing: Missing<'a>, +} + +struct Other<'a> { + hello: Hello<'a>, +} +"#, + expect![[r#" + Hello['a: bivariant] + Other['a: bivariant] + "#]], + ); + } + + #[test] + fn rustc_test_variance_associated_consts() { + // FIXME: Should be invariant + check( + r#" +trait Trait { + const Const: usize; +} + +struct Foo { //~ ERROR [T: o] + field: [u8; ::Const] +} +"#, + expect![[r#" + Foo[T: bivariant] + "#]], + ); + } + + #[test] + fn rustc_test_variance_associated_types() { + check( + r#" +trait Trait<'a> { + type Type; + + fn method(&'a self) { } +} + +struct Foo<'a, T : Trait<'a>> { //~ ERROR ['a: +, T: +] + field: (T, &'a ()) +} + +struct Bar<'a, T : Trait<'a>> { //~ ERROR ['a: o, T: o] + field: >::Type +} + +"#, + expect![[r#" + method[Self: contravariant, 'a: contravariant] + Foo['a: covariant, T: covariant] + Bar['a: invariant, T: invariant] + "#]], + ); + } + + #[test] + fn rustc_test_variance_associated_types2() { + // FIXME: RPITs have variance, but we can't treat them as their own thing right now + check( + r#" +trait Foo { + type Bar; +} + +fn make() -> *const dyn Foo {} +"#, + expect![""], + ); + } + + #[test] + fn rustc_test_variance_trait_bounds() { + check( + r#" +trait Getter { + fn get(&self) -> T; +} + +trait Setter { + fn get(&self, _: T); +} + +struct TestStruct> { //~ ERROR [U: +, T: +] + t: T, u: U +} + +enum TestEnum> { //~ ERROR [U: *, T: +] + //~^ ERROR: `U` is never used + Foo(T) +} + +struct TestContraStruct> { //~ ERROR [U: *, T: +] + //~^ ERROR: `U` is never used + t: T +} + +struct TestBox+Setter> { //~ ERROR [U: *, T: +] + //~^ ERROR: `U` is never used + t: T +} +"#, + expect![[r#" + get[Self: contravariant, T: covariant] + get[Self: contravariant, T: contravariant] + TestStruct[U: covariant, T: covariant] + TestEnum[U: bivariant, T: covariant] + TestContraStruct[U: bivariant, T: covariant] + TestBox[U: bivariant, T: covariant] + "#]], + ); + } + + #[test] + fn rustc_test_variance_trait_matching() { + check( + r#" + +trait Get { + fn get(&self) -> T; +} + +struct Cloner { + t: T +} + +impl Get for Cloner { + fn get(&self) -> T {} +} + +fn get<'a, G>(get: &G) -> i32 + where G : Get<&'a i32> +{} + +fn pick<'b, G>(get: &'b G, if_odd: &'b i32) -> i32 + where G : Get<&'b i32> +{} +"#, + expect![[r#" + get[Self: contravariant, T: covariant] + Cloner[T: covariant] + get[T: invariant] + get['a: invariant, G: contravariant] + pick['b: contravariant, G: contravariant] + "#]], + ); + } + + #[test] + fn rustc_test_variance_trait_object_bound() { + check( + r#" +enum Option { + Some(T), + None +} +trait T { fn foo(&self); } + +struct TOption<'a> { //~ ERROR ['a: +] + v: Option<*const (dyn T + 'a)>, +} +"#, + expect![[r#" + Option[T: covariant] + foo[Self: contravariant] + TOption['a: covariant] + "#]], + ); + } + + #[test] + fn rustc_test_variance_types_bounds() { + check( + r#" +//- minicore: send +struct TestImm { //~ ERROR [A: +, B: +] + x: A, + y: B, +} + +struct TestMut { //~ ERROR [A: +, B: o] + x: A, + y: &'static mut B, +} + +struct TestIndirect { //~ ERROR [A: +, B: o] + m: TestMut +} + +struct TestIndirect2 { //~ ERROR [A: o, B: o] + n: TestMut, + m: TestMut +} + +trait Getter { + fn get(&self) -> A; +} + +trait Setter { + fn set(&mut self, a: A); +} + +struct TestObject { //~ ERROR [A: o, R: o] + n: *const (dyn Setter + Send), + m: *const (dyn Getter + Send), +} +"#, + expect![[r#" + TestImm[A: covariant, B: covariant] + TestMut[A: covariant, B: invariant] + TestIndirect[A: covariant, B: invariant] + TestIndirect2[A: invariant, B: invariant] + get[Self: contravariant, A: covariant] + set[Self: invariant, A: contravariant] + TestObject[A: invariant, R: invariant] + "#]], + ); + } + + #[test] + fn rustc_test_variance_unused_region_param() { + check( + r#" +struct SomeStruct<'a> { x: u32 } //~ ERROR parameter `'a` is never used +enum SomeEnum<'a> { Nothing } //~ ERROR parameter `'a` is never used +trait SomeTrait<'a> { fn foo(&self); } // OK on traits. +"#, + expect![[r#" + SomeStruct['a: bivariant] + SomeEnum['a: bivariant] + foo[Self: contravariant, 'a: invariant] + "#]], + ); + } + + #[test] + fn rustc_test_variance_unused_type_param() { + check( + r#" +//- minicore: sized +struct SomeStruct { x: u32 } +enum SomeEnum { Nothing } +enum ListCell { + Cons(*const ListCell), + Nil +} + +struct SelfTyAlias(*const Self); +struct WithBounds {} +struct WithWhereBounds where T: Sized {} +struct WithOutlivesBounds {} +struct DoubleNothing { + s: SomeStruct, +} + +"#, + expect![[r#" + SomeStruct[A: bivariant] + SomeEnum[A: bivariant] + ListCell[T: bivariant] + SelfTyAlias[T: bivariant] + WithBounds[T: bivariant] + WithWhereBounds[T: bivariant] + WithOutlivesBounds[T: bivariant] + DoubleNothing[T: bivariant] + "#]], + ); + } + + #[test] + fn rustc_test_variance_use_contravariant_struct1() { + check( + r#" +struct SomeStruct(fn(T)); + +fn foo<'min,'max>(v: SomeStruct<&'max ()>) + -> SomeStruct<&'min ()> + where 'max : 'min +{} +"#, + expect![[r#" + SomeStruct[T: contravariant] + foo['min: contravariant, 'max: covariant] + "#]], + ); + } + + #[test] + fn rustc_test_variance_use_contravariant_struct2() { + check( + r#" +struct SomeStruct(fn(T)); + +fn bar<'min,'max>(v: SomeStruct<&'min ()>) + -> SomeStruct<&'max ()> + where 'max : 'min +{} +"#, + expect![[r#" + SomeStruct[T: contravariant] + bar['min: covariant, 'max: contravariant] + "#]], + ); + } + + #[test] + fn rustc_test_variance_use_covariant_struct1() { + check( + r#" +struct SomeStruct(T); + +fn foo<'min,'max>(v: SomeStruct<&'min ()>) + -> SomeStruct<&'max ()> + where 'max : 'min +{} +"#, + expect![[r#" + SomeStruct[T: covariant] + foo['min: contravariant, 'max: covariant] + "#]], + ); + } + + #[test] + fn rustc_test_variance_use_covariant_struct2() { + check( + r#" +struct SomeStruct(T); + +fn foo<'min,'max>(v: SomeStruct<&'max ()>) + -> SomeStruct<&'min ()> + where 'max : 'min +{} +"#, + expect![[r#" + SomeStruct[T: covariant] + foo['min: covariant, 'max: contravariant] + "#]], + ); + } + + #[test] + fn rustc_test_variance_use_invariant_struct1() { + check( + r#" +struct SomeStruct(*mut T); + +fn foo<'min,'max>(v: SomeStruct<&'max ()>) + -> SomeStruct<&'min ()> + where 'max : 'min +{} + +fn bar<'min,'max>(v: SomeStruct<&'min ()>) + -> SomeStruct<&'max ()> + where 'max : 'min +{} +"#, + expect![[r#" + SomeStruct[T: invariant] + foo['min: invariant, 'max: invariant] + bar['min: invariant, 'max: invariant] + "#]], + ); + } + + #[test] + fn recursive_one_bivariant_more_non_bivariant_params() { + // FIXME: This is wrong, this should be `BivariantPartialIndirect[T: bivariant, U: covariant]` (likewise for Wrapper) + // This is a limitation of current salsa where a cycle may only set a fallback value to the + // query result which is not what we want! We want to treat the cycle call as fallback + // without setting the query result to the fallback. + // `BivariantPartial` works as we workaround for the trivial case of being self-referential + check( + r#" +struct BivariantPartial(*const BivariantPartial, U); +struct Wrapper(BivariantPartialIndirect); +struct BivariantPartialIndirect(*const Wrapper, U); +"#, + expect![[r#" + BivariantPartial[T: bivariant, U: covariant] + Wrapper[T: bivariant, U: bivariant] + BivariantPartialIndirect[T: bivariant, U: bivariant] + "#]], + ); + } + + #[track_caller] + fn check(ra_fixture: &str, expected: Expect) { + // use tracing_subscriber::{layer::SubscriberExt, Layer}; + // let my_layer = tracing_subscriber::fmt::layer(); + // let _g = tracing::subscriber::set_default(tracing_subscriber::registry().with( + // my_layer.with_filter(tracing_subscriber::filter::filter_fn(|metadata| { + // metadata.target().starts_with("hir_ty::variance") + // })), + // )); + let (db, file_id) = TestDB::with_single_file(ra_fixture); + + let mut defs: Vec = Vec::new(); + let module = db.module_for_file_opt(file_id).unwrap(); + let def_map = module.def_map(&db); + crate::tests::visit_module(&db, &def_map, module.local_id, &mut |it| { + defs.push(match it { + ModuleDefId::FunctionId(it) => it.into(), + ModuleDefId::AdtId(it) => it.into(), + ModuleDefId::ConstId(it) => it.into(), + ModuleDefId::TraitId(it) => it.into(), + ModuleDefId::TraitAliasId(it) => it.into(), + ModuleDefId::TypeAliasId(it) => it.into(), + _ => return, + }) + }); + let defs = defs + .into_iter() + .filter_map(|def| { + Some(( + def, + match def { + GenericDefId::FunctionId(it) => { + let loc = it.lookup(&db); + loc.source(&db).value.name().unwrap() + } + GenericDefId::AdtId(AdtId::EnumId(it)) => { + let loc = it.lookup(&db); + loc.source(&db).value.name().unwrap() + } + GenericDefId::AdtId(AdtId::StructId(it)) => { + let loc = it.lookup(&db); + loc.source(&db).value.name().unwrap() + } + GenericDefId::AdtId(AdtId::UnionId(it)) => { + let loc = it.lookup(&db); + loc.source(&db).value.name().unwrap() + } + GenericDefId::TraitId(it) => { + let loc = it.lookup(&db); + loc.source(&db).value.name().unwrap() + } + GenericDefId::TraitAliasId(it) => { + let loc = it.lookup(&db); + loc.source(&db).value.name().unwrap() + } + GenericDefId::TypeAliasId(it) => { + let loc = it.lookup(&db); + loc.source(&db).value.name().unwrap() + } + GenericDefId::ImplId(_) => return None, + GenericDefId::ConstId(_) => return None, + }, + )) + }) + .sorted_by_key(|(_, n)| n.syntax().text_range().start()); + let mut res = String::new(); + for (def, name) in defs { + let Some(variances) = db.variances_of(def) else { + continue; + }; + format_to!( + res, + "{name}[{}]\n", + generics(&db, def) + .iter() + .map(|(_, param)| match param { + GenericParamDataRef::TypeParamData(type_param_data) => { + type_param_data.name.as_ref().unwrap() + } + GenericParamDataRef::ConstParamData(const_param_data) => + &const_param_data.name, + GenericParamDataRef::LifetimeParamData(lifetime_param_data) => { + &lifetime_param_data.name + } + }) + .zip_eq(&*variances) + .format_with(", ", |(name, var), f| f(&format_args!( + "{}: {var}", + name.as_str() + ))) + ); + } + + expected.assert_eq(&res); + } +} From 0e50c3c81be8bc1c80a7a5ed833ff3fc98e3257f Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Sat, 28 Dec 2024 18:55:38 +0100 Subject: [PATCH 4/8] Remove unnecessary VarianceTerm --- .../crates/hir-ty/src/variance.rs | 105 ++++++------------ 1 file changed, 31 insertions(+), 74 deletions(-) diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/variance.rs b/src/tools/rust-analyzer/crates/hir-ty/src/variance.rs index 2339e37fa1409..ca16e986af5e9 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/variance.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/variance.rs @@ -170,21 +170,6 @@ impl Variance { #[derive(Copy, Clone, Debug)] struct InferredIndex(usize); -#[derive(Clone)] -enum VarianceTerm { - ConstantTerm(Variance), - TransformTerm(Box, Box), -} - -impl fmt::Debug for VarianceTerm { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - VarianceTerm::ConstantTerm(c1) => write!(f, "{c1:?}"), - VarianceTerm::TransformTerm(v1, v2) => write!(f, "({v1:?} \u{00D7} {v2:?})"), - } - } -} - struct Context<'db> { db: &'db dyn HirDatabase, def: GenericDefId, @@ -200,7 +185,7 @@ struct Context<'db> { #[derive(Clone)] struct Constraint { inferred: InferredIndex, - variance: VarianceTerm, + variance: Variance, } impl Context<'_> { @@ -213,7 +198,7 @@ impl Context<'_> { for (_, field) in db.field_types(variant).iter() { self.add_constraints_from_ty( &field.clone().substitute(Interner, &subst), - &VarianceTerm::ConstantTerm(Variance::Covariant), + Variance::Covariant, ); } }; @@ -235,37 +220,22 @@ impl Context<'_> { .callable_item_signature(f.into()) .substitute(Interner, &subst) .params_and_return, - &VarianceTerm::ConstantTerm(Variance::Covariant), + Variance::Covariant, ); } _ => {} } } - fn contravariant(&mut self, variance: &VarianceTerm) -> VarianceTerm { - self.xform(variance, &VarianceTerm::ConstantTerm(Variance::Contravariant)) + fn contravariant(&mut self, variance: Variance) -> Variance { + variance.xform(Variance::Contravariant) } - fn invariant(&mut self, variance: &VarianceTerm) -> VarianceTerm { - self.xform(variance, &VarianceTerm::ConstantTerm(Variance::Invariant)) + fn invariant(&mut self, variance: Variance) -> Variance { + variance.xform(Variance::Invariant) } - fn xform(&mut self, v1: &VarianceTerm, v2: &VarianceTerm) -> VarianceTerm { - match (v1, v2) { - // Applying a "covariant" transform is always a no-op - (_, VarianceTerm::ConstantTerm(Variance::Covariant)) => v1.clone(), - (VarianceTerm::ConstantTerm(c1), VarianceTerm::ConstantTerm(c2)) => { - VarianceTerm::ConstantTerm(c1.xform(*c2)) - } - _ => VarianceTerm::TransformTerm(Box::new(v1.clone()), Box::new(v2.clone())), - } - } - - fn add_constraints_from_invariant_args( - &mut self, - args: &[GenericArg], - variance: &VarianceTerm, - ) { + fn add_constraints_from_invariant_args(&mut self, args: &[GenericArg], variance: Variance) { tracing::debug!( "add_constraints_from_invariant_args(args={:?}, variance={:?})", args, @@ -275,9 +245,9 @@ impl Context<'_> { for k in args { match k.data(Interner) { - GenericArgData::Lifetime(lt) => self.add_constraints_from_region(lt, &variance_i), - GenericArgData::Ty(ty) => self.add_constraints_from_ty(ty, &variance_i), - GenericArgData::Const(val) => self.add_constraints_from_const(val, &variance_i), + GenericArgData::Lifetime(lt) => self.add_constraints_from_region(lt, variance_i), + GenericArgData::Ty(ty) => self.add_constraints_from_ty(ty, variance_i), + GenericArgData::Const(val) => self.add_constraints_from_const(val, variance_i), } } } @@ -285,7 +255,7 @@ impl Context<'_> { /// Adds constraints appropriate for an instance of `ty` appearing /// in a context with the generics defined in `generics` and /// ambient variance `variance` - fn add_constraints_from_ty(&mut self, ty: &Ty, variance: &VarianceTerm) { + fn add_constraints_from_ty(&mut self, ty: &Ty, variance: Variance) { tracing::debug!("add_constraints_from_ty(ty={:?}, variance={:?})", ty, variance); match ty.kind(Interner) { TyKind::Scalar(_) | TyKind::Never | TyKind::Str | TyKind::Foreign(..) => { @@ -390,7 +360,7 @@ impl Context<'_> { InferredIndex(self.len_self + index) }; tracing::debug!("add_constraint(index={:?}, variance={:?})", inferred, variance); - self.constraints.push(Constraint { inferred, variance: variance.clone() }); + self.constraints.push(Constraint { inferred, variance }); } TyKind::Function(f) => { self.add_constraints_from_sig(f, variance); @@ -413,7 +383,7 @@ impl Context<'_> { &mut self, def_id: GenericDefId, args: &[GenericArg], - variance: &VarianceTerm, + variance: Variance, ) { tracing::debug!( "add_constraints_from_args(def_id={:?}, args={:?}, variance={:?})", @@ -429,13 +399,13 @@ impl Context<'_> { if def_id == self.def { // HACK: Workaround for the trivial cycle salsa case (see // recursive_one_bivariant_more_non_bivariant_params test) - let variance_i = self.xform(variance, &VarianceTerm::ConstantTerm(Variance::Bivariant)); + let variance_i = variance.xform(Variance::Bivariant); for k in args { match k.data(Interner) { GenericArgData::Lifetime(lt) => { - self.add_constraints_from_region(lt, &variance_i) + self.add_constraints_from_region(lt, variance_i) } - GenericArgData::Ty(ty) => self.add_constraints_from_ty(ty, &variance_i), + GenericArgData::Ty(ty) => self.add_constraints_from_ty(ty, variance_i), GenericArgData::Const(val) => self.add_constraints_from_const(val, variance), } } @@ -445,13 +415,12 @@ impl Context<'_> { }; for (i, k) in args.iter().enumerate() { - let variance_decl = &VarianceTerm::ConstantTerm(variances[i]); - let variance_i = self.xform(variance, variance_decl); + let variance_i = variance.xform(variances[i]); match k.data(Interner) { GenericArgData::Lifetime(lt) => { - self.add_constraints_from_region(lt, &variance_i) + self.add_constraints_from_region(lt, variance_i) } - GenericArgData::Ty(ty) => self.add_constraints_from_ty(ty, &variance_i), + GenericArgData::Ty(ty) => self.add_constraints_from_ty(ty, variance_i), GenericArgData::Const(val) => self.add_constraints_from_const(val, variance), } } @@ -460,7 +429,7 @@ impl Context<'_> { /// Adds constraints appropriate for a const expression `val` /// in a context with ambient variance `variance` - fn add_constraints_from_const(&mut self, c: &Const, variance: &VarianceTerm) { + fn add_constraints_from_const(&mut self, c: &Const, variance: Variance) { match &c.data(Interner).value { chalk_ir::ConstValue::Concrete(c) => { if let ConstScalar::UnevaluatedConst(_, subst) = &c.interned { @@ -473,27 +442,27 @@ impl Context<'_> { /// Adds constraints appropriate for a function with signature /// `sig` appearing in a context with ambient variance `variance` - fn add_constraints_from_sig(&mut self, sig: &FnPointer, variance: &VarianceTerm) { + fn add_constraints_from_sig(&mut self, sig: &FnPointer, variance: Variance) { let contra = self.contravariant(variance); let mut tys = sig.substitution.0.iter(Interner).filter_map(move |p| p.ty(Interner)); self.add_constraints_from_ty(tys.next_back().unwrap(), variance); for input in tys { - self.add_constraints_from_ty(input, &contra); + self.add_constraints_from_ty(input, contra); } } - fn add_constraints_from_sig2(&mut self, sig: &[Ty], variance: &VarianceTerm) { + fn add_constraints_from_sig2(&mut self, sig: &[Ty], variance: Variance) { let contra = self.contravariant(variance); let mut tys = sig.iter(); self.add_constraints_from_ty(tys.next_back().unwrap(), variance); for input in tys { - self.add_constraints_from_ty(input, &contra); + self.add_constraints_from_ty(input, contra); } } /// Adds constraints appropriate for a region appearing in a /// context with ambient variance `variance` - fn add_constraints_from_region(&mut self, region: &Lifetime, variance: &VarianceTerm) { + fn add_constraints_from_region(&mut self, region: &Lifetime, variance: Variance) { match region.data(Interner) { // FIXME: chalk has no params? LifetimeData::Placeholder(index) => { @@ -532,11 +501,11 @@ impl Context<'_> { /// Adds constraints appropriate for a mutability-type pair /// appearing in a context with ambient variance `variance` - fn add_constraints_from_mt(&mut self, ty: &Ty, mt: Mutability, variance: &VarianceTerm) { + fn add_constraints_from_mt(&mut self, ty: &Ty, mt: Mutability, variance: Variance) { match mt { Mutability::Mut => { let invar = self.invariant(variance); - self.add_constraints_from_ty(ty, &invar); + self.add_constraints_from_ty(ty, invar); } Mutability::Not => { @@ -559,13 +528,12 @@ impl Context<'_> { changed = false; for constraint in &self.constraints { - let Constraint { inferred, variance: term } = constraint; + let &Constraint { inferred, variance } = constraint; let InferredIndex(inferred) = inferred; - let variance = Self::evaluate(term); - let old_value = solutions[*inferred]; + let old_value = solutions[inferred]; let new_value = variance.glb(old_value); if old_value != new_value { - solutions[*inferred] = new_value; + solutions[inferred] = new_value; changed = true; } } @@ -590,17 +558,6 @@ impl Context<'_> { solutions } - - fn evaluate(term: &VarianceTerm) -> Variance { - match term { - VarianceTerm::ConstantTerm(v) => *v, - VarianceTerm::TransformTerm(t1, t2) => { - let v1 = Self::evaluate(t1); - let v2 = Self::evaluate(t2); - v1.xform(v2) - } - } - } } #[cfg(test)] From d66a337658a4e175380c1ff59a73375b76237b9f Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Sat, 28 Dec 2024 19:19:06 +0100 Subject: [PATCH 5/8] Get rid of constrain and solve steps --- .../crates/hir-ty/src/generics.rs | 8 -- .../crates/hir-ty/src/variance.rs | 131 ++++++------------ 2 files changed, 44 insertions(+), 95 deletions(-) diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/generics.rs b/src/tools/rust-analyzer/crates/hir-ty/src/generics.rs index e7a2721afee5c..fe7541d237478 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/generics.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/generics.rs @@ -132,14 +132,6 @@ impl Generics { self.params.len() } - pub(crate) fn len_self_lifetimes(&self) -> usize { - self.params.len_lifetimes() - } - - pub(crate) fn has_trait_self(&self) -> bool { - self.params.trait_self_param().is_some() - } - /// (parent total, self param, type params, const params, impl trait list, lifetimes) pub(crate) fn provenance_split(&self) -> (usize, bool, usize, usize, usize, usize) { let mut self_param = false; diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/variance.rs b/src/tools/rust-analyzer/crates/hir-ty/src/variance.rs index ca16e986af5e9..0cce1aec2b490 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/variance.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/variance.rs @@ -39,19 +39,9 @@ pub(crate) fn variances_of(db: &dyn HirDatabase, def: GenericDefId) -> Option { db: &'db dyn HirDatabase, - def: GenericDefId, - has_trait_self: bool, - len_self: usize, - len_self_lifetimes: usize, generics: Generics, - constraints: Vec, -} - -/// Declares that the variable `decl_id` appears in a location with -/// variance `variance`. -#[derive(Clone)] -struct Constraint { - inferred: InferredIndex, - variance: Variance, + variances: Vec, } impl Context<'_> { - fn build_constraints_for_item(&mut self) { - match self.def { + fn solve(mut self) -> Vec { + tracing::debug!("solve(generics={:?})", self.generics); + match self.generics.def() { GenericDefId::AdtId(adt) => { let db = self.db; let mut add_constraints_from_variant = |variant| { @@ -225,6 +204,26 @@ impl Context<'_> { } _ => {} } + let mut variances = self.variances; + + // Const parameters are always invariant. + // Make all const parameters invariant. + for (idx, param) in self.generics.iter_id().enumerate() { + if let GenericParamId::ConstParamId(_) = param { + variances[idx] = Variance::Invariant; + } + } + + // Functions are permitted to have unused generic parameters: make those invariant. + if let GenericDefId::FunctionId(_) = self.generics.def() { + for variance in &mut variances { + if *variance == Variance::Bivariant { + *variance = Variance::Invariant; + } + } + } + + variances } fn contravariant(&mut self, variance: Variance) -> Variance { @@ -353,14 +352,8 @@ impl Context<'_> { // Chalk has no params, so use placeholders for now? TyKind::Placeholder(index) => { let idx = crate::from_placeholder_idx(self.db, *index); - let index = idx.local_id.into_raw().into_u32() as usize + self.len_self_lifetimes; - let inferred = if idx.parent == self.def { - InferredIndex(self.has_trait_self as usize + index) - } else { - InferredIndex(self.len_self + index) - }; - tracing::debug!("add_constraint(index={:?}, variance={:?})", inferred, variance); - self.constraints.push(Constraint { inferred, variance }); + let inferred = InferredIndex(self.generics.type_or_const_param_idx(idx).unwrap()); + self.constrain(inferred, variance); } TyKind::Function(f) => { self.add_constraints_from_sig(f, variance); @@ -396,7 +389,7 @@ impl Context<'_> { if args.is_empty() { return; } - if def_id == self.def { + if def_id == self.generics.def() { // HACK: Workaround for the trivial cycle salsa case (see // recursive_one_bivariant_more_non_bivariant_params test) let variance_i = variance.xform(Variance::Bivariant); @@ -463,18 +456,17 @@ impl Context<'_> { /// Adds constraints appropriate for a region appearing in a /// context with ambient variance `variance` fn add_constraints_from_region(&mut self, region: &Lifetime, variance: Variance) { + tracing::debug!( + "add_constraints_from_region(region={:?}, variance={:?})", + region, + variance + ); match region.data(Interner) { // FIXME: chalk has no params? LifetimeData::Placeholder(index) => { let idx = crate::lt_from_placeholder_idx(self.db, *index); - let index = idx.local_id.into_raw().into_u32() as usize; - let inferred = if idx.parent == self.def { - InferredIndex(index) - } else { - InferredIndex(self.has_trait_self as usize + self.len_self + index) - }; - tracing::debug!("add_constraint(index={:?}, variance={:?})", inferred, variance); - self.constraints.push(Constraint { inferred, variance: variance.clone() }); + let inferred = InferredIndex(self.generics.lifetime_idx(idx).unwrap()); + self.constrain(inferred, variance); } LifetimeData::Static => {} @@ -513,50 +505,15 @@ impl Context<'_> { } } } -} - -impl Context<'_> { - fn solve(self) -> Vec { - let mut solutions = vec![Variance::Bivariant; self.generics.len()]; - // Propagate constraints until a fixed point is reached. Note - // that the maximum number of iterations is 2C where C is the - // number of constraints (each variable can change values at most - // twice). Since number of constraints is linear in size of the - // input, so is the inference process. - let mut changed = true; - while changed { - changed = false; - - for constraint in &self.constraints { - let &Constraint { inferred, variance } = constraint; - let InferredIndex(inferred) = inferred; - let old_value = solutions[inferred]; - let new_value = variance.glb(old_value); - if old_value != new_value { - solutions[inferred] = new_value; - changed = true; - } - } - } - // Const parameters are always invariant. - // Make all const parameters invariant. - for (idx, param) in self.generics.iter_id().enumerate() { - if let GenericParamId::ConstParamId(_) = param { - solutions[idx] = Variance::Invariant; - } - } - - // Functions are permitted to have unused generic parameters: make those invariant. - if let GenericDefId::FunctionId(_) = self.def { - for variance in &mut solutions { - if *variance == Variance::Bivariant { - *variance = Variance::Invariant; - } - } - } - - solutions + fn constrain(&mut self, inferred: InferredIndex, variance: Variance) { + tracing::debug!( + "constrain(index={:?}, variance={:?}, to={:?})", + inferred, + self.variances[inferred.0], + variance + ); + self.variances[inferred.0] = self.variances[inferred.0].glb(variance); } } From e54cf80b983298d19b983a9a378fe96f25daed2a Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Sat, 28 Dec 2024 19:51:04 +0100 Subject: [PATCH 6/8] Simplify --- .../crates/hir-ty/src/variance.rs | 190 ++++++++---------- 1 file changed, 89 insertions(+), 101 deletions(-) diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/variance.rs b/src/tools/rust-analyzer/crates/hir-ty/src/variance.rs index 0cce1aec2b490..64286121b6aaa 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/variance.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/variance.rs @@ -6,8 +6,8 @@ use crate::db::HirDatabase; use crate::generics::{generics, Generics}; use crate::{ - AliasTy, Const, ConstScalar, DynTyExt, FnPointer, GenericArg, GenericArgData, Interner, - Lifetime, LifetimeData, Ty, TyKind, + AliasTy, Const, ConstScalar, DynTyExt, GenericArg, GenericArgData, Interner, Lifetime, + LifetimeData, Ty, TyKind, }; use base_db::ra_salsa::Cycle; use chalk_ir::Mutability; @@ -15,6 +15,7 @@ use hir_def::data::adt::StructFlags; use hir_def::{AdtId, GenericDefId, GenericParamId, VariantId}; use std::fmt; use std::ops::Not; +use stdx::never; use triomphe::Arc; pub(crate) fn variances_of(db: &dyn HirDatabase, def: GenericDefId) -> Option> { @@ -156,9 +157,19 @@ impl Variance { (x, Variance::Bivariant) | (Variance::Bivariant, x) => x, } } + + pub fn invariant(self) -> Self { + self.xform(Variance::Invariant) + } + + pub fn covariant(self) -> Self { + self.xform(Variance::Covariant) + } + + pub fn contravariant(self) -> Self { + self.xform(Variance::Contravariant) + } } -#[derive(Copy, Clone, Debug)] -struct InferredIndex(usize); struct Context<'db> { db: &'db dyn HirDatabase, @@ -193,12 +204,12 @@ impl Context<'_> { } GenericDefId::FunctionId(f) => { let subst = self.generics.placeholder_subst(self.db); - self.add_constraints_from_sig2( - &self - .db + self.add_constraints_from_sig( + self.db .callable_item_signature(f.into()) .substitute(Interner, &subst) - .params_and_return, + .params_and_return + .iter(), Variance::Covariant, ); } @@ -216,41 +227,15 @@ impl Context<'_> { // Functions are permitted to have unused generic parameters: make those invariant. if let GenericDefId::FunctionId(_) = self.generics.def() { - for variance in &mut variances { - if *variance == Variance::Bivariant { - *variance = Variance::Invariant; - } - } + variances + .iter_mut() + .filter(|&&mut v| v == Variance::Bivariant) + .for_each(|v| *v = Variance::Invariant); } variances } - fn contravariant(&mut self, variance: Variance) -> Variance { - variance.xform(Variance::Contravariant) - } - - fn invariant(&mut self, variance: Variance) -> Variance { - variance.xform(Variance::Invariant) - } - - fn add_constraints_from_invariant_args(&mut self, args: &[GenericArg], variance: Variance) { - tracing::debug!( - "add_constraints_from_invariant_args(args={:?}, variance={:?})", - args, - variance - ); - let variance_i = self.invariant(variance); - - for k in args { - match k.data(Interner) { - GenericArgData::Lifetime(lt) => self.add_constraints_from_region(lt, variance_i), - GenericArgData::Ty(ty) => self.add_constraints_from_ty(ty, variance_i), - GenericArgData::Const(val) => self.add_constraints_from_const(val, variance_i), - } - } - } - /// Adds constraints appropriate for an instance of `ty` appearing /// in a context with the generics defined in `generics` and /// ambient variance `variance` @@ -260,39 +245,31 @@ impl Context<'_> { TyKind::Scalar(_) | TyKind::Never | TyKind::Str | TyKind::Foreign(..) => { // leaf type -- noop } - TyKind::FnDef(..) | TyKind::Coroutine(..) | TyKind::Closure(..) => { - panic!("Unexpected unnameable type in variance computation: {ty:?}"); + never!("Unexpected unnameable type in variance computation: {:?}", ty); } - TyKind::Ref(mutbl, lifetime, ty) => { self.add_constraints_from_region(lifetime, variance); self.add_constraints_from_mt(ty, *mutbl, variance); } - TyKind::Array(typ, len) => { self.add_constraints_from_const(len, variance); self.add_constraints_from_ty(typ, variance); } - TyKind::Slice(typ) => { self.add_constraints_from_ty(typ, variance); } - TyKind::Raw(mutbl, ty) => { self.add_constraints_from_mt(ty, *mutbl, variance); } - TyKind::Tuple(_, subtys) => { for subty in subtys.type_parameters(Interner) { self.add_constraints_from_ty(&subty, variance); } } - TyKind::Adt(def, args) => { self.add_constraints_from_args(def.0.into(), args.as_slice(Interner), variance); } - TyKind::Alias(AliasTy::Opaque(opaque)) => { self.add_constraints_from_invariant_args( opaque.substitution.as_slice(Interner), @@ -313,7 +290,6 @@ impl Context<'_> { TyKind::OpaqueType(_, subst) => { self.add_constraints_from_invariant_args(subst.as_slice(Interner), variance); } - TyKind::Dyn(it) => { // The type `dyn Trait +'a` is covariant w/r/t `'a`: self.add_constraints_from_region(&it.lifetime, variance); @@ -352,20 +328,33 @@ impl Context<'_> { // Chalk has no params, so use placeholders for now? TyKind::Placeholder(index) => { let idx = crate::from_placeholder_idx(self.db, *index); - let inferred = InferredIndex(self.generics.type_or_const_param_idx(idx).unwrap()); - self.constrain(inferred, variance); + let index = self.generics.type_or_const_param_idx(idx).unwrap(); + self.constrain(index, variance); } TyKind::Function(f) => { - self.add_constraints_from_sig(f, variance); + self.add_constraints_from_sig( + f.substitution.0.iter(Interner).filter_map(move |p| p.ty(Interner)), + variance, + ); } - TyKind::Error => { // we encounter this when walking the trait references for object // types, where we use Error as the Self type } - TyKind::CoroutineWitness(..) | TyKind::BoundVar(..) | TyKind::InferenceVar(..) => { - panic!("unexpected type encountered in variance inference: {:?}", ty); + never!("unexpected type encountered in variance inference: {:?}", ty) + } + } + } + + fn add_constraints_from_invariant_args(&mut self, args: &[GenericArg], variance: Variance) { + let variance_i = variance.invariant(); + + for k in args { + match k.data(Interner) { + GenericArgData::Lifetime(lt) => self.add_constraints_from_region(lt, variance_i), + GenericArgData::Ty(ty) => self.add_constraints_from_ty(ty, variance_i), + GenericArgData::Const(val) => self.add_constraints_from_const(val, variance_i), } } } @@ -378,13 +367,6 @@ impl Context<'_> { args: &[GenericArg], variance: Variance, ) { - tracing::debug!( - "add_constraints_from_args(def_id={:?}, args={:?}, variance={:?})", - def_id, - args, - variance - ); - // We don't record `inferred_starts` entries for empty generics. if args.is_empty() { return; @@ -392,13 +374,12 @@ impl Context<'_> { if def_id == self.generics.def() { // HACK: Workaround for the trivial cycle salsa case (see // recursive_one_bivariant_more_non_bivariant_params test) - let variance_i = variance.xform(Variance::Bivariant); for k in args { match k.data(Interner) { GenericArgData::Lifetime(lt) => { - self.add_constraints_from_region(lt, variance_i) + self.add_constraints_from_region(lt, Variance::Bivariant) } - GenericArgData::Ty(ty) => self.add_constraints_from_ty(ty, variance_i), + GenericArgData::Ty(ty) => self.add_constraints_from_ty(ty, Variance::Bivariant), GenericArgData::Const(val) => self.add_constraints_from_const(val, variance), } } @@ -408,12 +389,13 @@ impl Context<'_> { }; for (i, k) in args.iter().enumerate() { - let variance_i = variance.xform(variances[i]); match k.data(Interner) { GenericArgData::Lifetime(lt) => { - self.add_constraints_from_region(lt, variance_i) + self.add_constraints_from_region(lt, variance.xform(variances[i])) + } + GenericArgData::Ty(ty) => { + self.add_constraints_from_ty(ty, variance.xform(variances[i])) } - GenericArgData::Ty(ty) => self.add_constraints_from_ty(ty, variance_i), GenericArgData::Const(val) => self.add_constraints_from_const(val, variance), } } @@ -435,20 +417,17 @@ impl Context<'_> { /// Adds constraints appropriate for a function with signature /// `sig` appearing in a context with ambient variance `variance` - fn add_constraints_from_sig(&mut self, sig: &FnPointer, variance: Variance) { - let contra = self.contravariant(variance); - let mut tys = sig.substitution.0.iter(Interner).filter_map(move |p| p.ty(Interner)); - self.add_constraints_from_ty(tys.next_back().unwrap(), variance); - for input in tys { - self.add_constraints_from_ty(input, contra); - } - } - - fn add_constraints_from_sig2(&mut self, sig: &[Ty], variance: Variance) { - let contra = self.contravariant(variance); - let mut tys = sig.iter(); - self.add_constraints_from_ty(tys.next_back().unwrap(), variance); - for input in tys { + fn add_constraints_from_sig<'a>( + &mut self, + mut sig_tys: impl DoubleEndedIterator, + variance: Variance, + ) { + let contra = variance.contravariant(); + let Some(output) = sig_tys.next_back() else { + return never!("function signature has no return type"); + }; + self.add_constraints_from_ty(output, variance); + for input in sig_tys { self.add_constraints_from_ty(input, contra); } } @@ -462,27 +441,23 @@ impl Context<'_> { variance ); match region.data(Interner) { - // FIXME: chalk has no params? LifetimeData::Placeholder(index) => { let idx = crate::lt_from_placeholder_idx(self.db, *index); - let inferred = InferredIndex(self.generics.lifetime_idx(idx).unwrap()); + let inferred = self.generics.lifetime_idx(idx).unwrap(); self.constrain(inferred, variance); } LifetimeData::Static => {} - LifetimeData::BoundVar(..) => { // Either a higher-ranked region inside of a type or a // late-bound function parameter. // // We do not compute constraints for either of these. } - LifetimeData::Error => {} - LifetimeData::Phantom(..) | LifetimeData::InferenceVar(..) | LifetimeData::Erased => { // We don't expect to see anything but 'static or bound // regions when visiting member types or method types. - panic!( + never!( "unexpected region encountered in variance \ inference: {:?}", region @@ -494,26 +469,23 @@ impl Context<'_> { /// Adds constraints appropriate for a mutability-type pair /// appearing in a context with ambient variance `variance` fn add_constraints_from_mt(&mut self, ty: &Ty, mt: Mutability, variance: Variance) { - match mt { - Mutability::Mut => { - let invar = self.invariant(variance); - self.add_constraints_from_ty(ty, invar); - } - - Mutability::Not => { - self.add_constraints_from_ty(ty, variance); - } - } + self.add_constraints_from_ty( + ty, + match mt { + Mutability::Mut => variance.invariant(), + Mutability::Not => variance, + }, + ); } - fn constrain(&mut self, inferred: InferredIndex, variance: Variance) { + fn constrain(&mut self, index: usize, variance: Variance) { tracing::debug!( "constrain(index={:?}, variance={:?}, to={:?})", - inferred, - self.variances[inferred.0], + index, + self.variances[index], variance ); - self.variances[inferred.0] = self.variances[inferred.0].glb(variance); + self.variances[index] = self.variances[index].glb(variance); } } @@ -967,6 +939,22 @@ fn bar<'min,'max>(v: SomeStruct<&'min ()>) ); } + #[test] + fn invalid_arg_counts() { + check( + r#" +struct S(T); +struct S2(S<>); +struct S3(S); +"#, + expect![[r#" + S[T: covariant] + S2[T: bivariant] + S3[T: covariant] + "#]], + ); + } + #[test] fn recursive_one_bivariant_more_non_bivariant_params() { // FIXME: This is wrong, this should be `BivariantPartialIndirect[T: bivariant, U: covariant]` (likewise for Wrapper) From bf27d88616b60a33f568bf19cc689552b0d22218 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Sat, 28 Dec 2024 20:31:20 +0100 Subject: [PATCH 7/8] Show variance of parameters on hover --- .../crates/hir-ty/src/generics.rs | 8 ++--- .../rust-analyzer/crates/hir-ty/src/lib.rs | 8 ++--- src/tools/rust-analyzer/crates/hir/src/lib.rs | 20 +++++++++-- .../crates/ide/src/hover/render.rs | 9 +++++ .../crates/ide/src/hover/tests.rs | 36 +++++++++++++++++-- 5 files changed, 69 insertions(+), 12 deletions(-) diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/generics.rs b/src/tools/rust-analyzer/crates/hir-ty/src/generics.rs index fe7541d237478..abbf2a4f2efd4 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/generics.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/generics.rs @@ -26,14 +26,14 @@ use triomphe::Arc; use crate::{db::HirDatabase, lt_to_placeholder_idx, to_placeholder_idx, Interner, Substitution}; -pub(crate) fn generics(db: &dyn DefDatabase, def: GenericDefId) -> Generics { +pub fn generics(db: &dyn DefDatabase, def: GenericDefId) -> Generics { let parent_generics = parent_generic_def(db, def).map(|def| Box::new(generics(db, def))); let params = db.generic_params(def); let has_trait_self_param = params.trait_self_param().is_some(); Generics { def, params, parent_generics, has_trait_self_param } } #[derive(Clone, Debug)] -pub(crate) struct Generics { +pub struct Generics { def: GenericDefId, params: Arc, parent_generics: Option>, @@ -153,7 +153,7 @@ impl Generics { (parent_len, self_param, type_params, const_params, impl_trait_params, lifetime_params) } - pub(crate) fn type_or_const_param_idx(&self, param: TypeOrConstParamId) -> Option { + pub fn type_or_const_param_idx(&self, param: TypeOrConstParamId) -> Option { self.find_type_or_const_param(param) } @@ -174,7 +174,7 @@ impl Generics { } } - pub(crate) fn lifetime_idx(&self, lifetime: LifetimeParamId) -> Option { + pub fn lifetime_idx(&self, lifetime: LifetimeParamId) -> Option { self.find_lifetime(lifetime) } diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/lib.rs b/src/tools/rust-analyzer/crates/hir-ty/src/lib.rs index 88134f564c018..3c18ea9281655 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/lib.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/lib.rs @@ -24,7 +24,6 @@ extern crate ra_ap_rustc_pattern_analysis as rustc_pattern_analysis; mod builder; mod chalk_db; mod chalk_ext; -mod generics; mod infer; mod inhabitedness; mod interner; @@ -39,6 +38,7 @@ pub mod db; pub mod diagnostics; pub mod display; pub mod dyn_compatibility; +pub mod generics; pub mod lang_items; pub mod layout; pub mod method_resolution; @@ -89,10 +89,9 @@ pub use infer::{ PointerCast, }; pub use interner::Interner; -pub use lower::diagnostics::*; pub use lower::{ - associated_type_shorthand_candidates, ImplTraitLoweringMode, ParamLoweringMode, TyDefId, - TyLoweringContext, ValueTyDefId, + associated_type_shorthand_candidates, diagnostics::*, ImplTraitLoweringMode, ParamLoweringMode, + TyDefId, TyLoweringContext, ValueTyDefId, }; pub use mapping::{ from_assoc_type_id, from_chalk_trait_id, from_foreign_def_id, from_placeholder_idx, @@ -102,6 +101,7 @@ pub use mapping::{ pub use method_resolution::check_orphan_rules; pub use traits::TraitEnvironment; pub use utils::{all_super_traits, direct_super_traits, is_fn_unsafe_to_call}; +pub use variance::Variance; pub use chalk_ir::{ cast::Cast, diff --git a/src/tools/rust-analyzer/crates/hir/src/lib.rs b/src/tools/rust-analyzer/crates/hir/src/lib.rs index ac8a62ee85060..f8af04302f0c7 100644 --- a/src/tools/rust-analyzer/crates/hir/src/lib.rs +++ b/src/tools/rust-analyzer/crates/hir/src/lib.rs @@ -101,7 +101,6 @@ pub use crate::{ PathResolution, Semantics, SemanticsImpl, SemanticsScope, TypeInfo, VisibleTraits, }, }; -pub use hir_ty::method_resolution::TyFingerprint; // Be careful with these re-exports. // @@ -151,8 +150,9 @@ pub use { display::{ClosureStyle, HirDisplay, HirDisplayError, HirWrite}, dyn_compatibility::{DynCompatibilityViolation, MethodViolationCode}, layout::LayoutError, + method_resolution::TyFingerprint, mir::{MirEvalError, MirLowerError}, - CastError, FnAbi, PointerCast, Safety, + CastError, FnAbi, PointerCast, Safety, Variance, }, // FIXME: Properly encapsulate mir hir_ty::{mir, Interner as ChalkTyInterner}, @@ -3957,6 +3957,22 @@ impl GenericParam { GenericParam::LifetimeParam(it) => it.id.parent.into(), } } + + pub fn variance(self, db: &dyn HirDatabase) -> Option { + let parent = match self { + GenericParam::TypeParam(it) => it.id.parent(), + // const parameters are always invariant + GenericParam::ConstParam(_) => return None, + GenericParam::LifetimeParam(it) => it.id.parent, + }; + let generics = hir_ty::generics::generics(db.upcast(), parent); + let index = match self { + GenericParam::TypeParam(it) => generics.type_or_const_param_idx(it.id.into())?, + GenericParam::ConstParam(_) => return None, + GenericParam::LifetimeParam(it) => generics.lifetime_idx(it.id)?, + }; + db.variances_of(parent)?.get(index).copied() + } } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] diff --git a/src/tools/rust-analyzer/crates/ide/src/hover/render.rs b/src/tools/rust-analyzer/crates/ide/src/hover/render.rs index 119a864eb9649..d87c15154e726 100644 --- a/src/tools/rust-analyzer/crates/ide/src/hover/render.rs +++ b/src/tools/rust-analyzer/crates/ide/src/hover/render.rs @@ -594,12 +594,21 @@ pub(super) fn definition( _ => None, }; + let variance_info = || match def { + Definition::GenericParam(it) => it.variance(db).as_ref().map(ToString::to_string), + _ => None, + }; + let mut extra = String::new(); if hovered_definition { if let Some(notable_traits) = render_notable_trait(db, notable_traits, edition) { extra.push_str("\n___\n"); extra.push_str(¬able_traits); } + if let Some(variance_info) = variance_info() { + extra.push_str("\n___\n"); + extra.push_str(&variance_info); + } if let Some(layout_info) = layout_info() { extra.push_str("\n___\n"); extra.push_str(&layout_info); diff --git a/src/tools/rust-analyzer/crates/ide/src/hover/tests.rs b/src/tools/rust-analyzer/crates/ide/src/hover/tests.rs index ed8cd64cdbee0..fe7f0c79f5ba8 100644 --- a/src/tools/rust-analyzer/crates/ide/src/hover/tests.rs +++ b/src/tools/rust-analyzer/crates/ide/src/hover/tests.rs @@ -4721,7 +4721,7 @@ fn hover_type_param_sized_bounds() { //- minicore: sized trait Trait {} struct Foo(T); -impl Foo {} +impl Foo {} "#, expect![[r#" *T* @@ -4736,7 +4736,7 @@ impl Foo {} //- minicore: sized trait Trait {} struct Foo(T); -impl Foo {} +impl Foo {} "#, expect![[r#" *T* @@ -4764,6 +4764,10 @@ fn foo() {} ```rust T ``` + + --- + + invariant "#]], ); } @@ -4781,6 +4785,10 @@ fn foo() {} ```rust T ``` + + --- + + invariant "#]], ); } @@ -4798,6 +4806,10 @@ fn foo() {} ```rust T: ?Sized ``` + + --- + + invariant "#]], ); } @@ -4816,6 +4828,10 @@ fn foo() {} ```rust T: Trait ``` + + --- + + invariant "#]], ); } @@ -4834,6 +4850,10 @@ fn foo() {} ```rust T: Trait ``` + + --- + + invariant "#]], ); } @@ -4852,6 +4872,10 @@ fn foo() {} ```rust T: Trait + ?Sized ``` + + --- + + invariant "#]], ); } @@ -4869,6 +4893,10 @@ fn foo() {} ```rust T ``` + + --- + + invariant "#]], ); } @@ -4887,6 +4915,10 @@ fn foo() {} ```rust T: Trait ``` + + --- + + invariant "#]], ); } From a102ea1c2d57c51c95865034b259012e73700ea3 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Sun, 29 Dec 2024 10:52:47 +0100 Subject: [PATCH 8/8] Describe variance resolution approach differences to rustc --- .../crates/hir-ty/src/variance.rs | 63 ++++++++----------- 1 file changed, 27 insertions(+), 36 deletions(-) diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/variance.rs b/src/tools/rust-analyzer/crates/hir-ty/src/variance.rs index 64286121b6aaa..30711b16dfb25 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/variance.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/variance.rs @@ -2,6 +2,16 @@ //! chapter for more info. //! //! [rustc dev guide]: https://rustc-dev-guide.rust-lang.org/variance.html +//! +//! The implementation here differs from rustc. Rustc does a crate wide fixpoint resolution +//! as the algorithm for determining variance is a fixpoint computation with potential cycles that +//! need to be resolved. rust-analyzer does not want a crate-wide analysis though as that would hurt +//! incrementality too much and as such our query is based on a per item basis. +//! +//! This does unfortunately run into the issue that we can run into query cycles which salsa +//! currently does not allow to be resolved via a fixpoint computation. This will likely be resolved +//! by the next salsa version. If not, we will likely have to adapt and go with the rustc approach +//! while installing firewall per item queries to prevent invalidation issues. use crate::db::HirDatabase; use crate::generics::{generics, Generics}; @@ -371,33 +381,19 @@ impl Context<'_> { if args.is_empty() { return; } - if def_id == self.generics.def() { - // HACK: Workaround for the trivial cycle salsa case (see - // recursive_one_bivariant_more_non_bivariant_params test) - for k in args { - match k.data(Interner) { - GenericArgData::Lifetime(lt) => { - self.add_constraints_from_region(lt, Variance::Bivariant) - } - GenericArgData::Ty(ty) => self.add_constraints_from_ty(ty, Variance::Bivariant), - GenericArgData::Const(val) => self.add_constraints_from_const(val, variance), - } - } - } else { - let Some(variances) = self.db.variances_of(def_id) else { - return; - }; + let Some(variances) = self.db.variances_of(def_id) else { + return; + }; - for (i, k) in args.iter().enumerate() { - match k.data(Interner) { - GenericArgData::Lifetime(lt) => { - self.add_constraints_from_region(lt, variance.xform(variances[i])) - } - GenericArgData::Ty(ty) => { - self.add_constraints_from_ty(ty, variance.xform(variances[i])) - } - GenericArgData::Const(val) => self.add_constraints_from_const(val, variance), + for (i, k) in args.iter().enumerate() { + match k.data(Interner) { + GenericArgData::Lifetime(lt) => { + self.add_constraints_from_region(lt, variance.xform(variances[i])) + } + GenericArgData::Ty(ty) => { + self.add_constraints_from_ty(ty, variance.xform(variances[i])) } + GenericArgData::Const(val) => self.add_constraints_from_const(val, variance), } } } @@ -956,22 +952,17 @@ struct S3(S); } #[test] - fn recursive_one_bivariant_more_non_bivariant_params() { - // FIXME: This is wrong, this should be `BivariantPartialIndirect[T: bivariant, U: covariant]` (likewise for Wrapper) + fn prove_fixedpoint() { + // FIXME: This is wrong, this should be `FixedPoint[T: covariant, U: covariant, V: covariant]` // This is a limitation of current salsa where a cycle may only set a fallback value to the - // query result which is not what we want! We want to treat the cycle call as fallback - // without setting the query result to the fallback. - // `BivariantPartial` works as we workaround for the trivial case of being self-referential + // query result, but we need to solve a fixpoint here. The new salsa will have this + // fortunately. check( r#" -struct BivariantPartial(*const BivariantPartial, U); -struct Wrapper(BivariantPartialIndirect); -struct BivariantPartialIndirect(*const Wrapper, U); +struct FixedPoint(&'static FixedPoint<(), T, U>, V); "#, expect![[r#" - BivariantPartial[T: bivariant, U: covariant] - Wrapper[T: bivariant, U: bivariant] - BivariantPartialIndirect[T: bivariant, U: bivariant] + FixedPoint[T: bivariant, U: bivariant, V: bivariant] "#]], ); }