Skip to content

Commit

Permalink
Merge pull request #37 from phadej/dpll-b
Browse files Browse the repository at this point in the history
DPLL: Use SparseSet for unit literals set
  • Loading branch information
phadej authored Jul 24, 2024
2 parents e4ceaea + 5c221b4 commit bb6e02f
Show file tree
Hide file tree
Showing 7 changed files with 213 additions and 82 deletions.
123 changes: 67 additions & 56 deletions sat-simple-pure/DPLL.hs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ module DPLL (

import Control.Monad.ST (ST)
import Data.Bits (complementBit, testBit, unsafeShiftL, unsafeShiftR)
import Data.Coerce (coerce)
import Data.Functor ((<&>))
import Data.IntSet (IntSet)
import Data.Primitive.PrimArray (PrimArray, indexPrimArray, primArrayFromList, sizeofPrimArray)
Expand All @@ -33,6 +34,7 @@ import Data.Primitive.ByteArray
resizeMutableByteArray, shrinkMutableByteArray, writeByteArray)

import Lifted
import SparseSet
import UnliftedSTRef

#ifdef TWO_WATCHED_LITERALS
Expand Down Expand Up @@ -159,18 +161,21 @@ minViewVarSet (VS xs) = case IntSet.minView xs of
-- LitSet
-------------------------------------------------------------------------------

newtype LitSet = LS IntSet
newtype LitSet s = LS (SparseSet s)

emptyLitSet :: LitSet
emptyLitSet = LS IntSet.empty
newLitSet :: Int -> ST s (LitSet s)
newLitSet n = LS <$> newSparseSet n

insertLitSet :: Lit -> LitSet -> LitSet
insertLitSet (MkLit l) (LS ls) = LS (IntSet.insert l ls)
insertLitSet :: Lit -> LitSet s -> ST s ()
insertLitSet (MkLit l) (LS ls) = insertSparseSet ls l

minViewLitSet :: LitSet -> Maybe (Lit, LitSet)
minViewLitSet (LS xs) = case IntSet.minView xs of
Nothing -> Nothing
Just (x, xs') -> Just (MkLit x, LS xs')
minViewLitSet :: LitSet s -> ST s (Maybe Lit)
minViewLitSet (LS xs) = do
x <- popSparseSet xs
return (coerce x)

clearLitSet :: LitSet s -> ST s ()
clearLitSet (LS xs) = clearSparseSet xs

-------------------------------------------------------------------------------
-- Clauses
Expand Down Expand Up @@ -404,8 +409,10 @@ solve solver@Solver {..} = whenOk_ (simplify solver) $ do
vars <- readSTRef variables
-- traceM $ "solve " ++ show (length clauses')

#ifdef TWO_WATCHED_LITERALS
litCount <- readSTRef nextLit
units <- newLitSet litCount

#ifdef TWO_WATCHED_LITERALS
clauseDB <- newClauseDB litCount
forM_ clauses' $ \c -> satisfied2_ solution c >>= \case
Unresolved_ l1 l2 -> insertClauseDB l1 l2 c clauseDB
Expand All @@ -414,32 +421,32 @@ solve solver@Solver {..} = whenOk_ (simplify solver) $ do
let clauseDB = clauses'
#endif

solveLoop clauseDB End emptyLitSet solution vars >>= \case
solveLoop clauseDB End units solution vars >>= \case
False -> conflict solver
True -> return True

solveLoop :: ClauseDB s -> Trail -> LitSet -> PartialAssignment s -> VarSet -> ST s Bool
solveLoop !clauseDb !trail !units !pa !vars
| Just (l, units') <- minViewLitSet units
= lookupPartialAssignment l pa >>= \case
LUndef -> do
insertPartialAssignment l pa
let !vars' = deleteVarSet (litToVar l) vars
unitPropagate l clauseDb (Deduced l trail) units pa vars'
LTrue -> solveLoop clauseDb trail units' pa vars
LFalse -> backtrack clauseDb trail pa vars

| Just (v, vars') <- minViewVarSet vars
= do
-- traceM $ "decide" ++ show v
let l = varToLit v
insertPartialAssignment l pa
unitPropagate l clauseDb (Decided l trail) emptyLitSet pa vars'

| otherwise
= return True

unitPropagate :: forall s. Lit -> ClauseDB s -> Trail -> LitSet -> PartialAssignment s -> VarSet -> ST s Bool
solveLoop :: ClauseDB s -> Trail -> LitSet s -> PartialAssignment s -> VarSet -> ST s Bool
solveLoop !clauseDb !trail !units !pa !vars = do
minViewLitSet units >>= \case
Just l -> lookupPartialAssignment l pa >>= \case
LUndef -> do
insertPartialAssignment l pa
let !vars' = deleteVarSet (litToVar l) vars
unitPropagate l clauseDb (Deduced l trail) units pa vars'
LTrue -> solveLoop clauseDb trail units pa vars
LFalse -> backtrack clauseDb trail units pa vars
Nothing
| Just (v, vars') <- minViewVarSet vars
-> do
-- traceM $ "decide" ++ show v
let l = varToLit v
insertPartialAssignment l pa
unitPropagate l clauseDb (Decided l trail) units pa vars'

| otherwise
-> return True

unitPropagate :: forall s. Lit -> ClauseDB s -> Trail -> LitSet s -> PartialAssignment s -> VarSet -> ST s Bool

#ifdef TWO_WATCHED_LITERALS

Expand All @@ -448,14 +455,14 @@ unitPropagate !l !clauseDb !trail !units !pa !vars = do
-- traceM $ "unitPropagate " ++ show (l, dbSize)
watches <- lookupClauseDB (neg l) clauseDb
size <- sizeofVec watches
go units watches 0 0 size
go watches 0 0 size
where
go :: LitSet -> Vec s Watch -> Int -> Int -> Int -> ST s Bool
go !us watches i j size
go :: Vec s Watch -> Int -> Int -> Int -> ST s Bool
go watches i j size
| i >= size
= do
shrinkVec watches j
solveLoop clauseDb trail us pa vars
solveLoop clauseDb trail units pa vars

| otherwise
= readVec watches i >>= \ w@(W l' c) -> satisfied2_ pa c >>= \case
Expand All @@ -473,49 +480,53 @@ unitPropagate !l !clauseDb !trail !units !pa !vars = do

copy (i + 1) (j + 1)

backtrack clauseDb trail pa vars
backtrack clauseDb trail units pa vars
Satisfied_ -> do
writeVec watches j w
go us watches (i + 1) (j + 1) size
go watches (i + 1) (j + 1) size
Unit_ u -> do
writeVec watches j w
go (insertLitSet u us) watches (i + 1) (j + 1) size
insertLitSet u units
go watches (i + 1) (j + 1) size
Unresolved_ l1 l2
| l2 /= l', l2 /= l
-> do
insertWatch l2 w clauseDb
go us watches (i + 1) j size
go watches (i + 1) j size

| l1 /= l', l1 /= l
-> do
insertWatch l1 w clauseDb
go us watches (i + 1) j size
go watches (i + 1) j size

| otherwise
-> error ("watch" ++ show (l, l1, l2, l'))

#else

unitPropagate !_ !clauseDb !trail !units !pa !vars = go units clauseDb
unitPropagate !_ !clauseDb !trail !units !pa !vars = go clauseDb
where
go :: LitSet -> [Clause2] -> ST s Bool
go us [] = solveLoop clauseDb trail us pa vars
go us (c:cs) = satisfied2_ pa c >>= \case
Conflicting_ -> backtrack clauseDb trail pa vars
Satisfied_ -> go us cs
Unit_ u -> go (insertLitSet u us) cs
Unresolved_ _ _ -> go us cs
go :: [Clause2] -> ST s Bool
go [] = solveLoop clauseDb trail units pa vars
go (c:cs) = satisfied2_ pa c >>= \case
Conflicting_ -> backtrack clauseDb trail units pa vars
Satisfied_ -> go cs
Unit_ u -> do
insertLitSet u units
go cs
Unresolved_ _ _ -> go cs
#endif

backtrack :: ClauseDB s -> Trail -> PartialAssignment s -> VarSet -> ST s Bool
backtrack !_clauseDb End !_pa !_vars = return False
backtrack clauseDb (Deduced l trail) pa vars = do
backtrack :: ClauseDB s -> Trail -> LitSet s -> PartialAssignment s -> VarSet -> ST s Bool
backtrack !_clauseDb End !_units !_pa !_vars = return False
backtrack clauseDb (Deduced l trail) units pa vars = do
deletePartialAssignment l pa
backtrack clauseDb trail pa (insertVarSet (litToVar l) vars)
backtrack clauseDb (Decided l trail) pa vars = do
backtrack clauseDb trail units pa (insertVarSet (litToVar l) vars)
backtrack clauseDb (Decided l trail) units pa vars = do
deletePartialAssignment l pa
insertPartialAssignment (neg l) pa
unitPropagate (neg l) clauseDb (Deduced (neg l) trail) emptyLitSet pa vars
clearLitSet units
unitPropagate (neg l) clauseDb (Deduced (neg l) trail) units pa vars

-------------------------------------------------------------------------------
-- simplify
Expand Down
4 changes: 1 addition & 3 deletions sat-simple-pure/EST.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ module EST (
earlyExitEST,
) where

import GHC.Exts
(PromptTag#, State#, control0#, newPromptTag#, oneShot, prompt#, runRW#,
unsafeCoerce#)
import GHC.Exts (PromptTag#, State#, control0#, newPromptTag#, oneShot, prompt#, runRW#, unsafeCoerce#)
import GHC.ST (ST (..))

control0##
Expand Down
4 changes: 3 additions & 1 deletion sat-simple-pure/Lifted.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
module Lifted where
module Lifted (
Lifted (..),
) where

import Data.Kind (Type)
import GHC.Exts (UnliftedType)
Expand Down
112 changes: 112 additions & 0 deletions sat-simple-pure/SparseSet.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
module SparseSet (
SparseSet,
newSparseSet,
memberSparseSet,
insertSparseSet,
popSparseSet,
elemsSparseSet,
clearSparseSet,
) where

import Control.Monad.ST (ST)
import Data.Primitive.PrimArray
import Data.Primitive.PrimVar

-- $setup
-- >>> import Control.Monad.ST (runST)

-- | https://research.swtch.com/sparse
--
-- An 'Int' set which support efficient popping ('popSparseSet').
data SparseSet s = SS (PrimVar s Int) (MutablePrimArray s Int) (MutablePrimArray s Int)

-- | Create new sparse set
--
-- >>> runST $ newSparseSet 100 >>= elemsSparseSet
-- []
newSparseSet
:: Int -- ^ max integer
-> ST s (SparseSet s)
newSparseSet capacity = do
size <- newPrimVar 0
dense <- newPrimArray capacity
sparse <- newPrimArray capacity
return (SS size dense sparse)

-- | Test for membership
--
-- >>> runST $ do { set <- newSparseSet 100; mapM_ (insertSparseSet set) [3,5,7,11,13,11]; memberSparseSet set 10 }
-- False
--
-- >>> runST $ do { set <- newSparseSet 100; mapM_ (insertSparseSet set) [3,5,7,11,13,11]; memberSparseSet set 13 }
-- True
--
memberSparseSet :: SparseSet s -> Int -> ST s Bool
memberSparseSet (SS size dense sparse) x = do
n <- readPrimVar size
i <- readPrimArray sparse x
if i < n
then do
x' <- readPrimArray dense i
return (x' == x)
else return False

-- | Insert into set
--
-- >>> runST $ do { set <- newSparseSet 100; mapM_ (insertSparseSet set) [3,5,7,11,13,11]; elemsSparseSet set }
-- [13,11,7,5,3]
--
insertSparseSet :: SparseSet s -> Int -> ST s ()
insertSparseSet (SS size dense sparse) x = do
n <- readPrimVar size
i <- readPrimArray sparse x
if i < n
then do
x' <- readPrimArray dense i
if x == x' then return () else insert n
else insert n
where
insert n = do
writePrimArray dense n x
writePrimArray sparse x n
writePrimVar size (n + 1)

-- | Pop element from the set.
--
-- >>> runST $ do { set <- newSparseSet 100; mapM_ (insertSparseSet set) [3,5,7,11,13,11]; popSparseSet set }
-- Just 13
--
popSparseSet :: SparseSet s -> ST s (Maybe Int)
popSparseSet (SS size dense _sparse) = do
n <- readPrimVar size
if n <= 0
then return Nothing
else do
let !n' = n - 1
i <- readPrimArray dense n'
writePrimVar size n'
return (Just i)

-- | Clear sparse set.
--
-- >>> runST $ do { set <- newSparseSet 100; mapM_ (insertSparseSet set) [3,5,7,11,13,11]; clearSparseSet set; elemsSparseSet set }
-- []
--
clearSparseSet :: SparseSet s -> ST s ()
clearSparseSet (SS size _ _) = do
writePrimVar size 0

-- | Elements of the set
elemsSparseSet :: SparseSet s -> ST s [Int]
elemsSparseSet (SS size dense _sparse) = do
n <- readPrimVar size
go [] 0 n
where
go !acc !i !n
| i < n
= do
x <- readPrimArray dense i
go (x : acc) (i + 1) n

| otherwise
= return acc
11 changes: 8 additions & 3 deletions sat-simple-pure/UnliftedSTRef.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE UnboxedTuples #-}
module UnliftedSTRef where
module UnliftedSTRef (
USTRef,
newUSTRef,
readUSTRef,
writeUSTRef,
) where

import Data.Kind (Type)
import GHC.Exts (MutVar#, UnliftedType, newMutVar#, readMutVar#, writeMutVar#)
Expand All @@ -12,8 +17,8 @@ type USTRef :: Type -> UnliftedType -> Type
data USTRef s a = USTRef (MutVar# s a)

newUSTRef :: a -> ST s (USTRef s a)
newUSTRef init = ST $ \s1# ->
case newMutVar# init s1# of { (# s2#, var# #) ->
newUSTRef x = ST $ \s1# ->
case newMutVar# x s1# of { (# s2#, var# #) ->
(# s2#, USTRef var# #) }

readUSTRef :: USTRef s a -> ST s (Lifted a)
Expand Down
Loading

0 comments on commit bb6e02f

Please sign in to comment.