Skip to content

Commit

Permalink
feat: generalize batsat interface
Browse files Browse the repository at this point in the history
  • Loading branch information
chrjabs committed Jul 8, 2024
1 parent 4d15c0f commit 98af54b
Show file tree
Hide file tree
Showing 8 changed files with 229 additions and 111 deletions.
1 change: 1 addition & 0 deletions batsat/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ readme = "README.md"
[dependencies]
batsat = "0.5.0" # when changing this version, do not forget to update signature
anyhow.workspace = true
cpu-time.workspace = true
thiserror.workspace = true
rustsat.workspace = true

Expand Down
184 changes: 130 additions & 54 deletions batsat/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,85 @@
#![warn(clippy::pedantic)]
#![warn(missing_docs)]

use batsat::{intmap::AsIndex, lbool, BasicSolver, SolverInterface};
use std::time::Duration;

use batsat::{intmap::AsIndex, lbool, Callbacks, SolverInterface};
use cpu_time::ProcessTime;
use rustsat::{
solvers::{Solve, SolveIncremental, SolverResult},
types::{Clause, Lit, TernaryVal},
solvers::{Solve, SolveIncremental, SolveStats, SolverResult, SolverStats},
types::{Clause, Lit, TernaryVal, Var},
};
use thiserror::Error;

/// API Error from the BatSat library (for example if the solver is in an UNSAT state)
#[derive(Error, Clone, Copy, PartialEq, Eq, Debug)]
#[error("BatSat returned an invalid value: {error}")]
pub struct InvalidApiReturn {
error: &'static str,
}
/// RustSAT wrapper for [`batsat::BasicSolver`]
pub type BasicSolver = Solver<batsat::BasicCallbacks>;

/// RustSAT wrapper for the [`BasicSolver`] Solver from BatSat
/// RustSAT wrapper for a [`batsat::Solver`] Solver from BatSat
#[derive(Default)]
pub struct BatsatBasicSolver(BasicSolver);
pub struct Solver<Cb: Callbacks> {
internal: batsat::Solver<Cb>,
n_sat: usize,
n_unsat: usize,
n_terminated: usize,
avg_clause_len: f32,
cpu_time: Duration,
}

impl<Cb: Callbacks> Solver<Cb> {
/// Gets a reference to the internal [`BasicSolver`]
#[must_use]
pub fn batsat_ref(&self) -> &batsat::Solver<Cb> {
&self.internal
}

/// Gets a mutable reference to the internal [`BasicSolver`]
#[must_use]
pub fn batsat_mut(&mut self) -> &mut batsat::Solver<Cb> {
&mut self.internal
}

#[allow(clippy::cast_precision_loss)]
#[inline]
fn update_avg_clause_len(&mut self, clause: &Clause) {
self.avg_clause_len = (self.avg_clause_len * ((self.n_clauses()) as f32)
+ clause.len() as f32)
/ (self.n_clauses() + 1) as f32;
}

fn solve_track_stats(&mut self, assumps: &[Lit]) -> SolverResult {
let a = assumps
.iter()
.map(|l| batsat::Lit::new(self.internal.var_of_int(l.vidx32() + 1), l.is_pos()))
.collect::<Vec<_>>();

let start = ProcessTime::now();
let ret = match self.internal.solve_limited(&a) {
x if x == lbool::TRUE => {
self.n_sat += 1;
SolverResult::Sat
}
x if x == lbool::FALSE => {
self.n_unsat += 1;
SolverResult::Unsat
}
x if x == lbool::UNDEF => {
self.n_terminated += 1;
SolverResult::Interrupted
}
_ => unreachable!(),
};
self.cpu_time += start.elapsed();
ret
}
}

impl Extend<Clause> for BatsatBasicSolver {
impl<Cb: Callbacks> Extend<Clause> for Solver<Cb> {
fn extend<T: IntoIterator<Item = Clause>>(&mut self, iter: T) {
iter.into_iter()
.for_each(|cl| self.add_clause(cl).expect("Error adding clause in extend"));
}
}

impl<'a> Extend<&'a Clause> for BatsatBasicSolver {
impl<'a, Cb: Callbacks> Extend<&'a Clause> for Solver<Cb> {
fn extend<T: IntoIterator<Item = &'a Clause>>(&mut self, iter: T) {
iter.into_iter().for_each(|cl| {
self.add_clause_ref(cl)
Expand All @@ -45,81 +98,104 @@ impl<'a> Extend<&'a Clause> for BatsatBasicSolver {
}
}

impl Solve for BatsatBasicSolver {
impl<Cb: Callbacks> Solve for Solver<Cb> {
fn signature(&self) -> &'static str {
"BatSat 0.5.0"
}

fn solve(&mut self) -> anyhow::Result<SolverResult> {
match self.0.solve_limited(&[]) {
x if x == lbool::TRUE => Ok(SolverResult::Sat),
x if x == lbool::FALSE => Ok(SolverResult::Unsat),
x if x == lbool::UNDEF => Err(InvalidApiReturn {
error: "BatSat Solver is in an UNSAT state",
}
.into()),
_ => unreachable!(),
}
Ok(self.solve_track_stats(&[]))
}

fn lit_val(&self, lit: Lit) -> anyhow::Result<TernaryVal> {
let l = batsat::Lit::new(batsat::Var::from_index(lit.vidx() + 1), lit.is_pos());

match self.0.value_lit(l) {
match self.internal.value_lit(l) {
x if x == lbool::TRUE => Ok(TernaryVal::True),
x if x == lbool::FALSE => Ok(TernaryVal::False),
x if x == lbool::UNDEF => Ok(TernaryVal::DontCare),
_ => unreachable!(),
}
}

fn add_clause(&mut self, clause: Clause) -> anyhow::Result<()> {
let mut c: Vec<batsat::Lit> = clause
.iter()
.map(|l| batsat::Lit::new(self.0.var_of_int(l.vidx32() + 1), l.is_pos()))
.collect::<Vec<batsat::Lit>>();

self.0.add_clause_reuse(&mut c);

Ok(())
}

fn add_clause_ref(&mut self, clause: &Clause) -> anyhow::Result<()> {
self.update_avg_clause_len(clause);

let mut c: Vec<batsat::Lit> = clause
.iter()
.map(|l| batsat::Lit::new(self.0.var_of_int(l.vidx32() + 1), l.is_pos()))
.map(|l| batsat::Lit::new(self.internal.var_of_int(l.vidx32() + 1), l.is_pos()))
.collect::<Vec<batsat::Lit>>();

self.0.add_clause_reuse(&mut c);
self.internal.add_clause_reuse(&mut c);

Ok(())
}
}

impl SolveIncremental for BatsatBasicSolver {
impl<Cb: Callbacks> SolveIncremental for Solver<Cb> {
fn solve_assumps(&mut self, assumps: &[Lit]) -> anyhow::Result<SolverResult> {
let a = assumps
.iter()
.map(|l| batsat::Lit::new(self.0.var_of_int(l.vidx32() + 1), l.is_pos()))
.collect::<Vec<_>>();

match self.0.solve_limited(&a) {
x if x == lbool::TRUE => Ok(SolverResult::Sat),
x if x == lbool::FALSE => Ok(SolverResult::Unsat),
x if x == lbool::UNDEF => Err(InvalidApiReturn {
error: "BatSat Solver is in an UNSAT state",
}
.into()),
_ => unreachable!(),
}
Ok(self.solve_track_stats(assumps))
}

fn core(&mut self) -> anyhow::Result<Vec<Lit>> {
Ok(self
.0
.internal
.unsat_core()
.iter()
.map(|l| Lit::new(l.var().idx() - 1, !l.sign()))
.collect::<Vec<_>>())
}
}

impl<Cb: Callbacks> SolveStats for Solver<Cb> {
fn stats(&self) -> SolverStats {
SolverStats {
n_sat: self.n_sat,
n_unsat: self.n_unsat,
n_terminated: self.n_terminated,
n_clauses: self.n_clauses(),
max_var: self.max_var(),
avg_clause_len: self.avg_clause_len,
cpu_solve_time: self.cpu_time,
}
}

fn n_sat_solves(&self) -> usize {
self.n_sat
}

fn n_unsat_solves(&self) -> usize {
self.n_unsat
}

fn n_terminated(&self) -> usize {
self.n_terminated
}

fn n_clauses(&self) -> usize {
usize::try_from(self.internal.num_clauses()).expect("more than `usize::MAX` clauses")
}

fn max_var(&self) -> Option<Var> {
let num = self.internal.num_vars();
if num > 0 {
// BatSat returns a value that is off by one
Some(Var::new(num - 2))
} else {
None
}
}

fn avg_clause_len(&self) -> f32 {
self.avg_clause_len
}

fn cpu_solve_time(&self) -> Duration {
self.cpu_time
}
}

#[cfg(test)]
mod test {
rustsat_solvertests::basic_unittests!(super::BasicSolver, false);
}
2 changes: 1 addition & 1 deletion batsat/tests/incremental.rs
Original file line number Diff line number Diff line change
@@ -1 +1 @@
rustsat_solvertests::incremental_tests!(rustsat_batsat::BatsatBasicSolver);
rustsat_solvertests::incremental_tests!(rustsat_batsat::BasicSolver);
2 changes: 1 addition & 1 deletion batsat/tests/small.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
mod base {
rustsat_solvertests::base_tests!(rustsat_batsat::BatsatBasicSolver, false, true);
rustsat_solvertests::base_tests!(rustsat_batsat::BasicSolver, false, true);
}
10 changes: 5 additions & 5 deletions solvertests/src/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ use proc_macro2::TokenStream;
use quote::quote;
use syn::{parse_quote, Attribute};

use super::MacroInput;
use super::IntegrationInput;

pub fn base(input: MacroInput) -> TokenStream {
pub fn base(input: IntegrationInput) -> TokenStream {
let slv = input.slv;
let ignoretok = |idx: usize| -> Option<Attribute> {
if input.bools.len() > idx && input.bools[idx] {
Expand Down Expand Up @@ -76,7 +76,7 @@ pub fn base(input: MacroInput) -> TokenStream {
ts
}

pub fn incremental(input: MacroInput) -> TokenStream {
pub fn incremental(input: IntegrationInput) -> TokenStream {
let slv = input.slv;
let ignoretok = |idx: usize| -> Option<Attribute> {
if input.bools.len() > idx && input.bools[idx] {
Expand Down Expand Up @@ -184,7 +184,7 @@ pub fn incremental(input: MacroInput) -> TokenStream {
ts
}

pub fn phasing(input: MacroInput) -> TokenStream {
pub fn phasing(input: IntegrationInput) -> TokenStream {
let slv = input.slv;
let ignoretok = |idx: usize| -> Option<Attribute> {
if input.bools.len() > idx && input.bools[idx] {
Expand Down Expand Up @@ -238,7 +238,7 @@ pub fn phasing(input: MacroInput) -> TokenStream {
ts
}

pub fn flipping(input: MacroInput) -> TokenStream {
pub fn flipping(input: IntegrationInput) -> TokenStream {
let slv = input.slv;
let ignoretok = |idx: usize| -> Option<Attribute> {
if input.bools.len() > idx && input.bools[idx] {
Expand Down
37 changes: 29 additions & 8 deletions solvertests/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ impl ToTokens for InitBy {
}
}

struct MacroInput {
struct IntegrationInput {
slv: InitBy,
bools: Vec<bool>,
}

impl Parse for MacroInput {
impl Parse for IntegrationInput {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let slv: InitBy = input.parse()?;
if input.is_empty() {
Expand All @@ -49,10 +49,31 @@ impl Parse for MacroInput {
}
}

struct BasicUnitInput {
slv: Type,
mt: Option<bool>,
}

impl Parse for BasicUnitInput {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let slv: Type = input.parse()?;
if input.is_empty() {
return Ok(Self { slv, mt: None });
}
let _: Token![,] = input.parse()?;
let mt: LitBool = input.parse()?;
Ok(Self {
slv,
mt: Some(mt.value),
})
}
}

#[proc_macro]
pub fn basic_unittests(tokens: TokenStream) -> TokenStream {
let slv = parse_macro_input!(tokens as Type);
unit::basic(slv).into()
let input = parse_macro_input!(tokens as BasicUnitInput);
let mt = input.mt.unwrap_or(true);
unit::basic(input.slv, mt).into()
}

#[proc_macro]
Expand All @@ -75,24 +96,24 @@ pub fn freezing_unittests(tokens: TokenStream) -> TokenStream {

#[proc_macro]
pub fn base_tests(tokens: TokenStream) -> TokenStream {
let input = parse_macro_input!(tokens as MacroInput);
let input = parse_macro_input!(tokens as IntegrationInput);
integration::base(input).into()
}

#[proc_macro]
pub fn incremental_tests(tokens: TokenStream) -> TokenStream {
let input = parse_macro_input!(tokens as MacroInput);
let input = parse_macro_input!(tokens as IntegrationInput);
integration::incremental(input).into()
}

#[proc_macro]
pub fn phasing_tests(tokens: TokenStream) -> TokenStream {
let input = parse_macro_input!(tokens as MacroInput);
let input = parse_macro_input!(tokens as IntegrationInput);
integration::phasing(input).into()
}

#[proc_macro]
pub fn flipping_tests(tokens: TokenStream) -> TokenStream {
let input = parse_macro_input!(tokens as MacroInput);
let input = parse_macro_input!(tokens as IntegrationInput);
integration::flipping(input).into()
}
Loading

0 comments on commit 98af54b

Please sign in to comment.