diff --git a/crates/rattler_libsolv_rs/src/arena.rs b/crates/rattler_libsolv_rs/src/arena.rs index 07e8a4c33..c1b4f2a02 100644 --- a/crates/rattler_libsolv_rs/src/arena.rs +++ b/crates/rattler_libsolv_rs/src/arena.rs @@ -1,7 +1,7 @@ use std::cell::{Cell, UnsafeCell}; use std::cmp; use std::marker::PhantomData; -use std::ops::Index; +use std::ops::{Index, IndexMut}; const CHUNK_SIZE: usize = 128; @@ -77,11 +77,53 @@ impl Arena { TId::from_usize(id) } + /// Returns an iterator over the elements of the arena. + pub fn iter(&self) -> ArenaIter { + ArenaIter { + arena: self, + index: 0, + } + } + + /// Returns an mutable iterator over the elements of the arena. + pub fn iter_mut(&mut self) -> ArenaIterMut { + ArenaIterMut { + arena: self, + index: 0, + } + } + fn chunk_and_offset(index: usize) -> (usize, usize) { let offset = index % CHUNK_SIZE; let chunk = index / CHUNK_SIZE; (chunk, offset) } + + /// Returns mutable references to the two values references by the two distinct indices. + /// + /// Panics if one of the Ids is invalid or when the two ids are the same. + pub fn get_two_mut(&mut self, a: TId, b: TId) -> (&mut TValue, &mut TValue) { + let a_index = a.to_usize(); + let b_index = b.to_usize(); + assert!(a_index < self.len()); + assert!(b_index < self.len()); + assert_ne!(a_index, b_index); + let (a_chunk, a_offset) = Self::chunk_and_offset(a_index); + let (b_chunk, b_offset) = Self::chunk_and_offset(b_index); + // SAFE: because we check that the indices are less than the length and that both indices do + // not refer to the same item. + unsafe { + let chunks = self.chunks.get(); + ( + (*chunks) + .get_unchecked_mut(a_chunk) + .get_unchecked_mut(a_offset), + (*chunks) + .get_unchecked_mut(b_chunk) + .get_unchecked_mut(b_offset), + ) + } + } } impl Index for Arena { @@ -98,8 +140,78 @@ impl Index for Arena { } } +impl IndexMut for Arena { + fn index_mut(&mut self, index: TId) -> &mut Self::Output { + let index = index.to_usize(); + assert!(index < self.len()); + let (chunk, offset) = Self::chunk_and_offset(index); + // SAFE: because we check that the index is less than the length + unsafe { + self.chunks + .get_mut() + .get_unchecked_mut(chunk) + .get_unchecked_mut(offset) + } + } +} + /// A trait indicating that the type can be transformed to `usize` and back pub trait ArenaId { fn from_usize(x: usize) -> Self; fn to_usize(self) -> usize; } + +/// An iterator over the elements of an [`Arena`]. +pub struct ArenaIter<'a, TId: ArenaId, TValue> { + arena: &'a Arena, + index: usize, +} + +impl<'a, TId: ArenaId, TValue> Iterator for ArenaIter<'a, TId, TValue> { + type Item = (TId, &'a TValue); + + fn next(&mut self) -> Option { + if self.index < self.arena.len.get() { + let (chunk, offset) = Arena::::chunk_and_offset(self.index); + let element = unsafe { + let vec = self.arena.chunks.get(); + Some(( + TId::from_usize(self.index), + (*vec).get_unchecked(chunk).get_unchecked(offset), + )) + }; + + self.index += 1; + element + } else { + None + } + } +} + +/// An mutable iterator over the elements of an [`Arena`]. +pub struct ArenaIterMut<'a, TId: ArenaId, TValue> { + arena: &'a mut Arena, + index: usize, +} + +impl<'a, TId: ArenaId, TValue> Iterator for ArenaIterMut<'a, TId, TValue> { + type Item = (TId, &'a mut TValue); + + fn next(&mut self) -> Option { + if self.index < self.arena.len.get() { + let (chunk, offset) = Arena::::chunk_and_offset(self.index); + let element = unsafe { + let vec = self.arena.chunks.get(); + Some(( + TId::from_usize(self.index), + (*vec).get_unchecked_mut(chunk).get_unchecked_mut(offset), + )) + }; + self.index += 1; + element + } else { + None + } + } +} diff --git a/crates/rattler_libsolv_rs/src/id.rs b/crates/rattler_libsolv_rs/src/id.rs index 048aecfae..f552c7121 100644 --- a/crates/rattler_libsolv_rs/src/id.rs +++ b/crates/rattler_libsolv_rs/src/id.rs @@ -74,21 +74,12 @@ impl From for u32 { pub(crate) struct ClauseId(u32); impl ClauseId { - pub(crate) fn new(index: usize) -> Self { - debug_assert_ne!(index, 0); - debug_assert_ne!(index, u32::MAX as usize); - - Self(index as u32) - } - + /// There is a guarentee that ClauseId(0) will always be "Clause::InstallRoot". This assumption + /// is verified by the solver. pub(crate) fn install_root() -> Self { Self(0) } - pub(crate) fn index(self) -> usize { - self.0 as usize - } - pub(crate) fn is_null(self) -> bool { self.0 == u32::MAX } @@ -98,6 +89,17 @@ impl ClauseId { } } +impl ArenaId for ClauseId { + fn from_usize(x: usize) -> Self { + assert!(x < u32::MAX as usize, "clause id too big"); + Self(x as u32) + } + + fn to_usize(self) -> usize { + self.0 as usize + } +} + #[derive(Copy, Clone, Debug)] pub(crate) struct LearntClauseId(u32); diff --git a/crates/rattler_libsolv_rs/src/problem.rs b/crates/rattler_libsolv_rs/src/problem.rs index a14e9c8fc..44a641fc7 100644 --- a/crates/rattler_libsolv_rs/src/problem.rs +++ b/crates/rattler_libsolv_rs/src/problem.rs @@ -50,11 +50,11 @@ impl Problem { let unresolved_node = graph.add_node(ProblemNode::UnresolvedDependency); for clause_id in &self.clauses { - let clause = solver.clauses[clause_id.index()].kind; + let clause = &solver.clauses[*clause_id].kind; match clause { Clause::InstallRoot => (), Clause::Learnt(..) => unreachable!(), - Clause::Requires(package_id, version_set_id) => { + &Clause::Requires(package_id, version_set_id) => { let package_node = Self::add_node(&mut graph, &mut nodes, package_id); let candidates = solver.get_or_cache_sorted_candidates(version_set_id); @@ -81,19 +81,19 @@ impl Problem { } } } - Clause::Lock(locked, forbidden) => { + &Clause::Lock(locked, forbidden) => { let node2_id = Self::add_node(&mut graph, &mut nodes, forbidden); let conflict = ConflictCause::Locked(locked); graph.add_edge(root_node, node2_id, ProblemEdge::Conflict(conflict)); } - Clause::ForbidMultipleInstances(instance1_id, instance2_id) => { + &Clause::ForbidMultipleInstances(instance1_id, instance2_id) => { let node1_id = Self::add_node(&mut graph, &mut nodes, instance1_id); let node2_id = Self::add_node(&mut graph, &mut nodes, instance2_id); let conflict = ConflictCause::ForbidMultipleInstances; graph.add_edge(node1_id, node2_id, ProblemEdge::Conflict(conflict)); } - Clause::Constrains(package_id, dep_id, version_set_id) => { + &Clause::Constrains(package_id, dep_id, version_set_id) => { let package_node = Self::add_node(&mut graph, &mut nodes, package_id); let dep_node = Self::add_node(&mut graph, &mut nodes, dep_id); diff --git a/crates/rattler_libsolv_rs/src/solver/clause.rs b/crates/rattler_libsolv_rs/src/solver/clause.rs index c1607d057..dfc2af993 100644 --- a/crates/rattler_libsolv_rs/src/solver/clause.rs +++ b/crates/rattler_libsolv_rs/src/solver/clause.rs @@ -584,11 +584,11 @@ mod test { #[test] fn test_unlink_clause_different() { let clause1 = clause( - [ClauseId::new(2), ClauseId::new(3)], + [ClauseId::from_usize(2), ClauseId::from_usize(3)], [SolvableId::from_usize(1596), SolvableId::from_usize(1211)], ); let clause2 = clause( - [ClauseId::null(), ClauseId::new(3)], + [ClauseId::null(), ClauseId::from_usize(3)], [SolvableId::from_usize(1596), SolvableId::from_usize(1208)], ); let clause3 = clause( @@ -604,7 +604,10 @@ mod test { clause1.watched_literals, [SolvableId::from_usize(1596), SolvableId::from_usize(1211)] ); - assert_eq!(clause1.next_watches, [ClauseId::null(), ClauseId::new(3)]) + assert_eq!( + clause1.next_watches, + [ClauseId::null(), ClauseId::from_usize(3)] + ) } // Unlink 1 @@ -615,14 +618,17 @@ mod test { clause1.watched_literals, [SolvableId::from_usize(1596), SolvableId::from_usize(1211)] ); - assert_eq!(clause1.next_watches, [ClauseId::new(2), ClauseId::null()]) + assert_eq!( + clause1.next_watches, + [ClauseId::from_usize(2), ClauseId::null()] + ) } } #[test] fn test_unlink_clause_same() { let clause1 = clause( - [ClauseId::new(2), ClauseId::new(2)], + [ClauseId::from_usize(2), ClauseId::from_usize(2)], [SolvableId::from_usize(1596), SolvableId::from_usize(1211)], ); let clause2 = clause( @@ -638,7 +644,10 @@ mod test { clause1.watched_literals, [SolvableId::from_usize(1596), SolvableId::from_usize(1211)] ); - assert_eq!(clause1.next_watches, [ClauseId::null(), ClauseId::new(2)]) + assert_eq!( + clause1.next_watches, + [ClauseId::null(), ClauseId::from_usize(2)] + ) } // Unlink 1 @@ -649,7 +658,10 @@ mod test { clause1.watched_literals, [SolvableId::from_usize(1596), SolvableId::from_usize(1211)] ); - assert_eq!(clause1.next_watches, [ClauseId::new(2), ClauseId::null()]) + assert_eq!( + clause1.next_watches, + [ClauseId::from_usize(2), ClauseId::null()] + ) } } diff --git a/crates/rattler_libsolv_rs/src/solver/mod.rs b/crates/rattler_libsolv_rs/src/solver/mod.rs index 3fd8142d8..7b7ffd515 100644 --- a/crates/rattler_libsolv_rs/src/solver/mod.rs +++ b/crates/rattler_libsolv_rs/src/solver/mod.rs @@ -32,12 +32,12 @@ mod watch_map; pub struct Solver> { provider: D, - pub(crate) clauses: Vec, + pub(crate) clauses: Arena, watches: WatchMap, - learnt_clauses_start: ClauseId, learnt_clauses: Arena>, learnt_why: Mapping>, + learnt_clause_ids: Vec, /// A mapping from a solvable to a list of dependencies solvable_dependencies: Arena, @@ -73,11 +73,11 @@ impl> Solver Self { Self { provider, - clauses: Vec::new(), + clauses: Arena::new(), watches: WatchMap::new(), learnt_clauses: Arena::new(), - learnt_clauses_start: ClauseId::null(), learnt_why: Mapping::empty(), + learnt_clause_ids: Vec::new(), decision_tracker: DecisionTracker::new(), candidates: Arena::new(), solvable_dependencies: Default::default(), @@ -238,9 +238,14 @@ impl> Sol self.decision_tracker.clear(); self.learnt_clauses.clear(); self.learnt_why = Mapping::empty(); - self.clauses = vec![ClauseState::root()]; + self.clauses = Default::default(); self.root_requirements = root_requirements; + // The first clause will always be the install root clause. Here we verify that this is + // indeed the case. + let root_clause = self.clauses.alloc(ClauseState::root()); + assert_eq!(root_clause, ClauseId::install_root()); + // Create clauses for root's dependencies, and their dependencies, and so forth self.add_clauses_for_root_deps(); @@ -260,7 +265,7 @@ impl> Sol for (i, &candidate) in candidates.iter().enumerate() { for &other_candidate in &candidates[i + 1..] { self.clauses - .push(ClauseState::forbid_multiple(candidate, other_candidate)); + .alloc(ClauseState::forbid_multiple(candidate, other_candidate)); } } } @@ -280,14 +285,11 @@ impl> Sol for &other_candidate in &candidates.candidates { if other_candidate != locked_solvable_id { self.clauses - .push(ClauseState::lock(locked_solvable_id, other_candidate)); + .alloc(ClauseState::lock(locked_solvable_id, other_candidate)); } } } - // All new clauses are learnt after this point - self.learnt_clauses_start = ClauseId::new(self.clauses.len()); - // Create watches chains self.make_watches(); @@ -347,7 +349,6 @@ impl> Sol }; // Iterate over all the requirements and create clauses. - let mut clauses = Vec::new(); for &version_set_id in requirements { // Get the sorted candidates that can fulfill this requirement let candidates = self.get_or_cache_sorted_candidates(version_set_id); @@ -363,7 +364,7 @@ impl> Sol } } - clauses.push(ClauseState::requires( + self.clauses.alloc(ClauseState::requires( solvable_id, version_set_id, candidates, @@ -382,11 +383,9 @@ impl> Sol for forbidden_candidate in non_candidates { let clause = ClauseState::constrains(solvable_id, forbidden_candidate, version_set_id); - clauses.push(clause); + self.clauses.alloc(clause); } } - - self.clauses.extend(clauses); } } @@ -444,12 +443,11 @@ impl> Sol fn decide_requires_without_candidates(&mut self, level: u32) -> Result<(), ClauseId> { tracing::info!("=== Deciding assertions for requires without candidates"); - for (i, clause) in self.clauses.iter().enumerate() { + for (clause_id, clause) in self.clauses.iter() { if let Clause::Requires(solvable_id, _) = clause.kind { if !clause.has_watches() { // A requires clause without watches means it has a single literal (i.e. // there are no candidates) - let clause_id = ClauseId::new(i); let decided = self .decision_tracker .try_add_decision(Decision::new(solvable_id, false, clause_id), level) @@ -484,8 +482,10 @@ impl> Sol break; } + let clause_id = ClauseId::from_usize(i); + let (required_by, candidate) = { - let clause = &self.clauses[i]; + let clause = &self.clauses[clause_id]; i += 1; // We are only interested in requires clauses @@ -521,7 +521,7 @@ impl> Sol ) }; - level = self.set_propagate_learn(level, candidate, required_by, ClauseId::new(i))?; + level = self.set_propagate_learn(level, candidate, required_by, clause_id)?; // We have made progress, and should look at all clauses in the next iteration i = 0; @@ -576,7 +576,7 @@ impl> Sol ); tracing::info!( "During unit propagation for clause: {:?}", - self.clauses[conflicting_clause.index()].debug(self.pool()) + self.clauses[conflicting_clause].debug(self.pool()) ); tracing::info!( @@ -588,16 +588,15 @@ impl> Sol .iter() .find(|d| d.solvable_id == conflicting_solvable) .unwrap() - .derived_from - .index()] - .debug(self.pool()), + .derived_from] + .debug(self.pool()), ); } if level == 1 { tracing::info!("=== UNSOLVABLE"); for decision in self.decision_tracker.stack() { - let clause = &self.clauses[decision.derived_from.index()]; + let clause = &self.clauses[decision.derived_from]; let level = self.decision_tracker.level(decision.solvable_id); let action = if decision.value { "install" } else { "forbid" }; @@ -649,8 +648,8 @@ impl> Sol fn propagate(&mut self, level: u32) -> Result<(), (SolvableId, bool, ClauseId)> { // Learnt assertions (assertions are clauses that consist of a single literal, and therefore // do not have watches) - let learnt_clauses_start = self.learnt_clauses_start.index(); - for (i, clause) in self.clauses[learnt_clauses_start..].iter().enumerate() { + for &clause_id in self.learnt_clause_ids.iter() { + let clause = &self.clauses[clause_id]; let Clause::Learnt(learnt_index) = clause.kind else { unreachable!(); }; @@ -664,7 +663,6 @@ impl> Sol let literal = literals[0]; let decision = literal.satisfying_value(); - let clause_id = ClauseId::new(learnt_clauses_start + i); let decided = self .decision_tracker @@ -696,19 +694,14 @@ impl> Sol panic!("Linked list is circular!"); } - // This is a convoluted way of getting mutable access to the current and the previous clause, - // which is necessary when we have to remove the current clause from the list + // Get mutable access to both clauses. let (predecessor_clause, clause) = if let Some(prev_clause_id) = predecessor_clause_id { - if prev_clause_id < clause_id { - let (prev, current) = self.clauses.split_at_mut(clause_id.index()); - (Some(&mut prev[prev_clause_id.index()]), &mut current[0]) - } else { - let (current, prev) = self.clauses.split_at_mut(prev_clause_id.index()); - (Some(&mut prev[0]), &mut current[clause_id.index()]) - } + let (predecessor_clause, clause) = + self.clauses.get_two_mut(prev_clause_id, clause_id); + (Some(predecessor_clause), clause) } else { - (None, &mut self.clauses[clause_id.index()]) + (None, &mut self.clauses[clause_id]) }; // Update the prev_clause_id for the next run @@ -796,31 +789,21 @@ impl> Sol /// Because learnt clauses are not relevant for the user, they are not added to the `Problem`. /// Instead, we report the clauses that caused them. fn analyze_unsolvable_clause( - clauses: &[ClauseState], + clauses: &Arena, learnt_why: &Mapping>, - learnt_clauses_start: ClauseId, clause_id: ClauseId, problem: &mut Problem, seen: &mut HashSet, ) { - let clause = &clauses[clause_id.index()]; + let clause = &clauses[clause_id]; match clause.kind { - Clause::Learnt(..) => { + Clause::Learnt(learnt_clause_id) => { if !seen.insert(clause_id) { return; } - let clause_id = - LearntClauseId::from_usize(clause_id.index() - learnt_clauses_start.index()); - for &cause in &learnt_why[clause_id] { - Self::analyze_unsolvable_clause( - clauses, - learnt_why, - learnt_clauses_start, - cause, - problem, - seen, - ); + for &cause in &learnt_why[learnt_clause_id] { + Self::analyze_unsolvable_clause(clauses, learnt_why, cause, problem, seen); } } _ => problem.add_clause(clause_id), @@ -838,7 +821,7 @@ impl> Sol tracing::info!("=== ANALYZE UNSOLVABLE"); let mut involved = HashSet::new(); - self.clauses[clause_id.index()].kind.visit_literals( + self.clauses[clause_id].kind.visit_literals( &self.learnt_clauses, &self.version_set_to_sorted_candidates, |literal| { @@ -850,7 +833,6 @@ impl> Sol Self::analyze_unsolvable_clause( &self.clauses, &self.learnt_why, - self.learnt_clauses_start, clause_id, &mut problem, &mut seen, @@ -872,13 +854,12 @@ impl> Sol Self::analyze_unsolvable_clause( &self.clauses, &self.learnt_why, - self.learnt_clauses_start, why, &mut problem, &mut seen, ); - self.clauses[why.index()].kind.visit_literals( + self.clauses[why].kind.visit_literals( &self.learnt_clauses, &self.version_set_to_sorted_candidates, |literal| { @@ -922,7 +903,7 @@ impl> Sol loop { learnt_why.push(clause_id); - self.clauses[clause_id.index()].kind.visit_literals( + self.clauses[clause_id].kind.visit_literals( &self.learnt_clauses, &self.version_set_to_sorted_candidates, |literal| { @@ -988,19 +969,17 @@ impl> Sol learnt.push(last_literal); // Add the clause - let clause_id = ClauseId::new(self.clauses.len()); let learnt_id = self.learnt_clauses.alloc(learnt.clone()); self.learnt_why.extend(learnt_why); - let mut clause = ClauseState::learnt(learnt_id, &learnt); + let clause_id = self.clauses.alloc(ClauseState::learnt(learnt_id, &learnt)); + self.learnt_clause_ids.push(clause_id); + let clause = &mut self.clauses[clause_id]; if clause.has_watches() { - self.watches.start_watching(&mut clause, clause_id); + self.watches.start_watching(clause, clause_id); } - // Store it - self.clauses.push(clause); - tracing::info!( "Learnt disjunction:\n{}", learnt @@ -1024,13 +1003,13 @@ impl> Sol // Watches are already initialized in the clauses themselves, here we build a linked list for // each package (a clause will be linked to other clauses that are watching the same package) - for (i, clause) in self.clauses.iter_mut().enumerate() { + for (clause_id, clause) in self.clauses.iter_mut() { if !clause.has_watches() { // Skip clauses without watches continue; } - self.watches.start_watching(clause, ClauseId::new(i)); + self.watches.start_watching(clause, clause_id); } } }