diff --git a/batsat/Cargo.toml b/batsat/Cargo.toml index d23d95bb..baaee274 100644 --- a/batsat/Cargo.toml +++ b/batsat/Cargo.toml @@ -14,6 +14,7 @@ readme = "README.md" [dependencies] batsat = "0.5.0" # when changing this version, do not forget to update signature anyhow.workspace = true +cpu-time.workspace = true thiserror.workspace = true rustsat.workspace = true diff --git a/batsat/src/lib.rs b/batsat/src/lib.rs index 5937ffd9..d0ed3dc7 100644 --- a/batsat/src/lib.rs +++ b/batsat/src/lib.rs @@ -11,32 +11,85 @@ #![warn(clippy::pedantic)] #![warn(missing_docs)] -use batsat::{intmap::AsIndex, lbool, BasicSolver, SolverInterface}; +use std::time::Duration; + +use batsat::{intmap::AsIndex, lbool, Callbacks, SolverInterface}; +use cpu_time::ProcessTime; use rustsat::{ - solvers::{Solve, SolveIncremental, SolverResult}, - types::{Clause, Lit, TernaryVal}, + solvers::{Solve, SolveIncremental, SolveStats, SolverResult, SolverStats}, + types::{Clause, Lit, TernaryVal, Var}, }; -use thiserror::Error; -/// API Error from the BatSat library (for example if the solver is in an UNSAT state) -#[derive(Error, Clone, Copy, PartialEq, Eq, Debug)] -#[error("BatSat returned an invalid value: {error}")] -pub struct InvalidApiReturn { - error: &'static str, -} +/// RustSAT wrapper for [`batsat::BasicSolver`] +pub type BasicSolver = Solver; -/// RustSAT wrapper for the [`BasicSolver`] Solver from BatSat +/// RustSAT wrapper for a [`batsat::Solver`] Solver from BatSat #[derive(Default)] -pub struct BatsatBasicSolver(BasicSolver); +pub struct Solver { + internal: batsat::Solver, + n_sat: usize, + n_unsat: usize, + n_terminated: usize, + avg_clause_len: f32, + cpu_time: Duration, +} + +impl Solver { + /// Gets a reference to the internal [`BasicSolver`] + #[must_use] + pub fn batsat_ref(&self) -> &batsat::Solver { + &self.internal + } + + /// Gets a mutable reference to the internal [`BasicSolver`] + #[must_use] + pub fn batsat_mut(&mut self) -> &mut batsat::Solver { + &mut self.internal + } + + #[allow(clippy::cast_precision_loss)] + #[inline] + fn update_avg_clause_len(&mut self, clause: &Clause) { + self.avg_clause_len = (self.avg_clause_len * ((self.n_clauses()) as f32) + + clause.len() as f32) + / (self.n_clauses() + 1) as f32; + } + + fn solve_track_stats(&mut self, assumps: &[Lit]) -> SolverResult { + let a = assumps + .iter() + .map(|l| batsat::Lit::new(self.internal.var_of_int(l.vidx32() + 1), l.is_pos())) + .collect::>(); + + let start = ProcessTime::now(); + let ret = match self.internal.solve_limited(&a) { + x if x == lbool::TRUE => { + self.n_sat += 1; + SolverResult::Sat + } + x if x == lbool::FALSE => { + self.n_unsat += 1; + SolverResult::Unsat + } + x if x == lbool::UNDEF => { + self.n_terminated += 1; + SolverResult::Interrupted + } + _ => unreachable!(), + }; + self.cpu_time += start.elapsed(); + ret + } +} -impl Extend for BatsatBasicSolver { +impl Extend for Solver { fn extend>(&mut self, iter: T) { iter.into_iter() .for_each(|cl| self.add_clause(cl).expect("Error adding clause in extend")); } } -impl<'a> Extend<&'a Clause> for BatsatBasicSolver { +impl<'a, Cb: Callbacks> Extend<&'a Clause> for Solver { fn extend>(&mut self, iter: T) { iter.into_iter().for_each(|cl| { self.add_clause_ref(cl) @@ -45,27 +98,19 @@ impl<'a> Extend<&'a Clause> for BatsatBasicSolver { } } -impl Solve for BatsatBasicSolver { +impl Solve for Solver { fn signature(&self) -> &'static str { "BatSat 0.5.0" } fn solve(&mut self) -> anyhow::Result { - match self.0.solve_limited(&[]) { - x if x == lbool::TRUE => Ok(SolverResult::Sat), - x if x == lbool::FALSE => Ok(SolverResult::Unsat), - x if x == lbool::UNDEF => Err(InvalidApiReturn { - error: "BatSat Solver is in an UNSAT state", - } - .into()), - _ => unreachable!(), - } + Ok(self.solve_track_stats(&[])) } fn lit_val(&self, lit: Lit) -> anyhow::Result { let l = batsat::Lit::new(batsat::Var::from_index(lit.vidx() + 1), lit.is_pos()); - match self.0.value_lit(l) { + match self.internal.value_lit(l) { x if x == lbool::TRUE => Ok(TernaryVal::True), x if x == lbool::FALSE => Ok(TernaryVal::False), x if x == lbool::UNDEF => Ok(TernaryVal::DontCare), @@ -73,53 +118,84 @@ impl Solve for BatsatBasicSolver { } } - fn add_clause(&mut self, clause: Clause) -> anyhow::Result<()> { - let mut c: Vec = clause - .iter() - .map(|l| batsat::Lit::new(self.0.var_of_int(l.vidx32() + 1), l.is_pos())) - .collect::>(); - - self.0.add_clause_reuse(&mut c); - - Ok(()) - } - fn add_clause_ref(&mut self, clause: &Clause) -> anyhow::Result<()> { + self.update_avg_clause_len(clause); + let mut c: Vec = clause .iter() - .map(|l| batsat::Lit::new(self.0.var_of_int(l.vidx32() + 1), l.is_pos())) + .map(|l| batsat::Lit::new(self.internal.var_of_int(l.vidx32() + 1), l.is_pos())) .collect::>(); - self.0.add_clause_reuse(&mut c); + self.internal.add_clause_reuse(&mut c); Ok(()) } } -impl SolveIncremental for BatsatBasicSolver { +impl SolveIncremental for Solver { fn solve_assumps(&mut self, assumps: &[Lit]) -> anyhow::Result { - let a = assumps - .iter() - .map(|l| batsat::Lit::new(self.0.var_of_int(l.vidx32() + 1), l.is_pos())) - .collect::>(); - - match self.0.solve_limited(&a) { - x if x == lbool::TRUE => Ok(SolverResult::Sat), - x if x == lbool::FALSE => Ok(SolverResult::Unsat), - x if x == lbool::UNDEF => Err(InvalidApiReturn { - error: "BatSat Solver is in an UNSAT state", - } - .into()), - _ => unreachable!(), - } + Ok(self.solve_track_stats(assumps)) } fn core(&mut self) -> anyhow::Result> { Ok(self - .0 + .internal .unsat_core() .iter() .map(|l| Lit::new(l.var().idx() - 1, !l.sign())) .collect::>()) } } + +impl SolveStats for Solver { + fn stats(&self) -> SolverStats { + SolverStats { + n_sat: self.n_sat, + n_unsat: self.n_unsat, + n_terminated: self.n_terminated, + n_clauses: self.n_clauses(), + max_var: self.max_var(), + avg_clause_len: self.avg_clause_len, + cpu_solve_time: self.cpu_time, + } + } + + fn n_sat_solves(&self) -> usize { + self.n_sat + } + + fn n_unsat_solves(&self) -> usize { + self.n_unsat + } + + fn n_terminated(&self) -> usize { + self.n_terminated + } + + fn n_clauses(&self) -> usize { + usize::try_from(self.internal.num_clauses()).expect("more than `usize::MAX` clauses") + } + + fn max_var(&self) -> Option { + let num = self.internal.num_vars(); + if num > 0 { + // BatSat returns a value that is off by one + Some(Var::new(num - 2)) + } else { + None + } + } + + fn avg_clause_len(&self) -> f32 { + self.avg_clause_len + } + + fn cpu_solve_time(&self) -> Duration { + self.cpu_time + } +} + +#[cfg(test)] +mod test { + rustsat_solvertests::basic_unittests!(super::BasicSolver, false); +} diff --git a/batsat/tests/incremental.rs b/batsat/tests/incremental.rs index f76eb878..eed81a87 100644 --- a/batsat/tests/incremental.rs +++ b/batsat/tests/incremental.rs @@ -1 +1 @@ -rustsat_solvertests::incremental_tests!(rustsat_batsat::BatsatBasicSolver); +rustsat_solvertests::incremental_tests!(rustsat_batsat::BasicSolver); diff --git a/batsat/tests/small.rs b/batsat/tests/small.rs index a6f444c0..ae1e943c 100644 --- a/batsat/tests/small.rs +++ b/batsat/tests/small.rs @@ -1,3 +1,3 @@ mod base { - rustsat_solvertests::base_tests!(rustsat_batsat::BatsatBasicSolver, false, true); + rustsat_solvertests::base_tests!(rustsat_batsat::BasicSolver, false, true); } diff --git a/solvertests/src/integration.rs b/solvertests/src/integration.rs index 68e0ae65..f2e0afa6 100644 --- a/solvertests/src/integration.rs +++ b/solvertests/src/integration.rs @@ -4,9 +4,9 @@ use proc_macro2::TokenStream; use quote::quote; use syn::{parse_quote, Attribute}; -use super::MacroInput; +use super::IntegrationInput; -pub fn base(input: MacroInput) -> TokenStream { +pub fn base(input: IntegrationInput) -> TokenStream { let slv = input.slv; let ignoretok = |idx: usize| -> Option { if input.bools.len() > idx && input.bools[idx] { @@ -76,7 +76,7 @@ pub fn base(input: MacroInput) -> TokenStream { ts } -pub fn incremental(input: MacroInput) -> TokenStream { +pub fn incremental(input: IntegrationInput) -> TokenStream { let slv = input.slv; let ignoretok = |idx: usize| -> Option { if input.bools.len() > idx && input.bools[idx] { @@ -184,7 +184,7 @@ pub fn incremental(input: MacroInput) -> TokenStream { ts } -pub fn phasing(input: MacroInput) -> TokenStream { +pub fn phasing(input: IntegrationInput) -> TokenStream { let slv = input.slv; let ignoretok = |idx: usize| -> Option { if input.bools.len() > idx && input.bools[idx] { @@ -238,7 +238,7 @@ pub fn phasing(input: MacroInput) -> TokenStream { ts } -pub fn flipping(input: MacroInput) -> TokenStream { +pub fn flipping(input: IntegrationInput) -> TokenStream { let slv = input.slv; let ignoretok = |idx: usize| -> Option { if input.bools.len() > idx && input.bools[idx] { diff --git a/solvertests/src/lib.rs b/solvertests/src/lib.rs index 545b97ff..7ed87c99 100644 --- a/solvertests/src/lib.rs +++ b/solvertests/src/lib.rs @@ -31,12 +31,12 @@ impl ToTokens for InitBy { } } -struct MacroInput { +struct IntegrationInput { slv: InitBy, bools: Vec, } -impl Parse for MacroInput { +impl Parse for IntegrationInput { fn parse(input: syn::parse::ParseStream) -> syn::Result { let slv: InitBy = input.parse()?; if input.is_empty() { @@ -49,10 +49,31 @@ impl Parse for MacroInput { } } +struct BasicUnitInput { + slv: Type, + mt: Option, +} + +impl Parse for BasicUnitInput { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let slv: Type = input.parse()?; + if input.is_empty() { + return Ok(Self { slv, mt: None }); + } + let _: Token![,] = input.parse()?; + let mt: LitBool = input.parse()?; + Ok(Self { + slv, + mt: Some(mt.value), + }) + } +} + #[proc_macro] pub fn basic_unittests(tokens: TokenStream) -> TokenStream { - let slv = parse_macro_input!(tokens as Type); - unit::basic(slv).into() + let input = parse_macro_input!(tokens as BasicUnitInput); + let mt = input.mt.unwrap_or(true); + unit::basic(input.slv, mt).into() } #[proc_macro] @@ -75,24 +96,24 @@ pub fn freezing_unittests(tokens: TokenStream) -> TokenStream { #[proc_macro] pub fn base_tests(tokens: TokenStream) -> TokenStream { - let input = parse_macro_input!(tokens as MacroInput); + let input = parse_macro_input!(tokens as IntegrationInput); integration::base(input).into() } #[proc_macro] pub fn incremental_tests(tokens: TokenStream) -> TokenStream { - let input = parse_macro_input!(tokens as MacroInput); + let input = parse_macro_input!(tokens as IntegrationInput); integration::incremental(input).into() } #[proc_macro] pub fn phasing_tests(tokens: TokenStream) -> TokenStream { - let input = parse_macro_input!(tokens as MacroInput); + let input = parse_macro_input!(tokens as IntegrationInput); integration::phasing(input).into() } #[proc_macro] pub fn flipping_tests(tokens: TokenStream) -> TokenStream { - let input = parse_macro_input!(tokens as MacroInput); + let input = parse_macro_input!(tokens as IntegrationInput); integration::flipping(input).into() } diff --git a/solvertests/src/unit.rs b/solvertests/src/unit.rs index 596ef627..758bf275 100644 --- a/solvertests/src/unit.rs +++ b/solvertests/src/unit.rs @@ -4,8 +4,8 @@ use proc_macro2::TokenStream; use quote::quote; use syn::Type; -pub fn basic(slv: Type) -> TokenStream { - quote! { +pub fn basic(slv: Type, multi_threaded: bool) -> TokenStream { + let mut ts = quote! { #[test] fn build_destroy() { let _solver = #slv::default(); @@ -78,47 +78,51 @@ pub fn basic(slv: Type) -> TokenStream { Ok(res) => assert_eq!(res, SolverResult::Unsat), } } - - #[test] - fn tiny_instance_multithreaded_sat() { - use std::{sync::{Arc, Mutex}, thread}; - use rustsat::{lit, var, types::TernaryVal, solvers::{Solve, SolverResult}}; - - let mutex_solver = Arc::new(Mutex::new(#slv::default())); - - { - // Build in one thread - let mut solver = mutex_solver.lock().unwrap(); - solver.add_binary(lit![0], !lit![1]).unwrap(); - solver.add_unit(lit![0]).unwrap(); - solver.add_binary(lit![1], !lit![2]).unwrap(); - } - - // Now in another thread - let s = mutex_solver.clone(); - let ret = thread::spawn(move || { - let mut solver = s.lock().unwrap(); - solver.solve() - }) - .join() - .unwrap(); - match ret { - Err(e) => panic!("got error when solving: {}", e), - Ok(res) => assert_eq!(res, SolverResult::Sat), + }; + if multi_threaded { + ts.extend(quote! { + #[test] + fn tiny_instance_multithreaded_sat() { + use std::{sync::{Arc, Mutex}, thread}; + use rustsat::{lit, var, types::TernaryVal, solvers::{Solve, SolverResult}}; + + let mutex_solver = Arc::new(Mutex::new(#slv::default())); + + { + // Build in one thread + let mut solver = mutex_solver.lock().unwrap(); + solver.add_binary(lit![0], !lit![1]).unwrap(); + solver.add_unit(lit![0]).unwrap(); + solver.add_binary(lit![1], !lit![2]).unwrap(); + } + + // Now in another thread + let s = mutex_solver.clone(); + let ret = thread::spawn(move || { + let mut solver = s.lock().unwrap(); + solver.solve() + }) + .join() + .unwrap(); + match ret { + Err(e) => panic!("got error when solving: {}", e), + Ok(res) => assert_eq!(res, SolverResult::Sat), + } + + // Finally, back in the main thread + let ret = { + let solver = mutex_solver.lock().unwrap(); + solver.full_solution() + }; + + match ret { + Err(e) => panic!("got error when solving: {}", e), + Ok(res) => assert_eq!(res.var_value(var![0]), TernaryVal::True), + } } - - // Finally, back in the main thread - let ret = { - let solver = mutex_solver.lock().unwrap(); - solver.full_solution() - }; - - match ret { - Err(e) => panic!("got error when solving: {}", e), - Ok(res) => assert_eq!(res.var_value(var![0]), TernaryVal::True), - } - } - } + }); + }; + ts } pub fn termination(slv: Type) -> TokenStream { diff --git a/src/solvers.rs b/src/solvers.rs index 74e38e6a..77a7e545 100644 --- a/src/solvers.rs +++ b/src/solvers.rs @@ -76,6 +76,22 @@ //! - Fork used in solver crate: //! [https://github.com/chrjabs/glucose4](https://github.com/chrjabs/glucose4) //! +//! ### BatSat +//! +//! [BatSat](https://github.com/c-cube/batsat) is a SAT solver based on Minisat but fully +//! implemented in Rust. Because it is fully implemented in Rust, it is a good choice for +//! restricted compilation scenarios like WebAssembly. BatSat is available through the +//! [`rustsat-batsat`](httpe://crates.io/crates/rustsat-batsat) crate. +//! +//! #### References +//! +//! - Solver interface crate: +//! [https://crates.io/crates/rustsat-batsat](https://crates.io/crate/rustsat-batsat) +//! - BatSat crate: +//! [https://crate.io/crates/batsat](https://crates.io/crate/batsat) +//! - BatSat repository: +//! [https://github.com/c-cube/batsat](https://github.com/c-cube/batsat) +//! //! ### External Solvers //! //! RustSAT provides an interface for calling external solver binaries by passing them DIMACS input