From e001348ae28178bbf97270977a39b1bb7d9732d9 Mon Sep 17 00:00:00 2001 From: Oleg Grenrus Date: Tue, 23 Jul 2024 23:56:23 +0300 Subject: [PATCH] dpll --- examples/sat-simple-sudoku.hs | 26 ++++++------ sat-simple-pure/DPLL.hs | 75 ++++++++++++++++++++++++++++------- 2 files changed, 75 insertions(+), 26 deletions(-) diff --git a/examples/sat-simple-sudoku.hs b/examples/sat-simple-sudoku.hs index 43a201d..9226d87 100644 --- a/examples/sat-simple-sudoku.hs +++ b/examples/sat-simple-sudoku.hs @@ -60,24 +60,26 @@ main = do initValues :: Nine (Nine Int) initValues = N9 -{- -- From https://en.wikipedia.org/w/index.php?title=Sudoku&oldid=543290082 + + (N9 5 3 4 6 7 8 9 1 2) + (N9 6 7 2 1 9 5 3 4 8) + (N9 1 9 8 3 4 2 5 6 7) + (N9 8 5 9 7 6 1 4 2 3) +{- + (N9 4 2 6 8 5 3 7 9 1) + (N9 7 1 3 9 2 4 8 5 6) + (N9 9 6 1 5 3 7 2 8 4) + (N9 2 8 7 4 1 9 6 3 5) + (N9 3 4 5 2 8 6 1 7 9) +-} + +{- (N9 5 3 0 0 7 0 0 0 0) (N9 6 0 0 1 9 5 0 0 0) (N9 0 9 8 0 0 0 0 6 0) (N9 8 0 0 0 6 0 0 0 3) - (N9 4 0 0 8 0 3 0 0 1) - (N9 7 0 0 0 2 0 0 0 6) - (N9 0 6 0 0 0 0 2 8 0) - (N9 0 0 0 4 1 9 0 0 5) - (N9 0 0 0 0 8 0 0 7 9) -} - (N9 5 3 4 6 7 8 9 1 2) - (N9 6 7 2 1 9 5 3 4 8) - (N9 1 9 8 3 4 2 5 6 7) - (N9 8 5 9 7 6 1 4 2 3) - -- (N9 8 0 0 0 6 0 0 0 3) - -- (N9 4 2 6 8 5 3 7 9 1) (N9 4 0 0 8 0 3 0 0 1) (N9 7 0 0 0 2 0 0 0 6) (N9 0 6 0 0 0 0 2 8 0) diff --git a/sat-simple-pure/DPLL.hs b/sat-simple-pure/DPLL.hs index 10be1be..c43b28b 100644 --- a/sat-simple-pure/DPLL.hs +++ b/sat-simple-pure/DPLL.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE LambdaCase #-} @@ -17,12 +18,13 @@ module DPLL ( modelValue, ) where -import Control.Monad (forM_) +-- This is buggy +-- #define TWO_WATCHED_LITERALS + import Control.Monad.ST (ST) import Data.Bits import Data.Functor ((<&>)) import Data.IntSet (IntSet) -import Data.Primitive.Array import Data.Primitive.ByteArray import Data.Primitive.PrimArray import Data.Primitive.Types (Prim) @@ -31,11 +33,16 @@ import Data.Word (Word8) import Lifted import UnliftedSTRef + +#ifdef TWO_WATCHED_LITERALS import Vec +import Control.Monad (forM_) +import Data.Primitive.Array +#endif import qualified Data.IntSet as IntSet -import Debug.Trace +-- import Debug.Trace ------------------------------------------------------------------------------- -- Literals @@ -174,6 +181,8 @@ type Clauses = [Clause2] -- ClauseDB ------------------------------------------------------------------------------- +#ifdef TWO_WATCHED_LITERALS + newtype ClauseDB s = CDB (MutableArray s (Vec s Watch)) data Watch = W !Lit !Clause2 @@ -190,8 +199,8 @@ newClauseDB size = do insertClauseDB :: Clause2 -> ClauseDB s -> ST s () insertClauseDB clause@(MkClause2 l1 l2 _) cdb = do - insertWatch (neg l1) (W l2 clause) cdb - insertWatch (neg l2) (W l1 clause) cdb + insertWatch l1 (W l2 clause) cdb + insertWatch l2 (W l1 clause) cdb insertWatch :: Lit -> Watch -> ClauseDB s -> ST s () insertWatch (MkLit l) !w (CDB cdb) = do @@ -203,6 +212,23 @@ lookupClauseDB :: Lit -> ClauseDB s -> ST s (Vec s Watch) lookupClauseDB (MkLit l) (CDB arr) = do readArray arr l +_sizeofClauseDB :: ClauseDB s -> ST s Int +_sizeofClauseDB (CDB arr) = go 0 0 (sizeofMutableArray arr) + where + go !acc !i !size + | i < size + = do + vec <- readArray arr i + elm <- sizeofVec vec + go (acc + elm) (i + 1) size + + | otherwise + = return acc + +#else +type ClauseDB s = [Clause2] +#endif + ------------------------------------------------------------------------------- -- Clause ------------------------------------------------------------------------------- @@ -373,12 +399,18 @@ data Trail solve :: Solver s -> ST s Bool solve solver@Solver {..} = whenOk_ (simplify solver) $ do - litCount <- readSTRef nextLit clauses' <- readSTRef clauses vars <- readSTRef variables -- traceM $ "solve " ++ show (length clauses') + +#ifdef TWO_WATCHED_LITERALS + litCount <- readSTRef nextLit clauseDB <- newClauseDB litCount forM_ clauses' $ \c -> insertClauseDB c clauseDB +#else + let clauseDB = clauses' +#endif + solveLoop clauseDB End emptyLitSet solution vars >>= \case False -> conflict solver True -> return True @@ -404,12 +436,14 @@ solveLoop !clauseDb !trail !units !pa !vars | otherwise = return True -{-# SCC unitPropagate #-} - unitPropagate :: forall s. Lit -> ClauseDB s -> Trail -> LitSet -> PartialAssignment s -> VarSet -> ST s Bool + +#ifdef TWO_WATCHED_LITERALS + unitPropagate !l !clauseDb !trail !units !pa !vars = do - -- traceM $ "unitPropagate " ++ show (l, _sizeClauseDB clauseDb) - watches <- lookupClauseDB l clauseDb + -- dbSize <- _sizeofClauseDB clauseDb + -- traceM $ "unitPropagate " ++ show (l, dbSize) + watches <- lookupClauseDB (neg l) clauseDb size <- sizeofVec watches go units watches 0 0 size where @@ -444,19 +478,32 @@ unitPropagate !l !clauseDb !trail !units !pa !vars = do writeVec watches j w go (insertLitSet u us) watches (i + 1) (j + 1) size Unresolved_ l1 l2 - | l2 /= l' + | l2 /= l', l2 /= l -> do - insertWatch (neg l2) w clauseDb + insertWatch l2 w clauseDb go us watches (i + 1) j size - | l1 /= l' + | l1 /= l', l1 /= l -> do - insertWatch (neg l1) w clauseDb + insertWatch l1 w clauseDb go us watches (i + 1) j size | otherwise -> error ("watch" ++ show (l, l1, l2, l')) +#else + +unitPropagate !_ !clauseDb !trail !units !pa !vars = go units clauseDb + where + go :: LitSet -> [Clause2] -> ST s Bool + go us [] = solveLoop clauseDb trail us pa vars + go us (c:cs) = satisfied2_ pa c >>= \case + Conflicting_ -> backtrack clauseDb trail pa vars + Satisfied_ -> go us cs + Unit_ u -> go (insertLitSet u us) cs + Unresolved_ _ _ -> go us cs +#endif + backtrack :: ClauseDB s -> Trail -> PartialAssignment s -> VarSet -> ST s Bool backtrack !_clauseDb End !_pa !_vars = return False backtrack clauseDb (Deduced l trail) pa vars = do