Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve one level of dispatch for IPASIR callbacks #58

Merged
merged 1 commit into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 86 additions & 54 deletions crates/pindakaas-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,79 +91,109 @@ pub fn ipasir_solver_derive(input: TokenStream) -> TokenStream {
quote!()
};

let term_callback = if opts.term_callback {
let (term_callback, term_drop) = if opts.term_callback {
let term_cb = match opts.term_callback_ident {
Some(x) => quote! { self. #x },
None => quote! { self.term_cb },
};
quote! {
impl crate::solver::TermCallback for #ident {
fn set_terminate_callback<F: FnMut() -> crate::solver::SlvTermSignal + 'static>(
&mut self,
cb: Option<F>,
) {
if let Some(mut cb) = cb {
#term_cb = crate::solver::libloading::TermCB::new(move || -> std::ffi::c_int {
match cb() {
crate::solver::SlvTermSignal::Continue => std::ffi::c_int::from(0),
crate::solver::SlvTermSignal::Terminate => std::ffi::c_int::from(1),
(
quote! {
impl crate::solver::TermCallback for #ident {
fn set_terminate_callback<F: FnMut() -> crate::solver::SlvTermSignal + 'static>(
&mut self,
cb: Option<F>,
) {
if let Some(mut cb) = cb {
let mut wrapped_cb = move || -> std::ffi::c_int {
match cb() {
crate::solver::SlvTermSignal::Continue => std::ffi::c_int::from(0),
crate::solver::SlvTermSignal::Terminate => std::ffi::c_int::from(1),
}
};
let trampoline = crate::solver::libloading::get_trampoline0(&wrapped_cb);
let layout = std::alloc::Layout::for_value(&wrapped_cb);
let data = Box::leak(Box::new(wrapped_cb)) as *mut _ as *mut std::ffi::c_void;
if layout.size() != 0 {
// Otherwise nothing was leaked.
#term_cb = Some((data, layout));
}
});

unsafe {
#krate::ipasir_set_terminate(
#ptr,
#term_cb .as_ptr(),
Some(crate::solver::libloading::TermCB::exec_callback),
)
unsafe {
#krate::ipasir_set_terminate(
#ptr,
data,
Some(trampoline),
)
}
} else {
if let Some((ptr, layout)) = #term_cb .take() {
unsafe { std::alloc::dealloc(ptr as *mut _, layout) };
}
unsafe { #krate::ipasir_set_terminate(#ptr, std::ptr::null_mut(), None) }
}
} else {
#term_cb = Default::default();
unsafe { #krate::ipasir_set_terminate(#ptr, std::ptr::null_mut(), None) }
}
}
}
}
},
quote! {
if let Some((ptr, layout)) = #term_cb .take() {
unsafe { std::alloc::dealloc(ptr as *mut _, layout) };
}
},
)
} else {
quote!()
(quote!(), quote!())
};

let learn_callback = if opts.learn_callback {
let (learn_callback, learn_drop) = if opts.learn_callback {
let learn_cb = match opts.learn_callback_ident {
Some(x) => quote! { self. #x },
None => quote! { self.learn_cb },
};
quote! {
impl crate::solver::LearnCallback for #ident {
fn set_learn_callback<F: FnMut(&mut dyn Iterator<Item = crate::Lit>) + 'static>(
&mut self,
cb: Option<F>,
) {
const MAX_LEN: std::ffi::c_int = 512;
if let Some(mut cb) = cb {
#learn_cb = crate::solver::libloading::LearnCB::new(move |clause: *const i32| {
let mut iter = crate::solver::libloading::ExplIter(clause)
.map(|i: i32| crate::Lit(std::num::NonZeroI32::new(i).unwrap()));
cb(&mut iter)
});

unsafe {
#krate::ipasir_set_learn(
#ptr,
#learn_cb .as_ptr(),
MAX_LEN,
Some(crate::solver::libloading::LearnCB::exec_callback),
)
(
quote! {
impl crate::solver::LearnCallback for #ident {
fn set_learn_callback<F: FnMut(&mut dyn Iterator<Item = crate::Lit>) + 'static>(
&mut self,
cb: Option<F>,
) {
const MAX_LEN: std::ffi::c_int = 512;
if let Some(mut cb) = cb {
let mut wrapped_cb = move |clause: *const i32| {
let mut iter = crate::solver::libloading::ExplIter(clause)
.map(|i: i32| crate::Lit(std::num::NonZeroI32::new(i).unwrap()));
cb(&mut iter)
};
let trampoline = crate::solver::libloading::get_trampoline1(&wrapped_cb);
let layout = std::alloc::Layout::for_value(&wrapped_cb);
let data = Box::leak(Box::new(wrapped_cb)) as *mut _ as *mut std::ffi::c_void;
if layout.size() != 0 {
// Otherwise nothing was leaked.
#learn_cb = Some((data, layout));
}
unsafe {
#krate::ipasir_set_learn(
#ptr,
data,
MAX_LEN,
Some(trampoline),
)
}
} else {
if let Some((ptr, layout)) = #learn_cb .take() {
unsafe { std::alloc::dealloc(ptr as *mut _, layout) };
}
unsafe { #krate::ipasir_set_learn(#ptr, std::ptr::null_mut(), MAX_LEN, None) }
}
} else {
#learn_cb = Default::default();
unsafe { #krate::ipasir_set_learn(#ptr, std::ptr::null_mut(), MAX_LEN, None) }
}
}
}
}
},
quote! {
if let Some((ptr, layout)) = #learn_cb .take() {
unsafe { std::alloc::dealloc(ptr as *mut _, layout) };
}
},
)
} else {
quote!()
(quote!(), quote!())
};

let sol_ident = format_ident!("{}Sol", ident);
Expand Down Expand Up @@ -356,6 +386,8 @@ pub fn ipasir_solver_derive(input: TokenStream) -> TokenStream {
quote! {
impl Drop for #ident {
fn drop(&mut self) {
#learn_drop
#term_drop
unsafe { #krate::ipasir_release( #ptr ) }
}
}
Expand Down
25 changes: 13 additions & 12 deletions crates/pindakaas/src/solver/cadical.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,26 @@
use std::{ffi::CString, fmt};
use std::{
alloc::Layout,
ffi::{c_void, CString},
fmt,
};

use pindakaas_cadical::{ccadical_copy, ccadical_phase, ccadical_unphase};
use pindakaas_derive::IpasirSolver;

use super::VarFactory;
use crate::{
solver::libloading::{LearnCB, TermCB},
Lit,
};
use crate::Lit;

#[derive(IpasirSolver)]
#[ipasir(krate = pindakaas_cadical, assumptions, learn_callback, term_callback, ipasir_up)]
pub struct Cadical {
/// The raw pointer to the Cadical solver.
ptr: *mut std::ffi::c_void,
ptr: *mut c_void,
/// The variable factory for this solver.
vars: VarFactory,
/// The callback used when a clause is learned.
learn_cb: LearnCB,
learn_cb: Option<(*mut c_void, Layout)>,
/// The callback used to check whether the solver should terminate.
term_cb: TermCB,
term_cb: Option<(*mut c_void, Layout)>,

#[cfg(feature = "ipasir-up")]
/// The external propagator called by the solver
Expand All @@ -31,8 +32,8 @@ impl Default for Cadical {
Self {
ptr: unsafe { pindakaas_cadical::ipasir_init() },
vars: VarFactory::default(),
learn_cb: LearnCB::default(),
term_cb: TermCB::default(),
learn_cb: None,
term_cb: None,
#[cfg(feature = "ipasir-up")]
prop: None,
}
Expand All @@ -45,8 +46,8 @@ impl Clone for Cadical {
Self {
ptr,
vars: self.vars,
learn_cb: LearnCB::default(),
term_cb: TermCB::default(),
learn_cb: None,
term_cb: None,
#[cfg(feature = "ipasir-up")]
prop: None,
}
Expand Down
13 changes: 7 additions & 6 deletions crates/pindakaas/src/solver/intel_sat.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,29 @@
use std::{alloc::Layout, ffi::c_void};

use pindakaas_derive::IpasirSolver;

use super::VarFactory;
use crate::solver::libloading::{LearnCB, TermCB};

#[derive(Debug, IpasirSolver)]
#[ipasir(krate = pindakaas_intel_sat, assumptions, learn_callback, term_callback)]
pub struct IntelSat {
/// The raw pointer to the Intel SAT solver.
ptr: *mut std::ffi::c_void,
ptr: *mut c_void,
/// The variable factory for this solver.
vars: VarFactory,
/// The callback used when a clause is learned.
learn_cb: LearnCB,
learn_cb: Option<(*mut c_void, Layout)>,
/// The callback used to check whether the solver should terminate.
term_cb: TermCB,
term_cb: Option<(*mut c_void, Layout)>,
}

impl Default for IntelSat {
fn default() -> Self {
Self {
ptr: unsafe { pindakaas_intel_sat::ipasir_init() },
vars: VarFactory::default(),
term_cb: TermCB::default(),
learn_cb: LearnCB::default(),
term_cb: None,
learn_cb: None,
}
}
}
Expand Down
Loading