Skip to content

Commit

Permalink
dpll
Browse files Browse the repository at this point in the history
  • Loading branch information
phadej committed Jul 23, 2024
1 parent 7e45adf commit e001348
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 26 deletions.
26 changes: 14 additions & 12 deletions examples/sat-simple-sudoku.hs
Original file line number Diff line number Diff line change
Expand Up @@ -60,24 +60,26 @@ main = do

initValues :: Nine (Nine Int)
initValues = N9
{-
-- From https://en.wikipedia.org/w/index.php?title=Sudoku&oldid=543290082

(N9 5 3 4 6 7 8 9 1 2)
(N9 6 7 2 1 9 5 3 4 8)
(N9 1 9 8 3 4 2 5 6 7)
(N9 8 5 9 7 6 1 4 2 3)
{-
(N9 4 2 6 8 5 3 7 9 1)
(N9 7 1 3 9 2 4 8 5 6)
(N9 9 6 1 5 3 7 2 8 4)
(N9 2 8 7 4 1 9 6 3 5)
(N9 3 4 5 2 8 6 1 7 9)
-}

{-
(N9 5 3 0 0 7 0 0 0 0)
(N9 6 0 0 1 9 5 0 0 0)
(N9 0 9 8 0 0 0 0 6 0)
(N9 8 0 0 0 6 0 0 0 3)
(N9 4 0 0 8 0 3 0 0 1)
(N9 7 0 0 0 2 0 0 0 6)
(N9 0 6 0 0 0 0 2 8 0)
(N9 0 0 0 4 1 9 0 0 5)
(N9 0 0 0 0 8 0 0 7 9)
-}
(N9 5 3 4 6 7 8 9 1 2)
(N9 6 7 2 1 9 5 3 4 8)
(N9 1 9 8 3 4 2 5 6 7)
(N9 8 5 9 7 6 1 4 2 3)
-- (N9 8 0 0 0 6 0 0 0 3)
-- (N9 4 2 6 8 5 3 7 9 1)
(N9 4 0 0 8 0 3 0 0 1)
(N9 7 0 0 0 2 0 0 0 6)
(N9 0 6 0 0 0 0 2 8 0)
Expand Down
75 changes: 61 additions & 14 deletions sat-simple-pure/DPLL.hs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
Expand All @@ -17,12 +18,13 @@ module DPLL (
modelValue,
) where

import Control.Monad (forM_)
-- This is buggy
-- #define TWO_WATCHED_LITERALS

import Control.Monad.ST (ST)
import Data.Bits
import Data.Functor ((<&>))
import Data.IntSet (IntSet)
import Data.Primitive.Array
import Data.Primitive.ByteArray
import Data.Primitive.PrimArray
import Data.Primitive.Types (Prim)
Expand All @@ -31,11 +33,16 @@ import Data.Word (Word8)

import Lifted
import UnliftedSTRef

#ifdef TWO_WATCHED_LITERALS
import Vec
import Control.Monad (forM_)
import Data.Primitive.Array
#endif

import qualified Data.IntSet as IntSet

import Debug.Trace
-- import Debug.Trace

-------------------------------------------------------------------------------
-- Literals
Expand Down Expand Up @@ -174,6 +181,8 @@ type Clauses = [Clause2]
-- ClauseDB
-------------------------------------------------------------------------------

#ifdef TWO_WATCHED_LITERALS

newtype ClauseDB s = CDB (MutableArray s (Vec s Watch))

data Watch = W !Lit !Clause2
Expand All @@ -190,8 +199,8 @@ newClauseDB size = do

insertClauseDB :: Clause2 -> ClauseDB s -> ST s ()
insertClauseDB clause@(MkClause2 l1 l2 _) cdb = do
insertWatch (neg l1) (W l2 clause) cdb
insertWatch (neg l2) (W l1 clause) cdb
insertWatch l1 (W l2 clause) cdb
insertWatch l2 (W l1 clause) cdb

insertWatch :: Lit -> Watch -> ClauseDB s -> ST s ()
insertWatch (MkLit l) !w (CDB cdb) = do
Expand All @@ -203,6 +212,23 @@ lookupClauseDB :: Lit -> ClauseDB s -> ST s (Vec s Watch)
lookupClauseDB (MkLit l) (CDB arr) = do
readArray arr l

_sizeofClauseDB :: ClauseDB s -> ST s Int
_sizeofClauseDB (CDB arr) = go 0 0 (sizeofMutableArray arr)
where
go !acc !i !size
| i < size
= do
vec <- readArray arr i
elm <- sizeofVec vec
go (acc + elm) (i + 1) size

| otherwise
= return acc

#else
type ClauseDB s = [Clause2]
#endif

-------------------------------------------------------------------------------
-- Clause
-------------------------------------------------------------------------------
Expand Down Expand Up @@ -373,12 +399,18 @@ data Trail

solve :: Solver s -> ST s Bool
solve solver@Solver {..} = whenOk_ (simplify solver) $ do
litCount <- readSTRef nextLit
clauses' <- readSTRef clauses
vars <- readSTRef variables
-- traceM $ "solve " ++ show (length clauses')

#ifdef TWO_WATCHED_LITERALS
litCount <- readSTRef nextLit
clauseDB <- newClauseDB litCount
forM_ clauses' $ \c -> insertClauseDB c clauseDB
#else
let clauseDB = clauses'
#endif

solveLoop clauseDB End emptyLitSet solution vars >>= \case
False -> conflict solver
True -> return True
Expand All @@ -404,12 +436,14 @@ solveLoop !clauseDb !trail !units !pa !vars
| otherwise
= return True

{-# SCC unitPropagate #-}

unitPropagate :: forall s. Lit -> ClauseDB s -> Trail -> LitSet -> PartialAssignment s -> VarSet -> ST s Bool

#ifdef TWO_WATCHED_LITERALS

unitPropagate !l !clauseDb !trail !units !pa !vars = do
-- traceM $ "unitPropagate " ++ show (l, _sizeClauseDB clauseDb)
watches <- lookupClauseDB l clauseDb
-- dbSize <- _sizeofClauseDB clauseDb
-- traceM $ "unitPropagate " ++ show (l, dbSize)
watches <- lookupClauseDB (neg l) clauseDb
size <- sizeofVec watches
go units watches 0 0 size
where
Expand Down Expand Up @@ -444,19 +478,32 @@ unitPropagate !l !clauseDb !trail !units !pa !vars = do
writeVec watches j w
go (insertLitSet u us) watches (i + 1) (j + 1) size
Unresolved_ l1 l2
| l2 /= l'
| l2 /= l', l2 /= l
-> do
insertWatch (neg l2) w clauseDb
insertWatch l2 w clauseDb
go us watches (i + 1) j size

| l1 /= l'
| l1 /= l', l1 /= l
-> do
insertWatch (neg l1) w clauseDb
insertWatch l1 w clauseDb
go us watches (i + 1) j size

| otherwise
-> error ("watch" ++ show (l, l1, l2, l'))

#else

unitPropagate !_ !clauseDb !trail !units !pa !vars = go units clauseDb
where
go :: LitSet -> [Clause2] -> ST s Bool
go us [] = solveLoop clauseDb trail us pa vars
go us (c:cs) = satisfied2_ pa c >>= \case
Conflicting_ -> backtrack clauseDb trail pa vars
Satisfied_ -> go us cs
Unit_ u -> go (insertLitSet u us) cs
Unresolved_ _ _ -> go us cs
#endif

backtrack :: ClauseDB s -> Trail -> PartialAssignment s -> VarSet -> ST s Bool
backtrack !_clauseDb End !_pa !_vars = return False
backtrack clauseDb (Deduced l trail) pa vars = do
Expand Down

0 comments on commit e001348

Please sign in to comment.