Skip to content

Commit

Permalink
✨ Add more rewrites
Browse files Browse the repository at this point in the history
  • Loading branch information
lsrcz committed Feb 10, 2025
1 parent ce8842b commit 4ebfdf0
Show file tree
Hide file tree
Showing 3 changed files with 285 additions and 55 deletions.
142 changes: 116 additions & 26 deletions src/Grisette/Internal/SymPrim/Prim/Internal/Instances/PEvalShiftTerm.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
Expand All @@ -23,11 +24,13 @@ where

import Data.Bits (Bits (isSigned, shiftR, zeroBits), FiniteBits (finiteBitSize))
import Data.Proxy (Proxy (Proxy))
import GHC.TypeLits (KnownNat, type (<=))
import Data.Typeable ((:~:) (Refl))
import GHC.TypeLits (KnownNat, type (+), type (<=))
import Grisette.Internal.Core.Data.Class.SymShift (SymShift (symShift))
import Grisette.Internal.SymPrim.BV (IntN, WordN)
import Grisette.Internal.SymPrim.Prim.Internal.Term
( PEvalShiftTerm
( PEvalBVTerm (pevalBVConcatTerm, pevalBVExtendTerm),
PEvalShiftTerm
( pevalShiftLeftTerm,
pevalShiftRightTerm,
withSbvShiftTermConstraint
Expand All @@ -39,30 +42,57 @@ import Grisette.Internal.SymPrim.Prim.Internal.Term
conTerm,
shiftLeftTerm,
shiftRightTerm,
unsafePevalBVSelectTerm,
pattern ConTerm,
pattern SupportedTerm,
)
import Grisette.Internal.SymPrim.Prim.Internal.Unfold (unaryUnfoldOnce)
import Grisette.Internal.Utils.Parameterized
( LeqProof (LeqProof),
NatRepr,
SomePositiveNatRepr (SomePositiveNatRepr),
mkPositiveNatRepr,
natRepr,
subNat,
unsafeAxiom,
unsafeLeqProof,
)

-- | Partial evaluation of symbolic shift left term for finite bits types.
pevalFiniteBitsSymShiftShiftLeftTerm ::
forall a.
(Integral a, SymShift a, FiniteBits a, PEvalShiftTerm a) =>
Term a ->
Term a ->
Term a
forall bv n.
( forall m. (KnownNat m, 1 <= m) => Integral (bv m),
forall m. (KnownNat m, 1 <= m) => SymShift (bv m),
forall m. (KnownNat m, 1 <= m) => FiniteBits (bv m),
forall m. (KnownNat m, 1 <= m) => SupportedPrim (bv m),
forall m. (KnownNat m, 1 <= m) => PEvalShiftTerm (bv m),
PEvalBVTerm bv,
KnownNat n,
1 <= n
) =>
Term (bv n) ->
Term (bv n) ->
Term (bv n)
pevalFiniteBitsSymShiftShiftLeftTerm t@SupportedTerm n =
unaryUnfoldOnce
(`doPevalFiniteBitsSymShiftShiftLeftTerm` n)
(`shiftLeftTerm` n)
t

doPevalFiniteBitsSymShiftShiftLeftTerm ::
forall a.
(Integral a, SymShift a, FiniteBits a, SupportedPrim a) =>
Term a ->
Term a ->
Maybe (Term a)
forall bv n.
( forall m. (KnownNat m, 1 <= m) => Integral (bv m),
forall m. (KnownNat m, 1 <= m) => SymShift (bv m),
forall m. (KnownNat m, 1 <= m) => FiniteBits (bv m),
forall m. (KnownNat m, 1 <= m) => SupportedPrim (bv m),
forall m. (KnownNat m, 1 <= m) => PEvalShiftTerm (bv m),
PEvalBVTerm bv,
KnownNat n,
1 <= n
) =>
Term (bv n) ->
Term (bv n) ->
Maybe (Term (bv n))
doPevalFiniteBitsSymShiftShiftLeftTerm (ConTerm a) (ConTerm n)
| n >= 0 =
if (fromIntegral n :: Integer) >= fromIntegral (finiteBitSize n)
Expand All @@ -73,29 +103,65 @@ doPevalFiniteBitsSymShiftShiftLeftTerm x (ConTerm 0) = Just x
-- doPevalShiftLeftTerm (ShiftLeftTerm _ x (ConTerm n)) (ConTerm n1)
-- | n >= 0 && n1 >= 0 = Just $ pevalShiftLeftTerm x (conTerm $ n + n1)
doPevalFiniteBitsSymShiftShiftLeftTerm _ (ConTerm n)
| n >= 0 && (fromIntegral n :: Integer) >= fromIntegral (finiteBitSize n) =
| n > 0 && (fromIntegral n :: Integer) >= fromIntegral (finiteBitSize n) =
Just $ conTerm zeroBits
doPevalFiniteBitsSymShiftShiftLeftTerm x (ConTerm shiftAmount)
| shiftAmount > 0 =
case (namount, nremaining) of
( SomePositiveNatRepr (_ :: NatRepr amount),
SomePositiveNatRepr (nremaining :: NatRepr remaining)
) ->
case ( unsafeLeqProof @remaining @n,
unsafeAxiom @(remaining + amount) @n
) of
(LeqProof, Refl) ->
Just $
pevalBVConcatTerm
(unsafePevalBVSelectTerm nn (natRepr @0) nremaining x)
(conTerm zeroBits :: Term (bv amount))
where
nn = natRepr @n
namount = mkPositiveNatRepr $ fromIntegral shiftAmount
nremaining =
mkPositiveNatRepr $
fromIntegral (finiteBitSize shiftAmount) - fromIntegral shiftAmount
doPevalFiniteBitsSymShiftShiftLeftTerm _ _ = Nothing

-- | Partial evaluation of symbolic shift right term for finite bits types.
pevalFiniteBitsSymShiftShiftRightTerm ::
forall a.
(Integral a, SymShift a, FiniteBits a, PEvalShiftTerm a) =>
Term a ->
Term a ->
Term a
forall bv n.
( forall m. (KnownNat m, 1 <= m) => Integral (bv m),
forall m. (KnownNat m, 1 <= m) => SymShift (bv m),
forall m. (KnownNat m, 1 <= m) => FiniteBits (bv m),
forall m. (KnownNat m, 1 <= m) => SupportedPrim (bv m),
forall m. (KnownNat m, 1 <= m) => PEvalShiftTerm (bv m),
PEvalBVTerm bv,
KnownNat n,
1 <= n
) =>
Term (bv n) ->
Term (bv n) ->
Term (bv n)
pevalFiniteBitsSymShiftShiftRightTerm t@SupportedTerm n =
unaryUnfoldOnce
(`doPevalFiniteBitsSymShiftShiftRightTerm` n)
(`shiftRightTerm` n)
t

doPevalFiniteBitsSymShiftShiftRightTerm ::
forall a.
(Integral a, SymShift a, FiniteBits a, SupportedPrim a) =>
Term a ->
Term a ->
Maybe (Term a)
forall bv n.
( forall m. (KnownNat m, 1 <= m) => Integral (bv m),
forall m. (KnownNat m, 1 <= m) => SymShift (bv m),
forall m. (KnownNat m, 1 <= m) => FiniteBits (bv m),
forall m. (KnownNat m, 1 <= m) => SupportedPrim (bv m),
forall m. (KnownNat m, 1 <= m) => PEvalShiftTerm (bv m),
PEvalBVTerm bv,
KnownNat n,
1 <= n
) =>
Term (bv n) ->
Term (bv n) ->
Maybe (Term (bv n))
doPevalFiniteBitsSymShiftShiftRightTerm (ConTerm a) (ConTerm n)
| n >= 0 && not (isSigned a) =
if (fromIntegral n :: Integer) >= fromIntegral (finiteBitSize n)
Expand All @@ -107,10 +173,34 @@ doPevalFiniteBitsSymShiftShiftRightTerm (ConTerm a) (ConTerm n)
doPevalFiniteBitsSymShiftShiftRightTerm x (ConTerm 0) = Just x
-- doPevalFiniteBitsSymShiftShiftRightTerm (ShiftRightTerm _ x (ConTerm n)) (ConTerm n1)
-- | n >= 0 && n1 >= 0 = Just $ pevalFiniteBitsSymShiftShiftRightTerm x (conTerm $ n + n1)
doPevalFiniteBitsSymShiftShiftRightTerm _ (ConTerm n)
| not (isSigned n)
&& (fromIntegral n :: Integer) >= fromIntegral (finiteBitSize n) =
doPevalFiniteBitsSymShiftShiftRightTerm x (ConTerm shiftAmount)
| not (isSigned shiftAmount)
&& (fromIntegral shiftAmount :: Integer) >= fromIntegral (finiteBitSize shiftAmount) =
Just $ conTerm zeroBits
| isSigned shiftAmount
&& (fromIntegral shiftAmount :: Integer) >= fromIntegral (finiteBitSize shiftAmount) =
Just $ pevalBVExtendTerm True nn $ unsafePevalBVSelectTerm nn nnp1 none x
where
nn = natRepr @n
none = natRepr @1
nnp1 = subNat nn none
doPevalFiniteBitsSymShiftShiftRightTerm x (ConTerm shiftAmount)
| shiftAmount > 0 =
case (namount, nremaining) of
( SomePositiveNatRepr namount,
SomePositiveNatRepr (nremaining :: NatRepr remaining)
) ->
case unsafeLeqProof @remaining @n of
LeqProof ->
Just $
pevalBVExtendTerm (isSigned shiftAmount) nn $
unsafePevalBVSelectTerm nn namount nremaining x
where
nn = natRepr @n
namount = mkPositiveNatRepr $ fromIntegral shiftAmount
nremaining =
mkPositiveNatRepr $
fromIntegral (finiteBitSize shiftAmount) - fromIntegral shiftAmount
doPevalFiniteBitsSymShiftShiftRightTerm _ _ = Nothing

instance (KnownNat n, 1 <= n) => PEvalShiftTerm (IntN n) where
Expand Down
83 changes: 67 additions & 16 deletions src/Grisette/Internal/SymPrim/Prim/Internal/Term.hs
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ module Grisette.Internal.SymPrim.Prim.Internal.Term
unsafePevalBVConcatTerm,
unsafePevalBVExtendTerm,
unsafePevalBVSelectTerm,
boolToBVTerm,

-- * num
pevalDefaultAddNumTerm,
Expand Down Expand Up @@ -328,7 +329,10 @@ import qualified Control.Monad.Writer.Lazy as Lazy
import qualified Control.Monad.Writer.Strict as Strict
import Data.Atomics (atomicModifyIORefCAS_)
import qualified Data.Binary as Binary
import Data.Bits (Bits (complement, isSigned, xor, zeroBits, (.&.), (.|.)), FiniteBits (countLeadingZeros))
import Data.Bits
( Bits (complement, isSigned, xor, zeroBits, (.&.), (.|.)),
FiniteBits (countLeadingZeros),
)
import Data.Bytes.Serial (Serial (deserialize, serialize))
import Data.Coerce (coerce)
import qualified Data.HashMap.Strict as HM
Expand Down Expand Up @@ -5751,7 +5755,8 @@ leOrdTerm = unsafeInCurThread2 curThreadLeOrdTerm

-- | Construct and internalizing a 'AndBitsTerm'.
andBitsTerm :: (PEvalBitwiseTerm a) => Term a -> Term a -> Term a
andBitsTerm = unsafeInCurThread2 curThreadAndBitsTerm
andBitsTerm a b =
unsafeInCurThread2 curThreadAndBitsTerm a b
{-# NOINLINE andBitsTerm #-}

-- | Construct and internalizing a 'OrBitsTerm'.
Expand Down Expand Up @@ -6863,33 +6868,33 @@ pevalITEBVTerm cond (AndBitsTerm a b) (AndBitsTerm c d)
| b == c = Just $ andBitsTerm b $ pevalITETerm cond a d
| b == d = Just $ andBitsTerm b $ pevalITETerm cond a c
pevalITEBVTerm cond (AndBitsTerm a b) c
| a == c = Just $ andBitsTerm c $ orBitsTerm (expandCond $ pevalNotTerm cond) b
| b == c = Just $ andBitsTerm c $ orBitsTerm (expandCond $ pevalNotTerm cond) a
| a == c = Just $ andBitsTerm c $ pevalOrBitsTerm (boolToBVTerm $ pevalNotTerm cond) b
| b == c = Just $ andBitsTerm c $ pevalOrBitsTerm (boolToBVTerm $ pevalNotTerm cond) a
pevalITEBVTerm cond a (AndBitsTerm b c)
| a == b = Just $ andBitsTerm a $ orBitsTerm (expandCond cond) c
| a == c = Just $ andBitsTerm a $ orBitsTerm (expandCond cond) b
| a == b = Just $ andBitsTerm a $ pevalOrBitsTerm (boolToBVTerm cond) c
| a == c = Just $ andBitsTerm a $ pevalOrBitsTerm (boolToBVTerm cond) b
pevalITEBVTerm cond (OrBitsTerm a b) (OrBitsTerm c d)
| a == c = Just $ orBitsTerm a $ pevalITETerm cond b d
| a == d = Just $ orBitsTerm a $ pevalITETerm cond b c
| b == c = Just $ orBitsTerm b $ pevalITETerm cond a d
| b == d = Just $ orBitsTerm b $ pevalITETerm cond a c
pevalITEBVTerm cond (OrBitsTerm a b) c
| a == c = Just $ orBitsTerm c $ andBitsTerm (expandCond cond) b
| b == c = Just $ orBitsTerm c $ andBitsTerm (expandCond cond) a
| a == c = Just $ orBitsTerm c $ pevalAndBitsTerm (boolToBVTerm cond) b
| b == c = Just $ orBitsTerm c $ pevalAndBitsTerm (boolToBVTerm cond) a
pevalITEBVTerm cond a (OrBitsTerm b c)
| a == b = Just $ orBitsTerm a $ andBitsTerm (expandCond $ pevalNotTerm cond) c
| a == c = Just $ orBitsTerm a $ andBitsTerm (expandCond $ pevalNotTerm cond) b
| a == b = Just $ orBitsTerm a $ pevalAndBitsTerm (boolToBVTerm $ pevalNotTerm cond) c
| a == c = Just $ orBitsTerm a $ pevalAndBitsTerm (boolToBVTerm $ pevalNotTerm cond) b
pevalITEBVTerm _ _ _ = Nothing

expandCond ::
boolToBVTerm ::
forall bv n.
( PEvalBVTerm bv,
KnownNat n,
1 <= n,
forall m. (KnownNat m, 1 <= m) => SupportedPrim (bv m)
) =>
Term Bool -> Term (bv n)
expandCond cond =
boolToBVTerm cond =
let bv =
case cond of
NotTerm c -> iteTerm c (conTerm 0) (conTerm 1)
Expand Down Expand Up @@ -7264,10 +7269,34 @@ pevalDefaultAndBitsTerm = binaryUnfoldOnce doPevalAndBitsTerm andBitsTerm
acok = ac .&. (ac + 1) == 0
doPevalAndBitsTerm a b@(ConTerm _) = doPevalAndBitsTerm b a
doPevalAndBitsTerm a b | a == b = Just a
doPevalAndBitsTerm (ITETerm cond a@(ConTerm _) b@(ConTerm _)) c =
Just $ pevalITETerm cond (pevalAndBitsTerm a c) (pevalAndBitsTerm b c)
doPevalAndBitsTerm a (ITETerm cond b@(ConTerm _) c@(ConTerm _)) =
Just $ pevalITETerm cond (pevalAndBitsTerm a b) (pevalAndBitsTerm a c)
doPevalAndBitsTerm (ITETerm cond a@(ConTerm av) b@(ConTerm bv)) c
| av `elem` [0, -1] || bv `elem` [0, -1] =
Just $ pevalITETerm cond (pevalAndBitsTerm a c) (pevalAndBitsTerm b c)
doPevalAndBitsTerm a (ITETerm cond b@(ConTerm bv) c@(ConTerm cv))
| bv `elem` [0, -1] || cv `elem` [0, -1] =
Just $ pevalITETerm cond (pevalAndBitsTerm a b) (pevalAndBitsTerm a c)
doPevalAndBitsTerm (ITETerm cond a@(ConTerm v) b) c
| v == 0 = Just $ pevalITETerm cond a (pevalAndBitsTerm b c)
doPevalAndBitsTerm (ITETerm cond a b@(ConTerm v)) c
| v == 0 = Just $ pevalITETerm cond (pevalAndBitsTerm a c) b
doPevalAndBitsTerm a (ITETerm cond b@(ConTerm v) c)
| v == 0 = Just $ pevalITETerm cond b (pevalAndBitsTerm a c)
doPevalAndBitsTerm a (ITETerm cond b c@(ConTerm v))
| v == 0 = Just $ pevalITETerm cond (pevalAndBitsTerm a b) c
doPevalAndBitsTerm (BVExtendTerm True pl (ITETerm cond at@(ConTerm a) bt@(ConTerm b))) c
| a `elem` [0, -1] && b `elem` [0, -1] =
Just $
pevalITETerm
cond
(pevalAndBitsTerm (pevalBVExtendTerm True pl at) c)
(pevalAndBitsTerm (pevalBVExtendTerm True pl bt) c)
doPevalAndBitsTerm a (BVExtendTerm True pl (ITETerm cond bt@(ConTerm b) ct@(ConTerm c)))
| b `elem` [0, -1] && c `elem` [0, -1] =
Just $
pevalITETerm
cond
(pevalAndBitsTerm a (pevalBVExtendTerm True pl bt))
(pevalAndBitsTerm a (pevalBVExtendTerm True pl ct))
doPevalAndBitsTerm a b = bitOpOnConcat @bv @m pevalDefaultAndBitsTerm a b

pevalDefaultOrBitsTerm ::
Expand Down Expand Up @@ -7322,6 +7351,28 @@ pevalDefaultOrBitsTerm = binaryUnfoldOnce doPevalOrBitsTerm orBitsTerm
Just $ pevalITETerm cond (pevalOrBitsTerm a c) (pevalOrBitsTerm b c)
doPevalOrBitsTerm a (ITETerm cond b@(ConTerm _) c@(ConTerm _)) =
Just $ pevalITETerm cond (pevalOrBitsTerm a b) (pevalOrBitsTerm a c)
doPevalOrBitsTerm (ITETerm cond a@(ConTerm v) b) c
| v == -1 = Just $ pevalITETerm cond a (pevalOrBitsTerm b c)
doPevalOrBitsTerm (ITETerm cond a b@(ConTerm v)) c
| v == -1 = Just $ pevalITETerm cond (pevalOrBitsTerm a c) b
doPevalOrBitsTerm a (ITETerm cond b@(ConTerm v) c)
| v == -1 = Just $ pevalITETerm cond b (pevalOrBitsTerm a c)
doPevalOrBitsTerm a (ITETerm cond b c@(ConTerm v))
| v == -1 = Just $ pevalITETerm cond (pevalOrBitsTerm a b) c
doPevalOrBitsTerm (BVExtendTerm True pl (ITETerm cond at@(ConTerm a) bt@(ConTerm b))) c
| a `elem` [0, -1] && b `elem` [0, -1] =
Just $
pevalITETerm
cond
(pevalOrBitsTerm (pevalBVExtendTerm True pl at) c)
(pevalOrBitsTerm (pevalBVExtendTerm True pl bt) c)
doPevalOrBitsTerm a (BVExtendTerm True pl (ITETerm cond bt@(ConTerm b) ct@(ConTerm c)))
| b `elem` [0, -1] && c `elem` [0, -1] =
Just $
pevalITETerm
cond
(pevalOrBitsTerm a (pevalBVExtendTerm True pl bt))
(pevalOrBitsTerm a (pevalBVExtendTerm True pl ct))
doPevalOrBitsTerm a b = bitOpOnConcat @bv @m pevalDefaultOrBitsTerm a b

pevalDefaultXorBitsTerm ::
Expand Down
Loading

0 comments on commit 4ebfdf0

Please sign in to comment.