diff --git a/accelerate.cabal b/accelerate.cabal index 64aec140c..29957e9d5 100644 --- a/accelerate.cabal +++ b/accelerate.cabal @@ -333,7 +333,6 @@ Library Data.Array.Accelerate.Error Data.Array.Accelerate.Lifetime Data.Array.Accelerate.Pretty - Data.Array.Accelerate.Product Data.Array.Accelerate.Smart Data.Array.Accelerate.Trafo Data.Array.Accelerate.Type @@ -381,6 +380,7 @@ Library Data.Array.Accelerate.Trafo.Base Data.Array.Accelerate.Trafo.Config Data.Array.Accelerate.Trafo.Fusion + Data.Array.Accelerate.Trafo.LetSplit Data.Array.Accelerate.Trafo.Sharing Data.Array.Accelerate.Trafo.Shrink Data.Array.Accelerate.Trafo.Simplify diff --git a/src/Data/Array/Accelerate/Product.hs b/icebox/Product.hs similarity index 100% rename from src/Data/Array/Accelerate/Product.hs rename to icebox/Product.hs diff --git a/src/Data/Array/Accelerate.hs b/src/Data/Array/Accelerate.hs index 461e52546..1bf0f522a 100644 --- a/src/Data/Array/Accelerate.hs +++ b/src/Data/Array/Accelerate.hs @@ -303,6 +303,9 @@ module Data.Array.Accelerate ( -- ** Scalar data types Exp, + -- ** SIMD vectors + Vec, VecElt, + -- ** Type classes -- *** Basic type classes Eq(..), @@ -333,7 +336,7 @@ module Data.Array.Accelerate ( -- ** Pattern synonyms -- $pattern_synonyms -- - pattern Pattern, IsProduct, IsTuple, + pattern Pattern, pattern T2, pattern T3, pattern T4, pattern T5, pattern T6, pattern T7, pattern T8, pattern T9, pattern T10, pattern T11, pattern T12, pattern T13, pattern T14, pattern T15, pattern T16, @@ -342,6 +345,10 @@ module Data.Array.Accelerate ( pattern I0, pattern I1, pattern I2, pattern I3, pattern I4, pattern I5, pattern I6, pattern I7, pattern I8, pattern I9, + pattern V2, pattern V2_, pattern V3, pattern V3_, + pattern V4, pattern V4_, pattern V8, pattern V8_, + pattern V16, pattern V16_, + pattern True_, pattern False_, -- ** Scalar operations @@ -417,7 +424,6 @@ import Data.Array.Accelerate.Classes import Data.Array.Accelerate.Language import Data.Array.Accelerate.Pattern import Data.Array.Accelerate.Prelude -import Data.Array.Accelerate.Product import Data.Array.Accelerate.Pretty () -- show instances import Data.Array.Accelerate.Type import qualified Data.Array.Accelerate.Array.Sugar as S diff --git a/src/Data/Array/Accelerate/AST.hs b/src/Data/Array/Accelerate/AST.hs index 3da133ce4..425f94da8 100644 --- a/src/Data/Array/Accelerate/AST.hs +++ b/src/Data/Array/Accelerate/AST.hs @@ -1,6 +1,7 @@ {-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} -{-# LANGUAGE DeriveDataTypeable #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE EmptyCase #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} @@ -83,48 +84,53 @@ module Data.Array.Accelerate.AST ( -- * Typed de Bruijn indices - Idx(..), idxToInt, tupleIdxToInt, ArrayVar(..), ArrayVars(..), + Idx(..), idxToInt, Var(..), Vars(..), TupR(..), ArrayVar, ArrayVars, ExpVar, ExpVars, + evars, varsType, LeftHandSide(..), ALeftHandSide, ELeftHandSide, -- * Valuation environment - Val(..), ValElt(..), push, prj, prjElt, + Val(..), push, prj, -- * Accelerated array expressions PreOpenAfun(..), OpenAfun, PreAfun, Afun, PreOpenAcc(..), OpenAcc(..), Acc, - PreBoundary(..), Boundary, Stencil(..), StencilR(..), - LeftHandSide(..), HasArraysRepr(..), lhsToArraysR, + Boundary(..), StencilR(..), + HasArraysRepr(..), arrayRepr, lhsToTupR, PairIdx(..), + ArrayR(..), ArraysR, ShapeR(..), SliceIndex(..), VecR(..), vecRvector, vecRtuple, -- * Accelerated sequences -- PreOpenSeq(..), Seq, -- Producer(..), Consumer(..), -- * Scalar expressions - PreOpenFun(..), OpenFun, PreFun, Fun, PreOpenExp(..), OpenExp, PreExp, Exp, PrimConst(..), - PrimFun(..), + OpenFun(..), Fun, OpenExp(..), Exp, PrimConst(..), + PrimFun(..), expType, primConstType, primFunType, -- NFData NFDataAcc, - rnfPreOpenAfun, rnfPreOpenAcc, rnfPreOpenFun, rnfPreOpenExp, - rnfArrays, + rnfPreOpenAfun, rnfPreOpenAcc, rnfOpenFun, rnfOpenExp, + rnfArrays, rnfArrayR, -- TemplateHaskell LiftAcc, - liftIdx, liftTupleIdx, + liftIdx, liftConst, liftSliceIndex, liftPrimConst, liftPrimFun, - liftPreOpenAfun, liftPreOpenAcc, liftPreOpenFun, liftPreOpenExp, - liftArray, liftArraysR, liftLHS, + liftPreOpenAfun, liftPreOpenAcc, liftOpenFun, liftOpenExp, + liftALhs, liftELhs, liftArray, liftArraysR, liftTupleType, liftArrayR, + liftScalarType, liftShapeR, liftVecR, liftIntegralType, -- Utilities - Exists(..), weakenWithLHS, (:>), + Exists(..), weakenWithLHS, (:>), weakenId, weakenSucc, weakenSucc', weakenEmpty, (.>), (>:>), + sink, sinkWithLHS, -- debugging - showPreAccOp, showPreExpOp, + showPreAccOp, showPreExpOp, showShortendArr, showElement ) where --standard library import Control.DeepSeq import Control.Monad.ST -import Data.Typeable +import Data.List ( intercalate ) +import Data.Kind import Foreign.ForeignPtr import Foreign.Marshal import Foreign.Ptr @@ -142,13 +148,13 @@ import GHC.Int ( Int(..) ) import GHC.Prim ( (<#), (+#), indexWord8Array#, sizeofByteArray# ) import GHC.Ptr ( Ptr(..) ) import GHC.Word ( Word8(..) ) +import GHC.TypeNats -- friends import Data.Array.Accelerate.Array.Data -import Data.Array.Accelerate.Array.Representation ( SliceIndex(..), size ) -import Data.Array.Accelerate.Array.Sugar hiding ( size ) +import Data.Array.Accelerate.Array.Representation +import qualified Data.Array.Accelerate.Array.Sugar as Sugar import Data.Array.Accelerate.Array.Unique -import Data.Array.Accelerate.Product import Data.Array.Accelerate.Type #if __GLASGOW_HASKELL__ < 800 import Data.Array.Accelerate.Error @@ -171,10 +177,6 @@ idxToInt :: Idx env t -> Int idxToInt ZeroIdx = 0 idxToInt (SuccIdx idx) = 1 + idxToInt idx -tupleIdxToInt :: TupleIdx tup e -> Int -tupleIdxToInt ZeroTupIdx = 0 -tupleIdxToInt (SuccTupIdx idx) = 1 + tupleIdxToInt idx - -- Environments -- ------------ @@ -184,20 +186,12 @@ tupleIdxToInt (SuccTupIdx idx) = 1 + tupleIdxToInt idx data Val env where Empty :: Val () Push :: Val env -> t -> Val (env, t) -deriving instance Typeable Val -push :: Val env -> (LeftHandSide arrs env env', arrs) -> Val env' +push :: Val env -> (LeftHandSide s arrs env env', arrs) -> Val env' push env (LeftHandSideWildcard _, _ ) = env -push env (LeftHandSideArray , a ) = env `Push` a +push env (LeftHandSideSingle _ , a ) = env `Push` a push env (LeftHandSidePair l1 l2, (a, b)) = push env (l1, a) `push` (l2, b) --- Valuation for an environment of array elements --- -data ValElt env where - EmptyElt :: ValElt () - PushElt :: Elt t - => ValElt env -> EltRepr t -> ValElt (env, t) - -- Projection of a value from a valuation using a de Bruijn index -- prj :: Idx env t -> Val env -> t @@ -207,23 +201,14 @@ prj (SuccIdx idx) (Push val _) = prj idx val prj _ _ = $internalError "prj" "inconsistent valuation" #endif --- Projection of a value from a valuation of array elements using a de Bruijn index --- -prjElt :: Idx env t -> ValElt env -> t -prjElt ZeroIdx (PushElt _ v) = toElt v -prjElt (SuccIdx idx) (PushElt val _) = prjElt idx val -#if __GLASGOW_HASKELL__ < 800 -prjElt _ _ = $internalError "prjElt" "inconsistent valuation" -#endif - -- Array expressions -- ----------------- -- | Function abstraction over parametrised array computations -- data PreOpenAfun acc aenv t where - Abody :: acc aenv t -> PreOpenAfun acc aenv t - Alam :: LeftHandSide a aenv aenv' -> PreOpenAfun acc aenv' t -> PreOpenAfun acc aenv (a -> t) + Abody :: acc aenv t -> PreOpenAfun acc aenv t + Alam :: ALeftHandSide a aenv aenv' -> PreOpenAfun acc aenv' t -> PreOpenAfun acc aenv (a -> t) -- Function abstraction over vanilla open array computations -- @@ -245,50 +230,102 @@ newtype OpenAcc aenv t = OpenAcc (PreOpenAcc OpenAcc aenv t) -- type Acc = OpenAcc () -deriving instance Typeable PreOpenAcc -deriving instance Typeable OpenAcc +type ALeftHandSide = LeftHandSide ArrayR + +type ELeftHandSide = LeftHandSide ScalarType -data LeftHandSide arrs env env' where - LeftHandSideArray - :: (Shape sh, Elt e) - => LeftHandSide (Array sh e) env (env, Array sh e) +data LeftHandSide (s :: Type -> Type) v env env' where + LeftHandSideSingle + :: s v + -> LeftHandSide s v env (env, v) - -- Note: a unit is represented as LeftHandSide ArraysRunit + -- Note: a unit is represented as LeftHandSideWildcard TupRunit LeftHandSideWildcard - :: ArraysR arrs - -> LeftHandSide arrs env env + :: TupR s v + -> LeftHandSide s v env env LeftHandSidePair - :: LeftHandSide arrs1 env env' - -> LeftHandSide arrs2 env' env'' - -> LeftHandSide (arrs1, arrs2) env env'' + :: LeftHandSide s v1 env env' + -> LeftHandSide s v2 env' env'' + -> LeftHandSide s (v1, v2) env env'' -lhsToArraysR :: LeftHandSide arrs aenv aenv' -> ArraysR arrs -lhsToArraysR LeftHandSideArray = ArraysRarray -lhsToArraysR (LeftHandSideWildcard r) = r -lhsToArraysR (LeftHandSidePair as bs) = ArraysRpair (lhsToArraysR as) (lhsToArraysR bs) +lhsToTupR :: LeftHandSide s arrs aenv aenv' -> TupR s arrs +lhsToTupR (LeftHandSideSingle s) = TupRsingle s +lhsToTupR (LeftHandSideWildcard r) = r +lhsToTupR (LeftHandSidePair as bs) = TupRpair (lhsToTupR as) (lhsToTupR bs) -- The type of shifting terms from one context into another +-- This is defined as a newtype, as a type synonym containing a forall quantifier +-- may give issues with impredicative polymorphism which GHC does not support. -- -type env :> env' = forall t'. Idx env t' -> Idx env' t' +newtype env :> env' = Weaken (forall t'. Idx env t' -> Idx env' t') -- Weak or Weaken + +weakenId :: env :> env +weakenId = Weaken id + +(>:>) :: env :> env' -> Idx env t -> Idx env' t +(>:>) (Weaken k) ix = k ix -weakenWithLHS :: LeftHandSide arrs env env' -> env :> env' -weakenWithLHS (LeftHandSideWildcard _) = id -weakenWithLHS LeftHandSideArray = SuccIdx -weakenWithLHS (LeftHandSidePair lhs1 lhs2) = weakenWithLHS lhs2 . weakenWithLHS lhs1 +weakenSucc' :: env :> env' -> env :> (env', t) +weakenSucc' (Weaken f) = Weaken (SuccIdx . f) + +weakenSucc :: (env, t) :> env' -> env :> env' +weakenSucc (Weaken f) = Weaken (f . SuccIdx) + +weakenEmpty :: () :> env' +weakenEmpty = Weaken (\x -> case x of {}) + +sink :: forall env env' t. env :> env' -> (env, t) :> (env', t) +sink (Weaken f) = Weaken g + where + g :: Idx (env, t) t' -> Idx (env', t) t' + g ZeroIdx = ZeroIdx + g (SuccIdx ix) = SuccIdx $ f ix + +infixr 9 .> +(.>) :: env2 :> env3 -> env1 :> env2 -> env1 :> env3 +(.>) (Weaken f) (Weaken g) = Weaken (f . g) + +sinkWithLHS :: LeftHandSide s t env1 env1' -> LeftHandSide s t env2 env2' -> env1 :> env2 -> env1' :> env2' +sinkWithLHS (LeftHandSideWildcard _) (LeftHandSideWildcard _) k = k +sinkWithLHS (LeftHandSideSingle _) (LeftHandSideSingle _) k = sink k +sinkWithLHS (LeftHandSidePair a1 b1) (LeftHandSidePair a2 b2) k = sinkWithLHS b1 b2 $ sinkWithLHS a1 a2 k +sinkWithLHS _ _ _ = error "sinkWithLHS: left hand sides do not match" + +weakenWithLHS :: forall s t env env'. LeftHandSide s t env env' -> env :> env' +weakenWithLHS = go weakenId + where + go :: env2 :> env' -> LeftHandSide s arrs env1 env2 -> env1 :> env' + go k (LeftHandSideWildcard _) = k + go k (LeftHandSideSingle _) = weakenSucc k + go k (LeftHandSidePair l1 l2) = go (go k l2) l1 -- Often useful when working with LeftHandSide, when you need to -- existentially quantify on the resulting environment type. data Exists f where Exists :: f a -> Exists f -data ArrayVar aenv arr where - ArrayVar :: (Shape sh, Elt e) => Idx aenv (Array sh e) -> ArrayVar aenv (Array sh e) +type ArrayVar = Var ArrayR +type ArrayVars = Vars ArrayR -data ArrayVars aenv arrs where - ArrayVarsArray :: ArrayVar aenv a -> ArrayVars aenv a - ArrayVarsNil :: ArrayVars aenv () - ArrayVarsPair :: ArrayVars aenv a -> ArrayVars aenv b -> ArrayVars aenv (a, b) +type ExpVar = Var ScalarType +type ExpVars = Vars ScalarType + +data Var s env t = Var (s t) (Idx env t) +data Vars s env t where + VarsSingle :: Var s env a -> Vars s env a + VarsNil :: Vars s aenv () + VarsPair :: Vars s aenv a -> Vars s aenv b -> Vars s aenv (a, b) + +evars :: ExpVars env tp -> OpenExp env aenv tp +evars VarsNil = Nil +evars (VarsSingle var) = Evar var +evars (VarsPair v1 v2) = evars v1 `Pair` evars v2 + +varsType :: Vars s env t -> TupR s t +varsType (VarsSingle (Var tp _)) = TupRsingle tp +varsType VarsNil = TupRunit +varsType (VarsPair v1 v2) = varsType v1 `TupRpair` varsType v2 -- | Collective array computations parametrised over array variables -- represented with de Bruijn indices. @@ -309,12 +346,12 @@ data ArrayVars aenv arrs where -- We use a non-recursive variant parametrised over the recursive closure, -- to facilitate attribute calculation in the backend. -- -data PreOpenAcc acc aenv a where +data PreOpenAcc (acc :: Type -> Type -> Type) aenv a where -- Local non-recursive binding to represent sharing and demand -- explicitly. Note this is an eager binding! -- - Alet :: LeftHandSide bndArrs aenv aenv' + Alet :: ALeftHandSide bndArrs aenv aenv' -> acc aenv bndArrs -- bound expression -> acc aenv' bodyArrs -- the bound expression scope -> PreOpenAcc acc aenv bodyArrs @@ -337,7 +374,8 @@ data PreOpenAcc acc aenv a where -- The array function is not closed at the core level because we need access -- to free variables introduced by 'run1' style evaluators. See Issue#95. -- - Apply :: PreOpenAfun acc aenv (arrs1 -> arrs2) + Apply :: ArraysR arrs2 + -> PreOpenAfun acc aenv (arrs1 -> arrs2) -> acc aenv arrs1 -> PreOpenAcc acc aenv arrs2 @@ -345,15 +383,16 @@ data PreOpenAcc acc aenv a where -- Accelerate version for use with other backends. The functions must be -- closed. -- - Aforeign :: (Arrays as, Arrays bs, Foreign asm) - => asm (as -> bs) -- The foreign function for a given backend - -> PreAfun acc (ArrRepr as -> ArrRepr bs) -- Fallback implementation(s) - -> acc aenv (ArrRepr as) -- Arguments to the function - -> PreOpenAcc acc aenv (ArrRepr bs) + Aforeign :: Sugar.Foreign asm + => ArraysR bs + -> asm (as -> bs) -- The foreign function for a given backend + -> PreAfun acc (as -> bs) -- Fallback implementation(s) + -> acc aenv as -- Arguments to the function + -> PreOpenAcc acc aenv bs -- If-then-else for array-level computations -- - Acond :: PreExp acc aenv Bool + Acond :: Exp aenv Bool -> acc aenv arrs -> acc aenv arrs -> PreOpenAcc acc aenv arrs @@ -369,14 +408,14 @@ data PreOpenAcc acc aenv a where -- Array inlet. Triggers (possibly) asynchronous host->device transfer if -- necessary. -- - Use :: (Shape sh, Elt e) - => Array sh e + Use :: ArrayR (Array sh e) + -> Array sh e -> PreOpenAcc acc aenv (Array sh e) -- Capture a scalar (or a tuple of scalars) in a singleton array -- - Unit :: Elt e - => PreExp acc aenv e + Unit :: TupleType e + -> Exp aenv e -> PreOpenAcc acc aenv (Scalar e) -- Change the shape of an array without altering its contents. @@ -384,56 +423,54 @@ data PreOpenAcc acc aenv a where -- -- > dim == size dim' -- - Reshape :: (Shape sh, Shape sh', Elt e) - => PreExp acc aenv sh -- new shape + Reshape :: ShapeR sh + -> Exp aenv sh -- new shape -> acc aenv (Array sh' e) -- array to be reshaped -> PreOpenAcc acc aenv (Array sh e) -- Construct a new array by applying a function to each index. -- - Generate :: (Shape sh, Elt e) - => PreExp acc aenv sh -- output shape - -> PreFun acc aenv (sh -> e) -- representation function + Generate :: ArrayR (Array sh e) + -> Exp aenv sh -- output shape + -> Fun aenv (sh -> e) -- representation function -> PreOpenAcc acc aenv (Array sh e) -- Hybrid map/backpermute, where we separate the index and value -- transformations. -- - Transform :: (Elt a, Elt b, Shape sh, Shape sh') - => PreExp acc aenv sh' -- dimension of the result - -> PreFun acc aenv (sh' -> sh) -- index permutation function - -> PreFun acc aenv (a -> b) -- function to apply at each element + Transform :: ArrayR (Array sh' b) + -> Exp aenv sh' -- dimension of the result + -> Fun aenv (sh' -> sh) -- index permutation function + -> Fun aenv (a -> b) -- function to apply at each element -> acc aenv (Array sh a) -- source array -> PreOpenAcc acc aenv (Array sh' b) -- Replicate an array across one or more dimensions as given by the first -- argument -- - Replicate :: (Shape sh, Shape sl, Elt slix, Elt e) - => SliceIndex (EltRepr slix) -- slice type specification - (EltRepr sl) + Replicate :: SliceIndex slix -- slice type specification + sl co - (EltRepr sh) - -> PreExp acc aenv slix -- slice value specification + sh + -> Exp aenv slix -- slice value specification -> acc aenv (Array sl e) -- data to be replicated -> PreOpenAcc acc aenv (Array sh e) -- Index a sub-array out of an array; i.e., the dimensions not indexed -- are returned whole -- - Slice :: (Shape sh, Shape sl, Elt slix, Elt e) - => SliceIndex (EltRepr slix) -- slice type specification - (EltRepr sl) + Slice :: SliceIndex slix -- slice type specification + sl co - (EltRepr sh) + sh -> acc aenv (Array sh e) -- array to be indexed - -> PreExp acc aenv slix -- slice value specification + -> Exp aenv slix -- slice value specification -> PreOpenAcc acc aenv (Array sl e) -- Apply the given unary function to all elements of the given array -- - Map :: (Shape sh, Elt e, Elt e') - => PreFun acc aenv (e -> e') + Map :: TupleType e' + -> Fun aenv (e -> e') -> acc aenv (Array sh e) -> PreOpenAcc acc aenv (Array sh e') @@ -441,8 +478,8 @@ data PreOpenAcc acc aenv a where -- arrays. The length of the result is the length of the shorter of the -- two argument arrays. -- - ZipWith :: (Shape sh, Elt e1, Elt e2, Elt e3) - => PreFun acc aenv (e1 -> e2 -> e3) + ZipWith :: TupleType e3 + -> Fun aenv (e1 -> e2 -> e3) -> acc aenv (Array sh e1) -> acc aenv (Array sh e2) -> PreOpenAcc acc aenv (Array sh e3) @@ -450,86 +487,78 @@ data PreOpenAcc acc aenv a where -- Fold along the innermost dimension of an array with a given -- /associative/ function. -- - Fold :: (Shape sh, Elt e) - => PreFun acc aenv (e -> e -> e) -- combination function - -> PreExp acc aenv e -- default value - -> acc aenv (Array (sh:.Int) e) -- folded array + Fold :: Fun aenv (e -> e -> e) -- combination function + -> Exp aenv e -- default value + -> acc aenv (Array (sh, Int) e) -- folded array -> PreOpenAcc acc aenv (Array sh e) -- As 'Fold' without a default value -- - Fold1 :: (Shape sh, Elt e) - => PreFun acc aenv (e -> e -> e) -- combination function - -> acc aenv (Array (sh:.Int) e) -- folded array + Fold1 :: Fun aenv (e -> e -> e) -- combination function + -> acc aenv (Array (sh, Int) e) -- folded array -> PreOpenAcc acc aenv (Array sh e) -- Segmented fold along the innermost dimension of an array with a given -- /associative/ function -- - FoldSeg :: (Shape sh, Elt e, Elt i, IsIntegral i) - => PreFun acc aenv (e -> e -> e) -- combination function - -> PreExp acc aenv e -- default value - -> acc aenv (Array (sh:.Int) e) -- folded array + FoldSeg :: IntegralType i + -> Fun aenv (e -> e -> e) -- combination function + -> Exp aenv e -- default value + -> acc aenv (Array (sh, Int) e) -- folded array -> acc aenv (Segments i) -- segment descriptor - -> PreOpenAcc acc aenv (Array (sh:.Int) e) + -> PreOpenAcc acc aenv (Array (sh, Int) e) -- As 'FoldSeg' without a default value -- - Fold1Seg :: (Shape sh, Elt e, Elt i, IsIntegral i) - => PreFun acc aenv (e -> e -> e) -- combination function - -> acc aenv (Array (sh:.Int) e) -- folded array + Fold1Seg :: IntegralType i + -> Fun aenv (e -> e -> e) -- combination function + -> acc aenv (Array (sh, Int) e) -- folded array -> acc aenv (Segments i) -- segment descriptor - -> PreOpenAcc acc aenv (Array (sh:.Int) e) + -> PreOpenAcc acc aenv (Array (sh, Int) e) -- Left-to-right Haskell-style scan of a linear array with a given -- /associative/ function and an initial element (which does not need to -- be the neutral of the associative operations) -- - Scanl :: (Shape sh, Elt e) - => PreFun acc aenv (e -> e -> e) -- combination function - -> PreExp acc aenv e -- initial value - -> acc aenv (Array (sh:.Int) e) - -> PreOpenAcc acc aenv (Array (sh:.Int) e) + Scanl :: Fun aenv (e -> e -> e) -- combination function + -> Exp aenv e -- initial value + -> acc aenv (Array (sh, Int) e) + -> PreOpenAcc acc aenv (Array (sh, Int) e) -- Like 'Scan', but produces a rightmost fold value and an array with the -- same length as the input array (the fold value would be the rightmost -- element in a Haskell-style scan) -- - Scanl' :: (Shape sh, Elt e) - => PreFun acc aenv (e -> e -> e) -- combination function - -> PreExp acc aenv e -- initial value - -> acc aenv (Array (sh:.Int) e) - -> PreOpenAcc acc aenv (ArrRepr (Array (sh:.Int) e, Array sh e)) + Scanl' :: Fun aenv (e -> e -> e) -- combination function + -> Exp aenv e -- initial value + -> acc aenv (Array (sh, Int) e) + -> PreOpenAcc acc aenv (Array (sh, Int) e, Array sh e) -- Haskell-style scan without an initial value -- - Scanl1 :: (Shape sh, Elt e) - => PreFun acc aenv (e -> e -> e) -- combination function - -> acc aenv (Array (sh:.Int) e) - -> PreOpenAcc acc aenv (Array (sh:.Int) e) + Scanl1 :: Fun aenv (e -> e -> e) -- combination function + -> acc aenv (Array (sh, Int) e) + -> PreOpenAcc acc aenv (Array (sh, Int) e) -- Right-to-left version of 'Scanl' -- - Scanr :: (Shape sh, Elt e) - => PreFun acc aenv (e -> e -> e) -- combination function - -> PreExp acc aenv e -- initial value - -> acc aenv (Array (sh:.Int) e) - -> PreOpenAcc acc aenv (Array (sh:.Int) e) + Scanr :: Fun aenv (e -> e -> e) -- combination function + -> Exp aenv e -- initial value + -> acc aenv (Array (sh, Int) e) + -> PreOpenAcc acc aenv (Array (sh, Int) e) -- Right-to-left version of 'Scanl\'' -- - Scanr' :: (Shape sh, Elt e) - => PreFun acc aenv (e -> e -> e) -- combination function - -> PreExp acc aenv e -- initial value - -> acc aenv (Array (sh:.Int) e) - -> PreOpenAcc acc aenv (ArrRepr (Array (sh:.Int) e, Array sh e)) + Scanr' :: Fun aenv (e -> e -> e) -- combination function + -> Exp aenv e -- initial value + -> acc aenv (Array (sh, Int) e) + -> PreOpenAcc acc aenv (Array (sh, Int) e, Array sh e) -- Right-to-left version of 'Scanl1' -- - Scanr1 :: (Shape sh, Elt e) - => PreFun acc aenv (e -> e -> e) -- combination function - -> acc aenv (Array (sh:.Int) e) - -> PreOpenAcc acc aenv (Array (sh:.Int) e) + Scanr1 :: Fun aenv (e -> e -> e) -- combination function + -> acc aenv (Array (sh, Int) e) + -> PreOpenAcc acc aenv (Array (sh, Int) e) -- Generalised forward permutation is characterised by a permutation function -- that determines for each element of the source array where it should go in @@ -551,38 +580,40 @@ data PreOpenAcc acc aenv a where -- function is used to combine elements, which needs to be /associative/ -- and /commutative/. -- - Permute :: (Shape sh, Shape sh', Elt e) - => PreFun acc aenv (e -> e -> e) -- combination function + Permute :: Fun aenv (e -> e -> e) -- combination function -> acc aenv (Array sh' e) -- default values - -> PreFun acc aenv (sh -> sh') -- permutation function + -> Fun aenv (sh -> sh') -- permutation function -> acc aenv (Array sh e) -- source array -> PreOpenAcc acc aenv (Array sh' e) -- Generalised multi-dimensional backwards permutation; the permutation can -- be between arrays of varying shape; the permutation function must be total -- - Backpermute :: (Shape sh, Shape sh', Elt e) - => PreExp acc aenv sh' -- dimensions of the result - -> PreFun acc aenv (sh' -> sh) -- permutation function + Backpermute :: ShapeR sh' + -> Exp aenv sh' -- dimensions of the result + -> Fun aenv (sh' -> sh) -- permutation function -> acc aenv (Array sh e) -- source array -> PreOpenAcc acc aenv (Array sh' e) -- Map a stencil over an array. In contrast to 'map', the domain of -- a stencil function is an entire /neighbourhood/ of each array element. -- - Stencil :: (Elt e, Elt e', Stencil sh e stencil) - => PreFun acc aenv (stencil -> e') -- stencil function - -> PreBoundary acc aenv (Array sh e) -- boundary condition + Stencil :: StencilR sh e stencil + -> TupleType e' + -> Fun aenv (stencil -> e') -- stencil function + -> Boundary aenv (Array sh e) -- boundary condition -> acc aenv (Array sh e) -- source array -> PreOpenAcc acc aenv (Array sh e') -- Map a binary stencil over an array. -- - Stencil2 :: (Elt a, Elt b, Elt c, Stencil sh a stencil1, Stencil sh b stencil2) - => PreFun acc aenv (stencil1 -> stencil2 -> c) -- stencil function - -> PreBoundary acc aenv (Array sh a) -- boundary condition #1 + Stencil2 :: StencilR sh a stencil1 + -> StencilR sh b stencil2 + -> TupleType c + -> Fun aenv (stencil1 -> stencil2 -> c) -- stencil function + -> Boundary aenv (Array sh a) -- boundary condition #1 -> acc aenv (Array sh a) -- source array #1 - -> PreBoundary acc aenv (Array sh b) -- boundary condition #2 + -> Boundary aenv (Array sh b) -- boundary condition #2 -> acc aenv (Array sh b) -- source array #2 -> PreOpenAcc acc aenv (Array sh c) @@ -709,367 +740,251 @@ type Seq = PreOpenSeq OpenAcc () () --} --- | Vanilla stencil boundary condition --- -type Boundary = PreBoundary OpenAcc - --- | Boundary condition specification for stencil operations +-- | Vanilla boundary condition specification for stencil operations -- -data PreBoundary acc aenv t where +data Boundary aenv t where -- Clamp coordinates to the extent of the array - Clamp :: PreBoundary acc aenv t + Clamp :: Boundary aenv t -- Mirror coordinates beyond the array extent - Mirror :: PreBoundary acc aenv t + Mirror :: Boundary aenv t -- Wrap coordinates around on each dimension - Wrap :: PreBoundary acc aenv t + Wrap :: Boundary aenv t -- Use a constant value for outlying coordinates - Constant :: Elt e - => EltRepr e - -> PreBoundary acc aenv (Array sh e) + Constant :: e + -> Boundary aenv (Array sh e) -- Apply the given function to outlying coordinates - Function :: (Shape sh, Elt e) - => PreFun acc aenv (sh -> e) - -> PreBoundary acc aenv (Array sh e) + Function :: Fun aenv (sh -> e) + -> Boundary aenv (Array sh e) - --- | Operations on stencils --- -class (Shape sh, Elt e, IsTuple stencil, Elt stencil) => Stencil sh e stencil where - stencil :: StencilR sh e stencil - --- | GADT reifying the 'Stencil' class --- -data StencilR sh e pat where - StencilRunit3 :: Elt e => StencilR DIM1 e (e,e,e) - StencilRunit5 :: Elt e => StencilR DIM1 e (e,e,e,e,e) - StencilRunit7 :: Elt e => StencilR DIM1 e (e,e,e,e,e,e,e) - StencilRunit9 :: Elt e => StencilR DIM1 e (e,e,e,e,e,e,e,e,e) - - StencilRtup3 :: (Shape sh, Elt e) - => StencilR sh e pat1 - -> StencilR sh e pat2 - -> StencilR sh e pat3 - -> StencilR (sh:.Int) e (pat1,pat2,pat3) - - StencilRtup5 :: (Shape sh, Elt e) - => StencilR sh e pat1 - -> StencilR sh e pat2 - -> StencilR sh e pat3 - -> StencilR sh e pat4 - -> StencilR sh e pat5 - -> StencilR (sh:.Int) e (pat1,pat2,pat3,pat4,pat5) - - StencilRtup7 :: (Shape sh, Elt e) - => StencilR sh e pat1 - -> StencilR sh e pat2 - -> StencilR sh e pat3 - -> StencilR sh e pat4 - -> StencilR sh e pat5 - -> StencilR sh e pat6 - -> StencilR sh e pat7 - -> StencilR (sh:.Int) e (pat1,pat2,pat3,pat4,pat5,pat6,pat7) - - StencilRtup9 :: (Shape sh, Elt e) - => StencilR sh e pat1 - -> StencilR sh e pat2 - -> StencilR sh e pat3 - -> StencilR sh e pat4 - -> StencilR sh e pat5 - -> StencilR sh e pat6 - -> StencilR sh e pat7 - -> StencilR sh e pat8 - -> StencilR sh e pat9 - -> StencilR (sh:.Int) e (pat1,pat2,pat3,pat4,pat5,pat6,pat7,pat8,pat9) - - --- Note: [Stencil reification class] --- --- We cannot start with 'DIM0'. The 'IsTuple stencil' superclass would at --- 'DIM0' imply that the types of individual array elements are in 'IsTuple'. --- (That would only possible if we could have (degenerate) 1-tuple, but we can't --- as we can't distinguish between a 1-tuple of a pair and a simple pair.) --- Hence, we need to start from 'DIM1' and use 'sh:.Int:.Int' in the recursive --- case (to avoid overlapping instances). - --- DIM1 -instance Elt e => Stencil DIM1 e (e, e, e) where - stencil = StencilRunit3 - -instance Elt e => Stencil DIM1 e (e, e, e, e, e) where - stencil = StencilRunit5 - -instance Elt e => Stencil DIM1 e (e, e, e, e, e, e, e) where - stencil = StencilRunit7 - -instance Elt e => Stencil DIM1 e (e, e, e, e, e, e, e, e, e) where - stencil = StencilRunit9 - --- DIM(n+1), where n>1 -instance (Stencil (sh:.Int) a row1, - Stencil (sh:.Int) a row2, - Stencil (sh:.Int) a row3) => Stencil (sh:.Int:.Int) a (row1, row2, row3) where - stencil = StencilRtup3 stencil stencil stencil - -instance (Stencil (sh:.Int) a row1, - Stencil (sh:.Int) a row2, - Stencil (sh:.Int) a row3, - Stencil (sh:.Int) a row4, - Stencil (sh:.Int) a row5) => Stencil (sh:.Int:.Int) a (row1, row2, row3, row4, row5) where - stencil = StencilRtup5 stencil stencil stencil stencil stencil - -instance (Stencil (sh:.Int) a row1, - Stencil (sh:.Int) a row2, - Stencil (sh:.Int) a row3, - Stencil (sh:.Int) a row4, - Stencil (sh:.Int) a row5, - Stencil (sh:.Int) a row6, - Stencil (sh:.Int) a row7) - => Stencil (sh:.Int:.Int) a (row1, row2, row3, row4, row5, row6, row7) where - stencil = StencilRtup7 stencil stencil stencil stencil stencil stencil stencil - -instance (Stencil (sh:.Int) a row1, - Stencil (sh:.Int) a row2, - Stencil (sh:.Int) a row3, - Stencil (sh:.Int) a row4, - Stencil (sh:.Int) a row5, - Stencil (sh:.Int) a row6, - Stencil (sh:.Int) a row7, - Stencil (sh:.Int) a row8, - Stencil (sh:.Int) a row9) - => Stencil (sh:.Int:.Int) a (row1, row2, row3, row4, row5, row6, row7, row8, row9) where - stencil = StencilRtup9 stencil stencil stencil stencil stencil stencil stencil stencil stencil +data PairIdx p a where + PairIdxLeft :: PairIdx (a, b) a + PairIdxRight :: PairIdx (a, b) b class HasArraysRepr f where arraysRepr :: f aenv a -> ArraysR a +arrayRepr :: HasArraysRepr f => f aenv (Array sh e) -> ArrayR (Array sh e) +arrayRepr a = case arraysRepr a of + TupRsingle repr -> repr + instance HasArraysRepr acc => HasArraysRepr (PreOpenAcc acc) where arraysRepr (Alet _ _ body) = arraysRepr body - arraysRepr (Avar ArrayVar{}) = ArraysRarray - arraysRepr (Apair as bs) = ArraysRpair (arraysRepr as) (arraysRepr bs) - arraysRepr Anil = ArraysRunit - arraysRepr (Apply (Alam _ (Abody a)) _) = arraysRepr a - arraysRepr (Apply _ _) = error "Tomorrow will arrive, on time" - arraysRepr (Aforeign _ (Alam _ (Abody a)) _) = arraysRepr a - arraysRepr (Aforeign _ (Abody _) _) = error "And what have you got, at the end of the day?" - arraysRepr (Aforeign _ (Alam _ (Alam _ _)) _) = error "A bottle of whisky. And a new set of lies." + arraysRepr (Avar (Var repr _)) = TupRsingle repr + arraysRepr (Apair as bs) = TupRpair (arraysRepr as) (arraysRepr bs) + arraysRepr Anil = TupRunit + arraysRepr (Apply repr _ _) = repr + arraysRepr (Aforeign r _ _ _) = r arraysRepr (Acond _ whenTrue _) = arraysRepr whenTrue - arraysRepr (Awhile _ (Alam lhs _) _) = lhsToArraysR lhs + arraysRepr (Awhile _ (Alam lhs _) _) = lhsToTupR lhs arraysRepr (Awhile _ _ _) = error "I want my, I want my MTV!" - arraysRepr Use{} = ArraysRarray - arraysRepr Unit{} = ArraysRarray - arraysRepr Reshape{} = ArraysRarray - arraysRepr Generate{} = ArraysRarray - arraysRepr Transform{} = ArraysRarray - arraysRepr Replicate{} = ArraysRarray - arraysRepr Slice{} = ArraysRarray - arraysRepr Map{} = ArraysRarray - arraysRepr ZipWith{} = ArraysRarray - arraysRepr Fold{} = ArraysRarray - arraysRepr Fold1{} = ArraysRarray - arraysRepr FoldSeg{} = ArraysRarray - arraysRepr Fold1Seg{} = ArraysRarray - arraysRepr Scanl{} = ArraysRarray - arraysRepr Scanl'{} = arraysRtuple2 - arraysRepr Scanl1{} = ArraysRarray - arraysRepr Scanr{} = ArraysRarray - arraysRepr Scanr'{} = arraysRtuple2 - arraysRepr Scanr1{} = ArraysRarray - arraysRepr Permute{} = ArraysRarray - arraysRepr Backpermute{} = ArraysRarray - arraysRepr Stencil{} = ArraysRarray - arraysRepr Stencil2{} = ArraysRarray + arraysRepr (Use repr _) = TupRsingle repr + arraysRepr (Unit tp _) = arraysRarray ShapeRz tp + arraysRepr (Reshape sh _ a) = let ArrayR _ tp = arrayRepr a + in arraysRarray sh tp + arraysRepr (Generate repr _ _) = TupRsingle repr + arraysRepr (Transform repr _ _ _ _) = TupRsingle repr + arraysRepr (Replicate slice _ a) = let ArrayR _ tp = arrayRepr a + in arraysRarray (sliceDomainR slice) tp + arraysRepr (Slice slice a _) = let ArrayR _ tp = arrayRepr a + in arraysRarray (sliceShapeR slice) tp + arraysRepr (Map tp _ a) = let ArrayR sh _ = arrayRepr a + in arraysRarray sh tp + arraysRepr (ZipWith tp _ a _) = let ArrayR sh _ = arrayRepr a + in arraysRarray sh tp + arraysRepr (Fold _ _ a) = let ArrayR (ShapeRsnoc sh) tp = arrayRepr a + in arraysRarray sh tp + arraysRepr (Fold1 _ a) = let ArrayR (ShapeRsnoc sh) tp = arrayRepr a + in arraysRarray sh tp + arraysRepr (FoldSeg _ _ _ a _) = arraysRepr a + arraysRepr (Fold1Seg _ _ a _) = arraysRepr a + arraysRepr (Scanl _ _ a) = arraysRepr a + arraysRepr (Scanl' _ _ a) = let repr@(ArrayR (ShapeRsnoc sh) tp) = arrayRepr a + in TupRsingle repr `TupRpair` TupRsingle (ArrayR sh tp) + arraysRepr (Scanl1 _ a) = arraysRepr a + arraysRepr (Scanr _ _ a) = arraysRepr a + arraysRepr (Scanr' _ _ a) = let repr@(ArrayR (ShapeRsnoc sh) tp) = arrayRepr a + in TupRsingle repr `TupRpair` TupRsingle (ArrayR sh tp) + arraysRepr (Scanr1 _ a) = arraysRepr a + arraysRepr (Permute _ a _ _) = arraysRepr a + arraysRepr (Backpermute sh _ _ a) = let ArrayR _ tp = arrayRepr a + in arraysRarray sh tp + arraysRepr (Stencil _ tp _ _ a) = let ArrayR sh _ = arrayRepr a + in arraysRarray sh tp + arraysRepr (Stencil2 _ _ tp _ _ a _ _) = let ArrayR sh _ = arrayRepr a + in arraysRarray sh tp instance HasArraysRepr OpenAcc where arraysRepr (OpenAcc a) = arraysRepr a -- Embedded expressions -- -------------------- --- |Parametrised open function abstraction --- -data PreOpenFun acc env aenv t where - Body :: Elt t => PreOpenExp acc env aenv t -> PreOpenFun acc env aenv t - Lam :: Elt a => PreOpenFun acc (env, a) aenv t -> PreOpenFun acc env aenv (a -> t) - -- |Vanilla open function abstraction -- -type OpenFun = PreOpenFun OpenAcc - --- |Parametrised function without free scalar variables --- -type PreFun acc = PreOpenFun acc () +data OpenFun env aenv t where + Body :: OpenExp env aenv t -> OpenFun env aenv t + Lam :: ELeftHandSide a env env' -> OpenFun env' aenv t -> OpenFun env aenv (a -> t) -- |Vanilla function without free scalar variables -- type Fun = OpenFun () --- |Vanilla open expression --- -type OpenExp = PreOpenExp OpenAcc - --- |Parametrised expression without free scalar variables --- -type PreExp acc = PreOpenExp acc () - -- |Vanilla expression without free scalar variables -- type Exp = OpenExp () --- |Parametrised open expressions using de Bruijn indices for variables ranging over tuples +-- |Vanilla open expressions using de Bruijn indices for variables ranging over tuples -- of scalars and arrays of tuples. All code, except Cond, is evaluated eagerly. N-tuples are -- represented as nested pairs. -- --- The data type is parametrised over the surface types (not the representation type). +-- The data type is parametrised over the representation type (not the surface types). -- -data PreOpenExp acc env aenv t where +data OpenExp env aenv t where -- Local binding of a scalar expression - Let :: (Elt bnd_t, Elt body_t) - => PreOpenExp acc env aenv bnd_t - -> PreOpenExp acc (env, bnd_t) aenv body_t - -> PreOpenExp acc env aenv body_t + Let :: ELeftHandSide bnd_t env env' + -> OpenExp env aenv bnd_t + -> OpenExp env' aenv body_t + -> OpenExp env aenv body_t -- Variable index, ranging only over tuples or scalars - Var :: Elt t - => Idx env t - -> PreOpenExp acc env aenv t + Evar :: ExpVar env t + -> OpenExp env aenv t -- Apply a backend-specific foreign function - Foreign :: (Foreign asm, Elt x, Elt y) - => asm (x -> y) -- foreign function - -> PreFun acc () (x -> y) -- alternate implementation (for other backends) - -> PreOpenExp acc env aenv x - -> PreOpenExp acc env aenv y + Foreign :: Sugar.Foreign asm + => TupleType y + -> asm (x -> y) -- foreign function + -> Fun () (x -> y) -- alternate implementation (for other backends) + -> OpenExp env aenv x + -> OpenExp env aenv y -- Tuples - Tuple :: (Elt t, IsTuple t) - => Tuple (PreOpenExp acc env aenv) (TupleRepr t) - -> PreOpenExp acc env aenv t + Pair :: OpenExp env aenv t1 + -> OpenExp env aenv t2 + -> OpenExp env aenv (t1, t2) - Prj :: (Elt t, IsTuple t, Elt e) - => TupleIdx (TupleRepr t) e - -> PreOpenExp acc env aenv t - -> PreOpenExp acc env aenv e + Nil :: OpenExp env aenv () - -- Array indices & shapes - IndexNil :: PreOpenExp acc env aenv Z + -- SIMD vectors + VecPack :: KnownNat n + => VecR n s tup + -> OpenExp env aenv tup + -> OpenExp env aenv (Vec n s) - IndexCons :: (Elt sl, Elt a) - => PreOpenExp acc env aenv sl - -> PreOpenExp acc env aenv a - -> PreOpenExp acc env aenv (sl:.a) + VecUnpack :: KnownNat n + => VecR n s tup + -> OpenExp env aenv (Vec n s) + -> OpenExp env aenv tup - IndexHead :: (Elt sl, Elt a) - => PreOpenExp acc env aenv (sl:.a) - -> PreOpenExp acc env aenv a - - IndexTail :: (Elt sl, Elt a) - => PreOpenExp acc env aenv (sl:.a) - -> PreOpenExp acc env aenv sl - - IndexAny :: Shape sh - => PreOpenExp acc env aenv (Any sh) - - IndexSlice :: (Shape sh, Shape sl, Elt slix) - => SliceIndex (EltRepr slix) (EltRepr sl) co (EltRepr sh) - -> PreOpenExp acc env aenv slix - -> PreOpenExp acc env aenv sh - -> PreOpenExp acc env aenv sl + -- Array indices & shapes + IndexSlice :: SliceIndex slix sl co sh + -> OpenExp env aenv slix + -> OpenExp env aenv sh + -> OpenExp env aenv sl - IndexFull :: (Shape sh, Shape sl, Elt slix) - => SliceIndex (EltRepr slix) (EltRepr sl) co (EltRepr sh) - -> PreOpenExp acc env aenv slix - -> PreOpenExp acc env aenv sl - -> PreOpenExp acc env aenv sh + IndexFull :: SliceIndex slix sl co sh + -> OpenExp env aenv slix + -> OpenExp env aenv sl + -> OpenExp env aenv sh -- Shape and index conversion - ToIndex :: Shape sh - => PreOpenExp acc env aenv sh -- shape of the array - -> PreOpenExp acc env aenv sh -- index into the array - -> PreOpenExp acc env aenv Int + ToIndex :: ShapeR sh + -> OpenExp env aenv sh -- shape of the array + -> OpenExp env aenv sh -- index into the array + -> OpenExp env aenv Int - FromIndex :: Shape sh - => PreOpenExp acc env aenv sh -- shape of the array - -> PreOpenExp acc env aenv Int -- index into linear representation - -> PreOpenExp acc env aenv sh + FromIndex :: ShapeR sh + -> OpenExp env aenv sh -- shape of the array + -> OpenExp env aenv Int -- index into linear representation + -> OpenExp env aenv sh -- Conditional expression (non-strict in 2nd and 3rd argument) - Cond :: Elt t - => PreOpenExp acc env aenv Bool - -> PreOpenExp acc env aenv t - -> PreOpenExp acc env aenv t - -> PreOpenExp acc env aenv t + Cond :: OpenExp env aenv Bool + -> OpenExp env aenv t + -> OpenExp env aenv t + -> OpenExp env aenv t -- Value recursion - While :: Elt a - => PreOpenFun acc env aenv (a -> Bool) -- continue while true - -> PreOpenFun acc env aenv (a -> a) -- function to iterate - -> PreOpenExp acc env aenv a -- initial value - -> PreOpenExp acc env aenv a + While :: OpenFun env aenv (a -> Bool) -- continue while true + -> OpenFun env aenv (a -> a) -- function to iterate + -> OpenExp env aenv a -- initial value + -> OpenExp env aenv a -- Constant values - Const :: Elt t - => EltRepr t - -> PreOpenExp acc env aenv t + Const :: ScalarType t + -> t + -> OpenExp env aenv t - PrimConst :: Elt t - => PrimConst t - -> PreOpenExp acc env aenv t + PrimConst :: PrimConst t + -> OpenExp env aenv t -- Primitive scalar operations - PrimApp :: (Elt a, Elt r) - => PrimFun (a -> r) - -> PreOpenExp acc env aenv a - -> PreOpenExp acc env aenv r + PrimApp :: PrimFun (a -> r) + -> OpenExp env aenv a + -> OpenExp env aenv r -- Project a single scalar from an array. -- The array expression can not contain any free scalar variables. - Index :: (Shape dim, Elt t) - => acc aenv (Array dim t) - -> PreOpenExp acc env aenv dim - -> PreOpenExp acc env aenv t + Index :: ArrayVar aenv (Array dim t) + -> OpenExp env aenv dim + -> OpenExp env aenv t - LinearIndex :: (Shape dim, Elt t) - => acc aenv (Array dim t) - -> PreOpenExp acc env aenv Int - -> PreOpenExp acc env aenv t + LinearIndex :: ArrayVar aenv (Array dim t) + -> OpenExp env aenv Int + -> OpenExp env aenv t -- Array shape. -- The array expression can not contain any free scalar variables. - Shape :: (Shape dim, Elt e) - => acc aenv (Array dim e) - -> PreOpenExp acc env aenv dim + Shape :: ArrayVar aenv (Array dim e) + -> OpenExp env aenv dim -- Number of elements of an array given its shape - ShapeSize :: Shape dim - => PreOpenExp acc env aenv dim - -> PreOpenExp acc env aenv Int - - -- Intersection of two shapes - Intersect :: Shape dim - => PreOpenExp acc env aenv dim - -> PreOpenExp acc env aenv dim - -> PreOpenExp acc env aenv dim - - -- Union of two shapes - Union :: Shape dim - => PreOpenExp acc env aenv dim - -> PreOpenExp acc env aenv dim - -> PreOpenExp acc env aenv dim + ShapeSize :: ShapeR dim + -> OpenExp env aenv dim + -> OpenExp env aenv Int -- Unsafe operations (may fail or result in undefined behaviour) -- An unspecified bit pattern - Undef :: Elt t - => PreOpenExp acc env aenv t + Undef :: ScalarType t + -> OpenExp env aenv t -- Reinterpret the bits of a value as a different type - Coerce :: (Elt a, Elt b) - => PreOpenExp acc env aenv a - -> PreOpenExp acc env aenv b - + Coerce :: BitSizeEq a b + => ScalarType a + -> ScalarType b + -> OpenExp env aenv a + -> OpenExp env aenv b + + +expType :: OpenExp aenv env t -> TupleType t +expType expr = case expr of + Let _ _ body -> expType body + Evar (Var tp _) -> TupRsingle tp + Foreign tp _ _ _ -> tp + Pair e1 e2 -> TupRpair (expType e1) (expType e2) + Nil -> TupRunit + VecPack vecR _ -> TupRsingle $ VectorScalarType $ vecRvector vecR + VecUnpack vecR _ -> vecRtuple vecR + IndexSlice si _ _ -> shapeType $ sliceShapeR si + IndexFull si _ _ -> shapeType $ sliceDomainR si + ToIndex _ _ _ -> TupRsingle $ SingleScalarType $ NumSingleType $ IntegralNumType $ TypeInt + FromIndex shr _ _ -> shapeType shr + Cond _ e _ -> expType e + While _ (Lam lhs _) _ -> lhsToTupR lhs + While _ _ _ -> error "What's the matter, you're running in the shadows" + Const tp _ -> TupRsingle tp + PrimConst c -> TupRsingle $ SingleScalarType $ primConstType c + PrimApp f _ -> snd $ primFunType f + Index (Var repr _) _ -> arrayRtype repr + LinearIndex (Var repr _) _ -> arrayRtype repr + Shape (Var repr _) -> shapeType $ arrayRshape repr + ShapeSize _ _ -> TupRsingle $ SingleScalarType $ NumSingleType $ IntegralNumType $ TypeInt + Undef tp -> TupRsingle tp + Coerce _ tp _ -> TupRsingle tp -- |Primitive constant values -- @@ -1161,6 +1076,12 @@ data PrimFun sig where PrimMin :: SingleType a -> PrimFun ((a, a) -> a ) -- logical operators + -- Note that these operators are strict in both arguments, + -- eg the second argument of PrimLAnd is always evaluated + -- even when the first argument is false. We thus define + -- (&&) and (||) using if-then-else to enable short-circuiting. + -- (&&!) and (||!) are strict versions of these operators, + -- which are defined using PrimLAnd and PrimLOr. PrimLAnd :: PrimFun ((Bool, Bool) -> Bool) PrimLOr :: PrimFun ((Bool, Bool) -> Bool) PrimLNot :: PrimFun (Bool -> Bool) @@ -1176,6 +1097,147 @@ data PrimFun sig where PrimFromIntegral :: IntegralType a -> NumType b -> PrimFun (a -> b) PrimToFloating :: NumType a -> FloatingType b -> PrimFun (a -> b) +primConstType :: PrimConst a -> SingleType a +primConstType prim = case prim of + PrimMinBound t -> boundedTp t + PrimMaxBound t -> boundedTp t + PrimPi t -> floatingTp t + where + boundedTp :: BoundedType a -> SingleType a + boundedTp (IntegralBoundedType t) = NumSingleType $ IntegralNumType t + boundedTp (NonNumBoundedType t) = NonNumSingleType t + + floatingTp :: FloatingType t -> SingleType t + floatingTp = NumSingleType . FloatingNumType + +primFunType :: PrimFun (a -> b) -> (TupleType a, TupleType b) +primFunType prim = case prim of + -- Num + PrimAdd t -> binary' $ numTp t + PrimSub t -> binary' $ numTp t + PrimMul t -> binary' $ numTp t + PrimNeg t -> unary' $ numTp t + PrimAbs t -> unary' $ numTp t + PrimSig t -> unary' $ numTp t + + -- Integral + PrimQuot t -> binary' $ integralTp t + PrimRem t -> binary' $ integralTp t + PrimQuotRem t -> divModT t + PrimIDiv t -> binary' $ integralTp t + PrimMod t -> binary' $ integralTp t + PrimDivMod t -> divModT t + + -- Bits & FiniteBits + PrimBAnd t -> binary' $ integralTp t + PrimBOr t -> binary' $ integralTp t + PrimBXor t -> binary' $ integralTp t + PrimBNot t -> unary' $ integralTp t + PrimBShiftL t -> (integralTp t `TupRpair` typeInt, integralTp t) + PrimBShiftR t -> (integralTp t `TupRpair` typeInt, integralTp t) + PrimBRotateL t -> (integralTp t `TupRpair` typeInt, integralTp t) + PrimBRotateR t -> (integralTp t `TupRpair` typeInt, integralTp t) + PrimPopCount t -> unary (integralTp t) typeInt + PrimCountLeadingZeros t -> unary (integralTp t) typeInt + PrimCountTrailingZeros t -> unary (integralTp t) typeInt + + -- Fractional, Floating + PrimFDiv t -> binary' $ floatingTp t + PrimRecip t -> unary' $ floatingTp t + PrimSin t -> unary' $ floatingTp t + PrimCos t -> unary' $ floatingTp t + PrimTan t -> unary' $ floatingTp t + PrimAsin t -> unary' $ floatingTp t + PrimAcos t -> unary' $ floatingTp t + PrimAtan t -> unary' $ floatingTp t + PrimSinh t -> unary' $ floatingTp t + PrimCosh t -> unary' $ floatingTp t + PrimTanh t -> unary' $ floatingTp t + PrimAsinh t -> unary' $ floatingTp t + PrimAcosh t -> unary' $ floatingTp t + PrimAtanh t -> unary' $ floatingTp t + PrimExpFloating t -> unary' $ floatingTp t + PrimSqrt t -> unary' $ floatingTp t + PrimLog t -> unary' $ floatingTp t + PrimFPow t -> binary' $ floatingTp t + PrimLogBase t -> binary' $ floatingTp t + + -- RealFrac + PrimTruncate a b -> unary (floatingTp a) (integralTp b) + PrimRound a b -> unary (floatingTp a) (integralTp b) + PrimFloor a b -> unary (floatingTp a) (integralTp b) + PrimCeiling a b -> unary (floatingTp a) (integralTp b) + + -- RealFloat + PrimAtan2 t -> binary' $ floatingTp t + PrimIsNaN t -> unary (floatingTp t) typeBool + PrimIsInfinite t -> unary (floatingTp t) typeBool + + -- Relational and equality + PrimLt t -> compare' t + PrimGt t -> compare' t + PrimLtEq t -> compare' t + PrimGtEq t -> compare' t + PrimEq t -> compare' t + PrimNEq t -> compare' t + PrimMax t -> binary' $ singleTp t + PrimMin t -> binary' $ singleTp t + + -- Logical + PrimLAnd -> binary' typeBool + PrimLOr -> binary' typeBool + PrimLNot -> unary' typeBool + + -- character conversions + PrimOrd -> unary typeChar typeInt + PrimChr -> unary typeInt typeChar + + -- boolean conversion + PrimBoolToInt -> unary typeBool typeInt + + -- general conversion between types + PrimFromIntegral a b -> unary (integralTp a) (numTp b) + PrimToFloating a b -> unary (numTp a) (floatingTp b) + + where + unary :: TupleType a -> TupleType b -> (TupleType a, TupleType b) + unary a b = (a, b) + + unary' :: TupleType a -> (TupleType a, TupleType a) + unary' a = unary a a + + binary :: TupleType a -> TupleType b -> (TupleType (a, a), TupleType b) + binary a b = (a `TupRpair` a, b) + + binary' :: TupleType a -> (TupleType (a, a), TupleType a) + binary' a = binary a a + + compare' :: SingleType a -> (TupleType (a, a), TupleType Bool) + compare' a = binary (singleTp a) typeBool + + singleTp :: SingleType t -> TupleType t + singleTp = TupRsingle . SingleScalarType + + numTp :: NumType t -> TupleType t + numTp = TupRsingle . SingleScalarType . NumSingleType + + integralTp :: IntegralType t -> TupleType t + integralTp = numTp . IntegralNumType + + floatingTp :: FloatingType t -> TupleType t + floatingTp = numTp . FloatingNumType + + divModT :: IntegralType t -> (TupleType (t, t), TupleType (t, t)) + divModT t = unary' $ integralTp t `TupRpair` integralTp t + + typeBool :: TupleType Bool + typeBool = TupRsingle $ SingleScalarType $ NonNumSingleType $ TypeBool + + typeChar :: TupleType Char + typeChar = TupRsingle $ SingleScalarType $ NonNumSingleType $ TypeChar + + typeInt :: TupleType Int + typeInt = TupRsingle $ SingleScalarType $ NumSingleType $ IntegralNumType TypeInt -- NFData instances -- ================ @@ -1190,10 +1252,10 @@ instance NFData (OpenAcc aenv t) where -- rnf = rnfPreOpenSeq rnfOpenAcc instance NFData (OpenExp env aenv t) where - rnf = rnfPreOpenExp rnfOpenAcc + rnf = rnfOpenExp instance NFData (OpenFun env aenv t) where - rnf = rnfPreOpenFun rnfOpenAcc + rnf = rnfOpenFun -- Array expressions @@ -1205,10 +1267,6 @@ rnfIdx :: Idx env t -> () rnfIdx ZeroIdx = () rnfIdx (SuccIdx ix) = rnfIdx ix -rnfTupleIdx :: TupleIdx t e -> () -rnfTupleIdx ZeroTupIdx = () -rnfTupleIdx (SuccTupIdx tix) = rnfTupleIdx tix - rnfOpenAfun :: OpenAfun aenv t -> () rnfOpenAfun = rnfPreOpenAfun rnfOpenAcc @@ -1217,48 +1275,48 @@ rnfOpenAcc (OpenAcc pacc) = rnfPreOpenAcc rnfOpenAcc pacc rnfPreOpenAfun :: NFDataAcc acc -> PreOpenAfun acc aenv t -> () rnfPreOpenAfun rnfA (Abody b) = rnfA b -rnfPreOpenAfun rnfA (Alam lhs f) = rnfLHS lhs `seq` rnfPreOpenAfun rnfA f +rnfPreOpenAfun rnfA (Alam lhs f) = rnfALhs lhs `seq` rnfPreOpenAfun rnfA f -rnfPreOpenAcc :: forall acc aenv t. NFDataAcc acc -> PreOpenAcc acc aenv t -> () +rnfPreOpenAcc :: forall acc aenv t. HasArraysRepr acc => NFDataAcc acc -> PreOpenAcc acc aenv t -> () rnfPreOpenAcc rnfA pacc = let rnfAF :: PreOpenAfun acc aenv' t' -> () rnfAF = rnfPreOpenAfun rnfA - rnfE :: PreOpenExp acc env' aenv' t' -> () - rnfE = rnfPreOpenExp rnfA + rnfE :: OpenExp env' aenv' t' -> () + rnfE = rnfOpenExp - rnfF :: PreOpenFun acc env' aenv' t' -> () - rnfF = rnfPreOpenFun rnfA + rnfF :: OpenFun env' aenv' t' -> () + rnfF = rnfOpenFun -- rnfS :: PreOpenSeq acc aenv' senv' t' -> () -- rnfS = rnfPreOpenSeq rnfA - rnfB :: PreBoundary acc aenv' (Array sh e) -> () - rnfB = rnfBoundary rnfA + rnfB :: ArrayR (Array sh e) -> Boundary aenv' (Array sh e) -> () + rnfB = rnfBoundary in case pacc of - Alet lhs bnd body -> rnfLHS lhs `seq` rnfA bnd `seq` rnfA body - Avar (ArrayVar ix) -> rnfIdx ix + Alet lhs bnd body -> rnfALhs lhs `seq` rnfA bnd `seq` rnfA body + Avar var -> rnfArrayVar var Apair as bs -> rnfA as `seq` rnfA bs Anil -> () - Apply afun acc -> rnfAF afun `seq` rnfA acc - Aforeign asm afun a -> rnf (strForeign asm) `seq` rnfAF afun `seq` rnfA a + Apply repr afun acc -> rnfTupR rnfArrayR repr `seq` rnfAF afun `seq` rnfA acc + Aforeign repr asm afun a -> rnfTupR rnfArrayR repr `seq` rnf (Sugar.strForeign asm) `seq` rnfAF afun `seq` rnfA a Acond p a1 a2 -> rnfE p `seq` rnfA a1 `seq` rnfA a2 Awhile p f a -> rnfAF p `seq` rnfAF f `seq` rnfA a - Use arr -> rnf arr - Unit x -> rnfE x - Reshape sh a -> rnfE sh `seq` rnfA a - Generate sh f -> rnfE sh `seq` rnfF f - Transform sh p f a -> rnfE sh `seq` rnfF p `seq` rnfF f `seq` rnfA a + Use repr arr -> rnfArray repr arr + Unit tp x -> rnfTupleType tp `seq` rnfE x + Reshape shr sh a -> rnfShapeR shr `seq` rnfE sh `seq` rnfA a + Generate repr sh f -> rnfArrayR repr `seq` rnfE sh `seq` rnfF f + Transform repr sh p f a -> rnfArrayR repr `seq` rnfE sh `seq` rnfF p `seq` rnfF f `seq` rnfA a Replicate slice sh a -> rnfSliceIndex slice `seq` rnfE sh `seq` rnfA a Slice slice a sh -> rnfSliceIndex slice `seq` rnfE sh `seq` rnfA a - Map f a -> rnfF f `seq` rnfA a - ZipWith f a1 a2 -> rnfF f `seq` rnfA a1 `seq` rnfA a2 + Map tp f a -> rnfTupleType tp `seq` rnfF f `seq` rnfA a + ZipWith tp f a1 a2 -> rnfTupleType tp `seq` rnfF f `seq` rnfA a1 `seq` rnfA a2 Fold f z a -> rnfF f `seq` rnfE z `seq` rnfA a Fold1 f a -> rnfF f `seq` rnfA a - FoldSeg f z a s -> rnfF f `seq` rnfE z `seq` rnfA a `seq` rnfA s - Fold1Seg f a s -> rnfF f `seq` rnfA a `seq` rnfA s + FoldSeg i f z a s -> rnfIntegralType i `seq` rnfF f `seq` rnfE z `seq` rnfA a `seq` rnfA s + Fold1Seg i f a s -> rnfIntegralType i `seq` rnfF f `seq` rnfA a `seq` rnfA s Scanl f z a -> rnfF f `seq` rnfE z `seq` rnfA a Scanl1 f a -> rnfF f `seq` rnfA a Scanl' f z a -> rnfF f `seq` rnfE z `seq` rnfA a @@ -1266,32 +1324,74 @@ rnfPreOpenAcc rnfA pacc = Scanr1 f a -> rnfF f `seq` rnfA a Scanr' f z a -> rnfF f `seq` rnfE z `seq` rnfA a Permute f d p a -> rnfF f `seq` rnfA d `seq` rnfF p `seq` rnfA a - Backpermute sh f a -> rnfE sh `seq` rnfF f `seq` rnfA a - Stencil f b a -> rnfF f `seq` rnfB b `seq` rnfA a - Stencil2 f b1 a1 b2 a2 -> rnfF f `seq` rnfB b1 `seq` rnfB b2 `seq` rnfA a1 `seq` rnfA a2 + Backpermute shr sh f a -> rnfShapeR shr `seq` rnfE sh `seq` rnfF f `seq` rnfA a + Stencil sr tp f b a -> + let + TupRsingle (ArrayR shr _) = arraysRepr a + repr = ArrayR shr $ stencilElt sr + in rnfStencilR sr `seq` rnfTupR rnfScalarType tp `seq` rnfF f `seq` rnfB repr b `seq` rnfA a + Stencil2 sr1 sr2 tp f b1 a1 b2 a2 -> + let + TupRsingle (ArrayR shr _) = arraysRepr a1 + repr1 = ArrayR shr $ stencilElt sr1 + repr2 = ArrayR shr $ stencilElt sr2 + in rnfStencilR sr1 `seq` rnfStencilR sr2 `seq` rnfTupR rnfScalarType tp `seq` rnfF f `seq` rnfB repr1 b1 `seq` rnfB repr2 b2 `seq` rnfA a1 `seq` rnfA a2 -- Collect s -> rnfS s -rnfLHS :: LeftHandSide arrs aenv aenv' -> () -rnfLHS (LeftHandSideWildcard r) = rnfArraysR r -rnfLHS LeftHandSideArray = () -rnfLHS (LeftHandSidePair ar1 ar2) = rnfLHS ar1 `seq` rnfLHS ar2 +rnfArrayVar :: ArrayVar aenv a -> () +rnfArrayVar (Var repr ix) = rnfArrayR repr `seq` rnfIdx ix + +rnfLhs :: (forall b. s b -> ()) -> LeftHandSide s arrs env env' -> () +rnfLhs rnfS (LeftHandSideWildcard r) = rnfTupR rnfS r +rnfLhs rnfS (LeftHandSideSingle s) = rnfS s +rnfLhs rnfS (LeftHandSidePair ar1 ar2) = rnfLhs rnfS ar1 `seq` rnfLhs rnfS ar2 + +rnfALhs :: ALeftHandSide arrs aenv aenv' -> () +rnfALhs = rnfLhs rnfArrayR -rnfArraysR :: ArraysR arrs -> () -rnfArraysR ArraysRunit = () -rnfArraysR ArraysRarray = () -rnfArraysR (ArraysRpair ar1 ar2) = rnfArraysR ar1 `seq` rnfArraysR ar2 +rnfELhs :: ELeftHandSide t env env' -> () +rnfELhs = rnfLhs rnfScalarType + +rnfTupR :: (forall b. s b -> ()) -> TupR s a -> () +rnfTupR _ TupRunit = () +rnfTupR rnfS (TupRsingle s) = rnfS s +rnfTupR rnfS (TupRpair t1 t2) = rnfTupR rnfS t1 `seq` rnfTupR rnfS t2 + +rnfArrayR :: ArrayR arr -> () +rnfArrayR (ArrayR shr tp) = rnfShapeR shr `seq` rnfTupR rnfScalarType tp rnfArrays :: ArraysR arrs -> arrs -> () -rnfArrays ArraysRunit () = () -rnfArrays ArraysRarray arr = rnf arr -rnfArrays (ArraysRpair ar1 ar2) (a1,a2) = rnfArrays ar1 a1 `seq` rnfArrays ar2 a2 +rnfArrays TupRunit () = () +rnfArrays (TupRsingle repr) arr = rnfArray repr arr +rnfArrays (TupRpair ar1 ar2) (a1,a2) = rnfArrays ar1 a1 `seq` rnfArrays ar2 a2 + +rnfShapeR :: ShapeR sh -> () +rnfShapeR ShapeRz = () +rnfShapeR (ShapeRsnoc shr) = rnfShapeR shr + +rnfStencilR :: StencilR sh e pat -> () +rnfStencilR (StencilRunit3 tp) = rnfTupleType tp +rnfStencilR (StencilRunit5 tp) = rnfTupleType tp +rnfStencilR (StencilRunit7 tp) = rnfTupleType tp +rnfStencilR (StencilRunit9 tp) = rnfTupleType tp +rnfStencilR (StencilRtup3 s1 s2 s3) + = rnfStencilR s1 `seq` rnfStencilR s2 `seq` rnfStencilR s3 +rnfStencilR (StencilRtup5 s1 s2 s3 s4 s5) + = rnfStencilR s1 `seq` rnfStencilR s2 `seq` rnfStencilR s3 `seq` rnfStencilR s4 `seq` rnfStencilR s5 +rnfStencilR (StencilRtup7 s1 s2 s3 s4 s5 s6 s7) + = rnfStencilR s1 `seq` rnfStencilR s2 `seq` rnfStencilR s3 `seq` rnfStencilR s4 `seq` rnfStencilR s5 + `seq` rnfStencilR s6 `seq` rnfStencilR s7 +rnfStencilR (StencilRtup9 s1 s2 s3 s4 s5 s6 s7 s8 s9) + = rnfStencilR s1 `seq` rnfStencilR s2 `seq` rnfStencilR s3 `seq` rnfStencilR s4 `seq` rnfStencilR s5 + `seq` rnfStencilR s6 `seq` rnfStencilR s7 `seq` rnfStencilR s8 `seq` rnfStencilR s9 + +rnfBoundary :: forall aenv sh e. ArrayR (Array sh e) -> Boundary aenv (Array sh e) -> () +rnfBoundary _ Clamp = () +rnfBoundary _ Mirror = () +rnfBoundary _ Wrap = () +rnfBoundary (ArrayR _ tp) (Constant c) = rnfConst tp c +rnfBoundary _ (Function f) = rnfOpenFun f -rnfBoundary :: forall acc aenv sh e. NFDataAcc acc -> PreBoundary acc aenv (Array sh e) -> () -rnfBoundary _ Clamp = () -rnfBoundary _ Mirror = () -rnfBoundary _ Wrap = () -rnfBoundary _ (Constant c) = rnfConst (eltType @e) c -rnfBoundary rnfA (Function f) = rnfPreOpenFun rnfA f {-- @@ -1325,11 +1425,11 @@ rnfSeqProducer rnfA topSeq = rnfAF :: PreOpenAfun acc aenv' t' -> () rnfAF = rnfPreOpenAfun rnfA - rnfF :: PreOpenFun acc env' aenv' t' -> () - rnfF = rnfPreOpenFun rnfA + rnfF :: OpenFun env' aenv' t' -> () + rnfF = rnfOpenFun rnfA - rnfE :: PreOpenExp acc env' aenv' t' -> () - rnfE = rnfPreOpenExp rnfA + rnfE :: OpenExp env' aenv' t' -> () + rnfE = rnfOpenExp rnfA in case topSeq of StreamIn as -> rnfArrs as @@ -1345,11 +1445,11 @@ rnfSeqConsumer rnfA topSeq = rnfAF :: PreOpenAfun acc aenv' t' -> () rnfAF = rnfPreOpenAfun rnfA - rnfF :: PreOpenFun acc env' aenv' t' -> () - rnfF = rnfPreOpenFun rnfA + rnfF :: OpenFun env' aenv' t' -> () + rnfF = rnfOpenFun rnfA - rnfE :: PreOpenExp acc env' aenv' t' -> () - rnfE = rnfPreOpenExp rnfA + rnfE :: OpenExp env' aenv' t' -> () + rnfE = rnfOpenExp rnfA in case topSeq of FoldSeq f z ix -> rnfF f `seq` rnfE z `seq` rnfIdx ix @@ -1364,56 +1464,47 @@ rnfStuple rnfA (SnocAtup tup c) = rnfStuple rnfA tup `seq` rnfSeqConsumer rnfA c -- Scalar expressions -- ------------------ -rnfPreOpenFun :: NFDataAcc acc -> PreOpenFun acc env aenv t -> () -rnfPreOpenFun rnfA (Body b) = rnfPreOpenExp rnfA b -rnfPreOpenFun rnfA (Lam f) = rnfPreOpenFun rnfA f +rnfOpenFun :: OpenFun env aenv t -> () +rnfOpenFun (Body b) = rnfOpenExp b +rnfOpenFun (Lam lhs f) = rnfELhs lhs `seq` rnfOpenFun f -rnfPreOpenExp :: forall acc env aenv t. NFDataAcc acc -> PreOpenExp acc env aenv t -> () -rnfPreOpenExp rnfA topExp = +rnfOpenExp :: forall env aenv t. OpenExp env aenv t -> () +rnfOpenExp topExp = let - rnfF :: PreOpenFun acc env' aenv' t' -> () - rnfF = rnfPreOpenFun rnfA + rnfF :: OpenFun env' aenv' t' -> () + rnfF = rnfOpenFun - rnfE :: PreOpenExp acc env' aenv' t' -> () - rnfE = rnfPreOpenExp rnfA + rnfE :: OpenExp env' aenv' t' -> () + rnfE = rnfOpenExp in case topExp of - Let bnd body -> rnfE bnd `seq` rnfE body - Var ix -> rnfIdx ix - Foreign asm f x -> rnf (strForeign asm) `seq` rnfF f `seq` rnfE x - Const t -> rnfConst (eltType @t) t - Undef -> () - Tuple t -> rnfTuple rnfA t - Prj ix e -> rnfTupleIdx ix `seq` rnfE e - IndexNil -> () - IndexCons sh sz -> rnfE sh `seq` rnfE sz - IndexHead sh -> rnfE sh - IndexTail sh -> rnfE sh - IndexAny -> () + Let lhs bnd body -> rnfELhs lhs `seq` rnfE bnd `seq` rnfE body + Evar (Var tp ix) -> rnfScalarType tp `seq` rnfIdx ix + Foreign tp asm f x -> rnfTupleType tp `seq` rnf (Sugar.strForeign asm) `seq` rnfF f `seq` rnfE x + Const tp c -> c `seq` rnfScalarType tp -- scalars should have (nf == whnf) + Undef tp -> rnfScalarType tp + Pair a b -> rnfE a `seq` rnfE b + Nil -> () + VecPack vecr e -> rnfVecR vecr `seq` rnfE e + VecUnpack vecr e -> rnfVecR vecr `seq` rnfE e IndexSlice slice slix sh -> rnfSliceIndex slice `seq` rnfE slix `seq` rnfE sh IndexFull slice slix sl -> rnfSliceIndex slice `seq` rnfE slix `seq` rnfE sl - ToIndex sh ix -> rnfE sh `seq` rnfE ix - FromIndex sh ix -> rnfE sh `seq` rnfE ix + ToIndex shr sh ix -> rnfShapeR shr `seq` rnfE sh `seq` rnfE ix + FromIndex shr sh ix -> rnfShapeR shr `seq` rnfE sh `seq` rnfE ix Cond p e1 e2 -> rnfE p `seq` rnfE e1 `seq` rnfE e2 While p f x -> rnfF p `seq` rnfF f `seq` rnfE x PrimConst c -> rnfPrimConst c PrimApp f x -> rnfPrimFun f `seq` rnfE x - Index a ix -> rnfA a `seq` rnfE ix - LinearIndex a ix -> rnfA a `seq` rnfE ix - Shape a -> rnfA a - ShapeSize sh -> rnfE sh - Intersect sh1 sh2 -> rnfE sh1 `seq` rnfE sh2 - Union sh1 sh2 -> rnfE sh1 `seq` rnfE sh2 - Coerce e -> rnfE e - -rnfTuple :: NFDataAcc acc -> Tuple (PreOpenExp acc env aenv) t -> () -rnfTuple _ NilTup = () -rnfTuple rnfA (SnocTup t e) = rnfTuple rnfA t `seq` rnfPreOpenExp rnfA e + Index a ix -> rnfArrayVar a `seq` rnfE ix + LinearIndex a ix -> rnfArrayVar a `seq` rnfE ix + Shape a -> rnfArrayVar a + ShapeSize shr sh -> rnfShapeR shr `seq` rnfE sh + Coerce t1 t2 e -> rnfScalarType t1 `seq` rnfScalarType t2 `seq` rnfE e rnfConst :: TupleType t -> t -> () -rnfConst TypeRunit () = () -rnfConst (TypeRscalar t) !_ = rnfScalarType t -- scalars should have (nf == whnf) -rnfConst (TypeRpair ta tb) (a,b) = rnfConst ta a `seq` rnfConst tb b +rnfConst TupRunit () = () +rnfConst (TupRsingle t) !_ = rnfScalarType t -- scalars should have (nf == whnf) +rnfConst (TupRpair ta tb) (a,b) = rnfConst ta a `seq` rnfConst tb b rnfPrimConst :: PrimConst c -> () rnfPrimConst (PrimMinBound t) = rnfBoundedType t @@ -1492,6 +1583,9 @@ rnfSliceIndex SliceNil = () rnfSliceIndex (SliceAll sh) = rnfSliceIndex sh rnfSliceIndex (SliceFixed sh) = rnfSliceIndex sh +rnfTupleType :: TupleType t -> () +rnfTupleType = rnfTupR rnfScalarType + rnfScalarType :: ScalarType t -> () rnfScalarType (SingleScalarType t) = rnfSingleType t rnfScalarType (VectorScalarType t) = rnfVectorType t @@ -1512,26 +1606,29 @@ rnfNumType (IntegralNumType t) = rnfIntegralType t rnfNumType (FloatingNumType t) = rnfFloatingType t rnfNonNumType :: NonNumType t -> () -rnfNonNumType (TypeBool NonNumDict) = () -rnfNonNumType (TypeChar NonNumDict) = () +rnfNonNumType TypeBool = () +rnfNonNumType TypeChar = () rnfIntegralType :: IntegralType t -> () -rnfIntegralType (TypeInt IntegralDict) = () -rnfIntegralType (TypeInt8 IntegralDict) = () -rnfIntegralType (TypeInt16 IntegralDict) = () -rnfIntegralType (TypeInt32 IntegralDict) = () -rnfIntegralType (TypeInt64 IntegralDict) = () -rnfIntegralType (TypeWord IntegralDict) = () -rnfIntegralType (TypeWord8 IntegralDict) = () -rnfIntegralType (TypeWord16 IntegralDict) = () -rnfIntegralType (TypeWord32 IntegralDict) = () -rnfIntegralType (TypeWord64 IntegralDict) = () +rnfIntegralType TypeInt = () +rnfIntegralType TypeInt8 = () +rnfIntegralType TypeInt16 = () +rnfIntegralType TypeInt32 = () +rnfIntegralType TypeInt64 = () +rnfIntegralType TypeWord = () +rnfIntegralType TypeWord8 = () +rnfIntegralType TypeWord16 = () +rnfIntegralType TypeWord32 = () +rnfIntegralType TypeWord64 = () rnfFloatingType :: FloatingType t -> () -rnfFloatingType (TypeHalf FloatingDict) = () -rnfFloatingType (TypeFloat FloatingDict) = () -rnfFloatingType (TypeDouble FloatingDict) = () +rnfFloatingType TypeHalf = () +rnfFloatingType TypeFloat = () +rnfFloatingType TypeDouble = () +rnfVecR :: VecR n single tuple -> () +rnfVecR (VecRnil tp) = rnfSingleType tp +rnfVecR (VecRsucc vec) = rnfVecR vec -- Template Haskell -- ================ @@ -1542,57 +1639,54 @@ liftIdx :: Idx env t -> Q (TExp (Idx env t)) liftIdx ZeroIdx = [|| ZeroIdx ||] liftIdx (SuccIdx ix) = [|| SuccIdx $$(liftIdx ix) ||] -liftTupleIdx :: TupleIdx t e -> Q (TExp (TupleIdx t e)) -liftTupleIdx ZeroTupIdx = [|| ZeroTupIdx ||] -liftTupleIdx (SuccTupIdx tix) = [|| SuccTupIdx $$(liftTupleIdx tix) ||] - liftPreOpenAfun :: LiftAcc acc -> PreOpenAfun acc aenv t -> Q (TExp (PreOpenAfun acc aenv t)) -liftPreOpenAfun liftA (Alam lhs f) = [|| Alam $$(liftLHS lhs) $$(liftPreOpenAfun liftA f) ||] +liftPreOpenAfun liftA (Alam lhs f) = [|| Alam $$(liftALhs lhs) $$(liftPreOpenAfun liftA f) ||] liftPreOpenAfun liftA (Abody b) = [|| Abody $$(liftA b) ||] liftPreOpenAcc :: forall acc aenv a. - LiftAcc acc + HasArraysRepr acc + => LiftAcc acc -> PreOpenAcc acc aenv a -> Q (TExp (PreOpenAcc acc aenv a)) liftPreOpenAcc liftA pacc = let - liftE :: PreOpenExp acc env aenv t -> Q (TExp (PreOpenExp acc env aenv t)) - liftE = liftPreOpenExp liftA + liftE :: OpenExp env aenv t -> Q (TExp (OpenExp env aenv t)) + liftE = liftOpenExp - liftF :: PreOpenFun acc env aenv t -> Q (TExp (PreOpenFun acc env aenv t)) - liftF = liftPreOpenFun liftA + liftF :: OpenFun env aenv t -> Q (TExp (OpenFun env aenv t)) + liftF = liftOpenFun liftAF :: PreOpenAfun acc aenv f -> Q (TExp (PreOpenAfun acc aenv f)) liftAF = liftPreOpenAfun liftA - liftB :: PreBoundary acc aenv (Array sh e) -> Q (TExp (PreBoundary acc aenv (Array sh e))) - liftB = liftBoundary liftA + liftB :: ArrayR (Array sh e) -> Boundary aenv (Array sh e) -> Q (TExp (Boundary aenv (Array sh e))) + liftB = liftBoundary in case pacc of - Alet lhs bnd body -> [|| Alet $$(liftLHS lhs) $$(liftA bnd) $$(liftA body) ||] - Avar (ArrayVar ix) -> [|| Avar (ArrayVar $$(liftIdx ix)) ||] + Alet lhs bnd body -> [|| Alet $$(liftALhs lhs) $$(liftA bnd) $$(liftA body) ||] + Avar var -> [|| Avar $$(liftArrayVar var) ||] Apair as bs -> [|| Apair $$(liftA as) $$(liftA bs) ||] Anil -> [|| Anil ||] - Apply f a -> [|| Apply $$(liftAF f) $$(liftA a) ||] - Aforeign asm f a -> [|| Aforeign $$(liftForeign asm) $$(liftPreOpenAfun liftA f) $$(liftA a) ||] + Apply repr f a -> [|| Apply $$(liftArraysR repr) $$(liftAF f) $$(liftA a) ||] + Aforeign repr asm f a -> [|| Aforeign $$(liftArraysR repr) $$(Sugar.liftForeign asm) $$(liftPreOpenAfun liftA f) $$(liftA a) ||] Acond p t e -> [|| Acond $$(liftE p) $$(liftA t) $$(liftA e) ||] Awhile p f a -> [|| Awhile $$(liftAF p) $$(liftAF f) $$(liftA a) ||] - Use a -> [|| Use $$(liftArray a) ||] - Unit e -> [|| Unit $$(liftE e) ||] - Reshape sh a -> [|| Reshape $$(liftE sh) $$(liftA a) ||] - Generate sh f -> [|| Generate $$(liftE sh) $$(liftF f) ||] - Transform sh p f a -> [|| Transform $$(liftE sh) $$(liftF p) $$(liftF f) $$(liftA a) ||] + Use repr a -> [|| Use $$(liftArrayR repr) $$(liftArray repr a) ||] + Unit tp e -> [|| Unit $$(liftTupleType tp) $$(liftE e) ||] + Reshape shr sh a -> [|| Reshape $$(liftShapeR shr) $$(liftE sh) $$(liftA a) ||] + Generate repr sh f -> [|| Generate $$(liftArrayR repr) $$(liftE sh) $$(liftF f) ||] + Transform repr sh p f a -> [|| Transform $$(liftArrayR repr) $$(liftE sh) $$(liftF p) $$(liftF f) $$(liftA a) ||] Replicate slix sl a -> [|| Replicate $$(liftSliceIndex slix) $$(liftE sl) $$(liftA a) ||] Slice slix a sh -> [|| Slice $$(liftSliceIndex slix) $$(liftA a) $$(liftE sh) ||] - Map f a -> [|| Map $$(liftF f) $$(liftA a) ||] - ZipWith f a b -> [|| ZipWith $$(liftF f) $$(liftA a) $$(liftA b) ||] + Map tp f a -> [|| Map $$(liftTupleType tp) $$(liftF f) $$(liftA a) ||] + ZipWith tp f a b -> [|| ZipWith $$(liftTupleType tp) $$(liftF f) $$(liftA a) $$(liftA b) ||] Fold f z a -> [|| Fold $$(liftF f) $$(liftE z) $$(liftA a) ||] Fold1 f a -> [|| Fold1 $$(liftF f) $$(liftA a) ||] - FoldSeg f z a s -> [|| FoldSeg $$(liftF f) $$(liftE z) $$(liftA a) $$(liftA s) ||] - Fold1Seg f a s -> [|| Fold1Seg $$(liftF f) $$(liftA a) $$(liftA s) ||] + FoldSeg i f z a s -> [|| FoldSeg $$(liftIntegralType i) $$(liftF f) $$(liftE z) $$(liftA a) $$(liftA s) ||] + Fold1Seg i f a s -> [|| Fold1Seg $$(liftIntegralType i) $$(liftF f) $$(liftA a) $$(liftA s) ||] Scanl f z a -> [|| Scanl $$(liftF f) $$(liftE z) $$(liftA a) ||] Scanl1 f a -> [|| Scanl1 $$(liftF f) $$(liftA a) ||] Scanl' f z a -> [|| Scanl' $$(liftF f) $$(liftE z) $$(liftA a) ||] @@ -1600,94 +1694,145 @@ liftPreOpenAcc liftA pacc = Scanr1 f a -> [|| Scanr1 $$(liftF f) $$(liftA a) ||] Scanr' f z a -> [|| Scanr' $$(liftF f) $$(liftE z) $$(liftA a) ||] Permute f d p a -> [|| Permute $$(liftF f) $$(liftA d) $$(liftF p) $$(liftA a) ||] - Backpermute sh p a -> [|| Backpermute $$(liftE sh) $$(liftF p) $$(liftA a) ||] - Stencil f b a -> [|| Stencil $$(liftF f) $$(liftB b) $$(liftA a) ||] - Stencil2 f b1 a1 b2 a2 -> [|| Stencil2 $$(liftF f) $$(liftB b1) $$(liftA a1) $$(liftB b2) $$(liftA a2) ||] - -liftLHS :: LeftHandSide arrs aenv aenv' -> Q (TExp (LeftHandSide arrs aenv aenv')) -liftLHS LeftHandSideArray = [|| LeftHandSideArray ||] -liftLHS (LeftHandSideWildcard r) = [|| LeftHandSideWildcard $$(liftArraysR r) ||] -liftLHS (LeftHandSidePair a b) = [|| LeftHandSidePair $$(liftLHS a) $$(liftLHS b) ||] + Backpermute shr sh p a -> [|| Backpermute $$(liftShapeR shr) $$(liftE sh) $$(liftF p) $$(liftA a) ||] + Stencil sr tp f b a -> + let + TupRsingle (ArrayR shr _) = arraysRepr a + repr = ArrayR shr $ stencilElt sr + in [|| Stencil $$(liftStencilR sr) $$(liftTupleType tp) $$(liftF f) $$(liftB repr b) $$(liftA a) ||] + Stencil2 sr1 sr2 tp f b1 a1 b2 a2 -> + let + TupRsingle (ArrayR shr _) = arraysRepr a1 + repr1 = ArrayR shr $ stencilElt sr1 + repr2 = ArrayR shr $ stencilElt sr2 + in [|| Stencil2 $$(liftStencilR sr1) $$(liftStencilR sr2) $$(liftTupleType tp) $$(liftF f) $$(liftB repr1 b1) $$(liftA a1) $$(liftB repr2 b2) $$(liftA a2) ||] + +liftALhs :: ALeftHandSide arrs aenv aenv' -> Q (TExp (ALeftHandSide arrs aenv aenv')) +liftALhs (LeftHandSideSingle repr) = [|| LeftHandSideSingle $$(liftArrayR repr) ||] +liftALhs (LeftHandSideWildcard r) = [|| LeftHandSideWildcard $$(liftArraysR r) ||] +liftALhs (LeftHandSidePair a b) = [|| LeftHandSidePair $$(liftALhs a) $$(liftALhs b) ||] + +liftELhs :: ELeftHandSide t env env' -> Q (TExp (ELeftHandSide t env env')) +liftELhs (LeftHandSideSingle t) = [|| LeftHandSideSingle $$(liftScalarType t) ||] +liftELhs (LeftHandSideWildcard r) = [|| LeftHandSideWildcard $$(liftTupleType r) ||] +liftELhs (LeftHandSidePair a b) = [|| LeftHandSidePair $$(liftELhs a) $$(liftELhs b) ||] + +liftShapeR :: ShapeR sh -> Q (TExp (ShapeR sh)) +liftShapeR ShapeRz = [|| ShapeRz ||] +liftShapeR (ShapeRsnoc sh) = [|| ShapeRsnoc $$(liftShapeR sh) ||] + +liftArrayR :: ArrayR a -> Q (TExp (ArrayR a)) +liftArrayR (ArrayR shr tp) = [|| ArrayR $$(liftShapeR shr) $$(liftTupleType tp) ||] liftArraysR :: ArraysR arrs -> Q (TExp (ArraysR arrs)) -liftArraysR ArraysRunit = [|| ArraysRunit ||] -liftArraysR ArraysRarray = [|| ArraysRarray ||] -liftArraysR (ArraysRpair a b) = [|| ArraysRpair $$(liftArraysR a) $$(liftArraysR b) ||] - -liftPreOpenFun - :: LiftAcc acc - -> PreOpenFun acc env aenv t - -> Q (TExp (PreOpenFun acc env aenv t)) -liftPreOpenFun liftA (Lam f) = [|| Lam $$(liftPreOpenFun liftA f) ||] -liftPreOpenFun liftA (Body b) = [|| Body $$(liftPreOpenExp liftA b) ||] - -liftPreOpenExp - :: forall acc env aenv t. - LiftAcc acc - -> PreOpenExp acc env aenv t - -> Q (TExp (PreOpenExp acc env aenv t)) -liftPreOpenExp liftA pexp = +liftArraysR TupRunit = [|| TupRunit ||] +liftArraysR (TupRsingle repr) = [|| TupRsingle $$(liftArrayR repr) ||] +liftArraysR (TupRpair a b) = [|| TupRpair $$(liftArraysR a) $$(liftArraysR b) ||] + +liftStencilR :: StencilR sh e pat -> Q (TExp (StencilR sh e pat)) +liftStencilR (StencilRunit3 tp) = [|| StencilRunit3 $$(liftTupleType tp) ||] +liftStencilR (StencilRunit5 tp) = [|| StencilRunit5 $$(liftTupleType tp) ||] +liftStencilR (StencilRunit7 tp) = [|| StencilRunit7 $$(liftTupleType tp) ||] +liftStencilR (StencilRunit9 tp) = [|| StencilRunit9 $$(liftTupleType tp) ||] +liftStencilR (StencilRtup3 s1 s2 s3) + = [|| StencilRtup3 $$(liftStencilR s1) $$(liftStencilR s2) $$(liftStencilR s3) ||] +liftStencilR (StencilRtup5 s1 s2 s3 s4 s5) + = [|| StencilRtup5 $$(liftStencilR s1) $$(liftStencilR s2) $$(liftStencilR s3) $$(liftStencilR s4) $$(liftStencilR s5) ||] +liftStencilR (StencilRtup7 s1 s2 s3 s4 s5 s6 s7) + = [|| StencilRtup7 $$(liftStencilR s1) $$(liftStencilR s2) $$(liftStencilR s3) $$(liftStencilR s4) $$(liftStencilR s5) + $$(liftStencilR s6) $$(liftStencilR s7) ||] +liftStencilR (StencilRtup9 s1 s2 s3 s4 s5 s6 s7 s8 s9) + = [|| StencilRtup9 $$(liftStencilR s1) $$(liftStencilR s2) $$(liftStencilR s3) $$(liftStencilR s4) $$(liftStencilR s5) + $$(liftStencilR s6) $$(liftStencilR s7) $$(liftStencilR s8) $$(liftStencilR s9) ||] + +liftOpenFun + :: OpenFun env aenv t + -> Q (TExp (OpenFun env aenv t)) +liftOpenFun (Lam lhs f) = [|| Lam $$(liftELhs lhs) $$(liftOpenFun f) ||] +liftOpenFun (Body b) = [|| Body $$(liftOpenExp b) ||] + +liftOpenExp + :: forall env aenv t. + OpenExp env aenv t + -> Q (TExp (OpenExp env aenv t)) +liftOpenExp pexp = let - liftE :: PreOpenExp acc env aenv e -> Q (TExp (PreOpenExp acc env aenv e)) - liftE = liftPreOpenExp liftA + liftE :: OpenExp env aenv e -> Q (TExp (OpenExp env aenv e)) + liftE = liftOpenExp - liftF :: PreOpenFun acc env aenv f -> Q (TExp (PreOpenFun acc env aenv f)) - liftF = liftPreOpenFun liftA - - liftT :: Tuple (PreOpenExp acc env aenv) e -> Q (TExp (Tuple (PreOpenExp acc env aenv) e)) - liftT NilTup = [|| NilTup ||] - liftT (SnocTup tup e) = [|| SnocTup $$(liftT tup) $$(liftE e) ||] + liftF :: OpenFun env aenv f -> Q (TExp (OpenFun env aenv f)) + liftF = liftOpenFun in case pexp of - Let bnd body -> [|| Let $$(liftPreOpenExp liftA bnd) $$(liftPreOpenExp liftA body) ||] - Var ix -> [|| Var $$(liftIdx ix) ||] - Foreign asm f x -> [|| Foreign $$(liftForeign asm) $$(liftPreOpenFun liftA f) $$(liftE x) ||] - Const c -> [|| Const $$(liftConst (eltType @t) c) ||] - Undef -> [|| Undef ||] - Tuple tup -> [|| Tuple $$(liftT tup) ||] - Prj tix e -> [|| Prj $$(liftTupleIdx tix) $$(liftE e) ||] - IndexNil -> [|| IndexNil ||] - IndexCons sh sz -> [|| IndexCons $$(liftE sh) $$(liftE sz) ||] - IndexHead sh -> [|| IndexHead $$(liftE sh) ||] - IndexTail sh -> [|| IndexTail $$(liftE sh) ||] - IndexAny -> [|| IndexAny ||] + Let lhs bnd body -> [|| Let $$(liftELhs lhs) $$(liftOpenExp bnd) $$(liftOpenExp body) ||] + Evar var -> [|| Evar $$(liftExpVar var) ||] + Foreign repr asm f x -> [|| Foreign $$(liftTupleType repr) $$(Sugar.liftForeign asm) $$(liftOpenFun f) $$(liftE x) ||] + Const tp c -> [|| Const $$(liftScalarType tp) $$(liftConst (TupRsingle tp) c) ||] + Undef tp -> [|| Undef $$(liftScalarType tp) ||] + Pair a b -> [|| Pair $$(liftE a) $$(liftE b) ||] + Nil -> [|| Nil ||] + VecPack vecr e -> [|| VecPack $$(liftVecR vecr) $$(liftE e) ||] + VecUnpack vecr e -> [|| VecUnpack $$(liftVecR vecr) $$(liftE e) ||] IndexSlice slice slix sh -> [|| IndexSlice $$(liftSliceIndex slice) $$(liftE slix) $$(liftE sh) ||] IndexFull slice slix sl -> [|| IndexFull $$(liftSliceIndex slice) $$(liftE slix) $$(liftE sl) ||] - ToIndex sh ix -> [|| ToIndex $$(liftE sh) $$(liftE ix) ||] - FromIndex sh ix -> [|| FromIndex $$(liftE sh) $$(liftE ix) ||] + ToIndex shr sh ix -> [|| ToIndex $$(liftShapeR shr) $$(liftE sh) $$(liftE ix) ||] + FromIndex shr sh ix -> [|| FromIndex $$(liftShapeR shr) $$(liftE sh) $$(liftE ix) ||] Cond p t e -> [|| Cond $$(liftE p) $$(liftE t) $$(liftE e) ||] While p f x -> [|| While $$(liftF p) $$(liftF f) $$(liftE x) ||] PrimConst t -> [|| PrimConst $$(liftPrimConst t) ||] PrimApp f x -> [|| PrimApp $$(liftPrimFun f) $$(liftE x) ||] - Index a ix -> [|| Index $$(liftA a) $$(liftE ix) ||] - LinearIndex a ix -> [|| LinearIndex $$(liftA a) $$(liftE ix) ||] - Shape a -> [|| Shape $$(liftA a) ||] - ShapeSize ix -> [|| ShapeSize $$(liftE ix) ||] - Intersect sh1 sh2 -> [|| Intersect $$(liftE sh1) $$(liftE sh2) ||] - Union sh1 sh2 -> [|| Union $$(liftE sh1) $$(liftE sh2) ||] - Coerce e -> [|| Coerce $$(liftE e) ||] - - -liftArray :: forall sh e. (Shape sh, Elt e) => Array sh e -> Q (TExp (Array sh e)) -liftArray (Array sh adata) = - [|| Array $$(liftConst (eltType @sh) sh) $$(go arrayElt adata) ||] `sigE` typeRepToType (typeOf (undefined::Array sh e)) + Index a ix -> [|| Index $$(liftArrayVar a) $$(liftE ix) ||] + LinearIndex a ix -> [|| LinearIndex $$(liftArrayVar a) $$(liftE ix) ||] + Shape a -> [|| Shape $$(liftArrayVar a) ||] + ShapeSize shr ix -> [|| ShapeSize $$(liftShapeR shr) $$(liftE ix) ||] + Coerce t1 t2 e -> [|| Coerce $$(liftScalarType t1) $$(liftScalarType t2) $$(liftE e) ||] + +liftExpVar :: ExpVar env t -> Q (TExp (ExpVar env t)) +liftExpVar (Var tp ix) = [|| Var $$(liftScalarType tp) $$(liftIdx ix) ||] + +liftArrayVar :: ArrayVar aenv a -> Q (TExp (ArrayVar aenv a)) +liftArrayVar (Var repr ix) = [|| Var $$(liftArrayR repr) $$(liftIdx ix) ||] + +liftArray :: forall sh e. ArrayR (Array sh e) -> Array sh e -> Q (TExp (Array sh e)) +liftArray (ArrayR shr tp) (Array sh adata) = + [|| Array $$(liftConst (shapeType shr) sh) $$(go tp adata) ||] `sigE` [t| Array $(typeToQType $ shapeType shr) $(typeToQType tp) |] where sz :: Int - sz = size sh + sz = size shr sh sigE :: Q (TExp t) -> Q TH.Type -> Q (TExp t) sigE e t = TH.unsafeTExpCoerce $ TH.sigE (TH.unTypeQ e) t - typeRepToType :: TypeRep -> Q TH.Type - typeRepToType trep = do - let (con, args) = splitTyConApp trep - name = TH.Name (TH.OccName (tyConName con)) (TH.NameG TH.TcClsName (TH.PkgName (tyConPackage con)) (TH.ModName (tyConModule con))) - -- - appsT x [] = x - appsT x (y:xs) = appsT (TH.AppT x y) xs - -- - resultArgs <- mapM typeRepToType args - return (appsT (TH.ConT name) resultArgs) + typeToQType :: TupleType t -> Q TH.Type + typeToQType TupRunit = [t| () |] + typeToQType (TupRpair t1 t2) = [t| ($(typeToQType t1), $(typeToQType t2)) |] + typeToQType (TupRsingle t) = scalarTypeToQType t + + scalarTypeToQType :: ScalarType t -> Q TH.Type + scalarTypeToQType (SingleScalarType t) = singleTypeToQType t + scalarTypeToQType (VectorScalarType t) = vectorTypeToQType t + + singleTypeToQType :: SingleType t -> Q TH.Type + singleTypeToQType (NumSingleType (IntegralNumType t)) = case t of + TypeInt -> [t| Int |] + TypeInt8 -> [t| Int8 |] + TypeInt16 -> [t| Int16 |] + TypeInt32 -> [t| Int32 |] + TypeInt64 -> [t| Int64 |] + TypeWord -> [t| Word |] + TypeWord8 -> [t| Word8 |] + TypeWord16 -> [t| Word16 |] + TypeWord32 -> [t| Word32 |] + TypeWord64 -> [t| Word64 |] + singleTypeToQType (NumSingleType (FloatingNumType t)) = case t of + TypeHalf -> [t| Half |] + TypeFloat -> [t| Float |] + TypeDouble -> [t| Double |] + singleTypeToQType (NonNumSingleType TypeBool) = [t| Bool |] + singleTypeToQType (NonNumSingleType TypeChar) = [t| Char |] + + vectorTypeToQType :: VectorType (Vec n a) -> Q TH.Type + vectorTypeToQType (VectorType _ stp) = [t| Vec $(undefined) $(singleTypeToQType stp) |] -- TODO: make sure that the resulting array is 16-byte aligned... arr :: forall a. Storable a => UniqueArray a -> Q (TExp (UniqueArray a)) @@ -1699,37 +1844,67 @@ liftArray (Array sh adata) = return ua' ||] - go :: ArrayEltR e' -> ArrayData e' -> Q (TExp (ArrayData e')) - go ArrayEltRunit AD_Unit = [|| AD_Unit ||] - go ArrayEltRint (AD_Int ua) = [|| AD_Int $$(arr ua) ||] - go ArrayEltRint8 (AD_Int8 ua) = [|| AD_Int8 $$(arr ua) ||] - go ArrayEltRint16 (AD_Int16 ua) = [|| AD_Int16 $$(arr ua) ||] - go ArrayEltRint32 (AD_Int32 ua) = [|| AD_Int32 $$(arr ua) ||] - go ArrayEltRint64 (AD_Int64 ua) = [|| AD_Int64 $$(arr ua) ||] - go ArrayEltRword (AD_Word ua) = [|| AD_Word $$(arr ua) ||] - go ArrayEltRword8 (AD_Word8 ua) = [|| AD_Word8 $$(arr ua) ||] - go ArrayEltRword16 (AD_Word16 ua) = [|| AD_Word16 $$(arr ua) ||] - go ArrayEltRword32 (AD_Word32 ua) = [|| AD_Word32 $$(arr ua) ||] - go ArrayEltRword64 (AD_Word64 ua) = [|| AD_Word64 $$(arr ua) ||] - go ArrayEltRhalf (AD_Half ua) = [|| AD_Half $$(arr ua) ||] - go ArrayEltRfloat (AD_Float ua) = [|| AD_Float $$(arr ua) ||] - go ArrayEltRdouble (AD_Double ua) = [|| AD_Double $$(arr ua) ||] - go ArrayEltRbool (AD_Bool ua) = [|| AD_Bool $$(arr ua) ||] - go ArrayEltRchar (AD_Char ua) = [|| AD_Char $$(arr ua) ||] - go (ArrayEltRpair r1 r2) (AD_Pair a1 a2) = [|| AD_Pair $$(go r1 a1) $$(go r2 a2) ||] - go (ArrayEltRvec r) (AD_Vec w# a) = TH.unsafeTExpCoerce $ [| AD_Vec $(liftInt# w#) $(TH.unTypeQ (go r a)) |] - + go :: TupleType e' -> ArrayData e' -> Q (TExp (ArrayData e')) + go TupRunit () = [|| () ||] + go (TupRpair t1 t2) (a1, a2) = [|| ($$(go t1 a1), $$(go t2 a2)) ||] + go (TupRsingle stp) a = goScalar stp a + + goScalar :: ScalarType e' -> ArrayData e' -> Q (TExp (ArrayData e')) + goScalar (SingleScalarType stp) a = goSingle stp a + goScalar (VectorScalarType (VectorType _ stp)) a = goVector stp a + + goSingle :: SingleType e' -> ArrayData e' -> Q (TExp (ArrayData e')) + goSingle (NumSingleType (IntegralNumType itp)) = case itp of + TypeInt -> arr + TypeInt8 -> arr + TypeInt16 -> arr + TypeInt32 -> arr + TypeInt64 -> arr + TypeWord -> arr + TypeWord8 -> arr + TypeWord16 -> arr + TypeWord32 -> arr + TypeWord64 -> arr + goSingle (NumSingleType (FloatingNumType ftp)) = case ftp of + TypeHalf -> arr + TypeFloat -> arr + TypeDouble -> arr + goSingle (NonNumSingleType TypeChar) = arr + goSingle (NonNumSingleType TypeBool) = arr + + -- This function has the same implementation as goSingle, but different types. + -- We could convince the type system to have this written as a single function, + -- as ArrayData uses a type family to create a structure of arrays, containing + -- scalars, where the scalars are again handled by a type family (ScalarDataRepr) + goVector :: SingleType e' -> ArrayData (Vec n e') -> Q (TExp (ArrayData (Vec n e'))) + goVector (NumSingleType (IntegralNumType itp)) = case itp of + TypeInt -> arr + TypeInt8 -> arr + TypeInt16 -> arr + TypeInt32 -> arr + TypeInt64 -> arr + TypeWord -> arr + TypeWord8 -> arr + TypeWord16 -> arr + TypeWord32 -> arr + TypeWord64 -> arr + goVector (NumSingleType (FloatingNumType ftp)) = case ftp of + TypeHalf -> arr + TypeFloat -> arr + TypeDouble -> arr + goVector (NonNumSingleType TypeChar) = arr + goVector (NonNumSingleType TypeBool) = arr liftBoundary - :: forall acc aenv sh e. - LiftAcc acc - -> PreBoundary acc aenv (Array sh e) - -> Q (TExp (PreBoundary acc aenv (Array sh e))) -liftBoundary _ Clamp = [|| Clamp ||] -liftBoundary _ Mirror = [|| Mirror ||] -liftBoundary _ Wrap = [|| Wrap ||] -liftBoundary _ (Constant v) = [|| Constant $$(liftConst (eltType @e) v) ||] -liftBoundary liftA (Function f) = [|| Function $$(liftPreOpenFun liftA f) ||] + :: forall aenv sh e. + ArrayR (Array sh e) + -> Boundary aenv (Array sh e) + -> Q (TExp (Boundary aenv (Array sh e))) +liftBoundary _ Clamp = [|| Clamp ||] +liftBoundary _ Mirror = [|| Mirror ||] +liftBoundary _ Wrap = [|| Wrap ||] +liftBoundary (ArrayR _ tp) (Constant v) = [|| Constant $$(liftConst tp v) ||] +liftBoundary _ (Function f) = [|| Function $$(liftOpenFun f) ||] liftSliceIndex :: SliceIndex ix slice coSlice sliceDim -> Q (TExp (SliceIndex ix slice coSlice sliceDim)) liftSliceIndex SliceNil = [|| SliceNil ||] @@ -1808,11 +1983,15 @@ liftPrimFun PrimBoolToInt = [|| PrimBoolToInt ||] liftPrimFun (PrimFromIntegral ta tb) = [|| PrimFromIntegral $$(liftIntegralType ta) $$(liftNumType tb) ||] liftPrimFun (PrimToFloating ta tb) = [|| PrimToFloating $$(liftNumType ta) $$(liftFloatingType tb) ||] +liftTupleType :: TupleType t -> Q (TExp (TupleType t)) +liftTupleType TupRunit = [|| TupRunit ||] +liftTupleType (TupRsingle t) = [|| TupRsingle $$(liftScalarType t) ||] +liftTupleType (TupRpair ta tb) = [|| TupRpair $$(liftTupleType ta) $$(liftTupleType tb) ||] liftConst :: TupleType t -> t -> Q (TExp t) -liftConst TypeRunit () = [|| () ||] -liftConst (TypeRscalar t) x = [|| $$(liftScalar t x) ||] -liftConst (TypeRpair ta tb) (a,b) = [|| ($$(liftConst ta a), $$(liftConst tb b)) ||] +liftConst TupRunit () = [|| () ||] +liftConst (TupRsingle t) x = [|| $$(liftScalar t x) ||] +liftConst (TupRpair ta tb) (a,b) = [|| ($$(liftConst ta a), $$(liftConst tb b)) ||] liftScalar :: ScalarType t -> t -> Q (TExp t) liftScalar (SingleScalarType t) x = liftSingle t x @@ -1825,6 +2004,10 @@ liftSingle (NonNumSingleType t) x = liftNonNum t x liftVector :: VectorType t -> t -> Q (TExp t) liftVector VectorType{} x = liftVec x +liftVecR :: VecR n single tuple -> Q (TExp (VecR n single tuple)) +liftVecR (VecRnil tp) = [|| VecRnil $$(liftSingleType tp) ||] +liftVecR (VecRsucc vec) = [|| VecRsucc $$(liftVecR vec) ||] + -- O(n) at runtime to copy from the Addr# to the ByteArray#. We should be able -- to do this without copying, but I don't think the definition of ByteArray# is -- exported (or it is deeply magical). @@ -1884,25 +2067,25 @@ liftFloating TypeDouble{} x = [|| x ||] liftIntegralType :: IntegralType t -> Q (TExp (IntegralType t)) -liftIntegralType TypeInt{} = [|| TypeInt IntegralDict ||] -liftIntegralType TypeInt8{} = [|| TypeInt8 IntegralDict ||] -liftIntegralType TypeInt16{} = [|| TypeInt16 IntegralDict ||] -liftIntegralType TypeInt32{} = [|| TypeInt32 IntegralDict ||] -liftIntegralType TypeInt64{} = [|| TypeInt64 IntegralDict ||] -liftIntegralType TypeWord{} = [|| TypeWord IntegralDict ||] -liftIntegralType TypeWord8{} = [|| TypeWord8 IntegralDict ||] -liftIntegralType TypeWord16{} = [|| TypeWord16 IntegralDict ||] -liftIntegralType TypeWord32{} = [|| TypeWord32 IntegralDict ||] -liftIntegralType TypeWord64{} = [|| TypeWord64 IntegralDict ||] +liftIntegralType TypeInt{} = [|| TypeInt ||] +liftIntegralType TypeInt8{} = [|| TypeInt8 ||] +liftIntegralType TypeInt16{} = [|| TypeInt16 ||] +liftIntegralType TypeInt32{} = [|| TypeInt32 ||] +liftIntegralType TypeInt64{} = [|| TypeInt64 ||] +liftIntegralType TypeWord{} = [|| TypeWord ||] +liftIntegralType TypeWord8{} = [|| TypeWord8 ||] +liftIntegralType TypeWord16{} = [|| TypeWord16 ||] +liftIntegralType TypeWord32{} = [|| TypeWord32 ||] +liftIntegralType TypeWord64{} = [|| TypeWord64 ||] liftFloatingType :: FloatingType t -> Q (TExp (FloatingType t)) -liftFloatingType TypeHalf{} = [|| TypeHalf FloatingDict ||] -liftFloatingType TypeFloat{} = [|| TypeFloat FloatingDict ||] -liftFloatingType TypeDouble{} = [|| TypeDouble FloatingDict ||] +liftFloatingType TypeHalf{} = [|| TypeHalf ||] +liftFloatingType TypeFloat{} = [|| TypeFloat ||] +liftFloatingType TypeDouble{} = [|| TypeDouble ||] liftNonNumType :: NonNumType t -> Q (TExp (NonNumType t)) -liftNonNumType TypeBool{} = [|| TypeBool NonNumDict ||] -liftNonNumType TypeChar{} = [|| TypeChar NonNumDict ||] +liftNonNumType TypeBool{} = [|| TypeBool ||] +liftNonNumType TypeChar{} = [|| TypeChar ||] liftNumType :: NumType t -> Q (TExp (NumType t)) liftNumType (IntegralNumType t) = [|| IntegralNumType $$(liftIntegralType t) ||] @@ -1912,16 +2095,16 @@ liftBoundedType :: BoundedType t -> Q (TExp (BoundedType t)) liftBoundedType (IntegralBoundedType t) = [|| IntegralBoundedType $$(liftIntegralType t) ||] liftBoundedType (NonNumBoundedType t) = [|| NonNumBoundedType $$(liftNonNumType t) ||] --- liftScalarType :: ScalarType t -> Q (TExp (ScalarType t)) --- liftScalarType (SingleScalarType t) = [|| SingleScalarType $$(liftSingleType t) ||] --- liftScalarType (VectorScalarType t) = [|| VectorScalarType $$(liftVectorType t) ||] +liftScalarType :: ScalarType t -> Q (TExp (ScalarType t)) +liftScalarType (SingleScalarType t) = [|| SingleScalarType $$(liftSingleType t) ||] +liftScalarType (VectorScalarType t) = [|| VectorScalarType $$(liftVectorType t) ||] liftSingleType :: SingleType t -> Q (TExp (SingleType t)) liftSingleType (NumSingleType t) = [|| NumSingleType $$(liftNumType t) ||] liftSingleType (NonNumSingleType t) = [|| NonNumSingleType $$(liftNonNumType t) ||] --- liftVectorType :: VectorType t -> Q (TExp (VectorType t)) --- liftVectorType (VectorType n t) = [|| VectorType n $$(liftSingleType t) ||] +liftVectorType :: VectorType t -> Q (TExp (VectorType t)) +liftVectorType (VectorType n t) = [|| VectorType n $$(liftSingleType t) ||] -- Debugging @@ -1929,8 +2112,8 @@ liftSingleType (NonNumSingleType t) = [|| NonNumSingleType $$(liftNonNumType t) showPreAccOp :: forall acc aenv arrs. PreOpenAcc acc aenv arrs -> String showPreAccOp Alet{} = "Alet" -showPreAccOp (Avar (ArrayVar ix)) = "Avar a" ++ show (idxToInt ix) -showPreAccOp (Use a) = "Use " ++ showShortendArr a +showPreAccOp (Avar (Var _ ix)) = "Avar a" ++ show (idxToInt ix) +showPreAccOp (Use repr a) = "Use " ++ showShortendArr repr a showPreAccOp Apply{} = "Apply" showPreAccOp Aforeign{} = "Aforeign" showPreAccOp Acond{} = "Acond" @@ -1962,40 +2145,36 @@ showPreAccOp Stencil2{} = "Stencil2" -- showPreAccOp Collect{} = "Collect" -showShortendArr :: (Shape sh, Elt e) => Array sh e -> String -showShortendArr arr - = show (take cutoff l) ++ if length l > cutoff then ".." else "" +showShortendArr :: ArrayR (Array sh e) -> Array sh e -> String +showShortendArr repr@(ArrayR _ tp) arr + | length l > cutoff = "[" ++ elements ++ ", ..]" + | otherwise = "[" ++ elements ++ "]" where - l = toList arr + l = toList repr arr cutoff = 5 - - -showPreExpOp :: forall acc env aenv t. PreOpenExp acc env aenv t -> String -showPreExpOp Let{} = "Let" -showPreExpOp (Var ix) = "Var x" ++ show (idxToInt ix) -showPreExpOp (Const c) = "Const " ++ show (toElt c :: t) -showPreExpOp Undef = "Undef" -showPreExpOp Foreign{} = "Foreign" -showPreExpOp Tuple{} = "Tuple" -showPreExpOp Prj{} = "Prj" -showPreExpOp IndexNil = "IndexNil" -showPreExpOp IndexCons{} = "IndexCons" -showPreExpOp IndexHead{} = "IndexHead" -showPreExpOp IndexTail{} = "IndexTail" -showPreExpOp IndexAny = "IndexAny" -showPreExpOp IndexSlice{} = "IndexSlice" -showPreExpOp IndexFull{} = "IndexFull" -showPreExpOp ToIndex{} = "ToIndex" -showPreExpOp FromIndex{} = "FromIndex" -showPreExpOp Cond{} = "Cond" -showPreExpOp While{} = "While" -showPreExpOp PrimConst{} = "PrimConst" -showPreExpOp PrimApp{} = "PrimApp" -showPreExpOp Index{} = "Index" -showPreExpOp LinearIndex{} = "LinearIndex" -showPreExpOp Shape{} = "Shape" -showPreExpOp ShapeSize{} = "ShapeSize" -showPreExpOp Intersect{} = "Intersect" -showPreExpOp Union{} = "Union" -showPreExpOp Coerce{} = "Coerce" - + elements = intercalate ", " $ map (showElement tp) $ take cutoff l + + +showPreExpOp :: forall aenv env t. OpenExp aenv env t -> String +showPreExpOp Let{} = "Let" +showPreExpOp (Evar (Var _ ix)) = "Var x" ++ show (idxToInt ix) +showPreExpOp (Const tp c) = "Const " ++ showElement (TupRsingle tp) c +showPreExpOp Undef{} = "Undef" +showPreExpOp Foreign{} = "Foreign" +showPreExpOp Pair{} = "Pair" +showPreExpOp Nil{} = "Nil" +showPreExpOp VecPack{} = "VecPack" +showPreExpOp VecUnpack{} = "VecUnpack" +showPreExpOp IndexSlice{} = "IndexSlice" +showPreExpOp IndexFull{} = "IndexFull" +showPreExpOp ToIndex{} = "ToIndex" +showPreExpOp FromIndex{} = "FromIndex" +showPreExpOp Cond{} = "Cond" +showPreExpOp While{} = "While" +showPreExpOp PrimConst{} = "PrimConst" +showPreExpOp PrimApp{} = "PrimApp" +showPreExpOp Index{} = "Index" +showPreExpOp LinearIndex{} = "LinearIndex" +showPreExpOp Shape{} = "Shape" +showPreExpOp ShapeSize{} = "ShapeSize" +showPreExpOp Coerce{} = "Coerce" diff --git a/src/Data/Array/Accelerate/Analysis/Hash.hs b/src/Data/Array/Accelerate/Analysis/Hash.hs index 9f994325e..f1eb6b5e1 100644 --- a/src/Data/Array/Accelerate/Analysis/Hash.hs +++ b/src/Data/Array/Accelerate/Analysis/Hash.hs @@ -23,14 +23,13 @@ module Data.Array.Accelerate.Analysis.Hash ( Hash, HashOptions(..), defaultHashOptions, hashPreOpenAcc, hashPreOpenAccWith, - hashPreOpenFun, hashPreOpenFunWith, - hashPreOpenExp, hashPreOpenExpWith, + hashOpenFun, hashOpenExp, -- auxiliary EncodeAcc, encodePreOpenAcc, - encodePreOpenExp, - encodePreOpenFun, + encodeOpenExp, + encodeOpenFun, encodeArraysType, hashQ, @@ -38,9 +37,7 @@ module Data.Array.Accelerate.Analysis.Hash ( import Data.Array.Accelerate.AST import Data.Array.Accelerate.Analysis.Hash.TH -import Data.Array.Accelerate.Array.Sugar -import Data.Array.Accelerate.Array.Representation ( SliceIndex(..) ) -import Data.Array.Accelerate.Product +import Data.Array.Accelerate.Array.Representation import Data.Array.Accelerate.Type import Crypto.Hash @@ -94,37 +91,29 @@ defaultHashOptions = HashOptions True {-# INLINEABLE hashPreOpenAcc #-} -hashPreOpenAcc :: EncodeAcc acc -> PreOpenAcc acc aenv a -> Hash +hashPreOpenAcc :: HasArraysRepr acc => EncodeAcc acc -> PreOpenAcc acc aenv a -> Hash hashPreOpenAcc = hashPreOpenAccWith defaultHashOptions -{-# INLINEABLE hashPreOpenFun #-} -hashPreOpenFun :: EncodeAcc acc -> PreOpenFun acc env aenv f -> Hash -hashPreOpenFun = hashPreOpenFunWith defaultHashOptions - -{-# INLINEABLE hashPreOpenExp #-} -hashPreOpenExp :: EncodeAcc acc -> PreOpenExp acc env aenv t -> Hash -hashPreOpenExp = hashPreOpenExpWith defaultHashOptions - {-# INLINEABLE hashPreOpenAccWith #-} -hashPreOpenAccWith :: HashOptions -> EncodeAcc acc -> PreOpenAcc acc aenv a -> Hash +hashPreOpenAccWith :: HasArraysRepr acc => HashOptions -> EncodeAcc acc -> PreOpenAcc acc aenv a -> Hash hashPreOpenAccWith options encodeAcc = hashlazy . toLazyByteString . encodePreOpenAcc options encodeAcc -{-# INLINEABLE hashPreOpenFunWith #-} -hashPreOpenFunWith :: HashOptions -> EncodeAcc acc -> PreOpenFun acc env aenv f -> Hash -hashPreOpenFunWith options encodeAcc +{-# INLINEABLE hashOpenFun #-} +hashOpenFun :: OpenFun env aenv f -> Hash +hashOpenFun = hashlazy . toLazyByteString - . encodePreOpenFun options encodeAcc + . encodeOpenFun -{-# INLINEABLE hashPreOpenExpWith #-} -hashPreOpenExpWith :: HashOptions -> EncodeAcc acc -> PreOpenExp acc env aenv t -> Hash -hashPreOpenExpWith options encodeAcc +{-# INLINEABLE hashOpenExp #-} +hashOpenExp :: OpenExp env aenv t -> Hash +hashOpenExp = hashlazy . toLazyByteString - . encodePreOpenExp options encodeAcc + . encodeOpenExp -- Array computations @@ -134,8 +123,8 @@ type EncodeAcc acc = forall aenv a. HashOptions -> acc aenv a -> Builder {-# INLINEABLE encodePreOpenAcc #-} encodePreOpenAcc - :: forall acc aenv arrs. - HashOptions + :: forall acc aenv arrs. HasArraysRepr acc + => HashOptions -> EncodeAcc acc -> PreOpenAcc acc aenv arrs -> Builder @@ -147,61 +136,55 @@ encodePreOpenAcc options encodeAcc pacc = travAF :: PreOpenAfun acc aenv' f -> Builder travAF = encodePreOpenAfun options encodeAcc - travE :: PreOpenExp acc env' aenv' e -> Builder - travE = encodePreOpenExp options encodeAcc + travE :: OpenExp env' aenv' e -> Builder + travE = encodeOpenExp - travF :: PreOpenFun acc env' aenv' f -> Builder - travF = encodePreOpenFun options encodeAcc - - travB :: PreBoundary acc aenv' (Array sh e) -> Builder - travB = encodePreBoundary options encodeAcc + travF :: OpenFun env' aenv' f -> Builder + travF = encodeOpenFun deep :: Builder -> Builder deep | perfect options = id | otherwise = const mempty - deepE :: forall env' aenv' e. Elt e => PreOpenExp acc env' aenv' e -> Builder + deepE :: forall env' aenv' e. OpenExp env' aenv' e -> Builder deepE e | perfect options = travE e - | otherwise = encodeTupleType (eltType @e) - - arrayHash :: (Shape sh, Elt e, arrs ~ Array sh e) => Builder - arrayHash = encodeArrayType @arrs + | otherwise = encodeTupleType $ expType e in case pacc of - Alet lhs bnd body -> intHost $(hashQ "Alet") <> encodeLeftHandSide lhs <> travA bnd <> travA body - Avar (ArrayVar v) -> intHost $(hashQ "Avar") <> arrayHash <> deep (encodeIdx v) - Apair a1 a2 -> intHost $(hashQ "Apair") <> travA a1 <> travA a2 - Anil -> intHost $(hashQ "Anil") - Apply f a -> intHost $(hashQ "Apply") <> travAF f <> travA a - Aforeign _ f a -> intHost $(hashQ "Aforeign") <> travAF f <> travA a - Use a -> intHost $(hashQ "Use") <> arrayHash <> deep (encodeArray a) - Awhile p f a -> intHost $(hashQ "Awhile") <> travAF f <> travAF p <> travA a - Unit e -> intHost $(hashQ "Unit") <> travE e - Generate e f -> intHost $(hashQ "Generate") <> deepE e <> travF f + Alet lhs bnd body -> intHost $(hashQ "Alet") <> encodeLeftHandSide encodeArrayType lhs <> travA bnd <> travA body + Avar (Var repr v) -> intHost $(hashQ "Avar") <> encodeArrayType repr <> deep (encodeIdx v) + Apair a1 a2 -> intHost $(hashQ "Apair") <> travA a1 <> travA a2 + Anil -> intHost $(hashQ "Anil") + Apply _ f a -> intHost $(hashQ "Apply") <> travAF f <> travA a + Aforeign _ _ f a -> intHost $(hashQ "Aforeign") <> travAF f <> travA a + Use repr a -> intHost $(hashQ "Use") <> encodeArrayType repr <> deep (encodeArray a) + Awhile p f a -> intHost $(hashQ "Awhile") <> travAF f <> travAF p <> travA a + Unit _ e -> intHost $(hashQ "Unit") <> travE e + Generate _ e f -> intHost $(hashQ "Generate") <> deepE e <> travF f -- We don't need to encode the type of 'e' when perfect is False, as 'e' is an expression of type Bool. -- We thus use `deep (travE e)` instead of `deepE e`. - Acond e a1 a2 -> intHost $(hashQ "Acond") <> deep (travE e) <> travA a1 <> travA a2 - Reshape sh a -> intHost $(hashQ "Reshape") <> deepE sh <> travA a - Backpermute sh f a -> intHost $(hashQ "Backpermute") <> deepE sh <> travF f <> travA a - Transform sh f1 f2 a -> intHost $(hashQ "Transform") <> deepE sh <> travF f1 <> travF f2 <> travA a - Replicate spec ix a -> intHost $(hashQ "Replicate") <> deepE ix <> travA a <> encodeSliceIndex spec - Slice spec a ix -> intHost $(hashQ "Slice") <> deepE ix <> travA a <> encodeSliceIndex spec - Map f a -> intHost $(hashQ "Map") <> travF f <> travA a - ZipWith f a1 a2 -> intHost $(hashQ "ZipWith") <> travF f <> travA a1 <> travA a2 - Fold f e a -> intHost $(hashQ "Fold") <> travF f <> travE e <> travA a - Fold1 f a -> intHost $(hashQ "Fold1") <> travF f <> travA a - FoldSeg f e a s -> intHost $(hashQ "FoldSeg") <> travF f <> travE e <> travA a <> travA s - Fold1Seg f a s -> intHost $(hashQ "Fold1Seg") <> travF f <> travA a <> travA s - Scanl f e a -> intHost $(hashQ "Scanl") <> travF f <> travE e <> travA a - Scanl' f e a -> intHost $(hashQ "Scanl'") <> travF f <> travE e <> travA a - Scanl1 f a -> intHost $(hashQ "Scanl1") <> travF f <> travA a - Scanr f e a -> intHost $(hashQ "Scanr") <> travF f <> travE e <> travA a - Scanr' f e a -> intHost $(hashQ "Scanr'") <> travF f <> travE e <> travA a - Scanr1 f a -> intHost $(hashQ "Scanr1") <> travF f <> travA a - Permute f1 a1 f2 a2 -> intHost $(hashQ "Permute") <> travF f1 <> travA a1 <> travF f2 <> travA a2 - Stencil f b a -> intHost $(hashQ "Stencil") <> travF f <> travB b <> travA a - Stencil2 f b1 a1 b2 a2 -> intHost $(hashQ "Stencil2") <> travF f <> travB b1 <> travA a1 <> travB b2 <> travA a2 + Acond e a1 a2 -> intHost $(hashQ "Acond") <> deep (travE e) <> travA a1 <> travA a2 + Reshape _ sh a -> intHost $(hashQ "Reshape") <> deepE sh <> travA a + Backpermute _ sh f a -> intHost $(hashQ "Backpermute") <> deepE sh <> travF f <> travA a + Transform _ sh f1 f2 a -> intHost $(hashQ "Transform") <> deepE sh <> travF f1 <> travF f2 <> travA a + Replicate spec ix a -> intHost $(hashQ "Replicate") <> deepE ix <> travA a <> encodeSliceIndex spec + Slice spec a ix -> intHost $(hashQ "Slice") <> deepE ix <> travA a <> encodeSliceIndex spec + Map _ f a -> intHost $(hashQ "Map") <> travF f <> travA a + ZipWith _ f a1 a2 -> intHost $(hashQ "ZipWith") <> travF f <> travA a1 <> travA a2 + Fold f e a -> intHost $(hashQ "Fold") <> travF f <> travE e <> travA a + Fold1 f a -> intHost $(hashQ "Fold1") <> travF f <> travA a + FoldSeg _ f e a s -> intHost $(hashQ "FoldSeg") <> travF f <> travE e <> travA a <> travA s + Fold1Seg _ f a s -> intHost $(hashQ "Fold1Seg") <> travF f <> travA a <> travA s + Scanl f e a -> intHost $(hashQ "Scanl") <> travF f <> travE e <> travA a + Scanl' f e a -> intHost $(hashQ "Scanl'") <> travF f <> travE e <> travA a + Scanl1 f a -> intHost $(hashQ "Scanl1") <> travF f <> travA a + Scanr f e a -> intHost $(hashQ "Scanr") <> travF f <> travE e <> travA a + Scanr' f e a -> intHost $(hashQ "Scanr'") <> travF f <> travE e <> travA a + Scanr1 f a -> intHost $(hashQ "Scanr1") <> travF f <> travA a + Permute f1 a1 f2 a2 -> intHost $(hashQ "Permute") <> travF f1 <> travA a1 <> travF f2 <> travA a2 + Stencil s _ f b a -> intHost $(hashQ "Stencil") <> travF f <> encodeBoundary (stencilElt s) b <> travA a + Stencil2 s1 s2 _ f b1 a1 b2 a2 -> intHost $(hashQ "Stencil2") <> travF f <> encodeBoundary (stencilElt s1) b1 <> travA a1 <> encodeBoundary (stencilElt s2) b2 <> travA a2 {-- {-# INLINEABLE encodePreOpenSeq #-} @@ -211,14 +194,14 @@ encodePreOpenSeq encodeAcc s = travA :: acc aenv' a -> Builder travA = encodeAcc -- XXX: plus type information? - travE :: PreOpenExp acc env' aenv' e -> Builder - travE = encodePreOpenExp encodeAcc + travE :: OpenExp env' aenv' e -> Builder + travE = encodeOpenExp encodeAcc travAF :: PreOpenAfun acc aenv' f -> Builder travAF = encodePreOpenAfun encodeAcc - travF :: PreOpenFun acc env' aenv' f -> Builder - travF = encodePreOpenFun encodeAcc + travF :: OpenFun env' aenv' f -> Builder + travF = encodeOpenFun encodeAcc travS :: PreOpenSeq acc aenv senv' arrs' -> Builder travS = encodePreOpenSeq encodeAcc @@ -252,24 +235,27 @@ encodePreOpenSeq encodeAcc s = encodeIdx :: Idx env t -> Builder encodeIdx = intHost . idxToInt -encodeTupleIdx :: TupleIdx tup e -> Builder -encodeTupleIdx = intHost . tupleIdxToInt - -encodeArray :: (Shape sh, Elt e) => Array sh e -> Builder +encodeArray :: Array sh e -> Builder encodeArray ad = intHost . unsafePerformIO $! hashStableName <$> makeStableName ad -encodeArraysType :: forall a. ArraysR a -> Builder -encodeArraysType ArraysRunit = intHost $(hashQ "ArraysRunit") -encodeArraysType (ArraysRpair r1 r2) = intHost $(hashQ "ArraysRpair") <> encodeArraysType r1 <> encodeArraysType r2 -encodeArraysType ArraysRarray = intHost $(hashQ "ArraysRarray") <> encodeArrayType @a +encodeTupR :: (forall b. s b -> Builder) -> TupR s a -> Builder +encodeTupR _ TupRunit = intHost $(hashQ "TupRunit") +encodeTupR f (TupRpair r1 r2) = intHost $(hashQ "TupRpair") <> encodeTupR f r1 <> encodeTupR f r2 +encodeTupR f (TupRsingle s) = intHost $(hashQ "TupRsingle") <> f s -encodeLeftHandSide :: forall a env env'. LeftHandSide a env env' -> Builder -encodeLeftHandSide (LeftHandSideWildcard r) = intHost $(hashQ "LeftHandSideWildcard") <> encodeArraysType r -encodeLeftHandSide (LeftHandSidePair r1 r2) = intHost $(hashQ "LeftHandSidePair") <> encodeLeftHandSide r1 <> encodeLeftHandSide r2 -encodeLeftHandSide LeftHandSideArray = intHost $(hashQ "LeftHandSideArray") <> encodeArrayType @a +encodeLeftHandSide :: (forall b. s b -> Builder) -> LeftHandSide s a env env' -> Builder +encodeLeftHandSide f (LeftHandSideWildcard r) = intHost $(hashQ "LeftHandSideWildcard") <> encodeTupR f r +encodeLeftHandSide f (LeftHandSidePair r1 r2) = intHost $(hashQ "LeftHandSidePair") <> encodeLeftHandSide f r1 <> encodeLeftHandSide f r2 +encodeLeftHandSide f (LeftHandSideSingle s) = intHost $(hashQ "LeftHandSideArray") <> f s -encodeArrayType :: forall array sh e. (array ~ Array sh e, Shape sh, Elt e) => Builder -encodeArrayType = encodeTupleType (eltType @sh) <> encodeTupleType (eltType @e) +encodeArrayType :: ArrayR a -> Builder +encodeArrayType (ArrayR shr tp) = encodeShapeR shr <> encodeTupleType tp + +encodeArraysType :: ArraysR arrs -> Builder +encodeArraysType = encodeTupR encodeArrayType + +encodeShapeR :: ShapeR sh -> Builder +encodeShapeR = intHost . rank encodePreOpenAfun :: forall acc aenv f. @@ -279,25 +265,23 @@ encodePreOpenAfun -> Builder encodePreOpenAfun options travA afun = let - travL :: forall aenv1 aenv2 a b. LeftHandSide a aenv1 aenv2 -> PreOpenAfun acc aenv2 b -> Builder - travL lhs l = encodeLeftHandSide lhs <> encodePreOpenAfun options travA l + travL :: forall aenv1 aenv2 a b. ALeftHandSide a aenv1 aenv2 -> PreOpenAfun acc aenv2 b -> Builder + travL lhs l = encodeLeftHandSide encodeArrayType lhs <> encodePreOpenAfun options travA l in case afun of Abody b -> intHost $(hashQ "Abody") <> travA options b Alam lhs l -> intHost $(hashQ "Alam") <> travL lhs l -encodePreBoundary - :: forall acc aenv sh e. - HashOptions - -> EncodeAcc acc - -> PreBoundary acc aenv (Array sh e) +encodeBoundary + :: TupleType e + -> Boundary aenv (Array sh e) -> Builder -encodePreBoundary _ _ Wrap = intHost $(hashQ "Wrap") -encodePreBoundary _ _ Clamp = intHost $(hashQ "Clamp") -encodePreBoundary _ _ Mirror = intHost $(hashQ "Mirror") -encodePreBoundary _ _ (Constant c) = intHost $(hashQ "Constant") <> encodeConst (eltType @e) c -encodePreBoundary o h (Function f) = intHost $(hashQ "Function") <> encodePreOpenFun o h f +encodeBoundary _ Wrap = intHost $(hashQ "Wrap") +encodeBoundary _ Clamp = intHost $(hashQ "Clamp") +encodeBoundary _ Mirror = intHost $(hashQ "Mirror") +encodeBoundary tp (Constant c) = intHost $(hashQ "Constant") <> encodeConst tp c +encodeBoundary _ (Function f) = intHost $(hashQ "Function") <> encodeOpenFun f encodeSliceIndex :: SliceIndex slix sl co sh -> Builder encodeSliceIndex SliceNil = intHost $(hashQ "SliceNil") @@ -308,97 +292,58 @@ encodeSliceIndex (SliceFixed r) = intHost $(hashQ "sliceFixed") <> encodeSlice -- Scalar expressions -- ------------------ -{-# INLINEABLE encodePreOpenExp #-} -encodePreOpenExp - :: forall acc env aenv exp. - HashOptions - -> EncodeAcc acc - -> PreOpenExp acc env aenv exp +{-# INLINEABLE encodeOpenExp #-} +encodeOpenExp + :: forall env aenv exp. + OpenExp env aenv exp -> Builder -encodePreOpenExp options encodeAcc exp = +encodeOpenExp exp = let - -- XXX: Temporary fix for hashing expressions which only depend on - -- free array variables. For the code generating backends it will - -- never pick up expressions which differ only at free array - -- variables. We know that this will always be an Avar (we depend on - -- array expressions being floated out already) so we should change - -- this in the AST. This problem occurred in the Quickhull program. - -- -- TLM 2020-01-08 - -- - travA :: forall aenv' a. Arrays a => acc aenv' a -> Builder - travA a = encodeArraysType (arrays @a) <> encodeAcc (options {perfect=True}) a - - travE :: forall env' aenv' e. Elt e => PreOpenExp acc env' aenv' e -> Builder - travE e = encodeTupleType (eltType @e) <> encodePreOpenExp options encodeAcc e - - travF :: PreOpenFun acc env' aenv' f -> Builder - travF = encodePreOpenFun options encodeAcc - - nacl :: Elt exp => Builder - nacl = encodeTupleType (eltType @exp) + travE :: forall env' aenv' e. OpenExp env' aenv' e -> Builder + travE e = encodeOpenExp e + + travF :: OpenFun env' aenv' f -> Builder + travF = encodeOpenFun in case exp of - Let bnd body -> intHost $(hashQ "Let") <> travE bnd <> travE body - Var ix -> intHost $(hashQ "Var") <> nacl <> encodeIdx ix - Tuple t -> intHost $(hashQ "Tuple") <> nacl <> encodeTuple options encodeAcc t - Prj i e -> intHost $(hashQ "Prj") <> nacl <> encodeTupleIdx i <> travE e -- XXX: here multiplied nacl by hashTupleIdx - Const c -> intHost $(hashQ "Const") <> encodeConst (eltType @exp) c - Undef -> intHost $(hashQ "Undef") - IndexAny -> intHost $(hashQ "IndexAny") <> nacl - IndexNil -> intHost $(hashQ "IndexNil") - IndexCons sh sz -> intHost $(hashQ "IndexCons") <> travE sh <> travE sz - IndexHead sl -> intHost $(hashQ "IndexHead") <> travE sl - IndexTail sl -> intHost $(hashQ "IndexTail") <> travE sl + Let lhs bnd body -> intHost $(hashQ "Let") <> encodeLeftHandSide encodeScalarType lhs <> travE bnd <> travE body + Evar (Var tp ix) -> intHost $(hashQ "Evar") <> encodeScalarType tp <> encodeIdx ix + Nil -> intHost $(hashQ "Nil") + Pair e1 e2 -> intHost $(hashQ "Pair") <> travE e1 <> travE e2 + VecPack _ e -> intHost $(hashQ "VecPack") <> travE e + VecUnpack _ e -> intHost $(hashQ "VecUnpack") <> travE e + Const tp c -> intHost $(hashQ "Const") <> encodeScalarConst tp c + Undef tp -> intHost $(hashQ "Undef") <> encodeScalarType tp IndexSlice spec ix sh -> intHost $(hashQ "IndexSlice") <> travE ix <> travE sh <> encodeSliceIndex spec IndexFull spec ix sl -> intHost $(hashQ "IndexFull") <> travE ix <> travE sl <> encodeSliceIndex spec - ToIndex sh i -> intHost $(hashQ "ToIndex") <> travE sh <> travE i - FromIndex sh i -> intHost $(hashQ "FromIndex") <> travE sh <> travE i + ToIndex _ sh i -> intHost $(hashQ "ToIndex") <> travE sh <> travE i + FromIndex _ sh i -> intHost $(hashQ "FromIndex") <> travE sh <> travE i Cond c t e -> intHost $(hashQ "Cond") <> travE c <> travE t <> travE e While p f x -> intHost $(hashQ "While") <> travF p <> travF f <> travE x PrimApp f x -> intHost $(hashQ "PrimApp") <> encodePrimFun f <> travE x PrimConst c -> intHost $(hashQ "PrimConst") <> encodePrimConst c - Index a ix -> intHost $(hashQ "Index") <> travA a <> travE ix - LinearIndex a ix -> intHost $(hashQ "LinearIndex") <> travA a <> travE ix - Shape a -> intHost $(hashQ "Shape") <> travA a - ShapeSize sh -> intHost $(hashQ "ShapeSize") <> travE sh - Intersect sa sb -> intHost $(hashQ "Intersect") <> travE sa <> travE sb - Union sa sb -> intHost $(hashQ "Union") <> travE sa <> travE sb - Foreign _ f e -> intHost $(hashQ "Foreign") <> travF f <> travE e - Coerce e -> intHost $(hashQ "Coerce") <> travE e - - -{-# INLINEABLE encodePreOpenFun #-} -encodePreOpenFun - :: forall acc env aenv f. - HashOptions - -> EncodeAcc acc - -> PreOpenFun acc env aenv f - -> Builder -encodePreOpenFun options travA fun = - let - travB :: forall env' aenv' e. Elt e => PreOpenExp acc env' aenv' e -> Builder - travB b = encodeTupleType (eltType @e) <> encodePreOpenExp options travA b - - travL :: forall env' aenv' a b. Elt a => PreOpenFun acc (env',a) aenv' b -> Builder - travL l = encodeTupleType (eltType @a) <> encodePreOpenFun options travA l - in - case fun of - Body b -> intHost $(hashQ "Body") <> travB b - Lam l -> intHost $(hashQ "Lam") <> travL l - -encodeTuple - :: HashOptions - -> EncodeAcc acc - -> Tuple (PreOpenExp acc env aenv) e + Index a ix -> intHost $(hashQ "Index") <> encodeArrayVar a <> travE ix + LinearIndex a ix -> intHost $(hashQ "LinearIndex") <> encodeArrayVar a <> travE ix + Shape a -> intHost $(hashQ "Shape") <> encodeArrayVar a + ShapeSize _ sh -> intHost $(hashQ "ShapeSize") <> travE sh + Foreign _ _ f e -> intHost $(hashQ "Foreign") <> travF f <> travE e + Coerce _ tp e -> intHost $(hashQ "Coerce") <> encodeScalarType tp <> travE e + +encodeArrayVar :: ArrayVar aenv a -> Builder +encodeArrayVar (Var repr v) = encodeArrayType repr <> encodeIdx v + +{-# INLINEABLE encodeOpenFun #-} +encodeOpenFun + :: OpenFun env aenv f -> Builder -encodeTuple _ _ NilTup = intHost $(hashQ "NilTup") -encodeTuple o h (SnocTup t e) = intHost $(hashQ "SnocTup") <> encodeTuple o h t <> encodePreOpenExp o h e +encodeOpenFun (Body b) = intHost $(hashQ "Body") <> encodeOpenExp b +encodeOpenFun (Lam lhs l) = intHost $(hashQ "Lam") <> encodeLeftHandSide encodeScalarType lhs <> encodeOpenFun l encodeConst :: TupleType t -> t -> Builder -encodeConst TypeRunit () = mempty -encodeConst (TypeRscalar t) c = encodeScalarConst t c -encodeConst (TypeRpair ta tb) (a,b) = encodeConst ta a <> encodeConst tb b +encodeConst TupRunit () = intHost $(hashQ "nil") +encodeConst (TupRsingle t) c = encodeScalarConst t c +encodeConst (TupRpair ta tb) (a,b) = intHost $(hashQ "pair") <> encodeConst ta a <> encodeConst tb b encodeScalarConst :: ScalarType t -> t -> Builder encodeScalarConst (SingleScalarType t) = encodeSingleConst t @@ -514,15 +459,15 @@ encodePrimFun PrimBoolToInt = intHost $(hashQ "PrimBoolToInt") encodeTupleType :: TupleType t -> Builder -encodeTupleType TypeRunit = intHost $(hashQ "TypeRunit") -encodeTupleType (TypeRscalar t) = intHost $(hashQ "TypeRscalar") <> encodeScalarType t -encodeTupleType (TypeRpair a b) = intHost $(hashQ "TypeRpair") <> encodeTupleType a <> intHost (depthTypeR a) - <> encodeTupleType b <> intHost (depthTypeR b) +encodeTupleType TupRunit = intHost $(hashQ "TupRunit") +encodeTupleType (TupRsingle t) = intHost $(hashQ "TupRsingle") <> encodeScalarType t +encodeTupleType (TupRpair a b) = intHost $(hashQ "TupRpair") <> encodeTupleType a <> intHost (depthTypeR a) + <> encodeTupleType b <> intHost (depthTypeR b) depthTypeR :: TupleType t -> Int -depthTypeR TypeRunit = 0 -depthTypeR TypeRscalar{} = 1 -depthTypeR (TypeRpair a b) = depthTypeR a + depthTypeR b +depthTypeR TupRunit = 0 +depthTypeR TupRsingle{} = 1 +depthTypeR (TupRpair a b) = depthTypeR a + depthTypeR b encodeScalarType :: ScalarType t -> Builder encodeScalarType (SingleScalarType t) = intHost $(hashQ "SingleScalarType") <> encodeSingleType t diff --git a/src/Data/Array/Accelerate/Analysis/Match.hs b/src/Data/Array/Accelerate/Analysis/Match.hs index ecbae8dde..1779504e3 100644 --- a/src/Data/Array/Accelerate/Analysis/Match.hs +++ b/src/Data/Array/Accelerate/Analysis/Match.hs @@ -24,14 +24,14 @@ module Data.Array.Accelerate.Analysis.Match ( (:~:)(..), matchPreOpenAcc, matchPreOpenAfun, - matchPreOpenExp, - matchPreOpenFun, + matchOpenExp, + matchOpenFun, matchPrimFun, matchPrimFun', -- auxiliary - matchIdx, matchArrayVar, matchArrayVars, matchTupleType, matchShapeType, - matchIntegralType, matchFloatingType, matchNumType, matchScalarType, - matchLeftHandSide, matchLeftHandSide', + matchIdx, matchVar, matchVars, matchArrayR, matchArraysR, matchTupleType, matchShapeR, + matchShapeType, matchIntegralType, matchFloatingType, matchNumType, matchScalarType, + matchLeftHandSide, matchALeftHandSide, matchELeftHandSide, matchSingleType, matchTupR ) where @@ -45,10 +45,9 @@ import Prelude hiding ( exp ) -- friends import Data.Array.Accelerate.Analysis.Hash -import Data.Array.Accelerate.Array.Representation ( SliceIndex(..) ) -import Data.Array.Accelerate.Array.Sugar +import Data.Array.Accelerate.Array.Representation +import qualified Data.Array.Accelerate.Array.Sugar as Sugar import Data.Array.Accelerate.AST -import Data.Array.Accelerate.Product import Data.Array.Accelerate.Type @@ -62,29 +61,28 @@ type MatchAcc acc = forall aenv s t. acc aenv s -> acc aenv t -> Maybe (s :~: t) -- {-# INLINEABLE matchPreOpenAcc #-} matchPreOpenAcc - :: forall acc aenv s t. - MatchAcc acc - -> EncodeAcc acc + :: forall acc aenv s t. HasArraysRepr acc + => MatchAcc acc -> PreOpenAcc acc aenv s -> PreOpenAcc acc aenv t -> Maybe (s :~: t) -matchPreOpenAcc matchAcc encodeAcc = match +matchPreOpenAcc matchAcc = match where - matchFun :: PreOpenFun acc env' aenv' u -> PreOpenFun acc env' aenv' v -> Maybe (u :~: v) - matchFun = matchPreOpenFun matchAcc encodeAcc + matchFun :: OpenFun env' aenv' u -> OpenFun env' aenv' v -> Maybe (u :~: v) + matchFun = matchOpenFun - matchExp :: PreOpenExp acc env' aenv' u -> PreOpenExp acc env' aenv' v -> Maybe (u :~: v) - matchExp = matchPreOpenExp matchAcc encodeAcc + matchExp :: OpenExp env' aenv' u -> OpenExp env' aenv' v -> Maybe (u :~: v) + matchExp = matchOpenExp match :: PreOpenAcc acc aenv s -> PreOpenAcc acc aenv t -> Maybe (s :~: t) match (Alet lhs1 x1 a1) (Alet lhs2 x2 a2) - | Just Refl <- matchLeftHandSide lhs1 lhs2 + | Just Refl <- matchALeftHandSide lhs1 lhs2 , Just Refl <- matchAcc x1 x2 , Just Refl <- matchAcc a1 a2 = Just Refl - match (Avar (ArrayVar v1)) (Avar (ArrayVar v2)) - = matchIdx v1 v2 + match (Avar v1) (Avar v2) + = matchVar v1 v2 match (Apair a1 a2) (Apair b1 b2) | Just Refl <- matchAcc a1 b1 @@ -94,18 +92,19 @@ matchPreOpenAcc matchAcc encodeAcc = match match Anil Anil = Just Refl - match (Apply f1 a1) (Apply f2 a2) + match (Apply _ f1 a1) (Apply _ f2 a2) | Just Refl <- matchPreOpenAfun matchAcc f1 f2 , Just Refl <- matchAcc a1 a2 = Just Refl - match (Aforeign ff1 _ a1) (Aforeign ff2 _ a2) + match (Aforeign _ ff1 f1 a1) (Aforeign _ ff2 f2 a2) | Just Refl <- matchAcc a1 a2 , unsafePerformIO $ do sn1 <- makeStableName ff1 sn2 <- makeStableName ff2 return $! hashStableName sn1 == hashStableName sn2 - = gcast Refl + , Just Refl <- matchPreOpenAfun matchAcc f1 f2 + = Just Refl match (Acond p1 t1 e1) (Acond p2 t2 e2) | Just Refl <- matchExp p1 p2 @@ -119,47 +118,50 @@ matchPreOpenAcc matchAcc encodeAcc = match , Just Refl <- matchPreOpenAfun matchAcc f1 f2 = Just Refl - match (Use a1) (Use a2) - | Just Refl <- matchArray a1 a2 + match (Use repr1 a1) (Use repr2 a2) + | Just Refl <- matchArray repr1 repr2 a1 a2 = Just Refl - match (Unit e1) (Unit e2) - | Just Refl <- matchExp e1 e2 + match (Unit t1 e1) (Unit t2 e2) + | Just Refl <- matchTupleType t1 t2 + , Just Refl <- matchExp e1 e2 = Just Refl - match (Reshape sh1 a1) (Reshape sh2 a2) + match (Reshape _ sh1 a1) (Reshape _ sh2 a2) | Just Refl <- matchExp sh1 sh2 , Just Refl <- matchAcc a1 a2 = Just Refl - match (Generate sh1 f1) (Generate sh2 f2) + match (Generate _ sh1 f1) (Generate _ sh2 f2) | Just Refl <- matchExp sh1 sh2 , Just Refl <- matchFun f1 f2 = Just Refl - match (Transform sh1 ix1 f1 a1) (Transform sh2 ix2 f2 a2) + match (Transform _ sh1 ix1 f1 a1) (Transform _ sh2 ix2 f2 a2) | Just Refl <- matchExp sh1 sh2 , Just Refl <- matchFun ix1 ix2 , Just Refl <- matchFun f1 f2 , Just Refl <- matchAcc a1 a2 = Just Refl - match (Replicate _ ix1 a1) (Replicate _ ix2 a2) - | Just Refl <- matchExp ix1 ix2 + match (Replicate si1 ix1 a1) (Replicate si2 ix2 a2) + | Just Refl <- matchSliceIndex si1 si2 + , Just Refl <- matchExp ix1 ix2 , Just Refl <- matchAcc a1 a2 - = gcast Refl -- slice specification ?? + = Just Refl - match (Slice _ a1 ix1) (Slice _ a2 ix2) - | Just Refl <- matchAcc a1 a2 + match (Slice si1 a1 ix1) (Slice si2 a2 ix2) + | Just Refl <- matchSliceIndex si1 si2 + , Just Refl <- matchAcc a1 a2 , Just Refl <- matchExp ix1 ix2 - = gcast Refl -- slice specification ?? + = Just Refl - match (Map f1 a1) (Map f2 a2) + match (Map _ f1 a1) (Map _ f2 a2) | Just Refl <- matchFun f1 f2 , Just Refl <- matchAcc a1 a2 = Just Refl - match (ZipWith f1 a1 b1) (ZipWith f2 a2 b2) + match (ZipWith _ f1 a1 b1) (ZipWith _ f2 a2 b2) | Just Refl <- matchFun f1 f2 , Just Refl <- matchAcc a1 a2 , Just Refl <- matchAcc b1 b2 @@ -176,14 +178,14 @@ matchPreOpenAcc matchAcc encodeAcc = match , Just Refl <- matchAcc a1 a2 = Just Refl - match (FoldSeg f1 z1 a1 s1) (FoldSeg f2 z2 a2 s2) + match (FoldSeg _ f1 z1 a1 s1) (FoldSeg _ f2 z2 a2 s2) | Just Refl <- matchFun f1 f2 , Just Refl <- matchExp z1 z2 , Just Refl <- matchAcc a1 a2 , Just Refl <- matchAcc s1 s2 = Just Refl - match (Fold1Seg f1 a1 s1) (Fold1Seg f2 a2 s2) + match (Fold1Seg _ f1 a1 s1) (Fold1Seg _ f2 a2 s2) | Just Refl <- matchFun f1 f2 , Just Refl <- matchAcc a1 a2 , Just Refl <- matchAcc s1 s2 @@ -230,24 +232,24 @@ matchPreOpenAcc matchAcc encodeAcc = match , Just Refl <- matchAcc a1 a2 = Just Refl - match (Backpermute sh1 ix1 a1) (Backpermute sh2 ix2 a2) + match (Backpermute _ sh1 ix1 a1) (Backpermute _ sh2 ix2 a2) | Just Refl <- matchExp sh1 sh2 , Just Refl <- matchFun ix1 ix2 , Just Refl <- matchAcc a1 a2 = Just Refl - match (Stencil f1 b1 a1) (Stencil f2 b2 a2) + match (Stencil s1 _ f1 b1 a1) (Stencil _ _ f2 b2 a2) | Just Refl <- matchFun f1 f2 , Just Refl <- matchAcc a1 a2 - , matchBoundary matchAcc encodeAcc b1 b2 + , matchBoundary (stencilElt s1) b1 b2 = Just Refl - match (Stencil2 f1 b1 a1 b2 a2) (Stencil2 f2 b1' a1' b2' a2') + match (Stencil2 s1 s2 _ f1 b1 a1 b2 a2) (Stencil2 _ _ _ f2 b1' a1' b2' a2') | Just Refl <- matchFun f1 f2 , Just Refl <- matchAcc a1 a1' , Just Refl <- matchAcc a2 a2' - , matchBoundary matchAcc encodeAcc b1 b1' - , matchBoundary matchAcc encodeAcc b2 b2' + , matchBoundary (stencilElt s1) b1 b1' + , matchBoundary (stencilElt s2) b2 b2' = Just Refl -- match (Collect s1) (Collect s2) @@ -265,56 +267,47 @@ matchPreOpenAfun -> PreOpenAfun acc aenv t -> Maybe (s :~: t) matchPreOpenAfun m (Alam lhs1 s) (Alam lhs2 t) - | Just Refl <- matchLeftHandSide lhs1 lhs2 + | Just Refl <- matchALeftHandSide lhs1 lhs2 , Just Refl <- matchPreOpenAfun m s t = Just Refl matchPreOpenAfun m (Abody s) (Abody t) = m s t matchPreOpenAfun _ _ _ = Nothing -matchLeftHandSide :: forall aenv aenv1 aenv2 arr1 arr2. LeftHandSide arr1 aenv aenv1 -> LeftHandSide arr2 aenv aenv2 -> Maybe (LeftHandSide arr1 aenv aenv1 :~: LeftHandSide arr2 aenv aenv2) -matchLeftHandSide (LeftHandSideWildcard repr1) (LeftHandSideWildcard repr2) - | Just Refl <- matchArraysR repr1 repr2 - = Just Refl -matchLeftHandSide LeftHandSideArray LeftHandSideArray - | Just Refl <- gcast @arr1 @arr2 Refl - = Just Refl -matchLeftHandSide (LeftHandSidePair a1 a2) (LeftHandSidePair b1 b2) - | Just Refl <- matchLeftHandSide a1 b1 - , Just Refl <- matchLeftHandSide a2 b2 - = Just Refl -matchLeftHandSide _ _ = Nothing +matchALeftHandSide :: forall aenv aenv1 aenv2 t1 t2. ALeftHandSide t1 aenv aenv1 -> ALeftHandSide t2 aenv aenv2 -> Maybe (ALeftHandSide t1 aenv aenv1 :~: ALeftHandSide t2 aenv aenv2) +matchALeftHandSide = matchLeftHandSide matchArrayR -matchLeftHandSide' :: forall aenv aenv1 aenv2 arr1 arr2. LeftHandSide arr1 aenv1 aenv -> LeftHandSide arr2 aenv2 aenv -> Maybe (LeftHandSide arr1 aenv1 aenv :~: LeftHandSide arr2 aenv2 aenv) -matchLeftHandSide' (LeftHandSideWildcard repr1) (LeftHandSideWildcard repr2) - | Just Refl <- matchArraysR repr1 repr2 +matchELeftHandSide :: forall env env1 env2 t1 t2. ELeftHandSide t1 env env1 -> ELeftHandSide t2 env env2 -> Maybe (ELeftHandSide t1 env env1 :~: ELeftHandSide t2 env env2) +matchELeftHandSide = matchLeftHandSide matchScalarType + +matchLeftHandSide :: forall s env env1 env2 t1 t2. (forall x y. s x -> s y -> Maybe (x :~: y)) -> LeftHandSide s t1 env env1 -> LeftHandSide s t2 env env2 -> Maybe (LeftHandSide s t1 env env1 :~: LeftHandSide s t2 env env2) +matchLeftHandSide f (LeftHandSideWildcard repr1) (LeftHandSideWildcard repr2) + | Just Refl <- matchTupR f repr1 repr2 = Just Refl -matchLeftHandSide' LeftHandSideArray LeftHandSideArray - | Just Refl <- gcast @arr1 @arr2 Refl +matchLeftHandSide f (LeftHandSideSingle x) (LeftHandSideSingle y) + | Just Refl <- f x y = Just Refl -matchLeftHandSide' (LeftHandSidePair a1 a2) (LeftHandSidePair b1 b2) - | Just Refl <- matchLeftHandSide' a2 b2 - , Just Refl <- matchLeftHandSide' a1 b1 +matchLeftHandSide f (LeftHandSidePair a1 a2) (LeftHandSidePair b1 b2) + | Just Refl <- matchLeftHandSide f a1 b1 + , Just Refl <- matchLeftHandSide f a2 b2 = Just Refl -matchLeftHandSide' _ _ = Nothing +matchLeftHandSide _ _ _ = Nothing -- Match stencil boundaries -- matchBoundary - :: forall acc aenv sh t. Elt t - => MatchAcc acc - -> EncodeAcc acc - -> PreBoundary acc aenv (Array sh t) - -> PreBoundary acc aenv (Array sh t) + :: TupleType t + -> Boundary aenv (Array sh t) + -> Boundary aenv (Array sh t) -> Bool -matchBoundary _ _ Clamp Clamp = True -matchBoundary _ _ Mirror Mirror = True -matchBoundary _ _ Wrap Wrap = True -matchBoundary _ _ (Constant s) (Constant t) = matchConst (eltType @t) s t -matchBoundary m h (Function f) (Function g) - | Just Refl <- matchPreOpenFun m h f g +matchBoundary _ Clamp Clamp = True +matchBoundary _ Mirror Mirror = True +matchBoundary _ Wrap Wrap = True +matchBoundary tp (Constant s) (Constant t) = matchConst tp s t +matchBoundary _ (Function f) (Function g) + | Just Refl <- matchOpenFun f g = True -matchBoundary _ _ _ _ +matchBoundary _ _ _ = False @@ -330,11 +323,11 @@ matchSeq -> Maybe (s :~: t) matchSeq m h = match where - matchFun :: PreOpenFun acc env' aenv' u -> PreOpenFun acc env' aenv' v -> Maybe (u :~: v) - matchFun = matchPreOpenFun m h + matchFun :: OpenFun env' aenv' u -> OpenFun env' aenv' v -> Maybe (u :~: v) + matchFun = matchOpenFun m h - matchExp :: PreOpenExp acc env' aenv' u -> PreOpenExp acc env' aenv' v -> Maybe (u :~: v) - matchExp = matchPreOpenExp m h + matchExp :: OpenExp env' aenv' u -> OpenExp env' aenv' v -> Maybe (u :~: v) + matchExp = matchOpenExp m h match :: PreOpenSeq acc aenv senv' u -> PreOpenSeq acc aenv senv' v -> Maybe (u :~: v) match (Producer p1 s1) (Producer p2 s2) @@ -401,33 +394,39 @@ matchSeq m h = match -- As a convenience, we are just comparing the stable names, but we could also -- walk the structure comparing the underlying ptrsOfArrayData. -- -matchArray :: (Shape sh1, Elt e1, Shape sh2, Elt e2) - => Array sh1 e1 -> Array sh2 e2 -> Maybe (Array sh1 e1 :~: Array sh2 e2) -matchArray (Array _ ad1) (Array _ ad2) - | unsafePerformIO $ do +matchArray :: ArrayR (Array sh1 e1) + -> ArrayR (Array sh2 e2) + -> Array sh1 e1 + -> Array sh2 e2 + -> Maybe (Array sh1 e1 :~: Array sh2 e2) +matchArray repr1 repr2 (Array _ ad1) (Array _ ad2) + | Just Refl <- matchArrayR repr1 repr2 + , unsafePerformIO $ do + sn1 <- makeStableName ad1 sn2 <- makeStableName ad2 return $! hashStableName sn1 == hashStableName sn2 - = gcast Refl - -matchArray _ _ - = Nothing - -matchArraysR :: ArraysR s -> ArraysR t -> Maybe (s :~: t) -matchArraysR ArraysRunit ArraysRunit = Just Refl -matchArraysR (ArraysRpair a1 b1) (ArraysRpair a2 b2) - | Just Refl <- matchArraysR a1 a2 - , Just Refl <- matchArraysR b1 b2 - = Just Refl +matchArray _ _ _ _ + = Nothing -matchArraysR ArraysRarray ArraysRarray - = gcast Refl +matchTupR :: (forall u1 u2. s u1 -> s u2 -> Maybe (u1 :~: u2)) -> TupR s t1 -> TupR s t2 -> Maybe (t1 :~: t2) +matchTupR _ TupRunit TupRunit = Just Refl +matchTupR f (TupRsingle x) (TupRsingle y) = f x y +matchTupR f (TupRpair x1 x2) (TupRpair y1 y2) + | Just Refl <- matchTupR f x1 y1 + , Just Refl <- matchTupR f x2 y2 = Just Refl +matchTupR _ _ _ = Nothing -matchArraysR _ _ - = Nothing +matchArraysR :: ArraysR s -> ArraysR t -> Maybe (s :~: t) +matchArraysR = matchTupR matchArrayR +matchArrayR :: ArrayR s -> ArrayR t -> Maybe (s :~: t) +matchArrayR (ArrayR shr1 tp1) (ArrayR shr2 tp2) + | Just Refl <- matchShapeR shr1 shr2 + , Just Refl <- matchTupleType tp1 tp2 = Just Refl +matchArrayR _ _ = Nothing -- Compute the congruence of two scalar expressions. Two nodes are congruent if @@ -439,184 +438,142 @@ matchArraysR _ _ -- The below attempts to use real typed equality, but occasionally still needs -- to use a cast, particularly when we can only match the representation types. -- -{-# INLINEABLE matchPreOpenExp #-} -matchPreOpenExp - :: forall acc env aenv s t. - MatchAcc acc - -> EncodeAcc acc - -> PreOpenExp acc env aenv s - -> PreOpenExp acc env aenv t +{-# INLINEABLE matchOpenExp #-} +matchOpenExp + :: forall env aenv s t. + OpenExp env aenv s + -> OpenExp env aenv t -> Maybe (s :~: t) -matchPreOpenExp matchAcc encodeAcc = match - where - match :: forall env' aenv' s' t'. - PreOpenExp acc env' aenv' s' - -> PreOpenExp acc env' aenv' t' - -> Maybe (s' :~: t') - match (Let x1 e1) (Let x2 e2) - | Just Refl <- match x1 x2 - , Just Refl <- match e1 e2 - = Just Refl - - match (Var v1) (Var v2) - = matchIdx v1 v2 - - match (Foreign ff1 _ e1) (Foreign ff2 _ e2) - | Just Refl <- match e1 e2 - , unsafePerformIO $ do - sn1 <- makeStableName ff1 - sn2 <- makeStableName ff2 - return $! hashStableName sn1 == hashStableName sn2 - = gcast Refl - - match (Const c1) (Const c2) - | Just Refl <- matchTupleType (eltType @s') (eltType @t') - , matchConst (eltType @s') c1 c2 - = gcast Refl -- surface/representation type - match Undef Undef - | Just Refl <- matchTupleType (eltType @s') (eltType @t') - = gcast Refl - - match (Coerce e1) (Coerce e2) - | Just Refl <- matchTupleType (eltType @s') (eltType @t') - , Just Refl <- match e1 e2 - = gcast Refl +matchOpenExp (Let lhs1 x1 e1) (Let lhs2 x2 e2) + | Just Refl <- matchELeftHandSide lhs1 lhs2 + , Just Refl <- matchOpenExp x1 x2 + , Just Refl <- matchOpenExp e1 e2 + = Just Refl - match (Tuple t1) (Tuple t2) - | Just Refl <- matchTuple matchAcc encodeAcc t1 t2 - = gcast Refl -- surface/representation type +matchOpenExp (Evar v1) (Evar v2) + = matchVar v1 v2 - match (Prj ix1 t1) (Prj ix2 t2) - | Just Refl <- match t1 t2 - , Just Refl <- matchTupleIdx ix1 ix2 - = Just Refl +matchOpenExp (Foreign _ ff1 f1 e1) (Foreign _ ff2 f2 e2) + | Just Refl <- matchOpenExp e1 e2 + , unsafePerformIO $ do + sn1 <- makeStableName ff1 + sn2 <- makeStableName ff2 + return $! hashStableName sn1 == hashStableName sn2 + , Just Refl <- matchOpenFun f1 f2 + = Just Refl - match IndexAny IndexAny - = gcast Refl -- ??? +matchOpenExp (Const t1 c1) (Const t2 c2) + | Just Refl <- matchScalarType t1 t2 + , matchConst (TupRsingle t1) c1 c2 + = Just Refl - match IndexNil IndexNil - = Just Refl +matchOpenExp (Undef t1) (Undef t2) = matchScalarType t1 t2 - match (IndexCons sl1 a1) (IndexCons sl2 a2) - | Just Refl <- match sl1 sl2 - , Just Refl <- match a1 a2 - = Just Refl +matchOpenExp (Coerce _ t1 e1) (Coerce _ t2 e2) + | Just Refl <- matchScalarType t1 t2 + , Just Refl <- matchOpenExp e1 e2 + = Just Refl - match (IndexHead sl1) (IndexHead sl2) - | Just Refl <- match sl1 sl2 - = Just Refl +matchOpenExp (Pair a1 b1) (Pair a2 b2) + | Just Refl <- matchOpenExp a1 a2 + , Just Refl <- matchOpenExp b1 b2 + = Just Refl - match (IndexTail sl1) (IndexTail sl2) - | Just Refl <- match sl1 sl2 - = Just Refl +matchOpenExp Nil Nil + = Just Refl - match (IndexSlice sliceIndex1 ix1 sh1) (IndexSlice sliceIndex2 ix2 sh2) - | Just Refl <- match ix1 ix2 - , Just Refl <- match sh1 sh2 - , Just Refl <- matchSliceRestrict sliceIndex1 sliceIndex2 - = gcast Refl -- SliceIndex representation/surface type - - match (IndexFull sliceIndex1 ix1 sl1) (IndexFull sliceIndex2 ix2 sl2) - | Just Refl <- match ix1 ix2 - , Just Refl <- match sl1 sl2 - , Just Refl <- matchSliceExtend sliceIndex1 sliceIndex2 - = gcast Refl -- SliceIndex representation/surface type - - match (ToIndex sh1 i1) (ToIndex sh2 i2) - | Just Refl <- match sh1 sh2 - , Just Refl <- match i1 i2 - = Just Refl +matchOpenExp (IndexSlice sliceIndex1 ix1 sh1) (IndexSlice sliceIndex2 ix2 sh2) + | Just Refl <- matchOpenExp ix1 ix2 + , Just Refl <- matchOpenExp sh1 sh2 + , Just Refl <- matchSliceIndex sliceIndex1 sliceIndex2 + = Just Refl - match (FromIndex sh1 i1) (FromIndex sh2 i2) - | Just Refl <- match i1 i2 - , Just Refl <- match sh1 sh2 - = Just Refl +matchOpenExp (IndexFull sliceIndex1 ix1 sl1) (IndexFull sliceIndex2 ix2 sl2) + | Just Refl <- matchOpenExp ix1 ix2 + , Just Refl <- matchOpenExp sl1 sl2 + , Just Refl <- matchSliceIndex sliceIndex1 sliceIndex2 + = Just Refl - match (Cond p1 t1 e1) (Cond p2 t2 e2) - | Just Refl <- match p1 p2 - , Just Refl <- match t1 t2 - , Just Refl <- match e1 e2 - = Just Refl +matchOpenExp (ToIndex _ sh1 i1) (ToIndex _ sh2 i2) + | Just Refl <- matchOpenExp sh1 sh2 + , Just Refl <- matchOpenExp i1 i2 + = Just Refl - match (While p1 f1 x1) (While p2 f2 x2) - | Just Refl <- match x1 x2 - , Just Refl <- matchPreOpenFun matchAcc encodeAcc p1 p2 - , Just Refl <- matchPreOpenFun matchAcc encodeAcc f1 f2 - = Just Refl +matchOpenExp (FromIndex _ sh1 i1) (FromIndex _ sh2 i2) + | Just Refl <- matchOpenExp i1 i2 + , Just Refl <- matchOpenExp sh1 sh2 + = Just Refl - match (PrimConst c1) (PrimConst c2) - = matchPrimConst c1 c2 +matchOpenExp (Cond p1 t1 e1) (Cond p2 t2 e2) + | Just Refl <- matchOpenExp p1 p2 + , Just Refl <- matchOpenExp t1 t2 + , Just Refl <- matchOpenExp e1 e2 + = Just Refl - match (PrimApp f1 x1) (PrimApp f2 x2) - | Just x1' <- commutes encodeAcc f1 x1 - , Just x2' <- commutes encodeAcc f2 x2 - , Just Refl <- match x1' x2' - , Just Refl <- matchPrimFun f1 f2 - = Just Refl +matchOpenExp (While p1 f1 x1) (While p2 f2 x2) + | Just Refl <- matchOpenExp x1 x2 + , Just Refl <- matchOpenFun p1 p2 + , Just Refl <- matchOpenFun f1 f2 + = Just Refl - | Just Refl <- match x1 x2 - , Just Refl <- matchPrimFun f1 f2 - = Just Refl +matchOpenExp (PrimConst c1) (PrimConst c2) + = matchPrimConst c1 c2 - match (Index a1 x1) (Index a2 x2) - | Just Refl <- matchAcc a1 a2 -- should only be array indices - , Just Refl <- match x1 x2 - = Just Refl +matchOpenExp (PrimApp f1 x1) (PrimApp f2 x2) + | Just x1' <- commutes f1 x1 + , Just x2' <- commutes f2 x2 + , Just Refl <- matchOpenExp x1' x2' + , Just Refl <- matchPrimFun f1 f2 + = Just Refl - match (LinearIndex a1 x1) (LinearIndex a2 x2) - | Just Refl <- matchAcc a1 a2 - , Just Refl <- match x1 x2 - = Just Refl + | Just Refl <- matchOpenExp x1 x2 + , Just Refl <- matchPrimFun f1 f2 + = Just Refl - match (Shape a1) (Shape a2) - | Just Refl <- matchAcc a1 a2 -- should only be array indices - = Just Refl +matchOpenExp (Index a1 x1) (Index a2 x2) + | Just Refl <- matchVar a1 a2 -- should only be array indices + , Just Refl <- matchOpenExp x1 x2 + = Just Refl - match (ShapeSize sh1) (ShapeSize sh2) - | Just Refl <- match sh1 sh2 - = Just Refl +matchOpenExp (LinearIndex a1 x1) (LinearIndex a2 x2) + | Just Refl <- matchVar a1 a2 + , Just Refl <- matchOpenExp x1 x2 + = Just Refl - match (Intersect sa1 sb1) (Intersect sa2 sb2) - | Just Refl <- match sa1 sa2 - , Just Refl <- match sb1 sb2 - = Just Refl +matchOpenExp (Shape a1) (Shape a2) + | Just Refl <- matchVar a1 a2 -- should only be array indices + = Just Refl - match (Union sa1 sb1) (Union sa2 sb2) - | Just Refl <- match sa1 sa2 - , Just Refl <- match sb1 sb2 - = Just Refl +matchOpenExp (ShapeSize _ sh1) (ShapeSize _ sh2) + | Just Refl <- matchOpenExp sh1 sh2 + = Just Refl - match _ _ - = Nothing +matchOpenExp _ _ + = Nothing -- Match scalar functions -- -{-# INLINEABLE matchPreOpenFun #-} -matchPreOpenFun - :: MatchAcc acc - -> EncodeAcc acc - -> PreOpenFun acc env aenv s - -> PreOpenFun acc env aenv t +{-# INLINEABLE matchOpenFun #-} +matchOpenFun + :: OpenFun env aenv s + -> OpenFun env aenv t -> Maybe (s :~: t) -matchPreOpenFun m h (Lam s) (Lam t) - | Just Refl <- matchEnvTop s t - , Just Refl <- matchPreOpenFun m h s t +matchOpenFun (Lam lhs1 s) (Lam lhs2 t) + | Just Refl <- matchELeftHandSide lhs1 lhs2 + , Just Refl <- matchOpenFun s t = Just Refl - where - matchEnvTop :: (Elt s, Elt t) => PreOpenFun acc (env, s) aenv f -> PreOpenFun acc (env, t) aenv g -> Maybe (s :~: t) - matchEnvTop _ _ = gcast Refl -- ??? -matchPreOpenFun m h (Body s) (Body t) = matchPreOpenExp m h s t -matchPreOpenFun _ _ _ _ = Nothing +matchOpenFun (Body s) (Body t) = matchOpenExp s t +matchOpenFun _ _ = Nothing -- Matching constants -- matchConst :: TupleType a -> a -> a -> Bool -matchConst TypeRunit () () = True -matchConst (TypeRscalar ty) a b = evalEq ty (a,b) -matchConst (TypeRpair ta tb) (a1,b1) (a2,b2) = matchConst ta a1 a2 && matchConst tb b1 b2 +matchConst TupRunit () () = True +matchConst (TupRsingle ty) a b = evalEq ty (a,b) +matchConst (TupRpair ta tb) (a1,b1) (a2,b2) = matchConst ta a1 a2 && matchConst tb b1 b2 evalEq :: ScalarType a -> (a, a) -> Bool evalEq (SingleScalarType t) = evalEqSingle t @@ -642,87 +599,38 @@ matchIdx ZeroIdx ZeroIdx = Just Refl matchIdx (SuccIdx u) (SuccIdx v) = matchIdx u v matchIdx _ _ = Nothing -{-# INLINEABLE matchArrayVar #-} -matchArrayVar :: ArrayVar env s -> ArrayVar env t -> Maybe (s :~: t) -matchArrayVar (ArrayVar v1) (ArrayVar v2) = matchIdx v1 v2 - -{-# INLINEABLE matchArrayVars #-} -matchArrayVars :: ArrayVars env s -> ArrayVars env t -> Maybe (s :~: t) -matchArrayVars ArrayVarsNil ArrayVarsNil = Just Refl -matchArrayVars (ArrayVarsArray v1) (ArrayVarsArray v2) - | Just Refl <- matchArrayVar v1 v2 = Just Refl -matchArrayVars (ArrayVarsPair v w) (ArrayVarsPair x y) - | Just Refl <- matchArrayVars v x - , Just Refl <- matchArrayVars w y = Just Refl -matchArrayVars _ _ = Nothing - - --- Tuple projection indices. Given the same tuple expression structure (tup), --- check that the indices project identical elements. --- -{-# INLINEABLE matchTupleIdx #-} -matchTupleIdx :: TupleIdx tup s -> TupleIdx tup t -> Maybe (s :~: t) -matchTupleIdx ZeroTupIdx ZeroTupIdx = Just Refl -matchTupleIdx (SuccTupIdx s) (SuccTupIdx t) = matchTupleIdx s t -matchTupleIdx _ _ = Nothing - --- Tuples --- -matchTuple - :: MatchAcc acc - -> EncodeAcc acc - -> Tuple (PreOpenExp acc env aenv) s - -> Tuple (PreOpenExp acc env aenv) t - -> Maybe (s :~: t) -matchTuple _ _ NilTup NilTup = Just Refl -matchTuple m h (SnocTup t1 e1) (SnocTup t2 e2) - | Just Refl <- matchTuple m h t1 t2 - , Just Refl <- matchPreOpenExp m h e1 e2 - = Just Refl +{-# INLINEABLE matchVar #-} +matchVar :: Var s env t1 -> Var s env t2 -> Maybe (t1 :~: t2) +matchVar (Var _ v1) (Var _ v2) = matchIdx v1 v2 -matchTuple _ _ _ _ = Nothing +{-# INLINEABLE matchVars #-} +matchVars :: Vars s env t1 -> Vars s env t2 -> Maybe (t1 :~: t2) +matchVars VarsNil VarsNil = Just Refl +matchVars (VarsSingle v1) (VarsSingle v2) + | Just Refl <- matchVar v1 v2 = Just Refl +matchVars (VarsPair v w) (VarsPair x y) + | Just Refl <- matchVars v x + , Just Refl <- matchVars w y = Just Refl +matchVars _ _ = Nothing -- Slice specifications -- -matchSliceRestrict - :: SliceIndex slix s co sh - -> SliceIndex slix t co' sh - -> Maybe (s :~: t) -matchSliceRestrict SliceNil SliceNil +matchSliceIndex :: SliceIndex slix1 sl1 co1 sh1 -> SliceIndex slix2 sl2 co2 sh2 -> Maybe (SliceIndex slix1 sl1 co1 sh1 :~: SliceIndex slix2 sl2 co2 sh2) +matchSliceIndex SliceNil SliceNil = Just Refl -matchSliceRestrict (SliceAll sl1) (SliceAll sl2) - | Just Refl <- matchSliceRestrict sl1 sl2 +matchSliceIndex (SliceAll sl1) (SliceAll sl2) + | Just Refl <- matchSliceIndex sl1 sl2 = Just Refl -matchSliceRestrict (SliceFixed sl1) (SliceFixed sl2) - | Just Refl <- matchSliceRestrict sl1 sl2 +matchSliceIndex (SliceFixed sl1) (SliceFixed sl2) + | Just Refl <- matchSliceIndex sl1 sl2 = Just Refl -matchSliceRestrict _ _ +matchSliceIndex _ _ = Nothing - -matchSliceExtend - :: SliceIndex slix sl co s - -> SliceIndex slix sl co' t - -> Maybe (s :~: t) -matchSliceExtend SliceNil SliceNil - = Just Refl - -matchSliceExtend (SliceAll sl1) (SliceAll sl2) - | Just Refl <- matchSliceExtend sl1 sl2 - = Just Refl - -matchSliceExtend (SliceFixed sl1) (SliceFixed sl2) - | Just Refl <- matchSliceExtend sl1 sl2 - = Just Refl - -matchSliceExtend _ _ - = Nothing - - -- Primitive constants and functions -- matchPrimConst :: PrimConst s -> PrimConst t -> Maybe (s :~: t) @@ -902,15 +810,7 @@ matchPrimFun' _ _ -- {-# INLINEABLE matchTupleType #-} matchTupleType :: TupleType s -> TupleType t -> Maybe (s :~: t) -matchTupleType TypeRunit TypeRunit = Just Refl -matchTupleType (TypeRscalar s) (TypeRscalar t) = matchScalarType s t -matchTupleType (TypeRpair s1 s2) (TypeRpair t1 t2) - | Just Refl <- matchTupleType s1 t1 - , Just Refl <- matchTupleType s2 t2 - = Just Refl - -matchTupleType _ _ - = Nothing +matchTupleType = matchTupR matchScalarType -- Match shapes (dimensionality) @@ -922,9 +822,9 @@ matchTupleType _ _ -- a known branch. -- {-# INLINEABLE matchShapeType #-} -matchShapeType :: forall s t. (Shape s, Shape t) => Maybe (s :~: t) +matchShapeType :: forall s t. (Sugar.Shape s, Sugar.Shape t) => Maybe (s :~: t) matchShapeType - | Just Refl <- matchTupleType (eltType @s) (eltType @t) + | Just Refl <- matchShapeR (Sugar.shapeR @s) (Sugar.shapeR @t) #ifdef ACCELERATE_INTERNAL_CHECKS = gcast Refl #else @@ -933,6 +833,13 @@ matchShapeType | otherwise = Nothing +{-# INLINEABLE matchShapeR #-} +matchShapeR :: forall s t. ShapeR s -> ShapeR t -> Maybe (s :~: t) +matchShapeR ShapeRz ShapeRz = Just Refl +matchShapeR (ShapeRsnoc shr1) (ShapeRsnoc shr2) + | Just Refl <- matchShapeR shr1 shr2 = Just Refl +matchShapeR _ _ = Nothing + -- Match reified type dictionaries -- @@ -1007,12 +914,11 @@ matchNonNumType _ _ = Nothing -- commutativity. -- commutes - :: forall acc env aenv a r. - EncodeAcc acc - -> PrimFun (a -> r) - -> PreOpenExp acc env aenv a - -> Maybe (PreOpenExp acc env aenv a) -commutes h f x = case f of + :: forall env aenv a r. + PrimFun (a -> r) + -> OpenExp env aenv a + -> Maybe (OpenExp env aenv a) +commutes f x = case f of PrimAdd{} -> Just (swizzle x) PrimMul{} -> Just (swizzle x) PrimBAnd{} -> Just (swizzle x) @@ -1026,10 +932,10 @@ commutes h f x = case f of PrimLOr -> Just (swizzle x) _ -> Nothing where - swizzle :: PreOpenExp acc env aenv (a',a') -> PreOpenExp acc env aenv (a',a') + swizzle :: OpenExp env aenv (a',a') -> OpenExp env aenv (a',a') swizzle exp - | Tuple (NilTup `SnocTup` a `SnocTup` b) <- exp - , hashPreOpenExp h a > hashPreOpenExp h b = Tuple (NilTup `SnocTup` b `SnocTup` a) + | (a `Pair` b) <- exp + , hashOpenExp a > hashOpenExp b = b `Pair` a -- | otherwise = exp diff --git a/src/Data/Array/Accelerate/Analysis/Shape.hs b/src/Data/Array/Accelerate/Analysis/Shape.hs index 14f7c549d..416e77389 100644 --- a/src/Data/Array/Accelerate/Analysis/Shape.hs +++ b/src/Data/Array/Accelerate/Analysis/Shape.hs @@ -23,25 +23,22 @@ module Data.Array.Accelerate.Analysis.Shape ( ) where import Data.Array.Accelerate.AST -import Data.Array.Accelerate.Type -import Data.Array.Accelerate.Array.Sugar +import Data.Array.Accelerate.Array.Representation -- |Reify the dimensionality of the result type of an array computation -- accDim :: forall acc aenv sh e. HasArraysRepr acc => acc aenv (Array sh e) -> Int -accDim acc = case arraysRepr acc of - ArraysRarray -> rank @sh +accDim = rank . arrayRshape . arrayRepr -- |Reify dimensionality of a scalar expression yielding a shape -- -expDim :: forall acc env aenv sh. Elt sh => PreOpenExp acc env aenv sh -> Int -expDim _ = ndim (eltType @sh) - +expDim :: forall env aenv sh. OpenExp env aenv sh -> Int +expDim = ndim . expType -- Count the number of components to a tuple type -- -ndim :: TupleType a -> Int -ndim TypeRunit = 0 -ndim TypeRscalar{} = 1 -ndim (TypeRpair a b) = ndim a + ndim b +ndim :: TupR s a -> Int +ndim TupRunit = 0 +ndim TupRsingle{} = 1 +ndim (TupRpair a b) = ndim a + ndim b diff --git a/src/Data/Array/Accelerate/Analysis/Stencil.hs b/src/Data/Array/Accelerate/Analysis/Stencil.hs index 283f7b2c1..cf465f565 100644 --- a/src/Data/Array/Accelerate/Analysis/Stencil.hs +++ b/src/Data/Array/Accelerate/Analysis/Stencil.hs @@ -1,5 +1,6 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeOperators #-} {-# OPTIONS_HADDOCK hide #-} -- | @@ -12,10 +13,10 @@ -- Portability : non-portable (GHC extensions) -- -module Data.Array.Accelerate.Analysis.Stencil (offsets, offsets2) where +module Data.Array.Accelerate.Analysis.Stencil (positionsR) where import Data.Array.Accelerate.AST -import Data.Array.Accelerate.Array.Sugar +import Data.Array.Accelerate.Array.Representation -- |Calculate the offset coordinates for each stencil element relative to the @@ -23,68 +24,58 @@ import Data.Array.Accelerate.Array.Sugar -- bottom-left element to the top-right. This ordering matches the Var indexing -- order. -- -offsets :: forall a b sh aenv stencil. Stencil sh a stencil - => {- dummy -} Fun aenv (stencil -> b) - -> {- dummy -} OpenAcc aenv (Array sh a) - -> [sh] -offsets _ _ = positionsR (stencil :: StencilR sh a stencil) - -offsets2 :: forall a b c sh aenv stencil1 stencil2. (Stencil sh a stencil1, Stencil sh b stencil2) - => {- dummy -} Fun aenv (stencil1 -> stencil2 -> c) - -> {- dummy -} OpenAcc aenv (Array sh a) - -> {- dummy -} OpenAcc aenv (Array sh b) - -> ([sh], [sh]) -offsets2 _ _ _ = - ( positionsR (stencil :: StencilR sh a stencil1) - , positionsR (stencil :: StencilR sh b stencil2) ) - - --- |Position calculation on reified stencil values. --- positionsR :: StencilR sh e pat -> [sh] -positionsR StencilRunit3 = map (Z:.) [ -1, 0, 1 ] -positionsR StencilRunit5 = map (Z:.) [ -2,-1, 0, 1, 2 ] -positionsR StencilRunit7 = map (Z:.) [ -3,-2,-1, 0, 1, 2, 3 ] -positionsR StencilRunit9 = map (Z:.) [-4,-3,-2,-1, 0, 1, 2, 3, 4 ] +positionsR StencilRunit3{} = map ((), ) [ -1, 0, 1 ] +positionsR StencilRunit5{} = map ((), ) [ -2,-1, 0, 1, 2 ] +positionsR StencilRunit7{} = map ((), ) [ -3,-2,-1, 0, 1, 2, 3 ] +positionsR StencilRunit9{} = map ((), ) [-4,-3,-2,-1, 0, 1, 2, 3, 4 ] positionsR (StencilRtup3 c b a) = concat - [ map (innermost (:. -1)) $ positionsR c - , map (innermost (:. 0)) $ positionsR b - , map (innermost (:. 1)) $ positionsR a ] + [ map (innermost shr (, -1)) $ positionsR c + , map (innermost shr (, 0)) $ positionsR b + , map (innermost shr (, 1)) $ positionsR a ] + where + shr = stencilShape a positionsR (StencilRtup5 e d c b a) = concat - [ map (innermost (:. -2)) $ positionsR e - , map (innermost (:. -1)) $ positionsR d - , map (innermost (:. 0)) $ positionsR c - , map (innermost (:. 1)) $ positionsR b - , map (innermost (:. 2)) $ positionsR a ] + [ map (innermost shr (, -2)) $ positionsR e + , map (innermost shr (, -1)) $ positionsR d + , map (innermost shr (, 0)) $ positionsR c + , map (innermost shr (, 1)) $ positionsR b + , map (innermost shr (, 2)) $ positionsR a ] + where + shr = stencilShape a positionsR (StencilRtup7 g f e d c b a) = concat - [ map (innermost (:. -3)) $ positionsR g - , map (innermost (:. -2)) $ positionsR f - , map (innermost (:. -1)) $ positionsR e - , map (innermost (:. 0)) $ positionsR d - , map (innermost (:. 1)) $ positionsR c - , map (innermost (:. 2)) $ positionsR b - , map (innermost (:. 3)) $ positionsR a ] + [ map (innermost shr (, -3)) $ positionsR g + , map (innermost shr (, -2)) $ positionsR f + , map (innermost shr (, -1)) $ positionsR e + , map (innermost shr (, 0)) $ positionsR d + , map (innermost shr (, 1)) $ positionsR c + , map (innermost shr (, 2)) $ positionsR b + , map (innermost shr (, 3)) $ positionsR a ] + where + shr = stencilShape a positionsR (StencilRtup9 i h g f e d c b a) = concat - [ map (innermost (:. -4)) $ positionsR i - , map (innermost (:. -3)) $ positionsR h - , map (innermost (:. -2)) $ positionsR g - , map (innermost (:. -1)) $ positionsR f - , map (innermost (:. 0)) $ positionsR e - , map (innermost (:. 1)) $ positionsR d - , map (innermost (:. 2)) $ positionsR c - , map (innermost (:. 3)) $ positionsR b - , map (innermost (:. 4)) $ positionsR a ] + [ map (innermost shr (, -4)) $ positionsR i + , map (innermost shr (, -3)) $ positionsR h + , map (innermost shr (, -2)) $ positionsR g + , map (innermost shr (, -1)) $ positionsR f + , map (innermost shr (, 0)) $ positionsR e + , map (innermost shr (, 1)) $ positionsR d + , map (innermost shr (, 2)) $ positionsR c + , map (innermost shr (, 3)) $ positionsR b + , map (innermost shr (, 4)) $ positionsR a ] + where + shr = stencilShape a -- Inject a dimension component inner-most -- -innermost :: Shape sh => (sh -> sh :. Int) -> sh -> sh :. Int -innermost f = invertShape . f . invertShape +innermost :: ShapeR sh -> (sh -> (sh, Int)) -> sh -> (sh, Int) +innermost shr f = invertShape (ShapeRsnoc shr) . f . invertShape shr -invertShape :: Shape sh => sh -> sh -invertShape = listToShape . reverse . shapeToList +invertShape :: ShapeR sh -> sh -> sh +invertShape shr = listToShape shr . reverse . shapeToList shr diff --git a/src/Data/Array/Accelerate/Analysis/Type.hs b/src/Data/Array/Accelerate/Analysis/Type.hs index 4fb6cd509..677dc71dc 100644 --- a/src/Data/Array/Accelerate/Analysis/Type.hs +++ b/src/Data/Array/Accelerate/Analysis/Type.hs @@ -24,8 +24,7 @@ module Data.Array.Accelerate.Analysis.Type ( - arrayType, - accType, expType, + accType, sizeOf, sizeOfScalarType, @@ -38,73 +37,26 @@ module Data.Array.Accelerate.Analysis.Type ( -- friends import Data.Array.Accelerate.AST -import Data.Array.Accelerate.Array.Sugar +import Data.Array.Accelerate.Array.Representation import Data.Array.Accelerate.Type -- standard library import qualified Foreign.Storable as F --- |Determine an array type --- ------------------------ - --- |Reify the element type of an array. --- -arrayType :: forall sh e. Elt e => Array sh e -> TupleType (EltRepr e) -arrayType _ = eltType @e - - -- |Determine the type of an expressions -- ------------------------------------- -accType :: forall acc aenv sh e. HasArraysRepr acc => acc aenv (Array sh e) -> TupleType (EltRepr e) -accType acc = case arraysRepr acc of - ArraysRarray -> eltType @e - --- |Reify the result types of of a scalar expression using the expression AST before tying the --- knot. --- -expType :: forall acc aenv env t. - HasArraysRepr acc - => PreOpenExp acc aenv env t - -> TupleType (EltRepr t) -expType e = - case e of - Let _ _ -> eltType @t - Var _ -> eltType @t - Const _ -> eltType @t - Undef -> eltType @t - Tuple _ -> eltType @t - Prj _ _ -> eltType @t - IndexNil -> eltType @t - IndexCons _ _ -> eltType @t - IndexHead _ -> eltType @t - IndexTail _ -> eltType @t - IndexAny -> eltType @t - IndexSlice _ _ _ -> eltType @t - IndexFull _ _ _ -> eltType @t - ToIndex _ _ -> eltType @t - FromIndex _ _ -> eltType @t - Cond _ t _ -> expType t - While _ _ _ -> eltType @t - PrimConst _ -> eltType @t - PrimApp _ _ -> eltType @t - Index acc _ -> accType acc - LinearIndex acc _ -> accType acc - Shape _ -> eltType @t - ShapeSize _ -> eltType @t - Intersect _ _ -> eltType @t - Union _ _ -> eltType @t - Foreign _ _ _ -> eltType @t - Coerce _ -> eltType @t +accType :: forall acc aenv sh e. HasArraysRepr acc => acc aenv (Array sh e) -> TupleType e +accType = arrayRtype . arrayRepr -- |Size of a tuple type, in bytes -- sizeOf :: TupleType a -> Int -sizeOf TypeRunit = 0 -sizeOf (TypeRpair a b) = sizeOf a + sizeOf b -sizeOf (TypeRscalar t) = sizeOfScalarType t +sizeOf TupRunit = 0 +sizeOf (TupRpair a b) = sizeOf a + sizeOf b +sizeOf (TupRsingle t) = sizeOfScalarType t sizeOfScalarType :: ScalarType t -> Int sizeOfScalarType (SingleScalarType t) = sizeOfSingleType t diff --git a/src/Data/Array/Accelerate/Array/Data.hs b/src/Data/Array/Accelerate/Array/Data.hs index f1ede7162..ba8bafdd0 100644 --- a/src/Data/Array/Accelerate/Array/Data.hs +++ b/src/Data/Array/Accelerate/Array/Data.hs @@ -1,15 +1,18 @@ -{-# LANGUAGE BangPatterns #-} -{-# LANGUAGE CPP #-} -{-# LANGUAGE DeriveDataTypeable #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE MagicHash #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TemplateHaskell #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE UnboxedTuples #-} -{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE CPP #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeFamilyDependencies #-} +{-# LANGUAGE UnboxedTuples #-} +{-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Array.Accelerate.Array.Data @@ -29,11 +32,8 @@ module Data.Array.Accelerate.Array.Data ( -- * Array operations and representations - ArrayElt(..), ArrayData, MutableArrayData, runArrayData, - ArrayEltR(..), GArrayData(..), - - -- * Array tuple operations - fstArrayData, sndArrayData, pairArrayData, + ArrayData, MutableArrayData, runArrayData, GArrayData, rnfArrayData, ScalarData, ScalarDataRepr, + unsafeIndexArrayData, ptrOfArrayData, touchArrayData, newArrayData, unsafeReadArrayData, unsafeWriteArrayData, -- * Type macros HTYPE_INT, HTYPE_WORD, HTYPE_CLONG, HTYPE_CULONG, HTYPE_CCHAR, @@ -41,6 +41,9 @@ module Data.Array.Accelerate.Array.Data ( -- * Allocator internals registerForeignPtrAllocator, + -- * Utilities for type classes + ScalarDict(..), scalarDict, singleDict, IsScalarData + ) where -- friends @@ -56,12 +59,10 @@ import Data.Array.Accelerate.Debug.Trace -- standard libraries import Control.Applicative import Control.Monad ( (<=<) ) +import Control.DeepSeq import Data.Bits -import Data.Char import Data.IORef -import Data.Kind import Data.Primitive ( sizeOf# ) -import Data.Typeable ( Typeable ) import Foreign.ForeignPtr import Foreign.Storable import Language.Haskell.TH hiding ( Type ) @@ -72,7 +73,7 @@ import Prelude hiding ( map import GHC.Base import GHC.ForeignPtr import GHC.Ptr -import GHC.TypeLits +import Data.Primitive.Types ( Prim ) -- Determine the underlying type of a Haskell CLong or CULong. @@ -123,152 +124,132 @@ type MutableArrayData e = GArrayData e -- In previous versions this was abstracted over by the mutable/immutable array -- representation, but this is now fixed to our UniqueArray type. -- -data family GArrayData a :: Type -data instance GArrayData () = AD_Unit -data instance GArrayData Int = AD_Int {-# UNPACK #-} !(UniqueArray Int) -data instance GArrayData Int8 = AD_Int8 {-# UNPACK #-} !(UniqueArray Int8) -data instance GArrayData Int16 = AD_Int16 {-# UNPACK #-} !(UniqueArray Int16) -data instance GArrayData Int32 = AD_Int32 {-# UNPACK #-} !(UniqueArray Int32) -data instance GArrayData Int64 = AD_Int64 {-# UNPACK #-} !(UniqueArray Int64) -data instance GArrayData Word = AD_Word {-# UNPACK #-} !(UniqueArray Word) -data instance GArrayData Word8 = AD_Word8 {-# UNPACK #-} !(UniqueArray Word8) -data instance GArrayData Word16 = AD_Word16 {-# UNPACK #-} !(UniqueArray Word16) -data instance GArrayData Word32 = AD_Word32 {-# UNPACK #-} !(UniqueArray Word32) -data instance GArrayData Word64 = AD_Word64 {-# UNPACK #-} !(UniqueArray Word64) -data instance GArrayData Half = AD_Half {-# UNPACK #-} !(UniqueArray Half) -data instance GArrayData Float = AD_Float {-# UNPACK #-} !(UniqueArray Float) -data instance GArrayData Double = AD_Double {-# UNPACK #-} !(UniqueArray Double) -data instance GArrayData Bool = AD_Bool {-# UNPACK #-} !(UniqueArray Word8) -data instance GArrayData Char = AD_Char {-# UNPACK #-} !(UniqueArray Char) -data instance GArrayData (Vec n a) = AD_Vec !Int# !(GArrayData a) -- sad this does not get unpacked ): -data instance GArrayData (a, b) = AD_Pair (GArrayData a) (GArrayData b) -- XXX: non-strict to support lazy device-host copying - -deriving instance Typeable GArrayData - - --- | GADT to reify the 'ArrayElt' class. --- -data ArrayEltR a where - ArrayEltRunit :: ArrayEltR () - ArrayEltRint :: ArrayEltR Int - ArrayEltRint8 :: ArrayEltR Int8 - ArrayEltRint16 :: ArrayEltR Int16 - ArrayEltRint32 :: ArrayEltR Int32 - ArrayEltRint64 :: ArrayEltR Int64 - ArrayEltRword :: ArrayEltR Word - ArrayEltRword8 :: ArrayEltR Word8 - ArrayEltRword16 :: ArrayEltR Word16 - ArrayEltRword32 :: ArrayEltR Word32 - ArrayEltRword64 :: ArrayEltR Word64 - ArrayEltRhalf :: ArrayEltR Half - ArrayEltRfloat :: ArrayEltR Float - ArrayEltRdouble :: ArrayEltR Double - ArrayEltRbool :: ArrayEltR Bool - ArrayEltRchar :: ArrayEltR Char - ArrayEltRpair :: ArrayEltR a -> ArrayEltR b -> ArrayEltR (a,b) - ArrayEltRvec :: (KnownNat n, ArrayPtrs (Vec n a) ~ ArrayPtrs a, ArrayPtrs a ~ Ptr a) => ArrayEltR a -> ArrayEltR (Vec n a) - -- XXX: Do we really require these embedded class constraints? +type family GArrayData a where + GArrayData () = () + GArrayData (a, b) = (GArrayData a, GArrayData b) -- XXX: fields of tuple are non-strict, which enables lazy device-host copying + GArrayData a = ScalarData a + +type ScalarData a = UniqueArray (ScalarDataRepr a) + +-- Mapping from scalar type to the type as represented in memory in an array. +-- Booleans are stored as Word8, other types are represented as itself. +type family ScalarDataRepr tp where + ScalarDataRepr Int = Int + ScalarDataRepr Int8 = Int8 + ScalarDataRepr Int16 = Int16 + ScalarDataRepr Int32 = Int32 + ScalarDataRepr Int64 = Int64 + ScalarDataRepr Word = Word + ScalarDataRepr Word8 = Word8 + ScalarDataRepr Word16 = Word16 + ScalarDataRepr Word32 = Word32 + ScalarDataRepr Word64 = Word64 + ScalarDataRepr Half = Half + ScalarDataRepr Float = Float + ScalarDataRepr Double = Double + ScalarDataRepr Bool = Word8 + ScalarDataRepr Char = Char + ScalarDataRepr (Vec n tp) = ScalarDataRepr tp + +-- Utilities for working with the type families & type class instances +data ScalarDict e where + ScalarDict :: IsScalarData e => ScalarDict e + +type IsScalarData e = (Storable (ScalarDataRepr e), Prim (ScalarDataRepr e), ArrayData e ~ ScalarData e) + +{-# INLINE scalarDict #-} +scalarDict :: ScalarType e -> (Int, ScalarDict e) +scalarDict (SingleScalarType tp) + | (dict, _, _) <- singleDict tp = (1, dict) +scalarDict (VectorScalarType (VectorType n tp)) + | (ScalarDict, _, _) <- singleDict tp = (n, ScalarDict) + +{-# INLINE singleDict #-} +singleDict :: SingleType e -> (ScalarDict e, e -> ScalarDataRepr e, ScalarDataRepr e -> e) +singleDict (NonNumSingleType TypeBool) = (ScalarDict, fromBool, toBool) +singleDict (NonNumSingleType TypeChar) = (ScalarDict, id, id) +singleDict (NumSingleType (IntegralNumType tp)) = case tp of + TypeInt -> (ScalarDict, id, id) + TypeInt8 -> (ScalarDict, id, id) + TypeInt16 -> (ScalarDict, id, id) + TypeInt32 -> (ScalarDict, id, id) + TypeInt64 -> (ScalarDict, id, id) + TypeWord -> (ScalarDict, id, id) + TypeWord8 -> (ScalarDict, id, id) + TypeWord16 -> (ScalarDict, id, id) + TypeWord32 -> (ScalarDict, id, id) + TypeWord64 -> (ScalarDict, id, id) +singleDict (NumSingleType (FloatingNumType tp)) = case tp of + TypeHalf -> (ScalarDict, id, id) + TypeFloat -> (ScalarDict, id, id) + TypeDouble -> (ScalarDict, id, id) -- Array operations -- ---------------- -class ArrayElt e where - type ArrayPtrs e - arrayElt :: ArrayEltR e - -- - unsafeIndexArrayData :: ArrayData e -> Int -> e - ptrsOfArrayData :: ArrayData e -> ArrayPtrs e - touchArrayData :: ArrayData e -> IO () - -- - newArrayData :: Int -> IO (MutableArrayData e) - unsafeReadArrayData :: MutableArrayData e -> Int -> IO e - unsafeWriteArrayData :: MutableArrayData e -> Int -> e -> IO () - unsafeFreezeArrayData :: MutableArrayData e -> IO (ArrayData e) - ptrsOfMutableArrayData :: MutableArrayData e -> IO (ArrayPtrs e) - -- - {-# INLINE unsafeFreezeArrayData #-} - {-# INLINE ptrsOfMutableArrayData #-} - unsafeFreezeArrayData = return - ptrsOfMutableArrayData = return . ptrsOfArrayData - -instance ArrayElt () where - type ArrayPtrs () = () - arrayElt = ArrayEltRunit - {-# INLINE arrayElt #-} - {-# INLINE newArrayData #-} - {-# INLINE ptrsOfArrayData #-} - {-# INLINE touchArrayData #-} - {-# INLINE unsafeIndexArrayData #-} - {-# INLINE unsafeReadArrayData #-} - {-# INLINE unsafeWriteArrayData #-} - newArrayData !_ = return AD_Unit - ptrsOfArrayData AD_Unit = () - touchArrayData AD_Unit = return () - unsafeIndexArrayData AD_Unit !_ = () - unsafeReadArrayData AD_Unit !_ = return () - unsafeWriteArrayData AD_Unit !_ () = return () - --- Bool arrays are stored as arrays of bytes. While this is memory inefficient, --- it is better suited to parallel backends than a packed bit-vector --- representation. --- --- XXX: Currently there are _no_ (Vec n Bool) instances. We could use efficient --- bit-packed representations for these cases... --- -instance ArrayElt Bool where - type ArrayPtrs Bool = Ptr Word8 - arrayElt = ArrayEltRbool - {-# INLINE arrayElt #-} - {-# INLINE newArrayData #-} - {-# INLINE ptrsOfArrayData #-} - {-# INLINE touchArrayData #-} - {-# INLINE unsafeIndexArrayData #-} - {-# INLINE unsafeReadArrayData #-} - {-# INLINE unsafeWriteArrayData #-} - newArrayData size = AD_Bool <$> newArrayData' size - ptrsOfArrayData (AD_Bool ba) = unsafeUniqueArrayPtr ba - touchArrayData (AD_Bool ba) = touchUniqueArray ba - unsafeIndexArrayData (AD_Bool ba) i = toBool $! unsafeIndexArray ba i - unsafeReadArrayData (AD_Bool ba) i = toBool <$> unsafeReadArray ba i - unsafeWriteArrayData (AD_Bool ba) i e = unsafeWriteArray ba i (fromBool e) - -instance (ArrayElt a, ArrayElt b) => ArrayElt (a, b) where - type ArrayPtrs (a, b) = (ArrayPtrs a, ArrayPtrs b) - arrayElt = ArrayEltRpair arrayElt arrayElt - {-# INLINEABLE arrayElt #-} - {-# INLINEABLE newArrayData #-} - {-# INLINEABLE ptrsOfArrayData #-} - {-# INLINEABLE ptrsOfMutableArrayData #-} - {-# INLINEABLE touchArrayData #-} - {-# INLINEABLE unsafeFreezeArrayData #-} - {-# INLINEABLE unsafeIndexArrayData #-} - {-# INLINEABLE unsafeReadArrayData #-} - {-# INLINEABLE unsafeWriteArrayData #-} - newArrayData size = AD_Pair <$> newArrayData size <*> newArrayData size - touchArrayData (AD_Pair a b) = touchArrayData a >> touchArrayData b - ptrsOfArrayData (AD_Pair a b) = (ptrsOfArrayData a, ptrsOfArrayData b) - ptrsOfMutableArrayData (AD_Pair a b) = (,) <$> ptrsOfMutableArrayData a <*> ptrsOfMutableArrayData b - unsafeReadArrayData (AD_Pair a b) i = (,) <$> unsafeReadArrayData a i <*> unsafeReadArrayData b i - unsafeIndexArrayData (AD_Pair a b) i = (unsafeIndexArrayData a i, unsafeIndexArrayData b i) - unsafeWriteArrayData (AD_Pair a b) i (x, y) = unsafeWriteArrayData a i x >> unsafeWriteArrayData b i y - unsafeFreezeArrayData (AD_Pair a b) = AD_Pair <$> unsafeFreezeArrayData a <*> unsafeFreezeArrayData b - - --- Array tuple operations --- ---------------------- - -{-# INLINE fstArrayData #-} -fstArrayData :: ArrayData (a, b) -> ArrayData a -fstArrayData (AD_Pair x _) = x - -{-# INLINE sndArrayData #-} -sndArrayData :: ArrayData (a, b) -> ArrayData b -sndArrayData (AD_Pair _ y) = y - -{-# INLINE pairArrayData #-} -pairArrayData :: ArrayData a -> ArrayData b -> ArrayData (a, b) -pairArrayData = AD_Pair +-- Reads an element from an array +unsafeIndexArrayData :: TupleType e -> ArrayData e -> Int -> e +unsafeIndexArrayData TupRunit () !_ = () +unsafeIndexArrayData (TupRpair t1 t2) (a1, a2) !ix = (unsafeIndexArrayData t1 a1 ix, unsafeIndexArrayData t2 a2 ix) +unsafeIndexArrayData (TupRsingle (SingleScalarType tp)) arr ix + | (ScalarDict, _, to) <- singleDict tp = to $! unsafeIndexArray arr ix +-- VectorScalarType is handled in unsafeReadArrayData +unsafeIndexArrayData !tp !arr !ix = unsafePerformIO $! unsafeReadArrayData tp arr ix + +ptrOfArrayData :: ScalarType e -> ArrayData e -> Ptr (ScalarDataRepr e) +ptrOfArrayData tp arr + | (_, ScalarDict) <- scalarDict tp = unsafeUniqueArrayPtr arr + +touchArrayData :: TupleType e -> ArrayData e -> IO () +touchArrayData TupRunit () = return () +touchArrayData (TupRpair t1 t2) (a1, a2) = touchArrayData t1 a1 >> touchArrayData t2 a2 +touchArrayData (TupRsingle tp) arr + | (_, ScalarDict) <- scalarDict tp = touchUniqueArray arr + +newArrayData :: TupleType e -> Int -> IO (MutableArrayData e) +newArrayData TupRunit !_ = return () +newArrayData (TupRpair t1 t2) !size = (,) <$> newArrayData t1 size <*> newArrayData t2 size +newArrayData (TupRsingle tp) !size + | (n, ScalarDict) <- scalarDict tp = newArrayData' (n * size) + +unsafeReadArrayData :: forall e. TupleType e -> MutableArrayData e -> Int -> IO e +unsafeReadArrayData TupRunit () !_ = return () +unsafeReadArrayData (TupRpair t1 t2) (a1, a2) !ix = (,) <$> unsafeReadArrayData t1 a1 ix <*> unsafeReadArrayData t2 a2 ix +unsafeReadArrayData (TupRsingle (SingleScalarType tp)) arr !ix + | (ScalarDict, _, to) <- singleDict tp = to <$> unsafeReadArray arr ix +unsafeReadArrayData (TupRsingle (VectorScalarType (VectorType (I# w#) tp))) arr (I# ix#) + | (ScalarDict, _, _) <- singleDict tp = + let + !bytes# = w# *# sizeOf# (undefined :: ScalarDataRepr e) + !addr# = unPtr# (unsafeUniqueArrayPtr arr) `plusAddr#` (ix# *# bytes#) + in + IO $ \s -> + case newByteArray# bytes# s of { (# s1, mba# #) -> + case copyAddrToByteArray# addr# mba# 0# bytes# s1 of { s2 -> + case unsafeFreezeByteArray# mba# s2 of { (# s3, ba# #) -> + (# s3, Vec ba# #) + }}} + +unsafeWriteArrayData :: forall e. TupleType e -> MutableArrayData e -> Int -> e -> IO () +unsafeWriteArrayData TupRunit () !_ () = return () +unsafeWriteArrayData (TupRpair t1 t2) (a1, a2) !ix (v1, v2) + = unsafeWriteArrayData t1 a1 ix v1 + >> unsafeWriteArrayData t2 a2 ix v2 +unsafeWriteArrayData (TupRsingle (SingleScalarType tp)) arr !ix !val + | (ScalarDict, from, _) <- singleDict tp = unsafeWriteArray arr ix (from val) +unsafeWriteArrayData (TupRsingle (VectorScalarType (VectorType (I# w#) tp))) arr !(I# ix#) (Vec ba# :: Vec n t) + | (ScalarDict, _, _) <- singleDict tp = + let + !bytes# = w# *# sizeOf# (undefined :: ScalarDataRepr e) + !addr# = unPtr# (unsafeUniqueArrayPtr arr) `plusAddr#` (ix# *# bytes#) + in + IO $ \s -> case copyByteArrayToAddr# ba# 0# addr# bytes# s of + s1 -> (# s1, () #) +rnfArrayData :: TupleType e -> ArrayData e -> () +rnfArrayData TupRunit () = () +rnfArrayData (TupRpair t1 t2) (a1, a2) = rnfArrayData t1 a1 `seq` rnfArrayData t2 a2 +rnfArrayData (TupRsingle tp) arr = rnf $ ptrOfArrayData tp arr -- Auxiliary functions -- ------------------- @@ -371,110 +352,3 @@ mallocPlainForeignPtrBytesAligned (I# size) = IO $ \s -> case newAlignedPinnedByteArray# size 64# s of (# s', mbarr# #) -> (# s', ForeignPtr (byteArrayContents# (unsafeCoerce# mbarr#)) (PlainPtr mbarr#) #) - --- Instances --- --------- --- -$(runQ $ do - let - integralTypes :: [Name] - integralTypes = - [ ''Int - , ''Int8 - , ''Int16 - , ''Int32 - , ''Int64 - , ''Word - , ''Word8 - , ''Word16 - , ''Word32 - , ''Word64 - ] - - floatingTypes :: [Name] - floatingTypes = - [ ''Half - , ''Float - , ''Double - ] - - nonNumTypes :: [Name] - nonNumTypes = - [ ''Char -- wide characters are 4-bytes - --''Bool -- handled explicitly; stored as Word8 - ] - - allTypes :: [Name] - allTypes = integralTypes ++ floatingTypes ++ nonNumTypes - - mkSingleElt :: Name -> Q [Dec] - mkSingleElt name = - let - n = nameBase name - t = conT name - con = conE (mkName ("AD_" ++ n)) - pat = conP (mkName ("AD_" ++ n)) [varP (mkName "ba")] - in - [d| instance ArrayElt $t where - type ArrayPtrs $t = Ptr $t - arrayElt = $(conE (mkName ("ArrayEltR" ++ map toLower n))) - {-# INLINE arrayElt #-} - {-# INLINE newArrayData #-} - {-# INLINE ptrsOfArrayData #-} - {-# INLINE touchArrayData #-} - {-# INLINE unsafeIndexArrayData #-} - {-# INLINE unsafeReadArrayData #-} - {-# INLINE unsafeWriteArrayData #-} - newArrayData size = $con <$> newArrayData' size - ptrsOfArrayData $pat = unsafeUniqueArrayPtr ba - touchArrayData $pat = touchUniqueArray ba - unsafeIndexArrayData $pat i = unsafeIndexArray ba i - unsafeReadArrayData $pat i = unsafeReadArray ba i - unsafeWriteArrayData $pat i e = unsafeWriteArray ba i e - |] - - mkVectorElt :: Name -> Q [Dec] - mkVectorElt name = - let t = conT name - in - [d| instance KnownNat n => ArrayElt (Vec n $t) where - type ArrayPtrs (Vec n $t) = ArrayPtrs $t - arrayElt = ArrayEltRvec arrayElt - {-# INLINE arrayElt #-} - {-# INLINE newArrayData #-} - {-# INLINE ptrsOfArrayData #-} - {-# INLINE touchArrayData #-} - {-# INLINE unsafeIndexArrayData #-} - {-# INLINE unsafeReadArrayData #-} - {-# INLINE unsafeWriteArrayData #-} - newArrayData size = - let !w@(I# w#) = fromIntegral (natVal' (proxy# :: Proxy# n)) - in AD_Vec w# <$> newArrayData (w * size) - - ptrsOfArrayData (AD_Vec _ ba) = ptrsOfArrayData ba - touchArrayData (AD_Vec _ ba) = touchArrayData ba - unsafeIndexArrayData vec ix = unsafePerformIO $! unsafeReadArrayData vec ix - unsafeReadArrayData (AD_Vec w# ad) (I# ix#) = - let !bytes# = w# *# sizeOf# (undefined :: $t) - !addr# = unPtr# (ptrsOfArrayData ad) `plusAddr#` (ix# *# bytes#) - in - IO $ \s -> - case newByteArray# bytes# s of { (# s1, mba# #) -> - case copyAddrToByteArray# addr# mba# 0# bytes# s1 of { s2 -> - case unsafeFreezeByteArray# mba# s2 of { (# s3, ba# #) -> - (# s3, Vec ba# #) - }}} - unsafeWriteArrayData (AD_Vec w# ad) (I# ix#) (Vec ba#) = - let !bytes# = w# *# sizeOf# (undefined :: $t) - !addr# = unPtr# (ptrsOfArrayData ad) `plusAddr#` (ix# *# bytes#) - in - IO $ \s -> - case copyByteArrayToAddr# ba# 0# addr# bytes# s of - s1 -> (# s1, () #) - |] - -- - ss <- mapM mkSingleElt allTypes - vv <- mapM mkVectorElt allTypes - return (concat ss ++ concat vv) - ) - diff --git a/src/Data/Array/Accelerate/Array/Lifted.hs b/src/Data/Array/Accelerate/Array/Lifted.hs index 4001ae5d8..957e4695b 100644 --- a/src/Data/Array/Accelerate/Array/Lifted.hs +++ b/src/Data/Array/Accelerate/Array/Lifted.hs @@ -1,5 +1,4 @@ {-# LANGUAGE ConstraintKinds #-} -{-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE PatternGuards #-} @@ -35,10 +34,8 @@ module Data.Array.Accelerate.Array.Lifted ( ) where import Prelude hiding ( concat ) -import Data.Typeable -- friends -import Data.Array.Accelerate.Product import Data.Array.Accelerate.Array.Sugar import qualified Data.Array.Accelerate.Array.Representation as Repr @@ -51,7 +48,6 @@ import qualified Data.Array.Accelerate.Array.Representation as Repr -- of arrays, are still members of the 'Arrays' class. newtype Vector' a = Vector' (LiftedRepr (ArrRepr a) a) - deriving Typeable type family LiftedRepr r a where LiftedRepr () () = ((),Scalar Int) @@ -64,24 +60,7 @@ type instance LiftedTupleRepr (b, a) = (LiftedTupleRepr b, Vector' a) type LiftedArray sh e = Vector' (Array sh e) -instance Arrays t => IsProduct Arrays (Vector' t) where - type ProdRepr (Vector' t) = LiftedRepr (ArrRepr t) t - fromProd _ (Vector' t) = t - toProd _ = Vector' - prod _ _ = case flavour (undefined :: t) of - ArraysFunit -> ProdRsnoc ProdRunit - ArraysFarray -> ProdRsnoc (ProdRsnoc ProdRunit) - ArraysFtuple -> tup $ prod (Proxy :: Proxy Arrays) (undefined :: t) - where - tup :: forall a. ProdR Arrays a -> ProdR Arrays (LiftedTupleRepr a) - tup ProdRunit = ProdRunit - tup (ProdRsnoc t) = swiz - where - swiz :: forall l r. (a ~ (l,r), Arrays r) => ProdR Arrays (LiftedTupleRepr a) - swiz | IsC <- isArraysFlat (undefined :: r) - = ProdRsnoc (tup t) - -instance (Arrays t, Typeable (ArrRepr (Vector' t))) => Arrays (Vector' t) where +instance Arrays t => Arrays (Vector' t) where type ArrRepr (Vector' t) = ArrRepr (TupleRepr (Vector' t)) arrays _ = arrs (prod (Proxy :: Proxy Arrays) (undefined :: Vector' t)) where diff --git a/src/Data/Array/Accelerate/Array/Remote/Class.hs b/src/Data/Array/Accelerate/Array/Remote/Class.hs index 86b93b6a2..a09236399 100644 --- a/src/Data/Array/Accelerate/Array/Remote/Class.hs +++ b/src/Data/Array/Accelerate/Array/Remote/Class.hs @@ -27,27 +27,21 @@ module Data.Array.Accelerate.Array.Remote.Class ( - RemoteMemory(..), PrimElt + RemoteMemory(..) ) where import Data.Array.Accelerate.Array.Data +import Data.Array.Accelerate.Type (SingleType) import Control.Applicative import Control.Monad.Catch import Data.Int import Data.Kind -import Data.Typeable import Data.Word -import Foreign.Ptr -import Foreign.Storable import Prelude --- | Matches array element types to primitive types. --- -type PrimElt e a = (ArrayElt e, Storable a, ArrayPtrs e ~ Ptr a, Typeable e, Typeable a) - -- | Accelerate backends can provide an instance of this class in order to take -- advantage of the automated memory managers we provide as part of the base -- package. @@ -62,10 +56,10 @@ class (Applicative m, Monad m, MonadCatch m, MonadMask m) => RemoteMemory m wher mallocRemote :: Int -> m (Maybe (RemotePtr m Word8)) -- | Copy the given number of elements from the host array into remote memory. - pokeRemote :: PrimElt e a => Int -> RemotePtr m a -> ArrayData e -> m () + pokeRemote :: SingleType e -> Int -> RemotePtr m (ScalarDataRepr e) -> ArrayData e -> m () -- | Copy the given number of elements from remote memory to the host array. - peekRemote :: PrimElt e a => Int -> RemotePtr m a -> MutableArrayData e -> m () + peekRemote :: SingleType e -> Int -> RemotePtr m (ScalarDataRepr e) -> MutableArrayData e -> m () -- | Cast a remote pointer. castRemotePtr :: RemotePtr m a -> RemotePtr m b diff --git a/src/Data/Array/Accelerate/Array/Remote/LRU.hs b/src/Data/Array/Accelerate/Array/Remote/LRU.hs index 5f414c344..67783ff55 100644 --- a/src/Data/Array/Accelerate/Array/Remote/LRU.hs +++ b/src/Data/Array/Accelerate/Array/Remote/LRU.hs @@ -1,4 +1,5 @@ {-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE CPP #-} {-# LANGUAGE BangPatterns #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DoAndIfThenElse #-} @@ -41,20 +42,25 @@ import Control.Monad ( filterM ) import Control.Monad.Catch import Control.Monad.IO.Class ( MonadIO, liftIO ) import Data.Functor +#if __GLASGOW_HASKELL__ < 808 import Data.Int ( Int64 ) +#endif import Data.Maybe ( isNothing ) -import Foreign.Storable ( sizeOf ) import System.CPUTime import System.Mem.Weak ( Weak, deRefWeak, finalize ) import Prelude hiding ( lookup ) import qualified Data.HashTable.IO as HT -import Data.Array.Accelerate.Array.Data ( ArrayData, touchArrayData ) +import Data.Array.Accelerate.Type +import Data.Array.Accelerate.Analysis.Type ( sizeOfSingleType ) +import Data.Array.Accelerate.Analysis.Match ( matchSingleType, (:~:)(..) ) +import Data.Array.Accelerate.Array.Data ( ArrayData, ScalarData, ScalarDataRepr, ScalarDict(..), singleDict ) import Data.Array.Accelerate.Array.Remote.Class import Data.Array.Accelerate.Array.Remote.Table ( StableArray, makeWeakArrayData ) import Data.Array.Accelerate.Error ( internalError ) import qualified Data.Array.Accelerate.Array.Remote.Table as Basic import qualified Data.Array.Accelerate.Debug as D +import Data.Array.Accelerate.Array.Unique ( touchUniqueArray ) -- We build cached memory tables on top of a basic memory table. @@ -80,13 +86,14 @@ data Status = Clean -- Array in remote memory matches array in host memory. type Timestamp = Integer data Used task where - Used :: PrimElt e a + Used :: ArrayData e ~ ScalarData e => !Timestamp -> !Status -> {-# UNPACK #-} !Int -- Use count -> ![task] -- Asynchronous tasks using the array - -> {-# UNPACK #-} !Int -- Array size - -> {-# UNPACK #-} !(Weak (ArrayData e)) + -> {-# UNPACK #-} !Int -- Number of elements + -> !(SingleType e) + -> {-# UNPACK #-} !(Weak (ScalarData e)) -> Used task -- | A Task represents a process executing asynchronously that can be polled for @@ -129,55 +136,59 @@ new release = do -- more accesses of the remote pointer. -- withRemote - :: forall task m a b c. (PrimElt a b, Task task, RemoteMemory m, MonadIO m, Functor m) + :: forall task m a c. (Task task, RemoteMemory m, MonadIO m, Functor m) => MemoryTable (RemotePtr m) task + -> SingleType a -> ArrayData a - -> (RemotePtr m b -> m (task, c)) + -> (RemotePtr m (ScalarDataRepr a) -> m (task, c)) -> m (Maybe c) -withRemote (MemoryTable !mt !ref _) !arr run = do - key <- Basic.makeStableArray arr - mp <- withMVar' ref $ \utbl -> do - mu <- liftIO . HT.mutate utbl key $ \case - Nothing -> (Nothing, Nothing) - Just u -> (Just (incCount u), Just u) +withRemote (MemoryTable !mt !ref _) !tp !arr run + | (ScalarDict, _, _) <- singleDict tp = do + key <- Basic.makeStableArray tp arr + mp <- withMVar' ref $ \utbl -> do + mu <- liftIO . HT.mutate utbl key $ \case + Nothing -> (Nothing, Nothing) + Just u -> (Just (incCount u), Just u) + -- + case mu of + Nothing -> do + message ("withRemote/array has never been malloc'd: " ++ show key) + return Nothing -- The array was never in the table + + Just u -> do + mp <- liftIO $ Basic.lookup @m mt tp arr + ptr <- case mp of + Just p -> return p + Nothing + | isEvicted u -> copyBack utbl (incCount u) + | otherwise -> do message ("lost array " ++ show key) + $internalError "withRemote" "non-evicted array has been lost" + return (Just ptr) -- - case mu of - Nothing -> do - message ("withRemote/array has never been malloc'd: " ++ show key) - return Nothing -- The array was never in the table - - Just u -> do - mp <- liftIO $ Basic.lookup mt arr - ptr <- case mp of - Just p -> return p - Nothing - | isEvicted u -> copyBack utbl (incCount u) - | otherwise -> do message ("lost array " ++ show key) - $internalError "withRemote" "non-evicted array has been lost" - return (Just ptr) - -- - case mp of - Nothing -> return Nothing - Just ptr -> Just <$> go key ptr + case mp of + Nothing -> return Nothing + Just ptr -> Just <$> go key ptr where updateTask :: Used task -> task -> IO (Used task) - updateTask (Used _ status count tasks n weak_arr) task = do + updateTask (Used _ status count tasks n tp' weak_arr) task = do ts <- getCPUTime tasks' <- cleanUses tasks - return (Used ts status (count - 1) (task : tasks') n weak_arr) + return (Used ts status (count - 1) (task : tasks') n tp' weak_arr) - copyBack :: UT task -> Used task -> m (RemotePtr m b) - copyBack utbl (Used ts _ count tasks n weak_arr) = do - message "withRemote/reuploading-evicted-array" - p <- mallocWithUsage mt utbl arr (Used ts Clean count tasks n weak_arr) - pokeRemote n p arr - return p + copyBack :: UT task -> Used task -> m (RemotePtr m (ScalarDataRepr a)) + copyBack utbl (Used ts _ count tasks n tp' weak_arr) + | Just Refl <- matchSingleType tp tp' = do + message "withRemote/reuploading-evicted-array" + p <- mallocWithUsage mt utbl tp arr (Used ts Clean count tasks n tp weak_arr) + pokeRemote tp n p arr + return p + | otherwise = $internalError "withRemote" "Type mismatch" -- We can't combine the use of `withMVar ref` above with the one here -- because the `permute` operation from the PTX backend requires nested -- calls to `withRemote` in order to copy the defaults array. -- - go :: StableArray -> RemotePtr m b -> m c + go :: ArrayData a ~ ScalarData a => StableArray -> RemotePtr m (ScalarDataRepr a) -> m c go key ptr = do message ("withRemote/using: " ++ show key) (task, c) <- run ptr @@ -188,7 +199,7 @@ withRemote (MemoryTable !mt !ref _) !arr run = do u' <- updateTask u task return (Just u', ()) -- - touchArrayData arr + touchUniqueArray arr return c @@ -207,15 +218,18 @@ withRemote (MemoryTable !mt !ref _) !arr run = do -- On return, 'True' indicates that we allocated some remote memory, and 'False' -- indicates that we did not need to. -- -malloc :: forall a e m task. (PrimElt e a, RemoteMemory m, MonadIO m, Task task) +malloc :: forall e m task. (RemoteMemory m, MonadIO m, Task task) => MemoryTable (RemotePtr m) task + -> SingleType e -> ArrayData e -> Bool -- ^ True if host array is frozen. - -> Int + -> Int -- ^ Number of elements -> m Bool -- ^ Was the array allocated successfully? -malloc (MemoryTable mt ref weak_utbl) !ad !frozen !n = do +malloc (MemoryTable mt ref weak_utbl) !tp !ad !frozen !n + | (ScalarDict, _, _) <- singleDict tp -- Required for ArrayData e ~ ScalarData e + = do ts <- liftIO $ getCPUTime - key <- Basic.makeStableArray ad + key <- Basic.makeStableArray tp ad -- let status = if frozen then Clean @@ -225,30 +239,32 @@ malloc (MemoryTable mt ref weak_utbl) !ad !frozen !n = do mu <- liftIO $ HT.lookup utbl key if isNothing mu then do - weak_arr <- liftIO $ makeWeakArrayData ad ad (Just $ finalizer key weak_utbl) - _ <- mallocWithUsage mt utbl ad (Used ts status 0 [] n weak_arr) + weak_arr <- liftIO $ makeWeakArrayData tp ad ad (Just $ finalizer key weak_utbl) + _ <- mallocWithUsage mt utbl tp ad (Used ts status 0 [] n tp weak_arr) return True else return False mallocWithUsage - :: forall a e m task. (PrimElt e a, RemoteMemory m, MonadIO m, Task task) + :: forall e m task. (RemoteMemory m, MonadIO m, Task task, ArrayData e ~ ScalarData e) => Basic.MemoryTable (RemotePtr m) -> UT task + -> SingleType e -> ArrayData e -> Used task - -> m (RemotePtr m a) -mallocWithUsage !mt !utbl !ad !usage@(Used _ _ _ _ n _) = malloc' + -> m (RemotePtr m (ScalarDataRepr e)) +mallocWithUsage !mt !utbl !tp !ad !usage@(Used _ _ _ _ n _ _) = malloc' where + malloc' :: m (RemotePtr m (ScalarDataRepr e)) malloc' = do - mp <- Basic.malloc mt ad n :: m (Maybe (RemotePtr m a)) + mp <- Basic.malloc @e @m mt tp ad n :: m (Maybe (RemotePtr m (ScalarDataRepr e))) case mp of Nothing -> do success <- evictLRU utbl mt if success then malloc' else $internalError "malloc" "Remote memory exhausted" Just p -> liftIO $ do - key <- Basic.makeStableArray ad + key <- Basic.makeStableArray tp ad HT.insert utbl key usage return p @@ -260,7 +276,7 @@ evictLRU evictLRU !utbl !mt = trace "evictLRU/evicting-eldest-array" $ do mused <- liftIO $ HT.foldM eldest Nothing utbl case mused of - Just (sa, Used ts status count tasks n weak_arr) -> do + Just (sa, Used ts status count tasks n tp weak_arr) -> do mad <- liftIO $ deRefWeak weak_arr case mad of Nothing -> liftIO $ do @@ -277,28 +293,28 @@ evictLRU !utbl !mt = trace "evictLRU/evicting-eldest-array" $ do Just arr -> do message ("evictLRU/evicting " ++ show sa) - copyIfNecessary status n arr - liftIO $ D.didEvictBytes (remoteBytes n weak_arr) + copyIfNecessary status n tp arr + liftIO $ D.didEvictBytes (remoteBytes tp n) liftIO $ Basic.freeStable @m mt sa - liftIO $ HT.insert utbl sa (Used ts Evicted count tasks n weak_arr) + liftIO $ HT.insert utbl sa (Used ts Evicted count tasks n tp weak_arr) return True _ -> trace "evictLRU/All arrays in use, unable to evict" $ return False where -- Find the eldest, not currently in use, array. eldest :: (Maybe (StableArray, Used task)) -> (StableArray, Used task) -> IO (Maybe (StableArray, Used task)) - eldest prev (sa, used@(Used ts status count tasks n weak_arr)) | count == 0 + eldest prev (sa, used@(Used ts status count tasks n tp weak_arr)) | count == 0 , evictable status = do tasks' <- cleanUses tasks - HT.insert utbl sa (Used ts status count tasks' n weak_arr) + HT.insert utbl sa (Used ts status count tasks' n tp weak_arr) case tasks' of - [] | Just (_, Used ts' _ _ _ _ _) <- prev + [] | Just (_, Used ts' _ _ _ _ _ _) <- prev , ts < ts' -> return (Just (sa, used)) | Nothing <- prev -> return (Just (sa, used)) _ -> return prev eldest prev _ = return prev - remoteBytes :: forall e a. PrimElt e a => Int -> Weak (ArrayData e) -> Int64 - remoteBytes n _ = fromIntegral n * fromIntegral (sizeOf (undefined::a)) + remoteBytes :: SingleType e -> Int -> Int64 + remoteBytes tp n = fromIntegral (sizeOfSingleType tp) * fromIntegral n evictable :: Status -> Bool evictable Clean = True @@ -306,28 +322,29 @@ evictLRU !utbl !mt = trace "evictLRU/evicting-eldest-array" $ do evictable Unmanaged = False evictable Evicted = False - copyIfNecessary :: PrimElt e a => Status -> Int -> ArrayData e -> m () - copyIfNecessary Clean _ _ = return () - copyIfNecessary Unmanaged _ _ = return () - copyIfNecessary Evicted _ _ = $internalError "evictLRU" "Attempting to evict already evicted array" - copyIfNecessary Dirty n ad = do - mp <- liftIO $ Basic.lookup mt ad + copyIfNecessary :: Status -> Int -> SingleType e -> ArrayData e -> m () + copyIfNecessary Clean _ _ _ = return () + copyIfNecessary Unmanaged _ _ _ = return () + copyIfNecessary Evicted _ _ _ = $internalError "evictLRU" "Attempting to evict already evicted array" + copyIfNecessary Dirty n tp ad = do + mp <- liftIO $ Basic.lookup @m mt tp ad case mp of Nothing -> return () -- RCE: I think this branch is actually impossible. - Just p -> peekRemote n p ad + Just p -> peekRemote tp n p ad -- | Deallocate the device array associated with the given host-side array. -- Typically this should only be called in very specific circumstances. This -- operation is not thread-safe. -- -free :: forall m a b task. (RemoteMemory m, PrimElt a b) +free :: forall m a task. (RemoteMemory m) => MemoryTable (RemotePtr m) task + -> SingleType a -> ArrayData a -> IO () -free (MemoryTable !mt !ref _) !arr +free (MemoryTable !mt !ref _) !tp !arr = withMVar' ref $ \utbl -> do - key <- Basic.makeStableArray arr + key <- Basic.makeStableArray tp arr delete utbl key Basic.freeStable @m mt key @@ -338,20 +355,22 @@ free (MemoryTable !mt !ref _) !arr -- This typically only has use for backends that provide an FFI. -- insertUnmanaged - :: (PrimElt e a, MonadIO m) - => MemoryTable p task + :: (MonadIO m, RemoteMemory m) + => MemoryTable (RemotePtr m) task + -> SingleType e -> ArrayData e - -> p a + -> RemotePtr m (ScalarDataRepr e) -> m () -insertUnmanaged (MemoryTable mt ref weak_utbl) !arr !ptr - = liftIO - . withMVar ref - $ \utbl -> do - key <- Basic.makeStableArray arr - () <- Basic.insertUnmanaged mt arr ptr +insertUnmanaged (MemoryTable mt ref weak_utbl) !tp !arr !ptr + | (ScalarDict, _, _) <- singleDict tp = do -- Gives evidence that ArrayData e ~ ScalarData e + key <- Basic.makeStableArray tp arr + () <- Basic.insertUnmanaged mt tp arr ptr + liftIO + $ withMVar ref + $ \utbl -> do ts <- getCPUTime - weak_arr <- makeWeakArrayData arr arr (Just $ finalizer key weak_utbl) - HT.insert utbl key (Used ts Unmanaged 0 [] 0 weak_arr) + weak_arr <- makeWeakArrayData tp arr arr (Just $ finalizer key weak_utbl) + HT.insert utbl key (Used ts Unmanaged 0 [] 0 tp weak_arr) -- Removing entries @@ -383,7 +402,7 @@ cache_finalizer !tbl $ HT.mapM_ (\(_,u) -> f u) tbl where f :: Used task -> IO () - f (Used _ _ _ _ _ w) = finalize w + f (Used _ _ _ _ _ _ w) = finalize w -- Miscellaneous -- ------------- @@ -392,10 +411,10 @@ cleanUses :: Task task => [task] -> IO [task] cleanUses = filterM (fmap not . completed) incCount :: Used task -> Used task -incCount (Used ts status count uses n weak_arr) = Used ts status (count + 1) uses n weak_arr +incCount (Used ts status count uses n tp weak_arr) = Used ts status (count + 1) uses n tp weak_arr isEvicted :: Used task -> Bool -isEvicted (Used _ status _ _ _ _) = status == Evicted +isEvicted (Used _ status _ _ _ _ _) = status == Evicted {-# INLINE withMVar' #-} withMVar' :: (MonadIO m, MonadMask m) => MVar a -> (a -> m b) -> m b diff --git a/src/Data/Array/Accelerate/Array/Remote/Table.hs b/src/Data/Array/Accelerate/Array/Remote/Table.hs index cf469cb3d..1b4116820 100644 --- a/src/Data/Array/Accelerate/Array/Remote/Table.hs +++ b/src/Data/Array/Accelerate/Array/Remote/Table.hs @@ -10,6 +10,7 @@ {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} {-# LANGUAGE UnboxedTuples #-} {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_HADDOCK hide #-} @@ -35,7 +36,7 @@ module Data.Array.Accelerate.Array.Remote.Table ( -- Internals StableArray, makeStableArray, - makeWeakArrayData + makeWeakArrayData, ) where @@ -46,7 +47,6 @@ import Control.Monad.IO.Class ( MonadIO, liftI import Data.Functor import Data.Hashable ( hash, Hashable ) import Data.Maybe ( isJust ) -import Data.Typeable ( Typeable, gcast, typeOf ) import Data.Word import Foreign.Storable ( sizeOf ) import System.Mem ( performGC ) @@ -55,12 +55,10 @@ import Text.Printf import Prelude hiding ( lookup, id ) import qualified Data.HashTable.IO as HT -import GHC.Exts ( Ptr(..) ) - import Data.Array.Accelerate.Error ( internalError ) +import Data.Array.Accelerate.Type import Data.Array.Accelerate.Array.Unique ( UniqueArray(..) ) -import Data.Array.Accelerate.Array.Data ( ArrayData, GArrayData(..), - ArrayPtrs, ArrayElt, arrayElt, ArrayEltR(..) ) +import Data.Array.Accelerate.Array.Data import Data.Array.Accelerate.Array.Remote.Class import Data.Array.Accelerate.Array.Remote.Nursery ( Nursery(..) ) import Data.Array.Accelerate.Lifetime @@ -89,8 +87,7 @@ data MemoryTable p = MemoryTable {-# UNPACK #-} !(MT p) (p Word8 -> IO ()) data RemoteArray p where - RemoteArray :: Typeable e - => !(p e) -- The actual remote pointer + RemoteArray :: !(p Word8) -- The actual remote pointer -> {-# UNPACK #-} !Int -- The array size in bytes -> {-# UNPACK #-} !(Weak ()) -- Keep track of host array liveness -> RemoteArray p @@ -122,36 +119,37 @@ new release = do -- | Look for the remote pointer corresponding to a given host-side array. -- -lookup :: PrimElt a b - => MemoryTable p +lookup :: forall m a. + RemoteMemory m + => MemoryTable (RemotePtr m) + -> SingleType a -> ArrayData a - -> IO (Maybe (p b)) -lookup (MemoryTable !ref _ _ _) !arr = do - sa <- makeStableArray arr - mw <- withMVar ref (`HT.lookup` sa) - case mw of - Nothing -> trace ("lookup/not found: " ++ show sa) $ return Nothing - Just (RemoteArray p _ w) -> do - mv <- deRefWeak w - case mv of - Just{} | Just p' <- gcast p -> trace ("lookup/found: " ++ show sa) $ return (Just p') - | otherwise -> $internalError "lookup" "type mismatch" - - -- Note: [Weak pointer weirdness] - -- - -- After the lookup is successful, there might conceivably be no further - -- references to 'arr'. If that is so, and a garbage collection - -- intervenes, the weak pointer might get tombstoned before 'deRefWeak' - -- gets to it. In that case we throw an error (below). However, because - -- we have used 'arr' in the continuation, this ensures that 'arr' is - -- reachable in the continuation of 'deRefWeak' and thus 'deRefWeak' - -- always succeeds. This sort of weirdness, typical of the world of weak - -- pointers, is why we can not reuse the stable name 'sa' computed - -- above in the error message. - -- - Nothing -> - makeStableArray arr >>= \x -> $internalError "lookup" $ "dead weak pair: " ++ show x - + -> IO (Maybe (RemotePtr m (ScalarDataRepr a))) +lookup (MemoryTable !ref _ _ _) !tp !arr + | (ScalarDict, _, _) <- singleDict tp = do + sa <- makeStableArray tp arr + mw <- withMVar ref (`HT.lookup` sa) + case mw of + Nothing -> trace ("lookup/not found: " ++ show sa) $ return Nothing + Just (RemoteArray p _ w) -> do + mv <- deRefWeak w + case mv of + Just{} -> trace ("lookup/found: " ++ show sa) $ return (Just $ castRemotePtr @m p) + + -- Note: [Weak pointer weirdness] + -- + -- After the lookup is successful, there might conceivably be no further + -- references to 'arr'. If that is so, and a garbage collection + -- intervenes, the weak pointer might get tombstoned before 'deRefWeak' + -- gets to it. In that case we throw an error (below). However, because + -- we have used 'arr' in the continuation, this ensures that 'arr' is + -- reachable in the continuation of 'deRefWeak' and thus 'deRefWeak' + -- always succeeds. This sort of weirdness, typical of the world of weak + -- pointers, is why we can not reuse the stable name 'sa' computed + -- above in the error message. + -- + Nothing -> + makeStableArray tp arr >>= \x -> $internalError "lookup" $ "dead weak pair: " ++ show x -- | Allocate a new device array to be associated with the given host-side array. -- This may not always use the `malloc` provided by the `RemoteMemory` instance. @@ -159,45 +157,46 @@ lookup (MemoryTable !ref _ _ _) !arr = do -- arrays will be re-used. In the event that the remote memory is exhausted, -- 'Nothing' is returned. -- -malloc :: forall a b m. (PrimElt a b, RemoteMemory m, MonadIO m) +malloc :: forall a m. (RemoteMemory m, MonadIO m) => MemoryTable (RemotePtr m) + -> SingleType a -> ArrayData a -> Int - -> m (Maybe (RemotePtr m b)) -malloc mt@(MemoryTable _ _ !nursery _) !ad !n = do - -- Note: [Allocation sizes] - -- - -- Instead of allocating the exact number of elements requested, we round up to - -- a fixed chunk size as specified by RemoteMemory.remoteAllocationSize. This - -- means there is a greater chance the nursery will get a hit, and moreover - -- that we can search the nursery for an exact size. - -- - chunk <- remoteAllocationSize - let -- next highest multiple of f from x - multiple x f = (x + (f-1)) `quot` f - bytes = chunk * multiple (n * sizeOf (undefined::b)) chunk - -- - message $ printf "malloc %d bytes (%d x %d bytes, type=%s, pagesize=%d)" bytes n (sizeOf (undefined::b)) (show (typeOf (undefined::a))) chunk - -- - mp <- - fmap (castRemotePtr @m) - <$> attempt "malloc/nursery" (liftIO $ N.lookup bytes nursery) - `orElse` - attempt "malloc/new" (mallocRemote bytes) - `orElse` do message "malloc/remote-malloc-failed (cleaning)" - clean mt - liftIO $ N.lookup bytes nursery - `orElse` do message "malloc/remote-malloc-failed (purging)" - purge mt - mallocRemote bytes - `orElse` do message "malloc/remote-malloc-failed (non-recoverable)" - return Nothing - case mp of - Nothing -> return Nothing - Just p' -> do - insert mt ad p' bytes - return mp - + -> m (Maybe (RemotePtr m (ScalarDataRepr a))) +malloc mt@(MemoryTable _ _ !nursery _) !tp !ad !n + | (ScalarDict, _, _) <- singleDict tp = do + -- Note: [Allocation sizes] + -- + -- Instead of allocating the exact number of elements requested, we round up to + -- a fixed chunk size as specified by RemoteMemory.remoteAllocationSize. This + -- means there is a greater chance the nursery will get a hit, and moreover + -- that we can search the nursery for an exact size. + -- + chunk <- remoteAllocationSize + let -- next highest multiple of f from x + multiple x f = (x + (f-1)) `quot` f + bytes = chunk * multiple (n * sizeOf (undefined::(ScalarDataRepr a))) chunk + -- + message $ printf "malloc %d bytes (%d x %d bytes, type=%s, pagesize=%d)" bytes n (sizeOf (undefined:: (ScalarDataRepr a))) (show tp) chunk + -- + mp <- + fmap (castRemotePtr @m) + <$> attempt "malloc/nursery" (liftIO $ N.lookup bytes nursery) + `orElse` + attempt "malloc/new" (mallocRemote bytes) + `orElse` do message "malloc/remote-malloc-failed (cleaning)" + clean mt + liftIO $ N.lookup bytes nursery + `orElse` do message "malloc/remote-malloc-failed (purging)" + purge mt + mallocRemote bytes + `orElse` do message "malloc/remote-malloc-failed (non-recoverable)" + return Nothing + case mp of + Nothing -> return Nothing + Just p' -> do + insert mt tp ad p' bytes + return mp where {-# INLINE orElse #-} orElse :: m (Maybe x) -> m (Maybe x) -> m (Maybe x) @@ -220,12 +219,13 @@ malloc mt@(MemoryTable _ _ !nursery _) !ad !n = do -- | Deallocate the device array associated with the given host-side array. -- Typically this should only be called in very specific circumstances. -- -free :: forall m a b. (RemoteMemory m, PrimElt a b) +free :: forall m a. (RemoteMemory m) => MemoryTable (RemotePtr m) + -> SingleType a -> ArrayData a -> IO () -free mt !arr = do - sa <- makeStableArray arr +free mt tp !arr = do + sa <- makeStableArray tp arr freeStable @m mt sa @@ -257,18 +257,20 @@ freeStable (MemoryTable !ref _ !nrs _) !sa = -- collected. -- insert - :: forall m a b. (PrimElt a b, RemoteMemory m, MonadIO m) + :: forall m a. (RemoteMemory m, MonadIO m) => MemoryTable (RemotePtr m) + -> SingleType a -> ArrayData a - -> RemotePtr m b + -> RemotePtr m (ScalarDataRepr a) -> Int -> m () -insert mt@(MemoryTable !ref _ _ _) !arr !ptr !bytes = do - key <- makeStableArray arr - weak <- liftIO $ makeWeakArrayData arr () (Just $ freeStable @m mt key) +insert mt@(MemoryTable !ref _ _ _) !tp !arr !ptr !bytes + | (ScalarDict, _, _) <- singleDict tp = do + key <- makeStableArray tp arr + weak <- liftIO $ makeWeakArrayData tp arr () (Just $ freeStable @m mt key) message $ "insert: " ++ show key liftIO $ D.increaseCurrentBytesRemote (fromIntegral bytes) - liftIO $ withMVar ref $ \tbl -> HT.insert tbl key (RemoteArray ptr bytes weak) + liftIO $ withMVar ref $ \tbl -> HT.insert tbl key (RemoteArray (castRemotePtr @m ptr) bytes weak) -- | Record an association between a host-side array and a remote memory area @@ -278,16 +280,18 @@ insert mt@(MemoryTable !ref _ _ _) !arr !ptr !bytes = do -- This typically only has use for backends that provide an FFI. -- insertUnmanaged - :: (PrimElt a b, MonadIO m) - => MemoryTable p + :: forall m a. (MonadIO m, RemoteMemory m) + => MemoryTable (RemotePtr m) + -> SingleType a -> ArrayData a - -> p b + -> RemotePtr m (ScalarDataRepr a) -> m () -insertUnmanaged (MemoryTable !ref !weak_ref _ _) !arr !ptr = do - key <- makeStableArray arr - weak <- liftIO $ makeWeakArrayData arr () (Just $ remoteFinalizer weak_ref key) - message $ "insertUnmanaged: " ++ show key - liftIO $ withMVar ref $ \tbl -> HT.insert tbl key (RemoteArray ptr 0 weak) +insertUnmanaged (MemoryTable !ref !weak_ref _ _) tp !arr !ptr + | (ScalarDict, _, _) <- singleDict tp = do + key <- makeStableArray tp arr + weak <- liftIO $ makeWeakArrayData tp arr () (Just $ remoteFinalizer weak_ref key) + message $ "insertUnmanaged: " ++ show key + liftIO $ withMVar ref $ \tbl -> HT.insert tbl key (RemoteArray (castRemotePtr @m ptr) 0 weak) -- Removing entries @@ -351,32 +355,12 @@ remoteFinalizer !weak_ref !key = do -- {-# INLINE makeStableArray #-} makeStableArray - :: (MonadIO m, Typeable a, Typeable e, ArrayPtrs a ~ Ptr e, ArrayElt a) - => ArrayData a + :: MonadIO m + => SingleType a + -> ArrayData a -> m StableArray -makeStableArray !ad = return $! StableArray (id arrayElt ad) - where - id :: (ArrayPtrs e ~ Ptr a) => ArrayEltR e -> ArrayData e -> Unique - id ArrayEltRint (AD_Int ua) = uniqueArrayId ua - id ArrayEltRint8 (AD_Int8 ua) = uniqueArrayId ua - id ArrayEltRint16 (AD_Int16 ua) = uniqueArrayId ua - id ArrayEltRint32 (AD_Int32 ua) = uniqueArrayId ua - id ArrayEltRint64 (AD_Int64 ua) = uniqueArrayId ua - id ArrayEltRword (AD_Word ua) = uniqueArrayId ua - id ArrayEltRword8 (AD_Word8 ua) = uniqueArrayId ua - id ArrayEltRword16 (AD_Word16 ua) = uniqueArrayId ua - id ArrayEltRword32 (AD_Word32 ua) = uniqueArrayId ua - id ArrayEltRword64 (AD_Word64 ua) = uniqueArrayId ua - id ArrayEltRhalf (AD_Half ua) = uniqueArrayId ua - id ArrayEltRfloat (AD_Float ua) = uniqueArrayId ua - id ArrayEltRdouble (AD_Double ua) = uniqueArrayId ua - id ArrayEltRbool (AD_Bool ua) = uniqueArrayId ua - id ArrayEltRchar (AD_Char ua) = uniqueArrayId ua - id (ArrayEltRvec r) (AD_Vec _ a) = id r a -#if __GLASGOW_HASKELL__ < 800 - id _ _ = - error "I do have a cause, though. It is obscenity. I'm for it." -#endif +makeStableArray !tp !ad + | (ScalarDict, _, _) <- singleDict tp = return $! StableArray (uniqueArrayId ad) -- Weak arrays @@ -386,38 +370,15 @@ makeStableArray !ad = return $! StableArray (id arrayElt ad) -- this guarantees finalisers won't fire early. -- makeWeakArrayData - :: forall a e c. (ArrayElt e, ArrayPtrs e ~ Ptr a) - => ArrayData e + :: forall e c. + SingleType e + -> ArrayData e -> c -> Maybe (IO ()) -> IO (Weak c) -makeWeakArrayData !ad !c !mf = mw arrayElt ad - where - mw :: (ArrayPtrs e' ~ Ptr a') => ArrayEltR e' -> ArrayData e' -> IO (Weak c) - mw ArrayEltRint (AD_Int ua) = mkWeak' ua - mw ArrayEltRint8 (AD_Int8 ua) = mkWeak' ua - mw ArrayEltRint16 (AD_Int16 ua) = mkWeak' ua - mw ArrayEltRint32 (AD_Int32 ua) = mkWeak' ua - mw ArrayEltRint64 (AD_Int64 ua) = mkWeak' ua - mw ArrayEltRword (AD_Word ua) = mkWeak' ua - mw ArrayEltRword8 (AD_Word8 ua) = mkWeak' ua - mw ArrayEltRword16 (AD_Word16 ua) = mkWeak' ua - mw ArrayEltRword32 (AD_Word32 ua) = mkWeak' ua - mw ArrayEltRword64 (AD_Word64 ua) = mkWeak' ua - mw ArrayEltRhalf (AD_Half ua) = mkWeak' ua - mw ArrayEltRfloat (AD_Float ua) = mkWeak' ua - mw ArrayEltRdouble (AD_Double ua) = mkWeak' ua - mw ArrayEltRbool (AD_Bool ua) = mkWeak' ua - mw ArrayEltRchar (AD_Char ua) = mkWeak' ua - mw (ArrayEltRvec r) (AD_Vec _ a) = mw r a -#if __GLASGOW_HASKELL__ < 800 - mw _ _ = - error "Base eight is just like base ten really --- if you're missing two fingers." -#endif - - mkWeak' :: UniqueArray a' -> IO (Weak c) - mkWeak' !ua = do - let !uad = uniqueArrayData ua +makeWeakArrayData !tp !ad !c !mf + | (ScalarDict, _, _) <- singleDict tp = do + let !uad = uniqueArrayData ad case mf of Nothing -> return () Just f -> addFinalizer uad f diff --git a/src/Data/Array/Accelerate/Array/Representation.hs b/src/Data/Array/Accelerate/Array/Representation.hs index 4a0d62b31..7c8fdbed4 100644 --- a/src/Data/Array/Accelerate/Array/Representation.hs +++ b/src/Data/Array/Accelerate/Array/Representation.hs @@ -1,7 +1,10 @@ {-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE MagicHash #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TupleSections #-} @@ -21,163 +24,313 @@ -- module Data.Array.Accelerate.Array.Representation ( + -- * Array data type in terms of representation types + Array(..), ArrayR(..), arraysRarray, arraysRtuple2, arrayRshape, arrayRtype, rnfArray, rnfShape, + ArraysR, TupleType, Scalar, Vector, Matrix, fromList, toList, Segments, shape, reshape, concatVectors, + showArrayR, showArraysR, fromFunction, fromFunctionM, reduceRank, allocateArray, -- * Array shapes, indices, and slices - Shape(..), Slice(..), SliceIndex(..), + ShapeR(..), Slice(..), SliceIndex(..), + DIM0, DIM1, DIM2, dim0, dim1, dim2, (!), (!!), + + -- * Shape functions + rank, size, empty, ignore, intersect, union, toIndex, fromIndex, iter, iter1, + rangeToShape, shapeToRange, shapeToList, listToShape, listToShape', shapeType, shapeEq, -- * Slice shape functions - sliceShape, enumSlices, + sliceShape, sliceShapeR, sliceDomainR, enumSlices, + + -- * Vec representation & utilities + VecR(..), vecRvector, vecRtuple, vecPack, vecUnpack, + + -- * Stencils + StencilR(..), stencilElt, stencilShape, stencilType, stencilArrayR, stencilHalo, + -- * Show + showShape, showElement, showArray, showArray', ) where -- friends import Data.Array.Accelerate.Error +import Data.Array.Accelerate.Type +import Data.Array.Accelerate.Array.Data -- standard library -import GHC.Base ( quotInt, remInt ) - - --- |Index representation +import GHC.Base ( quotInt, remInt, Int(..), Int#, (-#) ) +import GHC.TypeNats +import Data.Primitive.ByteArray +import Data.Primitive.Types +import Prelude hiding ((!!)) +import Data.List ( intercalate ) +import Text.Show ( showListWith ) +import System.IO.Unsafe ( unsafePerformIO ) +import qualified Data.Vector.Unboxed as U +import Control.Monad.ST + +-- |Array data type, where the type arguments regard the representation types of the shape and elements. +data Array sh e where + Array :: sh -- extent of dimensions = shape + -> ArrayData e -- array payload + -> Array sh e + +{-# INLINE shape #-} +shape :: Array sh e -> sh +shape (Array sh _) = sh + +{-# INLINE reshape #-} +reshape :: ShapeR sh -> sh -> ShapeR sh' -> Array sh' e -> Array sh e +reshape shr sh shr' (Array sh' adata) + = $boundsCheck "reshape" "shape mismatch" (size shr sh == size shr' sh') + $ Array sh adata + +{-# INLINE [1] (!) #-} +(!) :: (ArrayR (Array sh e), Array sh e) -> sh -> e +(!) (ArrayR shr tp, Array sh adata) ix = unsafeIndexArrayData tp adata $ toIndex shr sh ix + +{-# INLINE [1] (!!) #-} +(!!) :: (TupleType e, Array sh e) -> Int -> e +(tp, Array _ adata) !! i = unsafeIndexArrayData tp adata i + +-- | Create an array from its representation function, applied at each index of +-- the array. -- +{-# INLINEABLE fromFunction #-} +fromFunction :: ArrayR (Array sh e) -> sh -> (sh -> e) -> Array sh e +fromFunction repr sh f = unsafePerformIO $! fromFunctionM repr sh (return . f) --- |Class of index representations (which are nested pairs) --- -class (Eq sh, Slice sh) => Shape sh where - -- user-facing methods - rank :: Int -- ^number of dimensions (>= 0); rank of the array - size :: sh -> Int -- ^total number of elements in an array of this /shape/ - empty :: sh -- ^empty shape. - - -- internal methods - intersect :: sh -> sh -> sh -- yield the intersection of two shapes - union :: sh -> sh -> sh -- yield the union of two shapes - ignore :: sh -- identifies ignored elements in 'permute' - toIndex :: sh -> sh -> Int -- yield the index position in a linear, row-major representation of - -- the array (first argument is the shape) - fromIndex :: sh -> Int -> sh -- inverse of `toIndex` - - iter :: sh -> (sh -> a) -> (a -> a -> a) -> a -> a - -- iterate through the entire shape, applying the function in the - -- second argument; third argument combines results and fourth is an - -- initial value that is combined with the results; the index space - -- is traversed in row-major order - - iter1 :: sh -> (sh -> a) -> (a -> a -> a) -> a - -- variant of 'iter' without an initial value - - -- operations to facilitate conversion with IArray - rangeToShape :: (sh, sh) -> sh -- convert a minpoint-maxpoint index - -- into a shape - shapeToRange :: sh -> (sh, sh) -- ...the converse - - - -- other conversions - shapeToList :: sh -> [Int] -- convert a shape into its list of dimensions - listToShape :: [Int] -> sh -- convert a list of dimensions into a shape - listToShape' :: [Int] -> Maybe sh -- attempt to convert a list of dimensions into a shape - - listToShape ds = - case listToShape' ds of - Just sh -> sh - Nothing -> $internalError "listToShape" "unable to convert list to a shape at the specified type" - -instance Shape () where - rank = 0 - empty = () - ignore = () - () `intersect` () = () - () `union` () = () - size () = 1 - toIndex () () = 0 - fromIndex () _ = () - iter () f _ _ = f () - iter1 () f _ = f () - - rangeToShape ((), ()) = () - shapeToRange () = ((), ()) - - shapeToList () = [] - listToShape [] = () - listToShape _ = $internalError "listToShape" "non-empty list when converting to unit" - - listToShape' [] = Just () - listToShape' _ = Nothing - -instance Shape sh => Shape (sh, Int) where - rank = rank @sh + 1 - empty = (empty, 0) - ignore = (ignore, -1) - (sh1, sz1) `intersect` (sh2, sz2) = (sh1 `intersect` sh2, sz1 `min` sz2) - (sh1, sz1) `union` (sh2, sz2) = (sh1 `union` sh2, sz1 `max` sz2) - - size (sh, sz) | sz <= 0 = 0 - | otherwise = size sh * sz - - toIndex (sh, sz) (ix, i) = $indexCheck "toIndex" i sz - $ toIndex sh ix * sz + i - - fromIndex (sh, sz) i = (fromIndex sh (i `quotInt` sz), r) - -- If we assume that the index is in range, there is no point in computing - -- the remainder for the highest dimension since i < sz must hold. +-- | Create an array using a monadic function applied at each index. +-- +-- @since 1.2.0.0 +-- +{-# INLINEABLE fromFunctionM #-} +fromFunctionM :: ArrayR (Array sh e) -> sh -> (sh -> IO e) -> IO (Array sh e) +fromFunctionM (ArrayR shr tp) sh f = do + let !n = size shr sh + arr <- newArrayData tp n + -- + let write !i + | i >= n = return () + | otherwise = do + v <- f (fromIndex shr sh i) + unsafeWriteArrayData tp arr i v + write (i+1) + -- + write 0 + return $! arr `seq` Array sh arr + + +{-# INLINEABLE concatVectors #-} +concatVectors :: forall e. TupleType e -> [Vector e] -> Vector e +concatVectors tp vs = adata `seq` Array ((), len) adata + where + offsets = scanl (+) 0 (map (size dim1 . shape) vs) + len = last offsets + (adata, _) = runArrayData @e $ do + arr <- newArrayData tp len + sequence_ [ unsafeWriteArrayData tp arr (i + k) (unsafeIndexArrayData tp ad i) + | (Array ((), n) ad, k) <- vs `zip` offsets + , i <- [0 .. n - 1] ] + return (arr, undefined) + +-- | Creates a new, uninitialized Accelerate array. +-- +{-# INLINEABLE allocateArray #-} +allocateArray :: ArrayR (Array sh e) -> sh -> IO (Array sh e) +allocateArray (ArrayR shr tp) sh = do + adata <- newArrayData tp (size shr sh) + return $! Array sh adata + +{-# INLINEABLE fromList #-} +fromList :: forall sh e. ArrayR (Array sh e) -> sh -> [e] -> Array sh e +fromList (ArrayR shr tp) sh xs = adata `seq` Array sh adata + where + -- Assume the array is in dense row-major order. This is safe because + -- otherwise backends would not be able to directly memcpy. -- - where - r | rank @sh == 0 = $indexCheck "fromIndex" i sz i - | otherwise = i `remInt` sz - -{-- - bound (sh, sz) (ix, i) bndy - | i < 0 = case bndy of - Clamp -> next `addDim` 0 - Mirror -> next `addDim` (-i) - Wrap -> next `addDim` (sz+i) - Constant e -> Left e - | i >= sz = case bndy of - Clamp -> next `addDim` (sz-1) - Mirror -> next `addDim` (sz-(i-sz+2)) - Wrap -> next `addDim` (i-sz) - Constant e -> Left e - | otherwise = next `addDim` i - where - -- This function is quite difficult to optimize due to the deep recursion - -- that it can generate with high-dimensional arrays. If we let 'next' be - -- inlined into each alternative of the cases above the size of this - -- function on an n-dimensional array will grow as 7^n. This quickly causes - -- GHC's head to explode. See GHC Trac #10491 for more details. - next = bound sh ix bndy - {-# NOINLINE next #-} - - Right ds `addDim` d = Right (ds, d) - Left e `addDim` _ = Left e ---} - - iter (sh, sz) f c r = iter sh (\ix -> iter' (ix,0)) c r - where - iter' (ix,i) | i >= sz = r - | otherwise = f (ix,i) `c` iter' (ix,i+1) + !n = size shr sh + (adata, _) = runArrayData @e $ do + arr <- newArrayData tp n + let go !i _ | i >= n = return () + go !i (v:vs) = unsafeWriteArrayData tp arr i v >> go (i+1) vs + go _ [] = error "Data.Array.Accelerate.fromList: not enough input data" + -- + go 0 xs + return (arr, undefined) + + +-- | Convert an accelerated 'Array' to a list in row-major order. +-- +{-# INLINEABLE toList #-} +toList :: ArrayR (Array sh e) -> Array sh e -> [e] +toList (ArrayR shr tp) (Array sh adata) = go 0 + where + -- Assume underling array is in row-major order. This is safe because + -- otherwise backends would not be able to directly memcpy. + -- + !n = size shr sh + go !i | i >= n = [] + | otherwise = (unsafeIndexArrayData tp adata i) : go (i+1) - iter1 (_, 0) _ _ = $boundsError "iter1" "empty iteration space" - iter1 (sh, sz) f c = iter1 sh (\ix -> iter1' (ix,0)) c - where - iter1' (ix,i) | i == sz-1 = f (ix,i) - | otherwise = f (ix,i) `c` iter1' (ix,i+1) +type ArraysR = TupR ArrayR +data ArrayR a where + ArrayR :: ShapeR sh -> TupleType e -> ArrayR (Array sh e) - rangeToShape ((sh1, sz1), (sh2, sz2)) - = (rangeToShape (sh1, sh2), sz2 - sz1 + 1) +arrayRshape :: ArrayR (Array sh e) -> ShapeR sh +arrayRshape (ArrayR sh _) = sh - shapeToRange (sh, sz) - = let (low, high) = shapeToRange sh - in - ((low, 0), (high, sz - 1)) +arrayRtype :: ArrayR (Array sh e) -> TupleType e +arrayRtype (ArrayR _ tp) = tp - shapeToList (sh,sz) = sz : shapeToList sh +arraysRarray :: ShapeR sh -> TupleType e -> ArraysR (Array sh e) +arraysRarray shr tp = TupRsingle $ ArrayR shr tp - listToShape [] = $internalError "listToShape" "empty list when converting to cons" - listToShape (x:xs) = (listToShape xs,x) +arraysRtuple2 :: ArrayR a -> ArrayR b -> ArraysR (((), a), b) +arraysRtuple2 a b = TupRpair TupRunit (TupRsingle a) `TupRpair` TupRsingle b - listToShape' [] = Nothing - listToShape' (x:xs) = do - xs' <- listToShape' xs - return (xs', x) +showArrayR :: ArrayR a -> ShowS +showArrayR (ArrayR shr tp) = showString "Array DIM" . shows (rank shr) . showString " " . showType tp + +showArraysR :: ArraysR tp -> ShowS +showArraysR TupRunit = showString "()" +showArraysR (TupRsingle repr) = showArrayR repr +showArraysR (TupRpair t1 t2) = showString "(" . showArraysR t1 . showString ", " . showArraysR t2 . showString ")" + +type Scalar = Array DIM0 +type Vector = Array DIM1 +type Matrix = Array DIM2 + +-- | Segment descriptor (vector of segment lengths). +-- +-- To represent nested one-dimensional arrays, we use a flat array of data +-- values in conjunction with a /segment descriptor/, which stores the lengths +-- of the subarrays. +-- +type Segments = Vector + +-- |Index representation +-- +type DIM0 = () +type DIM1 = ((), Int) +type DIM2 = (((), Int), Int) + +dim0 :: ShapeR DIM0 +dim0 = ShapeRz + +dim1 :: ShapeR DIM1 +dim1 = ShapeRsnoc dim0 + +dim2 :: ShapeR DIM2 +dim2 = ShapeRsnoc dim1 + +-- |Index representations (which are nested pairs) +-- + +data ShapeR sh where + ShapeRz :: ShapeR () + ShapeRsnoc :: ShapeR sh -> ShapeR (sh, Int) + +rank :: ShapeR sh -> Int +rank ShapeRz = 0 +rank (ShapeRsnoc shr) = rank shr + 1 + +size :: ShapeR sh -> sh -> Int +size ShapeRz () = 1 +size (ShapeRsnoc shr) (sh, sz) + | sz <= 0 = 0 + | otherwise = size shr sh * sz + +empty :: ShapeR sh -> sh +empty ShapeRz = () +empty (ShapeRsnoc shr) = (empty shr, 0) + +ignore :: ShapeR sh -> sh +ignore ShapeRz = () +ignore (ShapeRsnoc shr) = (ignore shr, -1) + +shapeZip :: (Int -> Int -> Int) -> ShapeR sh -> sh -> sh -> sh +shapeZip _ ShapeRz () () = () +shapeZip f (ShapeRsnoc shr) (as, a) (bs, b) = (shapeZip f shr as bs, f a b) + +intersect, union :: ShapeR sh -> sh -> sh -> sh +intersect = shapeZip min +union = shapeZip max + +toIndex :: ShapeR sh -> sh -> sh -> Int +toIndex ShapeRz () () = 0 +toIndex (ShapeRsnoc shr) (sh, sz) (ix, i) + = $indexCheck "toIndex" i sz + $ toIndex shr sh ix * sz + i + +fromIndex :: ShapeR sh -> sh -> Int -> sh +fromIndex ShapeRz () _ = () +fromIndex (ShapeRsnoc shr) (sh, sz) i + = (fromIndex shr sh (i `quotInt` sz), r) + -- If we assume that the index is in range, there is no point in computing + -- the remainder for the highest dimension since i < sz must hold. + -- + where + r = case shr of -- Check if rank of shr is 0 + ShapeRz -> $indexCheck "fromIndex" i sz i + _ -> i `remInt` sz + +shapeEq :: ShapeR sh -> sh -> sh -> Bool +shapeEq ShapeRz () () = True +shapeEq (ShapeRsnoc shr) (sh, i) (sh', i') = i == i' && shapeEq shr sh sh' + +-- iterate through the entire shape, applying the function in the +-- second argument; third argument combines results and fourth is an +-- initial value that is combined with the results; the index space +-- is traversed in row-major order +iter :: ShapeR sh -> sh -> (sh -> a) -> (a -> a -> a) -> a -> a +iter ShapeRz () f _ _ = f () +iter (ShapeRsnoc shr) (sh, sz) f c r = iter shr sh (\ix -> iter' (ix,0)) c r + where + iter' (ix,i) | i >= sz = r + | otherwise = f (ix,i) `c` iter' (ix,i+1) + +-- variant of 'iter' without an initial value +iter1 :: ShapeR sh -> sh -> (sh -> a) -> (a -> a -> a) -> a +iter1 ShapeRz () f _ = f () +iter1 (ShapeRsnoc _ ) (_, 0) _ _ = $boundsError "iter1" "empty iteration space" +iter1 (ShapeRsnoc shr) (sh, sz) f c = iter1 shr sh (\ix -> iter1' (ix,0)) c + where + iter1' (ix,i) | i == sz-1 = f (ix,i) + | otherwise = f (ix,i) `c` iter1' (ix,i+1) + +-- Operations to facilitate conversion with IArray + +-- convert a minpoint-maxpoint index into a shape +rangeToShape :: ShapeR sh -> (sh, sh) -> sh +rangeToShape ShapeRz ((), ()) = () +rangeToShape (ShapeRsnoc shr) ((sh1, sz1), (sh2, sz2)) = (rangeToShape shr (sh1, sh2), sz2 - sz1 + 1) + +-- the converse +shapeToRange :: ShapeR sh -> sh -> (sh, sh) +shapeToRange ShapeRz () = ((), ()) +shapeToRange (ShapeRsnoc shr) (sh, sz) = let (low, high) = shapeToRange shr sh in ((low, 0), (high, sz - 1)) + +-- Other conversions + +-- Convert a shape into its list of dimensions +shapeToList :: ShapeR sh -> sh -> [Int] +shapeToList ShapeRz () = [] +shapeToList (ShapeRsnoc shr) (sh,sz) = sz : shapeToList shr sh + +-- Convert a list of dimensions into a shape +listToShape :: ShapeR sh -> [Int] -> sh +listToShape shr ds = case listToShape' shr ds of + Just sh -> sh + Nothing -> $internalError "listToShape" "unable to convert list to a shape at the specified type" + +-- Attempt to convert a list of dimensions into a shape +listToShape' :: ShapeR sh -> [Int] -> Maybe sh +listToShape' ShapeRz [] = Just () +listToShape' (ShapeRsnoc shr) (x:xs) = (, x) <$> listToShape' shr xs +listToShape' _ _ = Nothing + +shapeType :: ShapeR sh -> TupleType sh +shapeType ShapeRz = TupRunit +shapeType (ShapeRsnoc shr) = shapeType shr `TupRpair` (TupRsingle $ SingleScalarType $ NumSingleType $ IntegralNumType TypeInt) -- |Slice representation -- @@ -234,6 +387,15 @@ sliceShape SliceNil () = () sliceShape (SliceAll sl) (sh, n) = (sliceShape sl sh, n) sliceShape (SliceFixed sl) (sh, _) = sliceShape sl sh +sliceShapeR :: SliceIndex slix sl co dim -> ShapeR sl +sliceShapeR SliceNil = ShapeRz +sliceShapeR (SliceAll sl) = ShapeRsnoc $ sliceShapeR sl +sliceShapeR (SliceFixed sl) = sliceShapeR sl + +sliceDomainR :: SliceIndex slix sl co dim -> ShapeR dim +sliceDomainR SliceNil = ShapeRz +sliceDomainR (SliceAll sl) = ShapeRsnoc $ sliceDomainR sl +sliceDomainR (SliceFixed sl) = ShapeRsnoc $ sliceDomainR sl -- | Enumerate all slices within a given bound. The innermost dimension changes -- most rapidly. @@ -248,3 +410,273 @@ enumSlices SliceNil () = [()] enumSlices (SliceAll sl) (sh, _) = [ (sh', ()) | sh' <- enumSlices sl sh] enumSlices (SliceFixed sl) (sh, n) = [ (sh', i) | sh' <- enumSlices sl sh, i <- [0..n-1]] + +-- | GADT reifying the 'Stencil' class +-- +data StencilR sh e pat where + StencilRunit3 :: TupleType e -> StencilR DIM1 e (Tup3 e e e) + StencilRunit5 :: TupleType e -> StencilR DIM1 e (Tup5 e e e e e) + StencilRunit7 :: TupleType e -> StencilR DIM1 e (Tup7 e e e e e e e) + StencilRunit9 :: TupleType e -> StencilR DIM1 e (Tup9 e e e e e e e e e) + + StencilRtup3 :: StencilR sh e pat1 + -> StencilR sh e pat2 + -> StencilR sh e pat3 + -> StencilR (sh, Int) e (Tup3 pat1 pat2 pat3) + + StencilRtup5 :: StencilR sh e pat1 + -> StencilR sh e pat2 + -> StencilR sh e pat3 + -> StencilR sh e pat4 + -> StencilR sh e pat5 + -> StencilR (sh, Int) e (Tup5 pat1 pat2 pat3 pat4 pat5) + + StencilRtup7 :: StencilR sh e pat1 + -> StencilR sh e pat2 + -> StencilR sh e pat3 + -> StencilR sh e pat4 + -> StencilR sh e pat5 + -> StencilR sh e pat6 + -> StencilR sh e pat7 + -> StencilR (sh, Int) e (Tup7 pat1 pat2 pat3 pat4 pat5 pat6 pat7) + + StencilRtup9 :: StencilR sh e pat1 + -> StencilR sh e pat2 + -> StencilR sh e pat3 + -> StencilR sh e pat4 + -> StencilR sh e pat5 + -> StencilR sh e pat6 + -> StencilR sh e pat7 + -> StencilR sh e pat8 + -> StencilR sh e pat9 + -> StencilR (sh, Int) e (Tup9 pat1 pat2 pat3 pat4 pat5 pat6 pat7 pat8 pat9) + +stencilElt :: StencilR sh e pat -> TupleType e +stencilElt (StencilRunit3 tp) = tp +stencilElt (StencilRunit5 tp) = tp +stencilElt (StencilRunit7 tp) = tp +stencilElt (StencilRunit9 tp) = tp +stencilElt (StencilRtup3 sr _ _) = stencilElt sr +stencilElt (StencilRtup5 sr _ _ _ _) = stencilElt sr +stencilElt (StencilRtup7 sr _ _ _ _ _ _) = stencilElt sr +stencilElt (StencilRtup9 sr _ _ _ _ _ _ _ _) = stencilElt sr + +stencilShape :: StencilR sh e pat -> ShapeR sh +stencilShape (StencilRunit3 _) = ShapeRsnoc ShapeRz +stencilShape (StencilRunit5 _) = ShapeRsnoc ShapeRz +stencilShape (StencilRunit7 _) = ShapeRsnoc ShapeRz +stencilShape (StencilRunit9 _) = ShapeRsnoc ShapeRz +stencilShape (StencilRtup3 sr _ _) = ShapeRsnoc $ stencilShape sr +stencilShape (StencilRtup5 sr _ _ _ _) = ShapeRsnoc $ stencilShape sr +stencilShape (StencilRtup7 sr _ _ _ _ _ _) = ShapeRsnoc $ stencilShape sr +stencilShape (StencilRtup9 sr _ _ _ _ _ _ _ _) = ShapeRsnoc $ stencilShape sr + +stencilType :: StencilR sh e pat -> TupleType pat +stencilType (StencilRunit3 tp) = tupR3 tp tp tp +stencilType (StencilRunit5 tp) = tupR5 tp tp tp tp tp +stencilType (StencilRunit7 tp) = tupR7 tp tp tp tp tp tp tp +stencilType (StencilRunit9 tp) = tupR9 tp tp tp tp tp tp tp tp tp +stencilType (StencilRtup3 s1 s2 s3) = tupR3 (stencilType s1) (stencilType s2) (stencilType s3) +stencilType (StencilRtup5 s1 s2 s3 s4 s5) = tupR5 (stencilType s1) (stencilType s2) (stencilType s3) + (stencilType s4) (stencilType s5) +stencilType (StencilRtup7 s1 s2 s3 s4 s5 s6 s7) = tupR7 (stencilType s1) (stencilType s2) (stencilType s3) + (stencilType s4) (stencilType s5) (stencilType s6) + (stencilType s7) +stencilType (StencilRtup9 s1 s2 s3 s4 s5 s6 s7 s8 s9) = tupR9 (stencilType s1) (stencilType s2) (stencilType s3) + (stencilType s4) (stencilType s5) (stencilType s6) + (stencilType s7) (stencilType s8) (stencilType s9) + +stencilArrayR :: StencilR sh e pat -> ArrayR (Array sh e) +stencilArrayR stencil = ArrayR (stencilShape stencil) (stencilElt stencil) + +stencilHalo :: StencilR sh e stencil -> (ShapeR sh, sh) +stencilHalo = go' + where + go' :: StencilR sh e stencil -> (ShapeR sh, sh) + go' StencilRunit3{} = (dim1, ((), 1)) + go' StencilRunit5{} = (dim1, ((), 2)) + go' StencilRunit7{} = (dim1, ((), 3)) + go' StencilRunit9{} = (dim1, ((), 4)) + -- + go' (StencilRtup3 a b c ) = (ShapeRsnoc shr, cons shr 1 $ foldl1 (union shr) [a', go b, go c]) + where (shr, a') = go' a + go' (StencilRtup5 a b c d e ) = (ShapeRsnoc shr, cons shr 2 $ foldl1 (union shr) [a', go b, go c, go d, go e]) + where (shr, a') = go' a + go' (StencilRtup7 a b c d e f g ) = (ShapeRsnoc shr, cons shr 3 $ foldl1 (union shr) [a', go b, go c, go d, go e, go f, go g]) + where (shr, a') = go' a + go' (StencilRtup9 a b c d e f g h i) = (ShapeRsnoc shr, cons shr 4 $ foldl1 (union shr) [a', go b, go c, go d, go e, go f, go g, go h, go i]) + where (shr, a') = go' a + + go :: StencilR sh e stencil -> sh + go = snd . go' + + cons :: ShapeR sh -> Int -> sh -> (sh, Int) + cons ShapeRz ix () = ((), ix) + cons (ShapeRsnoc shr) ix (sh, sz) = (cons shr ix sh, sz) + +rnfArray :: ArrayR a -> a -> () +rnfArray (ArrayR shr tp) (Array sh ad) = rnfShape shr sh `seq` rnfArrayData tp ad + +rnfShape :: ShapeR sh -> sh -> () +rnfShape ShapeRz () = () +rnfShape (ShapeRsnoc shr) (sh, s) = s `seq` rnfShape shr sh + +-- | SIMD Vectors (Vec n t) +-- + +-- Declares the size of a SIMD vector and the type of its elements. +-- This data type is used to denote the relation between a vector +-- type (Vec n single) with its tuple representation (tuple). +-- Conversions between those types are exposed through vecPack and +-- vecUnpack. +-- +data VecR (n :: Nat) single tuple where + VecRnil :: SingleType s -> VecR 0 s () + VecRsucc :: VecR n s t -> VecR (n + 1) s (t, s) + +vecRvector :: KnownNat n => VecR n s tuple -> VectorType (Vec n s) +vecRvector = uncurry VectorType . go + where + go :: VecR n s tuple -> (Int, SingleType s) + go (VecRnil tp) = (0, tp) + go (VecRsucc vec) = (n + 1, tp) + where (n, tp) = go vec + +vecRtuple :: VecR n s tuple -> TupleType tuple +vecRtuple = snd . go + where + go :: VecR n s tuple -> (SingleType s, TupleType tuple) + go (VecRnil tp) = (tp, TupRunit) + go (VecRsucc vec) + | (tp, tuple) <- go vec = (tp, TupRpair tuple $ TupRsingle $ SingleScalarType tp) + +vecPack :: forall n single tuple. KnownNat n => VecR n single tuple -> tuple -> Vec n single +vecPack vecR tuple + | VectorType n single <- vecRvector vecR + , PrimDict <- getPrim single = runST $ do + mba <- newByteArray (n * sizeOf (undefined :: single)) + go (n - 1) vecR tuple mba + ByteArray ba# <- unsafeFreezeByteArray mba + return $! Vec ba# + where + go :: Prim single => Int -> VecR n' single tuple' -> tuple' -> MutableByteArray s -> ST s () + go _ (VecRnil _) () _ = return () + go i (VecRsucc r) (xs, x) mba = do + writeByteArray mba i x + go (i - 1) r xs mba + +vecUnpack :: forall n single tuple. KnownNat n => VecR n single tuple -> Vec n single -> tuple +vecUnpack vecR (Vec ba#) + | VectorType n single <- vecRvector vecR + , !(I# n#) <- n + , PrimDict <- getPrim single + = go (n# -# 1#) vecR + where + go :: Prim single => Int# -> VecR n' single tuple' -> tuple' + go _ (VecRnil _) = () + go i# (VecRsucc r) = x `seq` xs `seq` (xs, x) + where + xs = go (i# -# 1#) r + x = indexByteArray# ba# i# + +-- | Nicely format a shape as a string +-- +showShape :: ShapeR sh -> sh -> String +showShape shr = foldr (\sh str -> str ++ " :. " ++ show sh) "Z" . shapeToList shr + +showElement :: TupleType e -> e -> String +showElement tuple value = showElement' tuple value "" + where + showElement' :: TupleType e -> e -> ShowS + showElement' TupRunit () = showString "()" + showElement' (TupRpair t1 t2) (e1, e2) = showString "(" . showElement' t1 e1 . showString ", " . showElement' t2 e2 . showString ")" + showElement' (TupRsingle tp) val = showScalar tp val + + showScalar :: ScalarType e -> e -> ShowS + showScalar (SingleScalarType t) e = showString $ showSingle t e + showScalar (VectorScalarType t) e = showString $ showVector t e + + showSingle :: SingleType e -> e -> String + showSingle (NumSingleType t) e = showNum t e + showSingle (NonNumSingleType t) e = showNonNum t e + + showNum :: NumType e -> e -> String + showNum (IntegralNumType t) e = showIntegral t e + showNum (FloatingNumType t) e = showFloating t e + + showIntegral :: IntegralType e -> e -> String + showIntegral TypeInt{} e = show e + showIntegral TypeInt8{} e = show e + showIntegral TypeInt16{} e = show e + showIntegral TypeInt32{} e = show e + showIntegral TypeInt64{} e = show e + showIntegral TypeWord{} e = show e + showIntegral TypeWord8{} e = show e + showIntegral TypeWord16{} e = show e + showIntegral TypeWord32{} e = show e + showIntegral TypeWord64{} e = show e + + showFloating :: FloatingType e -> e -> String + showFloating TypeHalf{} e = show e + showFloating TypeFloat{} e = show e + showFloating TypeDouble{} e = show e + + showNonNum :: NonNumType e -> e -> String + showNonNum TypeChar e = show e + showNonNum TypeBool e = show e + + showVector :: VectorType (Vec n a) -> Vec n a -> String + showVector (VectorType _ single) vec + | PrimDict <- getPrim single = "<" ++ (intercalate ", " $ showSingle single <$> vecToArray vec) ++ ">" + +showArray :: ArrayR (Array sh e) -> Array sh e -> String +showArray repr@(ArrayR _ tp) = showArray' (showString . showElement tp) repr + +{-# INLINE showArray' #-} +showArray' :: (e -> ShowS) -> ArrayR (Array sh e) -> Array sh e -> String +showArray' f repr@(ArrayR shr tp) arr@(Array sh _) = case shr of + ShapeRz -> "Scalar Z " ++ list + ShapeRsnoc ShapeRz -> "Vector (" ++ shapeString ++ ") " ++ list + ShapeRsnoc (ShapeRsnoc ShapeRz) -> "Matrix (" ++ shapeString ++ ") " ++ showMatrix f tp arr + _ -> "Array (" ++ shapeString ++ ") " ++ list + where + shapeString = showShape shr sh + list = showListWith f (toList repr arr) "" + +-- TODO: +-- Make special formatting optional? It is more difficult to copy/paste the +-- result, for example. Also it does not look good if the matrix row does +-- not fit on a single line. +-- +showMatrix :: (e -> ShowS) -> TupleType e -> Array DIM2 e -> String +showMatrix f tp arr@(Array sh _) + | rows * cols == 0 = "[]" + | otherwise = "\n [" ++ ppMat 0 0 + where + (((), rows), cols) = sh + lengths = U.generate (rows*cols) (\i -> length (f ((tp, arr) !! i) "")) + widths = U.generate cols (\c -> U.maximum (U.generate rows (\r -> lengths U.! (r*cols+c)))) + -- + ppMat :: Int -> Int -> String + ppMat !r !c | c >= cols = ppMat (r+1) 0 + ppMat !r !c = + let + !i = r*cols+c + !l = lengths U.! i + !w = widths U.! c + !pad = 1 + cell = replicate (w-l+pad) ' ' ++ f ((tp, arr) !! i) "" + -- + before + | r > 0 && c == 0 = "\n " + | otherwise = "" + -- + after + | r >= rows-1 && c >= cols-1 = "]" + | otherwise = ',' : ppMat r (c+1) + in + before ++ cell ++ after + + +reduceRank :: ArrayR (Array (sh, Int) e) -> ArrayR (Array sh e) +reduceRank (ArrayR (ShapeRsnoc shr) tp) = ArrayR shr tp diff --git a/src/Data/Array/Accelerate/Array/Sugar.hs b/src/Data/Array/Accelerate/Array/Sugar.hs index 431b0ecc1..94c4c6397 100644 --- a/src/Data/Array/Accelerate/Array/Sugar.hs +++ b/src/Data/Array/Accelerate/Array/Sugar.hs @@ -3,10 +3,10 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DefaultSignatures #-} -{-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TemplateHaskell #-} @@ -31,13 +31,15 @@ -- module Data.Array.Accelerate.Array.Sugar ( + -- * Tuple representation + TupR(..), -- * Array representation - Array(..), Scalar, Vector, Matrix, Segments, - Arrays(..), ArraysR(..), arraysRtuple2, + Array(..), Scalar, Vector, Matrix, Segments, arrayR, + Arrays(..), Repr.ArraysR, Repr.ArrayR(..), Repr.arraysRarray, Repr.arraysRtuple2, -- * Class of supported surface element types and their mapping to representation types - Elt(..), + Elt(..), TupleType, -- * Derived functions liftToElt, liftToElt2, sinkFromElt, sinkFromElt2, @@ -51,12 +53,8 @@ module Data.Array.Accelerate.Array.Sugar ( -- * Array shape query, indexing, and conversions shape, reshape, (!), (!!), allocateArray, fromFunction, fromFunctionM, fromList, toList, concatVectors, - -- * Tuples of expressions - TupleR, TupleRepr, tuple, - Tuple(..), IsTuple, fromTuple, toTuple, - -- * Miscellaneous - showShape, Foreign(..), sliceShape, enumSlices, + showShape, Foreign(..), sliceShape, enumSlices, VecElt, ) where @@ -64,11 +62,11 @@ module Data.Array.Accelerate.Array.Sugar ( import Control.DeepSeq import Data.Kind import Data.Typeable +import Data.Primitive.Types import System.IO.Unsafe ( unsafePerformIO ) import Language.Haskell.TH hiding ( Foreign, Type ) import Language.Haskell.TH.Extra import Prelude hiding ( (!!) ) -import qualified Data.Vector.Unboxed as U import GHC.Exts ( IsList ) import GHC.Generics @@ -76,10 +74,8 @@ import GHC.TypeLits import qualified GHC.Exts as GHC -- friends -import Data.Array.Accelerate.Array.Data import Data.Array.Accelerate.Error import Data.Array.Accelerate.Orphans () -import Data.Array.Accelerate.Product import Data.Array.Accelerate.Type import qualified Data.Array.Accelerate.Array.Representation as Repr @@ -98,14 +94,14 @@ import qualified Data.Array.Accelerate.Array.Representation as Repr -- | Rank-0 index -- data Z = Z - deriving (Typeable, Show, Eq) + deriving (Show, Eq) -- | Increase an index rank by one dimension. The ':.' operator is used to -- construct both values and types. -- infixl 3 :. data tail :. head = !tail :. !head - deriving (Typeable, Eq) + deriving Eq -- We don't we use a derived Show instance for (:.) because this will insert -- parenthesis to demonstrate which order the operator is applied, i.e.: @@ -139,7 +135,7 @@ instance (Show sh, Show sz) => Show (sh :. sz) where -- 'Data.Array.Accelerate.Language.replicate' for examples. -- data All = All - deriving (Typeable, Show, Eq) + deriving (Show, Eq) -- | Marker for arbitrary dimensions in 'Data.Array.Accelerate.Language.slice' -- and 'Data.Array.Accelerate.Language.replicate' descriptors. @@ -151,7 +147,7 @@ data All = All -- 'Data.Array.Accelerate.Language.replicate' for examples. -- data Any sh = Any - deriving (Typeable, Show, Eq) + deriving (Show, Eq) -- | Marker for splitting along an entire dimension in division descriptors. -- @@ -160,7 +156,7 @@ data Any sh = Any -- divided along this dimension forming the elements of the output sequence. -- data Split = Split - deriving (Typeable, Show, Eq) + deriving (Show, Eq) -- | Marker for arbitrary shapes in slices descriptors, where it is desired to -- split along an unknown number of dimensions. @@ -172,8 +168,7 @@ data Split = Split -- > vectors = toSeq (Divide :. All) -- data Divide sh = Divide - deriving (Typeable, Show, Eq) - + deriving (Show, Eq) -- Scalar elements -- --------------- @@ -209,7 +204,7 @@ data Divide sh = Divide -- > data Point = Point Int Float -- > deriving (Show, Generic, Elt) -- -class (Show a, Typeable a, Typeable (EltRepr a), ArrayElt (EltRepr a)) => Elt a where +class Show a => Elt a where -- | Type representation mapping, which explains how to convert a type from -- the surface type into the internal representation type consisting only of -- simple primitive types, unit '()', and pair '(,)'. @@ -225,7 +220,7 @@ class (Show a, Typeable a, Typeable (EltRepr a), ArrayElt (EltRepr a)) => Elt a default eltType :: (GElt (Rep a), EltRepr a ~ GEltRepr () (Rep a)) => TupleType (EltRepr a) - eltType = geltType @(Rep a) TypeRunit + eltType = geltType @(Rep a) TupRunit {-# INLINE [1] fromElt #-} default fromElt @@ -262,7 +257,7 @@ instance GElt a => GElt (M1 i c a) where instance Elt a => GElt (K1 i a) where type GEltRepr t (K1 i a) = (t, EltRepr a) - geltType t = TypeRpair t (eltType @a) + geltType t = TupRpair t (eltType @a) gfromElt t (K1 x) = (t, fromElt x) gtoElt (t, x) = (t, K1 (toElt x)) @@ -289,7 +284,7 @@ instance (GElt a, GElt b) => GElt (a :*: b) where -- > @(TupleType (EltRepr CShort)) -- > (eltType :: TupleType (EltRepr CShort)) -- --- Which yields the error "couldn't match type type 'EltRepr a0' with 'Int16'". +-- Which yields the error "couldn't match type 'EltRepr a0' with 'Int16'". -- Since this function returns a type family type, the type signature on the -- result is not enough to fix the type 'a'. Instead, we require the use of -- (visible) type applications: @@ -312,7 +307,7 @@ instance Elt () where {-# INLINE eltType #-} {-# INLINE toElt #-} {-# INLINE fromElt #-} - eltType = TypeRunit + eltType = TupRunit fromElt = id toElt = id @@ -321,7 +316,7 @@ instance Elt Z where {-# INLINE eltType #-} {-# INLINE [1] toElt #-} {-# INLINE [1] fromElt #-} - eltType = TypeRunit + eltType = TupRunit fromElt Z = () toElt () = Z @@ -330,7 +325,7 @@ instance (Elt t, Elt h) => Elt (t:.h) where {-# INLINE eltType #-} {-# INLINE [1] toElt #-} {-# INLINE [1] fromElt #-} - eltType = TypeRpair (eltType @t) (eltType @h) + eltType = TupRpair (eltType @t) (eltType @h) fromElt (t:.h) = (fromElt t, fromElt h) toElt (t, h) = toElt t :. toElt h @@ -339,33 +334,37 @@ instance Elt All where {-# INLINE eltType #-} {-# INLINE [1] toElt #-} {-# INLINE [1] fromElt #-} - eltType = TypeRunit + eltType = TupRunit fromElt All = () toElt () = All -instance Elt (Any Z) where - type EltRepr (Any Z) = () - {-# INLINE eltType #-} - {-# INLINE [1] toElt #-} - {-# INLINE [1] fromElt #-} - eltType = TypeRunit - fromElt _ = () - toElt _ = Any +type family AnyRepr sh +type instance AnyRepr () = () +type instance AnyRepr (sh, Int) = (AnyRepr sh, ()) + +instance Shape sh => Elt (Any sh) where + type EltRepr (Any sh) = AnyRepr (EltRepr sh) -instance Shape sh => Elt (Any (sh:.Int)) where - type EltRepr (Any (sh:.Int)) = (EltRepr (Any sh), ()) {-# INLINE eltType #-} {-# INLINE [1] toElt #-} {-# INLINE [1] fromElt #-} - eltType = TypeRpair (eltType @(Any sh)) TypeRunit - fromElt _ = (fromElt (Any @sh), ()) + eltType = go $ shapeR @sh + where + go :: Repr.ShapeR sh' -> TupleType (AnyRepr sh') + go Repr.ShapeRz = TupRunit + go (Repr.ShapeRsnoc shr) = TupRpair (go shr) TupRunit + fromElt _ = go $ shapeR @sh + where + go :: Repr.ShapeR sh' -> AnyRepr sh' + go Repr.ShapeRz = () + go (Repr.ShapeRsnoc shr) = (go shr, ()) toElt _ = Any -- Convenience functions -- singletonScalarType :: IsScalar a => TupleType a -singletonScalarType = TypeRscalar scalarType +singletonScalarType = TupRsingle scalarType {-# INLINE liftToElt #-} liftToElt :: (Elt a, Elt b) @@ -396,7 +395,6 @@ sinkFromElt2 f x y = fromElt $ f (toElt x) (toElt y) "toElt/fromElt" forall e. toElt (fromElt e) = e #-} - -- Foreign functions -- ----------------- @@ -417,35 +415,6 @@ class Typeable asm => Foreign asm where liftForeign _ = $internalError "liftForeign" "not supported by this backend" --- Tuple representation --- -------------------- - --- |The tuple representation is equivalent to the product representation. --- -type TupleRepr a = ProdRepr a -type TupleR a = ProdR Elt a -type IsTuple = IsProduct Elt --- type IsAtuple = IsProduct Arrays - --- |We represent tuples as heterogeneous lists, typed by a type list. --- -data Tuple c t where - NilTup :: Tuple c () - SnocTup :: Elt t => Tuple c s -> c t -> Tuple c (s, t) - - --- |Tuple reification --- -tuple :: forall tup. IsTuple tup => TupleR (TupleRepr tup) -tuple = prod @Elt @tup - -fromTuple :: IsTuple tup => tup -> TupleRepr tup -fromTuple = fromProd @Elt - -toTuple :: IsTuple tup => TupleRepr tup -> tup -toTuple = toProd @Elt - - -- Arrays -- ------ @@ -456,7 +425,7 @@ toTuple = toProd @Elt -- 16-elements wide. Accelerate computations can thereby return multiple -- results. -- -class (Typeable a, Typeable (ArrRepr a)) => Arrays a where +class Arrays a where -- | Type representation mapping, which explains how to convert from the -- surface type into the internal representation type, which consists only of -- 'Array', and '()' and '(,)' as type-level nil and snoc. @@ -464,15 +433,15 @@ class (Typeable a, Typeable (ArrRepr a)) => Arrays a where type ArrRepr a :: Type type ArrRepr a = GArrRepr () (Rep a) - arrays :: ArraysR (ArrRepr a) + arrays :: Repr.ArraysR (ArrRepr a) toArr :: ArrRepr a -> a fromArr :: a -> ArrRepr a {-# INLINE arrays #-} default arrays :: (GArrays (Rep a), ArrRepr a ~ GArrRepr () (Rep a)) - => ArraysR (ArrRepr a) - arrays = garrays @(Rep a) ArraysRunit + => Repr.ArraysR (ArrRepr a) + arrays = garrays @(Rep a) TupRunit {-# INLINE [1] toArr #-} default toArr @@ -492,10 +461,12 @@ class (Typeable a, Typeable (ArrRepr a)) => Arrays a where -- => a -> ArraysFlavour a -- flavour _ = gflavour @(Rep a) +arrayR :: forall sh e. (Shape sh, Elt e) => Repr.ArrayR (Repr.Array (EltRepr sh) (EltRepr e)) +arrayR = Repr.ArrayR (shapeR @sh) (eltType @e) class GArrays f where type GArrRepr t f - garrays :: ArraysR t -> ArraysR (GArrRepr t f) + garrays :: Repr.ArraysR t -> Repr.ArraysR (GArrRepr t f) gfromArr :: f a -> t -> GArrRepr t f gtoArr :: GArrRepr t f -> (t, f a) @@ -513,7 +484,7 @@ instance GArrays a => GArrays (M1 i c a) where instance Arrays a => GArrays (K1 i a) where type GArrRepr t (K1 i a) = (t, ArrRepr a) - garrays t = ArraysRpair t (arrays @a) + garrays t = TupRpair t (arrays @a) gfromArr (K1 x) t = (t, fromArr x) gtoArr (t, x) = (t, K1 (toArr x)) @@ -533,28 +504,18 @@ instance Arrays () where {-# INLINE arrays #-} {-# INLINE [1] fromArr #-} {-# INLINE [1] toArr #-} - arrays = ArraysRunit + arrays = TupRunit fromArr = id toArr = id instance (Shape sh, Elt e) => Arrays (Array sh e) where - type ArrRepr (Array sh e) = Array sh e + type ArrRepr (Array sh e) = Repr.Array (EltRepr sh) (EltRepr e) {-# INLINE arrays #-} {-# INLINE [1] fromArr #-} {-# INLINE [1] toArr #-} - arrays = ArraysRarray - fromArr = id - toArr = id - --- Array type reification --- -data ArraysR arrs where - ArraysRunit :: ArraysR () - ArraysRarray :: (Shape sh, Elt e) => ArraysR (Array sh e) - ArraysRpair :: ArraysR arrs1 -> ArraysR arrs2 -> ArraysR (arrs1, arrs2) - -arraysRtuple2 :: (Shape sh1, Elt e1, Shape sh2, Elt e2) => ArraysR (((), Array sh2 e2), Array sh1 e1) -arraysRtuple2 = ArraysRpair ArraysRunit ArraysRarray `ArraysRpair` ArraysRarray + arrays = Repr.arraysRarray (shapeR @sh) (eltType @e) + fromArr (Array arr) = arr + toArr (arr) = Array arr {-# RULES "fromArr/toArr" forall a. fromArr (toArr a) = a @@ -599,10 +560,8 @@ arraysRtuple2 = ArraysRpair ArraysRunit ArraysRarray `ArraysRpair` ArraysRarray -- Section "Getting data in" lists functions for getting data into and out of -- the 'Array' type. -- -data Array sh e where - Array :: EltRepr sh -- extent of dimensions = shape - -> ArrayData (EltRepr e) -- array payload - -> Array sh e +newtype Array sh e = Array (Repr.Array (EltRepr sh) (EltRepr e)) + -- -- Note: [Embedded class constraints on Array] -- @@ -638,47 +597,7 @@ instance (Shape sh, Elt e, Eq sh, Eq e) => Eq (Array sh e) where -- matrices may not always be shown with their appropriate format. -- instance (Shape sh, Elt e) => Show (Array sh e) where - show arr = case shapeToList $ shape arr of - [] -> "Scalar Z " ++ show (toList arr) - [_] -> "Vector (" ++ showShape (shape arr) ++ ") " ++ show (toList arr) - [cols, rows] -> showMatrix rows cols arr - _ -> "Array (" ++ showShape (shape arr) ++ ") " ++ show (toList arr) - --- TODO: --- Make special formatting optional? It is more difficult to copy/paste the --- result, for example. Also it does not look good if the matrix row does --- not fit on a single line. --- -showMatrix :: (Shape sh, Elt e) => Int -> Int -> Array sh e -> String -showMatrix rows cols arr = - "Matrix (" ++ showShape (shape arr) ++ ") " ++ showMat - where - lengths = U.generate (rows*cols) (\i -> length (show (arr !! i))) - widths = U.generate cols (\c -> U.maximum (U.generate rows (\r -> lengths U.! (r*cols+c)))) - -- - showMat - | rows * cols == 0 = "[]" - | otherwise = "\n [" ++ ppMat 0 0 - -- - ppMat :: Int -> Int -> String - ppMat !r !c | c >= cols = ppMat (r+1) 0 - ppMat !r !c = - let - !i = r*cols+c - !l = lengths U.! i - !w = widths U.! c - !pad = 1 - cell = replicate (w-l+pad) ' ' ++ show (arr !! i) - -- - before - | r > 0 && c == 0 = "\n " - | otherwise = "" - -- - after - | r >= rows-1 && c >= cols-1 = "]" - | otherwise = ',' : ppMat r (c+1) - in - before ++ cell ++ after + show (Array arr) = Repr.showArray' (shows . toElt @e) (arrayR @sh @e) arr instance Elt e => IsList (Vector e) where type Item (Vector e) = e @@ -687,28 +606,7 @@ instance Elt e => IsList (Vector e) where fromList xs = GHC.fromListN (length xs) xs instance (Shape sh, Elt e) => NFData (Array sh e) where - rnf (Array sh ad) = Repr.size sh `seq` go arrayElt ad `seq` () - where - go :: ArrayEltR e' -> ArrayData e' -> () - go ArrayEltRunit AD_Unit = () - go ArrayEltRint (AD_Int ua) = rnf ua - go ArrayEltRint8 (AD_Int8 ua) = rnf ua - go ArrayEltRint16 (AD_Int16 ua) = rnf ua - go ArrayEltRint32 (AD_Int32 ua) = rnf ua - go ArrayEltRint64 (AD_Int64 ua) = rnf ua - go ArrayEltRword (AD_Word ua) = rnf ua - go ArrayEltRword8 (AD_Word8 ua) = rnf ua - go ArrayEltRword16 (AD_Word16 ua) = rnf ua - go ArrayEltRword32 (AD_Word32 ua) = rnf ua - go ArrayEltRword64 (AD_Word64 ua) = rnf ua - go ArrayEltRhalf (AD_Half ua) = rnf ua - go ArrayEltRfloat (AD_Float ua) = rnf ua - go ArrayEltRdouble (AD_Double ua) = rnf ua - go ArrayEltRbool (AD_Bool ua) = rnf ua - go ArrayEltRchar (AD_Char ua) = rnf ua - go (ArrayEltRvec r) (AD_Vec !_ a) = go r a `seq` () - go (ArrayEltRpair r1 r2) (AD_Pair a1 a2) = go r1 a1 `seq` go r2 a2 `seq` () - + rnf (Array arr) = Repr.rnfArray (arrayR @sh @e) $ arr -- | Scalar arrays hold a single element -- @@ -749,9 +647,11 @@ type DIM9 = DIM8:.Int -- |Shapes and indices of multi-dimensional arrays -- -class (Elt sh, Elt (Any sh), Repr.Shape (EltRepr sh), FullShape sh ~ sh, CoSliceShape sh ~ sh, SliceShape sh ~ Z) +class (Elt sh, Elt (Any sh), FullShape sh ~ sh, CoSliceShape sh ~ sh, SliceShape sh ~ Z) => Shape sh where + shapeR :: Repr.ShapeR (EltRepr sh) + -- |Number of dimensions of a /shape/ or /index/ (>= 0). rank :: Int @@ -822,37 +722,39 @@ class (Elt sh, Elt (Any sh), Repr.Shape (EltRepr sh), FullShape sh ~ sh, CoSlice {-# INLINE shapeToList #-} {-# INLINE listToShape #-} {-# INLINE listToShape' #-} - rank = Repr.rank @(EltRepr sh) - size = Repr.size . fromElt - empty = toElt Repr.empty + rank = Repr.rank (shapeR @sh) + size = Repr.size (shapeR @sh) . fromElt + empty = toElt $ Repr.empty $ shapeR @sh -- (#) must be individually defined, as it holds for all instances *except* -- the one with the largest arity - ignore = toElt Repr.ignore - intersect sh1 sh2 = toElt (Repr.intersect (fromElt sh1) (fromElt sh2)) - union sh1 sh2 = toElt (Repr.union (fromElt sh1) (fromElt sh2)) - fromIndex sh ix = toElt (Repr.fromIndex (fromElt sh) ix) - toIndex sh ix = Repr.toIndex (fromElt sh) (fromElt ix) + ignore = toElt $ Repr.ignore $ shapeR @sh + intersect sh1 sh2 = toElt (Repr.intersect (shapeR @sh) (fromElt sh1) (fromElt sh2)) + union sh1 sh2 = toElt (Repr.union (shapeR @sh) (fromElt sh1) (fromElt sh2)) + fromIndex sh ix = toElt (Repr.fromIndex (shapeR @sh) (fromElt sh) ix) + toIndex sh ix = Repr.toIndex (shapeR @sh) (fromElt sh) (fromElt ix) - iter sh f c r = Repr.iter (fromElt sh) (f . toElt) c r - iter1 sh f r = Repr.iter1 (fromElt sh) (f . toElt) r + iter sh f c r = Repr.iter (shapeR @sh) (fromElt sh) (f . toElt) c r + iter1 sh f r = Repr.iter1 (shapeR @sh) (fromElt sh) (f . toElt) r rangeToShape (low, high) - = toElt (Repr.rangeToShape (fromElt low, fromElt high)) + = toElt (Repr.rangeToShape (shapeR @sh) (fromElt low, fromElt high)) shapeToRange ix - = let (low, high) = Repr.shapeToRange (fromElt ix) + = let (low, high) = Repr.shapeToRange (shapeR @sh) (fromElt ix) in (toElt low, toElt high) - shapeToList = Repr.shapeToList . fromElt - listToShape = toElt . Repr.listToShape - listToShape' = fmap toElt . Repr.listToShape' + shapeToList = Repr.shapeToList (shapeR @sh) . fromElt + listToShape = toElt . Repr.listToShape (shapeR @sh) + listToShape' = fmap toElt . Repr.listToShape' (shapeR @sh) instance Shape Z where + shapeR = Repr.ShapeRz sliceAnyIndex = Repr.SliceNil sliceNoneIndex = Repr.SliceNil instance Shape sh => Shape (sh:.Int) where + shapeR = Repr.ShapeRsnoc (shapeR @sh) sliceAnyIndex = Repr.SliceAll (sliceAnyIndex @sh) sliceNoneIndex = Repr.SliceFixed (sliceNoneIndex @sh) @@ -932,28 +834,26 @@ instance (Shape sh, Slice sh) => Division (Divide sh) where -- {-# INLINE shape #-} shape :: Shape sh => Array sh e -> sh -shape (Array sh _) = toElt sh +shape (Array arr) = toElt $ Repr.shape arr -- | Change the shape of an array without altering its contents. The 'size' of -- the source and result arrays must be identical. -- {-# INLINE reshape #-} -reshape :: (Shape sh, Shape sh') => sh -> Array sh' e -> Array sh e -reshape sh (Array sh' adata) - = $boundsCheck "reshape" "shape mismatch" (size sh == Repr.size sh') - $ Array (fromElt sh) adata +reshape :: forall sh sh' e. (Shape sh, Shape sh') => sh -> Array sh' e -> Array sh e +reshape sh (Array arr) = Array $ Repr.reshape (shapeR @sh) (fromElt sh) (shapeR @sh') arr -- | Array indexing -- infixl 9 ! {-# INLINE [1] (!) #-} -(!) :: (Shape sh, Elt e) => Array sh e -> sh -> e -(!) (Array sh adata) ix = toElt (adata `unsafeIndexArrayData` toIndex (toElt sh) ix) +(!) :: forall sh e. (Shape sh, Elt e) => Array sh e -> sh -> e +(!) (Array arr) ix = toElt $ (arrayR @sh @e, arr) Repr.! fromElt ix infixl 9 !! {-# INLINE [1] (!!) #-} -(!!) :: Elt e => Array sh e -> Int -> e -(!!) (Array _ adata) i = toElt (adata `unsafeIndexArrayData` i) +(!!) :: forall sh e. Elt e => Array sh e -> Int -> e +(!!) (Array arr) i = toElt $ (eltType @e, arr) Repr.!! i {-# RULES "indexArray/DIM0" forall arr. arr ! Z = arr !! 0 @@ -972,44 +872,26 @@ fromFunction sh f = unsafePerformIO $! fromFunctionM sh (return . f) -- @since 1.2.0.0 -- {-# INLINEABLE fromFunctionM #-} -fromFunctionM :: (Shape sh, Elt e) => sh -> (sh -> IO e) -> IO (Array sh e) -fromFunctionM sh f = do - let !n = size sh - arr <- newArrayData n - -- - let write !i - | i >= n = return () - | otherwise = do - v <- f (fromIndex sh i) - unsafeWriteArrayData arr i (fromElt v) - write (i+1) - -- - write 0 - return $! arr `seq` Array (fromElt sh) arr +fromFunctionM :: forall sh e. (Shape sh, Elt e) => sh -> (sh -> IO e) -> IO (Array sh e) +fromFunctionM sh f = Array <$> Repr.fromFunctionM (arrayR @sh @e) (fromElt sh) f' + where + f' x = do + y <- f $ toElt x + return $ fromElt y -- | Create a vector from the concatenation of the given list of vectors. -- {-# INLINEABLE concatVectors #-} -concatVectors :: Elt e => [Vector e] -> Vector e -concatVectors vs = adata `seq` Array ((), len) adata - where - offsets = scanl (+) 0 (map (size . shape) vs) - len = last offsets - (adata, _) = runArrayData $ do - arr <- newArrayData len - sequence_ [ unsafeWriteArrayData arr (i + k) (unsafeIndexArrayData ad i) - | (Array ((), n) ad, k) <- vs `zip` offsets - , i <- [0 .. n - 1] ] - return (arr, undefined) +concatVectors :: forall e. Elt e => [Vector e] -> Vector e +concatVectors = toArr . Repr.concatVectors (eltType @e) . map fromArr + -- | Creates a new, uninitialized Accelerate array. -- {-# INLINEABLE allocateArray #-} -allocateArray :: (Shape sh, Elt e) => sh -> IO (Array sh e) -allocateArray sh = do - adata <- newArrayData (size sh) - return $! Array (fromElt sh) adata +allocateArray :: forall sh e. (Shape sh, Elt e) => sh -> IO (Array sh e) +allocateArray sh = Array <$> Repr.allocateArray (arrayR @sh @e) (fromElt sh) -- | Convert elements of a list into an Accelerate 'Array'. @@ -1043,34 +925,14 @@ allocateArray sh = do -- thus forcing the spine of the list to be manifest on the heap. -- {-# INLINEABLE fromList #-} -fromList :: (Shape sh, Elt e) => sh -> [e] -> Array sh e -fromList sh xs = adata `seq` Array (fromElt sh) adata - where - -- Assume the array is in dense row-major order. This is safe because - -- otherwise backends would not be able to directly memcpy. - -- - !n = size sh - (adata, _) = runArrayData $ do - arr <- newArrayData n - let go !i _ | i >= n = return () - go !i (v:vs) = unsafeWriteArrayData arr i (fromElt v) >> go (i+1) vs - go _ [] = error "Data.Array.Accelerate.fromList: not enough input data" - -- - go 0 xs - return (arr, undefined) +fromList :: forall sh e. (Shape sh, Elt e) => sh -> [e] -> Array sh e +fromList sh xs = toArr $ Repr.fromList (arrayR @sh @e) (fromElt sh) $ map fromElt xs -- | Convert an accelerated 'Array' to a list in row-major order. -- {-# INLINEABLE toList #-} toList :: forall sh e. (Shape sh, Elt e) => Array sh e -> [e] -toList (Array sh adata) = go 0 - where - -- Assume underling array is in row-major order. This is safe because - -- otherwise backends would not be able to directly memcpy. - -- - !n = Repr.size sh - go !i | i >= n = [] - | otherwise = toElt (adata `unsafeIndexArrayData` i) : go (i+1) +toList = map toElt . Repr.toList (arrayR @sh @e) . fromArr -- | Nicely format a shape as a string -- @@ -1102,6 +964,22 @@ enumSlices :: forall slix co sl dim. (Elt slix, Elt dim) enumSlices slix = map toElt . Repr.enumSlices slix . fromElt +-- Vec +-- --- + +class (Elt a, IsSingle a, Prim a, a ~ EltRepr a) => VecElt a + +-- XXX: Should we fix this to known "good" vector sizes? +-- +instance (KnownNat n, VecElt a) => Elt (Vec n a) where + type EltRepr (Vec n a) = Vec n a + {-# INLINE eltType #-} + {-# INLINE [1] fromElt #-} + {-# INLINE [1] toElt #-} + eltType = TupRsingle $ VectorScalarType $ VectorType (fromIntegral $ natVal (undefined :: Proxy n)) $ singleType @a + fromElt = id + toElt = id + -- Instances -- --------- @@ -1167,21 +1045,11 @@ $(runQ $ do toElt = id |] - -- XXX: Should we fix this to known "good" vector sizes? - -- - mkVector :: Name -> Q [Dec] - mkVector name = + mkVecElt :: Name -> Q [Dec] + mkVecElt name = let t = conT name in - [d| instance KnownNat n => Elt (Vec n $t) where - type EltRepr (Vec n $t) = Vec n $t - {-# INLINE eltType #-} - {-# INLINE [1] fromElt #-} - {-# INLINE [1] toElt #-} - eltType = singletonScalarType - fromElt = id - toElt = id - |] + [d| instance VecElt $t |] -- ghci> $( stringE . show =<< reify ''CFloat ) -- TyConI (NewtypeD [] Foreign.C.Types.CFloat [] Nothing (NormalC Foreign.C.Types.CFloat [(Bang NoSourceUnpackedness NoSourceStrictness,ConT GHC.Types.Float)]) []) @@ -1204,7 +1072,7 @@ $(runQ $ do |] -- ss <- mapM mkSimple ( integralTypes ++ floatingTypes ++ nonNumTypes ) - vs <- mapM mkVector ( integralTypes ++ floatingTypes ++ tail nonNumTypes ) -- not Bool + vs <- mapM mkVecElt ( integralTypes ++ floatingTypes ++ tail nonNumTypes ) -- not Bool ns <- mapM mkNewtype newtypes return (concat ss ++ concat vs ++ concat ns) ) diff --git a/src/Data/Array/Accelerate/Classes/Bounded.hs b/src/Data/Array/Accelerate/Classes/Bounded.hs index 187cf0150..82c31b96a 100644 --- a/src/Data/Array/Accelerate/Classes/Bounded.hs +++ b/src/Data/Array/Accelerate/Classes/Bounded.hs @@ -3,7 +3,7 @@ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} -{-# OPTIONS_GHC -fno-warn-orphans #-} +{-# OPTIONS_GHC -fno-warn-orphans -freduction-depth=100 #-} -- | -- Module : Data.Array.Accelerate.Classes.Bounded -- Copyright : [2016..2019] The Accelerate Team @@ -135,6 +135,12 @@ instance P.Bounded (Exp CUChar) where minBound = mkBitcast (mkMinBound @Word8) maxBound = mkBitcast (mkMaxBound @Word8) +-- To support 16-tuples, we must set the maximum recursion depth of the type +-- checker higher. The default is 51, which appears to be a problem for +-- 16-tuples (15-tuples do work). Hence we set a compiler flag at the top +-- of this file: -freduction-depth=100 +-- + $(runQ $ do let mkInstance :: Int -> Q [Dec] diff --git a/src/Data/Array/Accelerate/Classes/Eq.hs b/src/Data/Array/Accelerate/Classes/Eq.hs index 28fd413e3..c075aa501 100644 --- a/src/Data/Array/Accelerate/Classes/Eq.hs +++ b/src/Data/Array/Accelerate/Classes/Eq.hs @@ -5,7 +5,7 @@ {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -{-# OPTIONS_GHC -fno-warn-orphans #-} +{-# OPTIONS_GHC -fno-warn-orphans -freduction-depth=100 #-} -- | -- Module : Data.Array.Accelerate.Classes.Eq -- Copyright : [2016..2019] The Accelerate Team @@ -20,8 +20,8 @@ module Data.Array.Accelerate.Classes.Eq ( Bool(..), pattern True_, pattern False_, Eq(..), - (&&), - (||), + (&&), (&&!), + (||), (||!), not, ) where @@ -39,12 +39,12 @@ import qualified Prelude as P pattern True_ :: Exp Bool -pattern True_ = Exp (Const True) +pattern True_ = Exp (SmartExp (Const (SingleScalarType (NonNumSingleType TypeBool)) True)) pattern False_ :: Exp Bool -pattern False_ = Exp (Const False) -{-# COMPLETE True_, False_ #-} +pattern False_ = Exp (SmartExp (Const (SingleScalarType (NonNumSingleType TypeBool)) False)) +{-# COMPLETE True_, False_ #-} infix 4 == infix 4 /= @@ -54,7 +54,14 @@ infix 4 /= -- infixr 3 && (&&) :: Exp Bool -> Exp Bool -> Exp Bool -(&&) = mkLAnd +(&&) x y = cond x y $ constant False + +-- | Conjunction: True if both arguments are true. This is a strict version of +-- '(&&)': it will always evaluate both arguments, even when the first is false. +-- +infixr 3 &&! +(&&!) :: Exp Bool -> Exp Bool -> Exp Bool +(&&!) = mkLAnd -- | Disjunction: True if either argument is true. This is a short-circuit -- operator, so the second argument will be evaluated only if the first is @@ -62,7 +69,14 @@ infixr 3 && -- infixr 2 || (||) :: Exp Bool -> Exp Bool -> Exp Bool -(||) = mkLOr +(||) x y = cond x (constant True) y + +-- | Disjunction: True if either argument is true. This is a strict version of +-- '(||)': it will always evaluate both arguments, even when the first is true. +-- +infixr 2 ||! +(||!) :: Exp Bool -> Exp Bool -> Exp Bool +(||!) = mkLOr -- | Logical negation -- @@ -106,12 +120,18 @@ instance P.Eq (Exp a) where preludeError :: String -> String -> a preludeError x y = error (printf "Prelude.%s applied to EDSL types: use Data.Array.Accelerate.%s instead" x y) -lift2 :: (Elt a, Elt b, IsScalar b, b ~ EltRepr a) - => (Exp b -> Exp b -> Exp Bool) - -> Exp a - -> Exp a - -> Exp Bool -lift2 f x y = f (mkUnsafeCoerce x) (mkUnsafeCoerce y) +cond :: Elt t + => Exp Bool -- ^ condition + -> Exp t -- ^ then-expression + -> Exp t -- ^ else-expression + -> Exp t +cond (Exp c) (Exp x) (Exp y) = exp $ Cond c x y + +-- To support 16-tuples, we must set the maximum recursion depth of the type +-- checker higher. The default is 51, which appears to be a problem for +-- 16-tuples (15-tuples do work). Hence we set a compiler flag at the top +-- of this file: -freduction-depth=100 +-- $(runQ $ do let @@ -166,13 +186,6 @@ $(runQ $ do (/=) = mkNEq |] - mkCPrim :: Name -> Q [Dec] - mkCPrim t = - [d| instance Eq $(conT t) where - (==) = lift2 mkEq - (/=) = lift2 mkNEq - |] - mkTup :: Int -> Q [Dec] mkTup n = let @@ -190,7 +203,7 @@ $(runQ $ do is <- mapM mkPrim integralTypes fs <- mapM mkPrim floatingTypes ns <- mapM mkPrim nonNumTypes - cs <- mapM mkCPrim cTypes + cs <- mapM mkPrim cTypes ts <- mapM mkTup [2..16] return $ concat (concat [is,fs,ns,cs,ts]) ) diff --git a/src/Data/Array/Accelerate/Classes/Floating.hs b/src/Data/Array/Accelerate/Classes/Floating.hs index 344d3758b..5b451ffb6 100644 --- a/src/Data/Array/Accelerate/Classes/Floating.hs +++ b/src/Data/Array/Accelerate/Classes/Floating.hs @@ -30,7 +30,6 @@ module Data.Array.Accelerate.Classes.Floating ( ) where -import Data.Array.Accelerate.Array.Sugar import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Type @@ -106,54 +105,40 @@ instance P.Floating (Exp Double) where instance P.Floating (Exp CFloat) where pi = mkBitcast (mkPi @Float) - sin = lift1 mkSin - cos = lift1 mkCos - tan = lift1 mkTan - asin = lift1 mkAsin - acos = lift1 mkAcos - atan = lift1 mkAtan - sinh = lift1 mkSinh - cosh = lift1 mkCosh - tanh = lift1 mkTanh - asinh = lift1 mkAsinh - acosh = lift1 mkAcosh - atanh = lift1 mkAtanh - exp = lift1 mkExpFloating - sqrt = lift1 mkSqrt - log = lift1 mkLog - (**) = lift2 mkFPow - logBase = lift2 mkLogBase + sin = mkSin + cos = mkCos + tan = mkTan + asin = mkAsin + acos = mkAcos + atan = mkAtan + sinh = mkSinh + cosh = mkCosh + tanh = mkTanh + asinh = mkAsinh + acosh = mkAcosh + atanh = mkAtanh + exp = mkExpFloating + sqrt = mkSqrt + log = mkLog + (**) = mkFPow + logBase = mkLogBase instance P.Floating (Exp CDouble) where pi = mkBitcast (mkPi @Double) - sin = lift1 mkSin - cos = lift1 mkCos - tan = lift1 mkTan - asin = lift1 mkAsin - acos = lift1 mkAcos - atan = lift1 mkAtan - sinh = lift1 mkSinh - cosh = lift1 mkCosh - tanh = lift1 mkTanh - asinh = lift1 mkAsinh - acosh = lift1 mkAcosh - atanh = lift1 mkAtanh - exp = lift1 mkExpFloating - sqrt = lift1 mkSqrt - log = lift1 mkLog - (**) = lift2 mkFPow - logBase = lift2 mkLogBase - -lift1 :: (Elt a, Elt b, IsScalar b, b ~ EltRepr a) - => (Exp b -> Exp b) - -> Exp a - -> Exp a -lift1 f x = mkUnsafeCoerce (f (mkUnsafeCoerce x)) - -lift2 :: (Elt a, Elt b, IsScalar b, b ~ EltRepr a) - => (Exp b -> Exp b -> Exp b) - -> Exp a - -> Exp a - -> Exp a -lift2 f x y = mkUnsafeCoerce (f (mkUnsafeCoerce x) (mkUnsafeCoerce y)) - + sin = mkSin + cos = mkCos + tan = mkTan + asin = mkAsin + acos = mkAcos + atan = mkAtan + sinh = mkSinh + cosh = mkCosh + tanh = mkTanh + asinh = mkAsinh + acosh = mkAcosh + atanh = mkAtanh + exp = mkExpFloating + sqrt = mkSqrt + log = mkLog + (**) = mkFPow + logBase = mkLogBase diff --git a/src/Data/Array/Accelerate/Classes/Fractional.hs b/src/Data/Array/Accelerate/Classes/Fractional.hs index bb5fae615..fa159ccd1 100644 --- a/src/Data/Array/Accelerate/Classes/Fractional.hs +++ b/src/Data/Array/Accelerate/Classes/Fractional.hs @@ -20,7 +20,6 @@ module Data.Array.Accelerate.Classes.Fractional ( ) where -import Data.Array.Accelerate.Array.Sugar import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Type @@ -62,25 +61,11 @@ instance P.Fractional (Exp Double) where fromRational = constant . P.fromRational instance P.Fractional (Exp CFloat) where - (/) = lift2 mkFDiv - recip = lift1 mkRecip + (/) = mkFDiv + recip = mkRecip fromRational = constant . P.fromRational instance P.Fractional (Exp CDouble) where - (/) = lift2 mkFDiv - recip = lift1 mkRecip + (/) = mkFDiv + recip = mkRecip fromRational = constant . P.fromRational - -lift1 :: (Elt a, Elt b, b ~ EltRepr a) - => (Exp b -> Exp b) - -> Exp a - -> Exp a -lift1 f = mkUnsafeCoerce . f . mkUnsafeCoerce - -lift2 :: (Elt a, Elt b, b ~ EltRepr a) - => (Exp b -> Exp b -> Exp b) - -> Exp a - -> Exp a - -> Exp a -lift2 f x y = mkUnsafeCoerce (f (mkUnsafeCoerce x) (mkUnsafeCoerce y)) - diff --git a/src/Data/Array/Accelerate/Classes/Integral.hs b/src/Data/Array/Accelerate/Classes/Integral.hs index e03751a66..4d6f9e96c 100644 --- a/src/Data/Array/Accelerate/Classes/Integral.hs +++ b/src/Data/Array/Accelerate/Classes/Integral.hs @@ -26,7 +26,6 @@ module Data.Array.Accelerate.Classes.Integral ( ) where -import Data.Array.Accelerate.Array.Sugar import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Type @@ -135,90 +134,73 @@ instance P.Integral (Exp Word64) where toInteger = error "Prelude.toInteger not supported for Accelerate types" instance P.Integral (Exp CInt) where - quot = lift2 mkQuot - rem = lift2 mkRem - div = lift2 mkIDiv - mod = lift2 mkMod - quotRem = lift2' mkQuotRem - divMod = lift2' mkDivMod + quot = mkQuot + rem = mkRem + div = mkIDiv + mod = mkMod + quotRem = mkQuotRem + divMod = mkDivMod toInteger = error "Prelude.toInteger not supported for Accelerate types" instance P.Integral (Exp CUInt) where - quot = lift2 mkQuot - rem = lift2 mkRem - div = lift2 mkIDiv - mod = lift2 mkMod - quotRem = lift2' mkQuotRem - divMod = lift2' mkDivMod + quot = mkQuot + rem = mkRem + div = mkIDiv + mod = mkMod + quotRem = mkQuotRem + divMod = mkDivMod toInteger = error "Prelude.toInteger not supported for Accelerate types" instance P.Integral (Exp CLong) where - quot = lift2 mkQuot - rem = lift2 mkRem - div = lift2 mkIDiv - mod = lift2 mkMod - quotRem = lift2' mkQuotRem - divMod = lift2' mkDivMod + quot = mkQuot + rem = mkRem + div = mkIDiv + mod = mkMod + quotRem = mkQuotRem + divMod = mkDivMod toInteger = error "Prelude.toInteger not supported for Accelerate types" instance P.Integral (Exp CULong) where - quot = lift2 mkQuot - rem = lift2 mkRem - div = lift2 mkIDiv - mod = lift2 mkMod - quotRem = lift2' mkQuotRem - divMod = lift2' mkDivMod + quot = mkQuot + rem = mkRem + div = mkIDiv + mod = mkMod + quotRem = mkQuotRem + divMod = mkDivMod toInteger = error "Prelude.toInteger not supported for Accelerate types" instance P.Integral (Exp CLLong) where - quot = lift2 mkQuot - rem = lift2 mkRem - div = lift2 mkIDiv - mod = lift2 mkMod - quotRem = lift2' mkQuotRem - divMod = lift2' mkDivMod + quot = mkQuot + rem = mkRem + div = mkIDiv + mod = mkMod + quotRem = mkQuotRem + divMod = mkDivMod toInteger = error "Prelude.toInteger not supported for Accelerate types" instance P.Integral (Exp CULLong) where - quot = lift2 mkQuot - rem = lift2 mkRem - div = lift2 mkIDiv - mod = lift2 mkMod - quotRem = lift2' mkQuotRem - divMod = lift2' mkDivMod + quot = mkQuot + rem = mkRem + div = mkIDiv + mod = mkMod + quotRem = mkQuotRem + divMod = mkDivMod toInteger = error "Prelude.toInteger not supported for Accelerate types" instance P.Integral (Exp CShort) where - quot = lift2 mkQuot - rem = lift2 mkRem - div = lift2 mkIDiv - mod = lift2 mkMod - quotRem = lift2' mkQuotRem - divMod = lift2' mkDivMod + quot = mkQuot + rem = mkRem + div = mkIDiv + mod = mkMod + quotRem = mkQuotRem + divMod = mkDivMod toInteger = error "Prelude.toInteger not supported for Accelerate types" instance P.Integral (Exp CUShort) where - quot = lift2 mkQuot - rem = lift2 mkRem - div = lift2 mkIDiv - mod = lift2 mkMod - quotRem = lift2' mkQuotRem - divMod = lift2' mkDivMod - toInteger = error "Prelude.toInteger not supported for Accelerate types" - -lift2 :: (Elt a, Elt b, IsScalar b, b ~ EltRepr a) - => (Exp b -> Exp b -> Exp b) - -> Exp a - -> Exp a - -> Exp a -lift2 f x y = mkUnsafeCoerce (f (mkUnsafeCoerce x) (mkUnsafeCoerce y)) - -lift2' :: (Elt a, Elt b, IsScalar b, b ~ EltRepr a) - => (Exp b -> Exp b -> (Exp b, Exp b)) - -> Exp a - -> Exp a - -> (Exp a, Exp a) -lift2' f x y = - let (u,v) = f (mkUnsafeCoerce x) (mkUnsafeCoerce y) - in (mkUnsafeCoerce u, mkUnsafeCoerce v) - + quot = mkQuot + rem = mkRem + div = mkIDiv + mod = mkMod + quotRem = mkQuotRem + divMod = mkDivMod + toInteger = error "Prelude.toInteger not supported for Accelerate types" diff --git a/src/Data/Array/Accelerate/Classes/Num.hs b/src/Data/Array/Accelerate/Classes/Num.hs index 536be51d5..b98a82ab3 100644 --- a/src/Data/Array/Accelerate/Classes/Num.hs +++ b/src/Data/Array/Accelerate/Classes/Num.hs @@ -158,75 +158,75 @@ instance P.Num (Exp Word64) where fromInteger = constant . P.fromInteger instance P.Num (Exp CInt) where - (+) = lift2 mkAdd - (-) = lift2 mkSub - (*) = lift2 mkMul - negate = lift1 mkNeg - abs = lift1 mkAbs - signum = lift1 mkSig + (+) = mkAdd + (-) = mkSub + (*) = mkMul + negate = mkNeg + abs = mkAbs + signum = mkSig fromInteger = constant . P.fromInteger instance P.Num (Exp CUInt) where - (+) = lift2 mkAdd - (-) = lift2 mkSub - (*) = lift2 mkMul - negate = lift1 mkNeg - abs = lift1 mkAbs - signum = lift1 mkSig + (+) = mkAdd + (-) = mkSub + (*) = mkMul + negate = mkNeg + abs = mkAbs + signum = mkSig fromInteger = constant . P.fromInteger instance P.Num (Exp CLong) where - (+) = lift2 mkAdd - (-) = lift2 mkSub - (*) = lift2 mkMul - negate = lift1 mkNeg - abs = lift1 mkAbs - signum = lift1 mkSig + (+) = mkAdd + (-) = mkSub + (*) = mkMul + negate = mkNeg + abs = mkAbs + signum = mkSig fromInteger = constant . P.fromInteger instance P.Num (Exp CULong) where - (+) = lift2 mkAdd - (-) = lift2 mkSub - (*) = lift2 mkMul - negate = lift1 mkNeg - abs = lift1 mkAbs - signum = lift1 mkSig + (+) = mkAdd + (-) = mkSub + (*) = mkMul + negate = mkNeg + abs = mkAbs + signum = mkSig fromInteger = constant . P.fromInteger instance P.Num (Exp CLLong) where - (+) = lift2 mkAdd - (-) = lift2 mkSub - (*) = lift2 mkMul - negate = lift1 mkNeg - abs = lift1 mkAbs - signum = lift1 mkSig + (+) = mkAdd + (-) = mkSub + (*) = mkMul + negate = mkNeg + abs = mkAbs + signum = mkSig fromInteger = constant . P.fromInteger instance P.Num (Exp CULLong) where - (+) = lift2 mkAdd - (-) = lift2 mkSub - (*) = lift2 mkMul - negate = lift1 mkNeg - abs = lift1 mkAbs - signum = lift1 mkSig + (+) = mkAdd + (-) = mkSub + (*) = mkMul + negate = mkNeg + abs = mkAbs + signum = mkSig fromInteger = constant . P.fromInteger instance P.Num (Exp CShort) where - (+) = lift2 mkAdd - (-) = lift2 mkSub - (*) = lift2 mkMul - negate = lift1 mkNeg - abs = lift1 mkAbs - signum = lift1 mkSig + (+) = mkAdd + (-) = mkSub + (*) = mkMul + negate = mkNeg + abs = mkAbs + signum = mkSig fromInteger = constant . P.fromInteger instance P.Num (Exp CUShort) where - (+) = lift2 mkAdd - (-) = lift2 mkSub - (*) = lift2 mkMul - negate = lift1 mkNeg - abs = lift1 mkAbs - signum = lift1 mkSig + (+) = mkAdd + (-) = mkSub + (*) = mkMul + negate = mkNeg + abs = mkAbs + signum = mkSig fromInteger = constant . P.fromInteger instance P.Num (Exp Half) where @@ -257,33 +257,19 @@ instance P.Num (Exp Double) where fromInteger = constant . P.fromInteger instance P.Num (Exp CFloat) where - (+) = lift2 mkAdd - (-) = lift2 mkSub - (*) = lift2 mkMul - negate = lift1 mkNeg - abs = lift1 mkAbs - signum = lift1 mkSig + (+) = mkAdd + (-) = mkSub + (*) = mkMul + negate = mkNeg + abs = mkAbs + signum = mkSig fromInteger = constant . P.fromInteger instance P.Num (Exp CDouble) where - (+) = lift2 mkAdd - (-) = lift2 mkSub - (*) = lift2 mkMul - negate = lift1 mkNeg - abs = lift1 mkAbs - signum = lift1 mkSig + (+) = mkAdd + (-) = mkSub + (*) = mkMul + negate = mkNeg + abs = mkAbs + signum = mkSig fromInteger = constant . P.fromInteger - -lift1 :: (Elt a, Elt b, IsScalar b, b ~ EltRepr a) - => (Exp b -> Exp b) - -> Exp a - -> Exp a -lift1 f = mkUnsafeCoerce . f . mkUnsafeCoerce - -lift2 :: (Elt a, Elt b, IsScalar b, b ~ EltRepr a) - => (Exp b -> Exp b -> Exp b) - -> Exp a - -> Exp a - -> Exp a -lift2 f x y = mkUnsafeCoerce (f (mkUnsafeCoerce x) (mkUnsafeCoerce y)) - diff --git a/src/Data/Array/Accelerate/Classes/Ord.hs b/src/Data/Array/Accelerate/Classes/Ord.hs index d3d908b63..a2b52b1c0 100644 --- a/src/Data/Array/Accelerate/Classes/Ord.hs +++ b/src/Data/Array/Accelerate/Classes/Ord.hs @@ -8,7 +8,7 @@ {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -{-# OPTIONS_GHC -fno-warn-orphans #-} +{-# OPTIONS_GHC -fno-warn-orphans -freduction-depth=100 #-} -- | -- Module : Data.Array.Accelerate.Classes.Ord -- Copyright : [2016..2019] The Accelerate Team @@ -32,28 +32,30 @@ import Data.Array.Accelerate.Pattern import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Type -import Data.Array.Accelerate.Classes.Eq +-- We must hide (==), as that operator is used for the literals 0, 1 and 2 in the pattern synonyms for Ordering. +-- As RebindableSyntax is enabled, a literal pattern is compiled to a call to (==), meaning that the Prelude.(==) should be in scope as (==). +import Data.Array.Accelerate.Classes.Eq hiding ( (==) ) +import qualified Data.Array.Accelerate.Classes.Eq as A import Text.Printf -import Prelude ( ($), (.), (>>=), Ordering(..), Num(..), Maybe(..), String, show, error, unlines, return, concat, map, mapM ) +import Prelude ( ($), (.), (>>=), Ordering(..), Num(..), Maybe(..), String, show, error, unlines, return, concat, map, mapM, (==) ) import Language.Haskell.TH hiding ( Exp ) import Language.Haskell.TH.Extra import qualified Prelude as P - infix 4 < infix 4 > infix 4 <= infix 4 >= pattern LT_ :: Exp Ordering -pattern LT_ = Exp (Const LT) +pattern LT_ = Exp (SmartExp (Const (SingleScalarType (NumSingleType (IntegralNumType TypeInt8))) 0)) pattern EQ_ :: Exp Ordering -pattern EQ_ = Exp (Const EQ) +pattern EQ_ = Exp (SmartExp (Const (SingleScalarType (NumSingleType (IntegralNumType TypeInt8))) 1)) pattern GT_ :: Exp Ordering -pattern GT_ = Exp (Const GT) +pattern GT_ = Exp (SmartExp (Const (SingleScalarType (NumSingleType (IntegralNumType TypeInt8))) 2)) {-# COMPLETE LT_, EQ_, GT_ #-} -- | The 'Ord' class for totally ordered datatypes @@ -68,23 +70,23 @@ class Eq a => Ord a where max :: Exp a -> Exp a -> Exp a compare :: Exp a -> Exp a -> Exp Ordering - x < y = if compare x y == constant LT then constant True else constant False - x <= y = if compare x y == constant GT then constant False else constant True - x > y = if compare x y == constant GT then constant True else constant False - x >= y = if compare x y == constant LT then constant False else constant True + x < y = if compare x y A.== constant LT then constant True else constant False + x <= y = if compare x y A.== constant GT then constant False else constant True + x > y = if compare x y A.== constant GT then constant True else constant False + x >= y = if compare x y A.== constant LT then constant False else constant True min x y = if x <= y then x else y max x y = if x <= y then y else x compare x y = - if x == y then constant EQ else - if x <= y then constant LT - else constant GT + if x A.== y then constant EQ else + if x <= y then constant LT + else constant GT -- Local redefinition for use with RebindableSyntax (pulled forward from Prelude.hs) -- ifThenElse :: Elt a => Exp Bool -> Exp a -> Exp a -> Exp a -ifThenElse = Exp $$$ Cond +ifThenElse (Exp c) (Exp x) (Exp y) = Exp $ SmartExp $ Cond c x y instance Ord () where (<) _ _ = constant False @@ -117,13 +119,13 @@ instance Ord sh => Ord (sh :. Int) where instance Elt Ordering where type EltRepr Ordering = Int8 - eltType = TypeRscalar scalarType + eltType = TupRsingle scalarType fromElt = P.fromIntegral . P.fromEnum toElt = P.toEnum . P.fromIntegral instance Eq Ordering where - x == y = mkBitcast x == (mkBitcast y :: Exp Int8) - x /= y = mkBitcast x /= (mkBitcast y :: Exp Int8) + x == y = mkBitcast x A.== (mkBitcast y :: Exp Int8) + x /= y = mkBitcast x /= (mkBitcast y :: Exp Int8) instance Ord Ordering where x < y = mkBitcast x < (mkBitcast y :: Exp Int8) @@ -160,19 +162,11 @@ preludeError x y , "hierarchy." ] -lift2 :: (Elt a, Elt b, IsScalar b, b ~ EltRepr a) - => (Exp b -> Exp b -> Exp b) - -> Exp a - -> Exp a - -> Exp a -lift2 f x y = mkUnsafeCoerce (f (mkUnsafeCoerce x) (mkUnsafeCoerce y)) - -liftB :: (Elt a, Elt b, IsScalar b, b ~ EltRepr a) - => (Exp b -> Exp b -> Exp Bool) - -> Exp a - -> Exp a - -> Exp Bool -liftB f x y = f (mkUnsafeCoerce x) (mkUnsafeCoerce y) +-- To support 16-tuples, we must set the maximum recursion depth of the type +-- checker higher. The default is 51, which appears to be a problem for +-- 16-tuples (15-tuples do work). Hence we set a compiler flag at the top +-- of this file: -freduction-depth=100 +-- $(runQ $ do let @@ -231,35 +225,24 @@ $(runQ $ do max = mkMax |] - mkCPrim :: Name -> Q [Dec] - mkCPrim t = - [d| instance Ord $(conT t) where - (<) = liftB mkLt - (>) = liftB mkGt - (<=) = liftB mkLtEq - (>=) = liftB mkGtEq - min = lift2 mkMin - max = lift2 mkMax - |] - mkLt' :: [ExpQ] -> [ExpQ] -> ExpQ mkLt' [x] [y] = [| $x < $y |] - mkLt' (x:xs) (y:ys) = [| $x < $y || ( $x == $y && $(mkLt' xs ys) ) |] + mkLt' (x:xs) (y:ys) = [| $x < $y || ( $x A.== $y && $(mkLt' xs ys) ) |] mkLt' _ _ = error "mkLt'" mkGt' :: [ExpQ] -> [ExpQ] -> ExpQ mkGt' [x] [y] = [| $x > $y |] - mkGt' (x:xs) (y:ys) = [| $x > $y || ( $x == $y && $(mkGt' xs ys) ) |] + mkGt' (x:xs) (y:ys) = [| $x > $y || ( $x A.== $y && $(mkGt' xs ys) ) |] mkGt' _ _ = error "mkGt'" mkLtEq' :: [ExpQ] -> [ExpQ] -> ExpQ mkLtEq' [x] [y] = [| $x < $y |] - mkLtEq' (x:xs) (y:ys) = [| $x < $y || ( $x == $y && $(mkLtEq' xs ys) ) |] + mkLtEq' (x:xs) (y:ys) = [| $x < $y || ( $x A.== $y && $(mkLtEq' xs ys) ) |] mkLtEq' _ _ = error "mkLtEq'" mkGtEq' :: [ExpQ] -> [ExpQ] -> ExpQ mkGtEq' [x] [y] = [| $x > $y |] - mkGtEq' (x:xs) (y:ys) = [| $x > $y || ( $x == $y && $(mkGtEq' xs ys) ) |] + mkGtEq' (x:xs) (y:ys) = [| $x > $y || ( $x A.== $y && $(mkGtEq' xs ys) ) |] mkGtEq' _ _ = error "mkGtEq'" mkTup :: Int -> Q [Dec] @@ -281,7 +264,7 @@ $(runQ $ do is <- mapM mkPrim integralTypes fs <- mapM mkPrim floatingTypes ns <- mapM mkPrim nonNumTypes - cs <- mapM mkCPrim cTypes + cs <- mapM mkPrim cTypes ts <- mapM mkTup [2..16] return $ concat (concat [is,fs,ns,cs,ts]) ) diff --git a/src/Data/Array/Accelerate/Classes/RealFloat.hs b/src/Data/Array/Accelerate/Classes/RealFloat.hs index 1fbe13985..9bf70d470 100644 --- a/src/Data/Array/Accelerate/Classes/RealFloat.hs +++ b/src/Data/Array/Accelerate/Classes/RealFloat.hs @@ -79,9 +79,9 @@ class (RealFrac a, Floating a) => RealFloat a where -- | Corresponds to the second component of 'decodeFloat' exponent :: Exp a -> Exp Int exponent x = let (m,n) = decodeFloat x - in Exp $ Cond (m == 0) - 0 - (n + floatDigits x) + in cond (m == 0) + 0 + (n + floatDigits x) -- | Corresponds to the first component of 'decodeFloat' significand :: Exp a -> Exp a @@ -91,8 +91,8 @@ class (RealFrac a, Floating a) => RealFloat a where -- | Multiply a floating point number by an integer power of the radix scaleFloat :: Exp Int -> Exp a -> Exp a scaleFloat k x = - Exp $ Cond (k == 0 || isFix) x - $ encodeFloat m (n + clamp b) + cond (k == 0 || isFix) x + $ encodeFloat m (n + clamp b) where isFix = x == 0 || isNaN x || isInfinite x (m,n) = decodeFloat x @@ -157,7 +157,7 @@ instance RealFloat Double where in (m, n)) instance RealFloat CFloat where - atan2 = lift2 mkAtan2 + atan2 = mkAtan2 isNaN = mkIsNaN . mkBitcast @Float isInfinite = mkIsInfinite . mkBitcast @Float isDenormalized = ieee754 "isDenormalized" (ieee754_f32_is_denormalized . mkBitcast) @@ -167,7 +167,7 @@ instance RealFloat CFloat where encodeFloat x e = mkBitcast (encodeFloat @Float x e) instance RealFloat CDouble where - atan2 = lift2 mkAtan2 + atan2 = mkAtan2 isNaN = mkIsNaN . mkBitcast @Double isInfinite = mkIsInfinite . mkBitcast @Double isDenormalized = ieee754 "isDenormalized" (ieee754_f64_is_denormalized . mkBitcast) @@ -201,13 +201,6 @@ preludeError x ] -lift2 :: (Elt a, Elt b, IsScalar b, b ~ EltRepr a) - => (Exp b -> Exp b -> Exp b) - -> Exp a - -> Exp a - -> Exp a -lift2 f x y = mkUnsafeCoerce (f (mkUnsafeCoerce x) (mkUnsafeCoerce y)) - ieee754 :: forall a b. P.RealFloat a => String -> (Exp a -> b) -> Exp a -> b ieee754 name f x | P.isIEEE (undefined::a) = f x @@ -325,19 +318,19 @@ ieee754_f16_decode i = exp2 = exp1 + 1 T2 high3 exp3 - = Exp $ Cond (exp1 /= _HMINEXP) - -- don't add hidden bit to denorms - (T2 (high2 .|. _HHIGHBIT) exp1) - -- a denorm, normalise the mantissa - (Exp $ While (\(T2 h _) -> (h .&. _HHIGHBIT) /= 0 ) - (\(T2 h e) -> T2 (h `unsafeShiftL` 1) (e-1)) - (T2 high2 exp2)) - - high4 = Exp $ Cond (fromIntegral i < (0 :: Exp Int16)) (-high3) high3 + = cond (exp1 /= _HMINEXP) + -- don't add hidden bit to denorms + (T2 (high2 .|. _HHIGHBIT) exp1) + -- a denorm, normalise the mantissa + (while (\(T2 h _) -> (h .&. _HHIGHBIT) /= 0 ) + (\(T2 h e) -> T2 (h `unsafeShiftL` 1) (e-1)) + (T2 high2 exp2)) + + high4 = cond (fromIntegral i < (0 :: Exp Int16)) (-high3) high3 in - Exp $ Cond (high1 .&. complement _HMSBIT == 0) - (tup2 (0,0)) - (tup2 (high4, exp3)) + cond (high1 .&. complement _HMSBIT == 0) + (T2 0 0) + (T2 high4 exp3) -- From: ghc/rts/StgPrimFloat.c @@ -359,19 +352,19 @@ ieee754_f32_decode i = exp2 = exp1 + 1 T2 high3 exp3 - = Exp $ Cond (exp1 /= _FMINEXP) - -- don't add hidden bit to denorms - (T2 (high2 .|. _FHIGHBIT) exp1) - -- a denorm, normalise the mantissa - (Exp $ While (\(T2 h _) -> (h .&. _FHIGHBIT) /= 0 ) - (\(T2 h e) -> T2 (h `unsafeShiftL` 1) (e-1)) - (T2 high2 exp2)) - - high4 = Exp $ Cond (fromIntegral i < (0 :: Exp Int32)) (-high3) high3 + = cond (exp1 /= _FMINEXP) + -- don't add hidden bit to denorms + (T2 (high2 .|. _FHIGHBIT) exp1) + -- a denorm, normalise the mantissa + (while (\(T2 h _) -> (h .&. _FHIGHBIT) /= 0 ) + (\(T2 h e) -> T2 (h `unsafeShiftL` 1) (e-1)) + (T2 high2 exp2)) + + high4 = cond (fromIntegral i < (0 :: Exp Int32)) (-high3) high3 in - Exp $ Cond (high1 .&. complement _FMSBIT == 0) - (T2 0 0) - (T2 high4 exp3) + cond (high1 .&. complement _FMSBIT == 0) + (T2 0 0) + (T2 high4 exp3) ieee754_f64_decode :: Exp Word64 -> Exp (Int64, Int) @@ -392,25 +385,30 @@ ieee754_f64_decode2 i = high = fromIntegral (i `unsafeShiftR` 32) iexp = (fromIntegral ((high `unsafeShiftR` 20) .&. 0x7FF) + _DMINEXP) - sign = Exp $ Cond (fromIntegral i < (0 :: Exp Int64)) (-1) 1 + sign = cond (fromIntegral i < (0 :: Exp Int64)) (-1) 1 high2 = high .&. (_DHIGHBIT - 1) iexp2 = iexp + 1 T3 hi lo ie - = Exp $ Cond (iexp2 /= _DMINEXP) - -- don't add hidden bit to denorms - (T3 (high2 .|. _DHIGHBIT) low iexp) - -- a denorm, nermalise the mantissa - (Exp $ While (\(T3 h _ _) -> (h .&. _DHIGHBIT) /= 0) - (\(T3 h l e) -> - let h1 = h `unsafeShiftL` 1 - h2 = Exp $ Cond ((l .&. _DMSBIT) /= 0) (h1+1) h1 - in T3 h2 (l `unsafeShiftL` 1) (e-1)) - (T3 high2 low iexp2)) + = cond (iexp2 /= _DMINEXP) + -- don't add hidden bit to denorms + (T3 (high2 .|. _DHIGHBIT) low iexp) + -- a denorm, nermalise the mantissa + (while (\(T3 h _ _) -> (h .&. _DHIGHBIT) /= 0) + (\(T3 h l e) -> + let h1 = h `unsafeShiftL` 1 + h2 = cond ((l .&. _DMSBIT) /= 0) (h1+1) h1 + in T3 h2 (l `unsafeShiftL` 1) (e-1)) + (T3 high2 low iexp2)) in - Exp $ Cond (low == 0 && (high .&. (complement _DMSBIT)) == 0) - (T4 1 0 0 0) - (T4 sign hi lo ie) + cond (low == 0 && (high .&. (complement _DMSBIT)) == 0) + (T4 1 0 0 0) + (T4 sign hi lo ie) + +cond :: Exp Bool -> Exp a -> Exp a -> Exp a +cond (Exp c) (Exp x) (Exp y) = Exp $ SmartExp $ Cond c x y +while :: forall e. Elt e => (Exp e -> Exp Bool) -> (Exp e -> Exp e) -> Exp e -> Exp e +while c f (Exp e) = Exp $ SmartExp $ While (eltType @e) (unExp . c . Exp) (unExp . f . Exp) e diff --git a/src/Data/Array/Accelerate/Classes/RealFrac.hs b/src/Data/Array/Accelerate/Classes/RealFrac.hs index 200bcf7b5..d2048e923 100644 --- a/src/Data/Array/Accelerate/Classes/RealFrac.hs +++ b/src/Data/Array/Accelerate/Classes/RealFrac.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} @@ -39,7 +40,6 @@ import Data.Array.Accelerate.Classes.Num import Data.Array.Accelerate.Classes.ToFloating import {-# SOURCE #-} Data.Array.Accelerate.Classes.RealFloat -- defaultProperFraction -import Data.Typeable import Data.Maybe import Text.Printf import Prelude ( ($), String, error, unlines, otherwise ) @@ -130,17 +130,17 @@ instance RealFrac Double where instance RealFrac CFloat where properFraction = defaultProperFraction - truncate = lift1 defaultTruncate - round = lift1 defaultRound - ceiling = lift1 defaultCeiling - floor = lift1 defaultFloor + truncate = defaultTruncate + round = defaultRound + ceiling = defaultCeiling + floor = defaultFloor instance RealFrac CDouble where properFraction = defaultProperFraction - truncate = lift1 defaultTruncate - round = lift1 defaultRound - ceiling = lift1 defaultCeiling - floor = lift1 defaultFloor + truncate = defaultTruncate + round = defaultRound + ceiling = defaultCeiling + floor = defaultFloor -- Must test for ±0.0 to avoid returning -0.0 in the second component of the @@ -198,6 +198,7 @@ defaultFloor x | otherwise = let (n, r) = properFraction x in cond (r < 0) (n-1) n +-- mkRound :: (Elt a, Elt b, IsFloating (EltRepr a), IsIntegral (EltRepr b)) => Exp a -> Exp b defaultRound :: forall a b. (RealFrac a, Integral b, FromIntegral Int64 b) => Exp a -> Exp b defaultRound x | Just IsFloatingDict <- isFloating @a @@ -221,10 +222,9 @@ data IsFloatingDict a where data IsIntegralDict a where IsIntegralDict :: IsIntegral a => IsIntegralDict a -isFloating :: forall a. Elt a => Maybe (IsFloatingDict a) +isFloating :: forall a. Elt a => Maybe (IsFloatingDict (EltRepr a)) isFloating - | Just Refl <- eqT @a @(EltRepr a) - , TypeRscalar t <- eltType @a + | TupRsingle t <- eltType @a , SingleScalarType s <- t , NumSingleType n <- s , FloatingNumType f <- n @@ -236,10 +236,9 @@ isFloating | otherwise = Nothing -isIntegral :: forall a. Elt a => Maybe (IsIntegralDict a) +isIntegral :: forall a. Elt a => Maybe (IsIntegralDict (EltRepr a)) isIntegral - | Just Refl <- eqT @a @(EltRepr a) - , TypeRscalar t <- eltType @a + | TupRsingle t <- eltType @a , SingleScalarType s <- t , NumSingleType n <- s , IntegralNumType i <- n @@ -276,10 +275,3 @@ preludeError x , "These Prelude.RealFrac instances are present only to fulfil superclass" , "constraints for subsequent classes in the standard Haskell numeric hierarchy." ] - -lift1 :: (Elt a, Elt b, Elt c, IsScalar b, b ~ EltRepr a) - => (Exp b -> Exp c) - -> Exp a - -> Exp c -lift1 f x = f (mkUnsafeCoerce x) - diff --git a/src/Data/Array/Accelerate/Data/Bits.hs b/src/Data/Array/Accelerate/Data/Bits.hs index 013e20acd..ad17ec9fe 100644 --- a/src/Data/Array/Accelerate/Data/Bits.hs +++ b/src/Data/Array/Accelerate/Data/Bits.hs @@ -366,200 +366,200 @@ instance Bits Word64 where popCount = mkPopCount instance Bits CInt where - (.&.) = lift2 mkBAnd - (.|.) = lift2 mkBOr - xor = lift2 mkBXor - complement = lift1 mkBNot + (.&.) = mkBAnd + (.|.) = mkBOr + xor = mkBXor + complement = mkBNot bit = mkBitcast . bitDefault @Int32 testBit b = testBitDefault (mkBitcast @Int32 b) - shift = lift2' shiftDefault - shiftL = lift2' shiftLDefault - shiftR = lift2' shiftRDefault - unsafeShiftL = lift2' mkBShiftL - unsafeShiftR = lift2' mkBShiftR - rotate = lift2' rotateDefault - rotateL = lift2' rotateLDefault - rotateR = lift2' rotateRDefault + shift = shiftDefault + shiftL = shiftLDefault + shiftR = shiftRDefault + unsafeShiftL = mkBShiftL + unsafeShiftR = mkBShiftR + rotate = rotateDefault + rotateL = rotateLDefault + rotateR = rotateRDefault isSigned = isSignedDefault popCount = mkPopCount . mkBitcast @Int32 instance Bits CUInt where - (.&.) = lift2 mkBAnd - (.|.) = lift2 mkBOr - xor = lift2 mkBXor - complement = lift1 mkBNot + (.&.) = mkBAnd + (.|.) = mkBOr + xor = mkBXor + complement = mkBNot bit = mkBitcast . bitDefault @Word32 testBit b = testBitDefault (mkBitcast @Word32 b) - shift = lift2' shiftDefault - shiftL = lift2' shiftLDefault - shiftR = lift2' shiftRDefault - unsafeShiftL = lift2' mkBShiftL - unsafeShiftR = lift2' mkBShiftR - rotate = lift2' rotateDefault - rotateL = lift2' rotateLDefault - rotateR = lift2' rotateRDefault + shift = shiftDefault + shiftL = shiftLDefault + shiftR = shiftRDefault + unsafeShiftL = mkBShiftL + unsafeShiftR = mkBShiftR + rotate = rotateDefault + rotateL = rotateLDefault + rotateR = rotateRDefault isSigned = isSignedDefault popCount = mkPopCount . mkBitcast @Word32 instance Bits CLong where - (.&.) = lift2 mkBAnd - (.|.) = lift2 mkBOr - xor = lift2 mkBXor - complement = lift1 mkBNot + (.&.) = mkBAnd + (.|.) = mkBOr + xor = mkBXor + complement = mkBNot bit = mkBitcast . bitDefault @HTYPE_CLONG testBit b = testBitDefault (mkBitcast @HTYPE_CLONG b) - shift = lift2' shiftDefault - shiftL = lift2' shiftLDefault - shiftR = lift2' shiftRDefault - unsafeShiftL = lift2' mkBShiftL - unsafeShiftR = lift2' mkBShiftR - rotate = lift2' rotateDefault - rotateL = lift2' rotateLDefault - rotateR = lift2' rotateRDefault + shift = shiftDefault + shiftL = shiftLDefault + shiftR = shiftRDefault + unsafeShiftL = mkBShiftL + unsafeShiftR = mkBShiftR + rotate = rotateDefault + rotateL = rotateLDefault + rotateR = rotateRDefault isSigned = isSignedDefault popCount = mkPopCount . mkBitcast @HTYPE_CLONG instance Bits CULong where - (.&.) = lift2 mkBAnd - (.|.) = lift2 mkBOr - xor = lift2 mkBXor - complement = lift1 mkBNot + (.&.) = mkBAnd + (.|.) = mkBOr + xor = mkBXor + complement = mkBNot bit = mkBitcast . bitDefault @HTYPE_CULONG testBit b = testBitDefault (mkBitcast @HTYPE_CULONG b) - shift = lift2' shiftDefault - shiftL = lift2' shiftLDefault - shiftR = lift2' shiftRDefault - unsafeShiftL = lift2' mkBShiftL - unsafeShiftR = lift2' mkBShiftR - rotate = lift2' rotateDefault - rotateL = lift2' rotateLDefault - rotateR = lift2' rotateRDefault + shift = shiftDefault + shiftL = shiftLDefault + shiftR = shiftRDefault + unsafeShiftL = mkBShiftL + unsafeShiftR = mkBShiftR + rotate = rotateDefault + rotateL = rotateLDefault + rotateR = rotateRDefault isSigned = isSignedDefault popCount = mkPopCount . mkBitcast @HTYPE_CULONG instance Bits CLLong where - (.&.) = lift2 mkBAnd - (.|.) = lift2 mkBOr - xor = lift2 mkBXor - complement = lift1 mkBNot + (.&.) = mkBAnd + (.|.) = mkBOr + xor = mkBXor + complement = mkBNot bit = mkBitcast . bitDefault @Int64 testBit b = testBitDefault (mkBitcast @Int64 b) - shift = lift2' shiftDefault - shiftL = lift2' shiftLDefault - shiftR = lift2' shiftRDefault - unsafeShiftL = lift2' mkBShiftL - unsafeShiftR = lift2' mkBShiftR - rotate = lift2' rotateDefault - rotateL = lift2' rotateLDefault - rotateR = lift2' rotateRDefault + shift = shiftDefault + shiftL = shiftLDefault + shiftR = shiftRDefault + unsafeShiftL = mkBShiftL + unsafeShiftR = mkBShiftR + rotate = rotateDefault + rotateL = rotateLDefault + rotateR = rotateRDefault isSigned = isSignedDefault popCount = mkPopCount . mkBitcast @Int64 instance Bits CULLong where - (.&.) = lift2 mkBAnd - (.|.) = lift2 mkBOr - xor = lift2 mkBXor - complement = lift1 mkBNot + (.&.) = mkBAnd + (.|.) = mkBOr + xor = mkBXor + complement = mkBNot bit = mkBitcast . bitDefault @Word64 testBit b = testBitDefault (mkBitcast @Word64 b) - shift = lift2' shiftDefault - shiftL = lift2' shiftLDefault - shiftR = lift2' shiftRDefault - unsafeShiftL = lift2' mkBShiftL - unsafeShiftR = lift2' mkBShiftR - rotate = lift2' rotateDefault - rotateL = lift2' rotateLDefault - rotateR = lift2' rotateRDefault + shift = shiftDefault + shiftL = shiftLDefault + shiftR = shiftRDefault + unsafeShiftL = mkBShiftL + unsafeShiftR = mkBShiftR + rotate = rotateDefault + rotateL = rotateLDefault + rotateR = rotateRDefault isSigned = isSignedDefault popCount = mkPopCount . mkBitcast @Word64 instance Bits CShort where - (.&.) = lift2 mkBAnd - (.|.) = lift2 mkBOr - xor = lift2 mkBXor - complement = lift1 mkBNot + (.&.) = mkBAnd + (.|.) = mkBOr + xor = mkBXor + complement = mkBNot bit = mkBitcast . bitDefault @Int16 testBit b = testBitDefault (mkBitcast @Int16 b) - shift = lift2' shiftDefault - shiftL = lift2' shiftLDefault - shiftR = lift2' shiftRDefault - unsafeShiftL = lift2' mkBShiftL - unsafeShiftR = lift2' mkBShiftR - rotate = lift2' rotateDefault - rotateL = lift2' rotateLDefault - rotateR = lift2' rotateRDefault + shift = shiftDefault + shiftL = shiftLDefault + shiftR = shiftRDefault + unsafeShiftL = mkBShiftL + unsafeShiftR = mkBShiftR + rotate = rotateDefault + rotateL = rotateLDefault + rotateR = rotateRDefault isSigned = isSignedDefault popCount = mkPopCount . mkBitcast @Int16 instance Bits CUShort where - (.&.) = lift2 mkBAnd - (.|.) = lift2 mkBOr - xor = lift2 mkBXor - complement = lift1 mkBNot + (.&.) = mkBAnd + (.|.) = mkBOr + xor = mkBXor + complement = mkBNot bit = mkBitcast . bitDefault @Word16 testBit b = testBitDefault (mkBitcast @Word16 b) - shift = lift2' shiftDefault - shiftL = lift2' shiftLDefault - shiftR = lift2' shiftRDefault - unsafeShiftL = lift2' mkBShiftL - unsafeShiftR = lift2' mkBShiftR - rotate = lift2' rotateDefault - rotateL = lift2' rotateLDefault - rotateR = lift2' rotateRDefault + shift = shiftDefault + shiftL = shiftLDefault + shiftR = shiftRDefault + unsafeShiftL = mkBShiftL + unsafeShiftR = mkBShiftR + rotate = rotateDefault + rotateL = rotateLDefault + rotateR = rotateRDefault isSigned = isSignedDefault popCount = mkPopCount . mkBitcast @Word16 instance Bits CChar where - (.&.) = lift2 mkBAnd - (.|.) = lift2 mkBOr - xor = lift2 mkBXor - complement = lift1 mkBNot + (.&.) = mkBAnd + (.|.) = mkBOr + xor = mkBXor + complement = mkBNot bit = mkBitcast . bitDefault @HTYPE_CCHAR testBit b = testBitDefault (mkBitcast @HTYPE_CCHAR b) - shift = lift2' shiftDefault - shiftL = lift2' shiftLDefault - shiftR = lift2' shiftRDefault - unsafeShiftL = lift2' mkBShiftL - unsafeShiftR = lift2' mkBShiftR - rotate = lift2' rotateDefault - rotateL = lift2' rotateLDefault - rotateR = lift2' rotateRDefault + shift = shiftDefault + shiftL = shiftLDefault + shiftR = shiftRDefault + unsafeShiftL = mkBShiftL + unsafeShiftR = mkBShiftR + rotate = rotateDefault + rotateL = rotateLDefault + rotateR = rotateRDefault isSigned = isSignedDefault popCount = mkPopCount . mkBitcast @HTYPE_CCHAR instance Bits CSChar where - (.&.) = lift2 mkBAnd - (.|.) = lift2 mkBOr - xor = lift2 mkBXor - complement = lift1 mkBNot + (.&.) = mkBAnd + (.|.) = mkBOr + xor = mkBXor + complement = mkBNot bit = mkBitcast . bitDefault @Int8 testBit b = testBitDefault (mkBitcast @Int8 b) - shift = lift2' shiftDefault - shiftL = lift2' shiftLDefault - shiftR = lift2' shiftRDefault - unsafeShiftL = lift2' mkBShiftL - unsafeShiftR = lift2' mkBShiftR - rotate = lift2' rotateDefault - rotateL = lift2' rotateLDefault - rotateR = lift2' rotateRDefault + shift = shiftDefault + shiftL = shiftLDefault + shiftR = shiftRDefault + unsafeShiftL = mkBShiftL + unsafeShiftR = mkBShiftR + rotate = rotateDefault + rotateL = rotateLDefault + rotateR = rotateRDefault isSigned = isSignedDefault popCount = mkPopCount . mkBitcast @Int8 instance Bits CUChar where - (.&.) = lift2 mkBAnd - (.|.) = lift2 mkBOr - xor = lift2 mkBXor - complement = lift1 mkBNot + (.&.) = mkBAnd + (.|.) = mkBOr + xor = mkBXor + complement = mkBNot bit = mkBitcast . bitDefault @Word8 testBit b = testBitDefault (mkBitcast @Word8 b) - shift = lift2' shiftDefault - shiftL = lift2' shiftLDefault - shiftR = lift2' shiftRDefault - unsafeShiftL = lift2' mkBShiftL - unsafeShiftR = lift2' mkBShiftR - rotate = lift2' rotateDefault - rotateL = lift2' rotateLDefault - rotateR = lift2' rotateRDefault + shift = shiftDefault + shiftL = shiftLDefault + shiftR = shiftRDefault + unsafeShiftL = mkBShiftL + unsafeShiftR = mkBShiftR + rotate = rotateDefault + rotateL = rotateLDefault + rotateR = rotateRDefault isSigned = isSignedDefault popCount = mkPopCount . mkBitcast @Word8 @@ -569,176 +569,154 @@ instance Bits CUChar where -- ------------------------ instance FiniteBits Bool where - finiteBitSize _ = constant 8 -- stored as Word8 {- (B.finiteBitSize (undefined::Bool)) -} + finiteBitSize _ = constInt 8 -- stored as Word8 {- (B.finiteBitSize (undefined::Bool)) -} countLeadingZeros x = cond x 0 1 countTrailingZeros x = cond x 0 1 instance FiniteBits Int where - finiteBitSize _ = constant (B.finiteBitSize (undefined::Int)) + finiteBitSize _ = constInt (B.finiteBitSize (undefined::Int)) countLeadingZeros = mkCountLeadingZeros countTrailingZeros = mkCountTrailingZeros instance FiniteBits Int8 where - finiteBitSize _ = constant (B.finiteBitSize (undefined::Int8)) + finiteBitSize _ = constInt (B.finiteBitSize (undefined::Int8)) countLeadingZeros = mkCountLeadingZeros countTrailingZeros = mkCountTrailingZeros instance FiniteBits Int16 where - finiteBitSize _ = constant (B.finiteBitSize (undefined::Int16)) + finiteBitSize _ = constInt (B.finiteBitSize (undefined::Int16)) countLeadingZeros = mkCountLeadingZeros countTrailingZeros = mkCountTrailingZeros instance FiniteBits Int32 where - finiteBitSize _ = constant (B.finiteBitSize (undefined::Int32)) + finiteBitSize _ = constInt (B.finiteBitSize (undefined::Int32)) countLeadingZeros = mkCountLeadingZeros countTrailingZeros = mkCountTrailingZeros instance FiniteBits Int64 where - finiteBitSize _ = constant (B.finiteBitSize (undefined::Int64)) + finiteBitSize _ = constInt (B.finiteBitSize (undefined::Int64)) countLeadingZeros = mkCountLeadingZeros countTrailingZeros = mkCountTrailingZeros instance FiniteBits Word where - finiteBitSize _ = constant (B.finiteBitSize (undefined::Word)) + finiteBitSize _ = constInt (B.finiteBitSize (undefined::Word)) countLeadingZeros = mkCountLeadingZeros countTrailingZeros = mkCountTrailingZeros instance FiniteBits Word8 where - finiteBitSize _ = constant (B.finiteBitSize (undefined::Word8)) + finiteBitSize _ = constInt (B.finiteBitSize (undefined::Word8)) countLeadingZeros = mkCountLeadingZeros countTrailingZeros = mkCountTrailingZeros instance FiniteBits Word16 where - finiteBitSize _ = constant (B.finiteBitSize (undefined::Word16)) + finiteBitSize _ = constInt (B.finiteBitSize (undefined::Word16)) countLeadingZeros = mkCountLeadingZeros countTrailingZeros = mkCountTrailingZeros instance FiniteBits Word32 where - finiteBitSize _ = constant (B.finiteBitSize (undefined::Word32)) + finiteBitSize _ = constInt (B.finiteBitSize (undefined::Word32)) countLeadingZeros = mkCountLeadingZeros countTrailingZeros = mkCountTrailingZeros instance FiniteBits Word64 where - finiteBitSize _ = constant (B.finiteBitSize (undefined::Word64)) + finiteBitSize _ = constInt (B.finiteBitSize (undefined::Word64)) countLeadingZeros = mkCountLeadingZeros countTrailingZeros = mkCountTrailingZeros instance FiniteBits CInt where - finiteBitSize _ = constant (B.finiteBitSize (undefined::CInt)) + finiteBitSize _ = constInt (B.finiteBitSize (undefined::CInt)) countLeadingZeros = mkCountLeadingZeros . mkBitcast @Int32 countTrailingZeros = mkCountTrailingZeros . mkBitcast @Int32 instance FiniteBits CUInt where - finiteBitSize _ = constant (B.finiteBitSize (undefined::CUInt)) + finiteBitSize _ = constInt (B.finiteBitSize (undefined::CUInt)) countLeadingZeros = mkCountLeadingZeros . mkBitcast @Word32 countTrailingZeros = mkCountTrailingZeros . mkBitcast @Word32 instance FiniteBits CLong where - finiteBitSize _ = constant (B.finiteBitSize (undefined::CLong)) + finiteBitSize _ = constInt (B.finiteBitSize (undefined::CLong)) countLeadingZeros = mkCountLeadingZeros . mkBitcast @HTYPE_CLONG countTrailingZeros = mkCountTrailingZeros . mkBitcast @HTYPE_CLONG instance FiniteBits CULong where - finiteBitSize _ = constant (B.finiteBitSize (undefined::CULong)) + finiteBitSize _ = constInt (B.finiteBitSize (undefined::CULong)) countLeadingZeros = mkCountLeadingZeros . mkBitcast @HTYPE_CULONG countTrailingZeros = mkCountTrailingZeros . mkBitcast @HTYPE_CULONG instance FiniteBits CLLong where - finiteBitSize _ = constant (B.finiteBitSize (undefined::CLLong)) + finiteBitSize _ = constInt (B.finiteBitSize (undefined::CLLong)) countLeadingZeros = mkCountLeadingZeros . mkBitcast @Int64 countTrailingZeros = mkCountTrailingZeros . mkBitcast @Int64 instance FiniteBits CULLong where - finiteBitSize _ = constant (B.finiteBitSize (undefined::CULLong)) + finiteBitSize _ = constInt (B.finiteBitSize (undefined::CULLong)) countLeadingZeros = mkCountLeadingZeros . mkBitcast @Word64 countTrailingZeros = mkCountTrailingZeros . mkBitcast @Word64 instance FiniteBits CShort where - finiteBitSize _ = constant (B.finiteBitSize (undefined::CShort)) + finiteBitSize _ = constInt (B.finiteBitSize (undefined::CShort)) countLeadingZeros = mkCountLeadingZeros . mkBitcast @Int16 countTrailingZeros = mkCountTrailingZeros . mkBitcast @Int16 instance FiniteBits CUShort where - finiteBitSize _ = constant (B.finiteBitSize (undefined::CUShort)) + finiteBitSize _ = constInt (B.finiteBitSize (undefined::CUShort)) countLeadingZeros = mkCountLeadingZeros . mkBitcast @Word16 countTrailingZeros = mkCountTrailingZeros . mkBitcast @Word16 instance FiniteBits CChar where - finiteBitSize _ = constant (B.finiteBitSize (undefined::CChar)) + finiteBitSize _ = constInt (B.finiteBitSize (undefined::CChar)) countLeadingZeros = mkCountLeadingZeros . mkBitcast @HTYPE_CCHAR countTrailingZeros = mkCountTrailingZeros . mkBitcast @HTYPE_CCHAR instance FiniteBits CSChar where - finiteBitSize _ = constant (B.finiteBitSize (undefined::CSChar)) + finiteBitSize _ = constInt (B.finiteBitSize (undefined::CSChar)) countLeadingZeros = mkCountLeadingZeros . mkBitcast @Int8 countTrailingZeros = mkCountTrailingZeros . mkBitcast @Int8 instance FiniteBits CUChar where - finiteBitSize _ = constant (B.finiteBitSize (undefined::CUChar)) + finiteBitSize _ = constInt (B.finiteBitSize (undefined::CUChar)) countLeadingZeros = mkCountLeadingZeros . mkBitcast @Word8 countTrailingZeros = mkCountTrailingZeros . mkBitcast @Word8 -- Default implementations -- ----------------------- +bitDefault :: (IsIntegral (EltRepr t), Bits t) => Exp Int -> Exp t +bitDefault x = constInt 1 `shiftL` x -lift1 :: (Elt a, Elt b, IsScalar b, b ~ EltRepr a) - => (Exp b -> Exp b) - -> Exp a - -> Exp a -lift1 f x = mkUnsafeCoerce (f (mkUnsafeCoerce x)) - -lift2 :: (Elt a, Elt b, IsScalar b, b ~ EltRepr a) - => (Exp b -> Exp b -> Exp b) - -> Exp a - -> Exp a - -> Exp a -lift2 f x y = mkUnsafeCoerce (f (mkUnsafeCoerce x) (mkUnsafeCoerce y)) - -lift2' :: (Elt a, Elt b, IsScalar b, b ~ EltRepr a) - => (Exp b -> Exp Int -> Exp b) - -> Exp a - -> Exp Int - -> Exp a -lift2' f x y = mkUnsafeCoerce (f (mkUnsafeCoerce x) y) +testBitDefault :: (IsIntegral (EltRepr t), Bits t) => Exp t -> Exp Int -> Exp Bool +testBitDefault x i = (x .&. bit i) /= constInt 0 - -bitDefault :: (IsIntegral t, Bits t) => Exp Int -> Exp t -bitDefault x = constant 1 `shiftL` x - -testBitDefault :: (IsIntegral t, Bits t) => Exp t -> Exp Int -> Exp Bool -testBitDefault x i = (x .&. bit i) /= constant 0 - -shiftDefault :: (FiniteBits t, IsIntegral t, B.Bits t) => Exp t -> Exp Int -> Exp t +shiftDefault :: (FiniteBits t, IsIntegral (EltRepr t), B.Bits t) => Exp t -> Exp Int -> Exp t shiftDefault x i = cond (i >= 0) (shiftLDefault x i) (shiftRDefault x (-i)) -shiftLDefault :: (FiniteBits t, IsIntegral t) => Exp t -> Exp Int -> Exp t +shiftLDefault :: (FiniteBits t, IsIntegral (EltRepr t)) => Exp t -> Exp Int -> Exp t shiftLDefault x i - = cond (i >= finiteBitSize x) (constant 0) + = cond (i >= finiteBitSize x) (constInt 0) $ mkBShiftL x i -shiftRDefault :: forall t. (B.Bits t, FiniteBits t, IsIntegral t) => Exp t -> Exp Int -> Exp t +shiftRDefault :: forall t. (B.Bits t, FiniteBits t, IsIntegral (EltRepr t)) => Exp t -> Exp Int -> Exp t shiftRDefault | B.isSigned (undefined::t) = shiftRADefault | otherwise = shiftRLDefault -- Shift the argument right (signed) -shiftRADefault :: (FiniteBits t, IsIntegral t) => Exp t -> Exp Int -> Exp t +shiftRADefault :: (FiniteBits t, IsIntegral (EltRepr t)) => Exp t -> Exp Int -> Exp t shiftRADefault x i - = cond (i >= finiteBitSize x) (cond (mkLt x (constant 0)) (constant (-1)) (constant 0)) + = cond (i >= finiteBitSize x) (cond (mkLt x (constInt 0)) (constInt (-1)) (constInt 0)) $ mkBShiftR x i -- Shift the argument right (unsigned) -shiftRLDefault :: (FiniteBits t, IsIntegral t) => Exp t -> Exp Int -> Exp t +shiftRLDefault :: (FiniteBits t, IsIntegral (EltRepr t)) => Exp t -> Exp Int -> Exp t shiftRLDefault x i - = cond (i >= finiteBitSize x) (constant 0) + = cond (i >= finiteBitSize x) (constInt 0) $ mkBShiftR x i -rotateDefault :: forall t. (FiniteBits t, IsIntegral t) => Exp t -> Exp Int -> Exp t +rotateDefault :: forall t. (FiniteBits t, IsIntegral (EltRepr t)) => Exp t -> Exp Int -> Exp t rotateDefault = - case (integralType :: IntegralType t) of + case integralType :: IntegralType (EltRepr t) of TypeInt{} -> rotateDefault' (undefined::Word) TypeInt8{} -> rotateDefault' (undefined::Word8) TypeInt16{} -> rotateDefault' (undefined::Word16) @@ -751,7 +729,7 @@ rotateDefault = TypeWord64{} -> rotateDefault' (undefined::Word64) rotateDefault' - :: forall i w. (Elt w, FiniteBits i, IsIntegral i, IsIntegral w, IsIntegral (EltRepr i), IsIntegral (EltRepr w), BitSizeEq (EltRepr i) (EltRepr w), BitSizeEq (EltRepr w) (EltRepr i)) + :: forall i w. (Elt w, FiniteBits i, IsIntegral (EltRepr i), IsIntegral (EltRepr w), IsIntegral (EltRepr i), IsIntegral (EltRepr w), BitSizeEq (EltRepr i) (EltRepr w), BitSizeEq (EltRepr w) (EltRepr i)) => w {- dummy -} -> Exp i -> Exp Int @@ -767,12 +745,12 @@ rotateDefault' _ x i i' = i `mkBAnd` (wsib - 1) wsib = finiteBitSize x -rotateLDefault :: (Elt t, IsIntegral t) => Exp t -> Exp Int -> Exp t +rotateLDefault :: (Elt t, IsIntegral (EltRepr t)) => Exp t -> Exp Int -> Exp t rotateLDefault x i = cond (i == 0) x $ mkBRotateL x i -rotateRDefault :: (Elt t, IsIntegral t) => Exp t -> Exp Int -> Exp t +rotateRDefault :: (Elt t, IsIntegral (EltRepr t)) => Exp t -> Exp Int -> Exp t rotateRDefault x i = cond (i == 0) x $ mkBRotateR x i @@ -780,6 +758,9 @@ rotateRDefault x i isSignedDefault :: forall b. B.Bits b => Exp b -> Exp Bool isSignedDefault _ = constant (B.isSigned (undefined::b)) +constInt :: IsIntegral (EltRepr e) => EltRepr e -> Exp e +constInt = exp . Const (SingleScalarType $ NumSingleType $ IntegralNumType $ integralType) + {-- _popCountDefault :: forall a. (B.FiniteBits a, IsScalar a, Bits a, Num a) => Exp a -> Exp Int _popCountDefault = diff --git a/src/Data/Array/Accelerate/Data/Complex.hs b/src/Data/Array/Accelerate/Data/Complex.hs index d5f1d0e35..c10b923d9 100644 --- a/src/Data/Array/Accelerate/Data/Complex.hs +++ b/src/Data/Array/Accelerate/Data/Complex.hs @@ -2,6 +2,8 @@ {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE MagicHash #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RebindableSyntax #-} @@ -48,91 +50,119 @@ import Data.Array.Accelerate.Classes import Data.Array.Accelerate.Data.Functor import Data.Array.Accelerate.Pattern import Data.Array.Accelerate.Prelude -import Data.Array.Accelerate.Product -import Data.Array.Accelerate.Smart +import Data.Array.Accelerate.Smart hiding (exp) import Data.Array.Accelerate.Type import Data.Complex ( Complex(..) ) import qualified Data.Complex as C +import Prelude (($)) import qualified Prelude as P - infix 6 ::+ -pattern (::+) :: (Elt a, Elt (Complex a)) => Exp a -> Exp a -> Exp (Complex a) -pattern r ::+ c = Pattern (r, c) +pattern (::+) :: Elt a => Exp a -> Exp a -> Exp (Complex a) +pattern r ::+ i <- (deconstructComplex -> (r, i)) + where (::+) = constructComplex {-# COMPLETE (::+) #-} --- Use an array-of-structs representation for complex numbers. This matches the --- standard C-style layout, but means that we can define instances only at --- specific types (not for any type 'a') as we can only have vectors of --- primitive type. +-- Use an array-of-structs representation for complex numbers if possible. +-- This matches the standard C-style layout, but we can use this representation only at +-- specific types (not for any type 'a') as we can only have vectors of primitive type. +-- For other types, we use a structure-of-arrays representation. This is handled by the +-- ComplexRepr. We use the GADT ComplexR and function complexR to reconstruct +-- information on how the elements are represented. -- - -instance Elt (Complex Half) where - type EltRepr (Complex Half) = V2 Half - {-# INLINE eltType #-} - {-# INLINE [1] toElt #-} - {-# INLINE [1] fromElt #-} - eltType = TypeRscalar scalarType - toElt (V2 r i) = r :+ i - fromElt (r :+ i) = V2 r i - -instance Elt (Complex Float) where - type EltRepr (Complex Float) = V2 Float +instance Elt a => Elt (Complex a) where + type EltRepr (Complex a) = ComplexRepr (EltRepr a) {-# INLINE eltType #-} {-# INLINE [1] toElt #-} {-# INLINE [1] fromElt #-} - eltType = TypeRscalar scalarType - toElt (V2 r i) = r :+ i - fromElt (r :+ i) = V2 r i - -instance Elt (Complex Double) where - type EltRepr (Complex Double) = V2 Double - {-# INLINE eltType #-} - {-# INLINE [1] toElt #-} - {-# INLINE [1] fromElt #-} - eltType = TypeRscalar scalarType - toElt (V2 r i) = r :+ i - fromElt (r :+ i) = V2 r i - -instance Elt (Complex CFloat) where - type EltRepr (Complex CFloat) = V2 Float - {-# INLINE eltType #-} - {-# INLINE [1] toElt #-} - {-# INLINE [1] fromElt #-} - eltType = TypeRscalar scalarType - toElt (V2 r i) = CFloat r :+ CFloat i - fromElt (CFloat r :+ CFloat i) = V2 r i - -instance Elt (Complex CDouble) where - type EltRepr (Complex CDouble) = V2 Double - {-# INLINE eltType #-} - {-# INLINE [1] toElt #-} - {-# INLINE [1] fromElt #-} - eltType = TypeRscalar scalarType - toElt (V2 r i) = CDouble r :+ CDouble i - fromElt (CDouble r :+ CDouble i) = V2 r i - -instance cst a => IsProduct cst (Complex a) where - type ProdRepr (Complex a) = ProdRepr (a,a) - fromProd (r :+ i) = fromProd @cst (r,i) - toProd p = let (r,i) = toProd @cst p in (r :+ i) - prod = prod @cst @(a,a) - -instance (Lift Exp a, Elt (Plain a), Elt (Complex (Plain a))) => Lift Exp (Complex a) where + eltType = case complexR tp of + ComplexRvec s -> TupRsingle $ VectorScalarType $ VectorType 2 s + ComplexRtup -> TupRunit `TupRpair` tp `TupRpair` tp + where + tp = eltType @a + toElt = case complexR $ eltType @a of + ComplexRvec _ -> \(V2 r i) -> toElt r :+ toElt i + ComplexRtup -> \(((), r), i) -> toElt r :+ toElt i + fromElt (r :+ i) = case complexR $ eltType @a of + ComplexRvec _ -> V2 (fromElt r) (fromElt i) + ComplexRtup -> (((), fromElt r), fromElt i) + +type family ComplexRepr a where + ComplexRepr Half = V2 Half + ComplexRepr Float = V2 Float + ComplexRepr Double = V2 Double + ComplexRepr Int = V2 Int + ComplexRepr Int8 = V2 Int8 + ComplexRepr Int16 = V2 Int16 + ComplexRepr Int32 = V2 Int32 + ComplexRepr Int64 = V2 Int64 + ComplexRepr Word = V2 Word + ComplexRepr Word8 = V2 Word8 + ComplexRepr Word16 = V2 Word16 + ComplexRepr Word32 = V2 Word32 + ComplexRepr Word64 = V2 Word64 + ComplexRepr a = Tup2 a a + +data ComplexR a c where + ComplexRvec :: VecElt a => SingleType a -> ComplexR a (V2 a) + ComplexRtup :: ComplexR a (Tup2 a a) + +complexR :: TupleType a -> ComplexR a (ComplexRepr a) +complexR (TupRsingle (SingleScalarType (NumSingleType (FloatingNumType TypeHalf )))) = ComplexRvec singleType +complexR (TupRsingle (SingleScalarType (NumSingleType (FloatingNumType TypeFloat )))) = ComplexRvec singleType +complexR (TupRsingle (SingleScalarType (NumSingleType (FloatingNumType TypeDouble)))) = ComplexRvec singleType +complexR (TupRsingle (SingleScalarType (NumSingleType (IntegralNumType TypeInt )))) = ComplexRvec singleType +complexR (TupRsingle (SingleScalarType (NumSingleType (IntegralNumType TypeInt8 )))) = ComplexRvec singleType +complexR (TupRsingle (SingleScalarType (NumSingleType (IntegralNumType TypeInt16 )))) = ComplexRvec singleType +complexR (TupRsingle (SingleScalarType (NumSingleType (IntegralNumType TypeInt32 )))) = ComplexRvec singleType +complexR (TupRsingle (SingleScalarType (NumSingleType (IntegralNumType TypeInt64 )))) = ComplexRvec singleType +complexR (TupRsingle (SingleScalarType (NumSingleType (IntegralNumType TypeWord )))) = ComplexRvec singleType +complexR (TupRsingle (SingleScalarType (NumSingleType (IntegralNumType TypeWord8 )))) = ComplexRvec singleType +complexR (TupRsingle (SingleScalarType (NumSingleType (IntegralNumType TypeWord16)))) = ComplexRvec singleType +complexR (TupRsingle (SingleScalarType (NumSingleType (IntegralNumType TypeWord32)))) = ComplexRvec singleType +complexR (TupRsingle (SingleScalarType (NumSingleType (IntegralNumType TypeWord64)))) = ComplexRvec singleType +complexR (TupRsingle (SingleScalarType (NonNumSingleType TypeChar))) = ComplexRtup +complexR (TupRsingle (SingleScalarType (NonNumSingleType TypeBool))) = ComplexRtup +complexR (TupRsingle (VectorScalarType (_))) = ComplexRtup +complexR TupRunit = ComplexRtup +complexR TupRpair{} = ComplexRtup + +constructComplex :: forall a. Elt a => Exp a -> Exp a -> Exp (Complex a) +constructComplex r i = case complexR $ eltType @a of + ComplexRvec _ -> + let + r', i' :: Exp (EltRepr a) + r' = reExp @a @(EltRepr a) r + i' = reExp i + v :: Exp (V2 (EltRepr a)) + v = V2_ r' i' + in + reExp @(V2 (EltRepr a)) @(Complex a) $ v + ComplexRtup -> reExp $ T2 r i + +deconstructComplex :: forall a. Elt a => Exp (Complex a) -> (Exp a, Exp a) +deconstructComplex c = case complexR $ eltType @a of + ComplexRvec _ -> let V2_ r i = reExp @(Complex a) @(V2 (EltRepr a)) c in (reExp r, reExp i) + ComplexRtup -> let T2 r i = reExp c in (r, i) + +reExp :: EltRepr a ~ EltRepr b => Exp a -> Exp b +reExp (Exp e) = Exp e + +instance (Lift Exp a, Elt (Plain a)) => Lift Exp (Complex a) where type Plain (Complex a) = Complex (Plain a) lift (r :+ i) = lift r ::+ lift i -instance (Elt a, Elt (Complex a)) => Unlift Exp (Complex (Exp a)) where +instance Elt a => Unlift Exp (Complex (Exp a)) where unlift (r ::+ i) = r :+ i -instance (Eq a, Elt (Complex a)) => Eq (Complex a) where +instance Eq a => Eq (Complex a) where r1 ::+ c1 == r2 ::+ c2 = r1 == r2 && c1 == c2 r1 ::+ c1 /= r2 ::+ c2 = r1 /= r2 || c1 /= c2 -instance (RealFloat a, Elt (Complex a)) => P.Num (Exp (Complex a)) where +instance RealFloat a => P.Num (Exp (Complex a)) where (+) = lift2 ((+) :: Complex (Exp a) -> Complex (Exp a) -> Complex (Exp a)) (-) = lift2 ((-) :: Complex (Exp a) -> Complex (Exp a) -> Complex (Exp a)) (*) = lift2 ((*) :: Complex (Exp a) -> Complex (Exp a) -> Complex (Exp a)) @@ -145,7 +175,7 @@ instance (RealFloat a, Elt (Complex a)) => P.Num (Exp (Complex a)) where abs z = magnitude z ::+ 0 fromInteger n = fromInteger n ::+ 0 -instance (RealFloat a, Elt (Complex a)) => P.Fractional (Exp (Complex a)) where +instance RealFloat a => P.Fractional (Exp (Complex a)) where fromRational x = fromRational x ::+ 0 z / z' = (x*x''+y*y'') / d ::+ (y*x''-x*y'') / d where @@ -157,7 +187,7 @@ instance (RealFloat a, Elt (Complex a)) => P.Fractional (Exp (Complex a)) where k = - max (exponent x') (exponent y') d = x'*x'' + y'*y'' -instance (RealFloat a, Elt (Complex a)) => P.Floating (Exp (Complex a)) where +instance RealFloat a => P.Floating (Exp (Complex a)) where pi = pi ::+ 0 exp (x ::+ y) = let expx = exp x in expx * cos y ::+ expx * sin y @@ -234,7 +264,7 @@ instance Functor Complex where -- | The non-negative magnitude of a complex number -- -magnitude :: (RealFloat a, Elt (Complex a)) => Exp (Complex a) -> Exp a +magnitude :: RealFloat a => Exp (Complex a) -> Exp a magnitude (r ::+ i) = scaleFloat k (sqrt (sqr (scaleFloat mk r) + sqr (scaleFloat mk i))) where k = max (exponent r) (exponent i) @@ -246,13 +276,13 @@ magnitude (r ::+ i) = scaleFloat k (sqrt (sqr (scaleFloat mk r) + sqr (scaleFloa -- -- @since 1.3.0.0 -- -magnitude' :: (RealFloat a, Elt (Complex a)) => Exp (Complex a) -> Exp a +magnitude' :: RealFloat a => Exp (Complex a) -> Exp a magnitude' (r ::+ i) = sqrt (r*r + i*i) -- | The phase of a complex number, in the range @(-'pi', 'pi']@. If the -- magnitude is zero, then so is the phase. -- -phase :: (RealFloat a, Elt (Complex a)) => Exp (Complex a) -> Exp a +phase :: RealFloat a => Exp (Complex a) -> Exp a phase z@(r ::+ i) = if z == 0 then 0 @@ -262,15 +292,15 @@ phase z@(r ::+ i) = -- phase) pair in canonical form: the magnitude is non-negative, and the phase -- in the range @(-'pi', 'pi']@; if the magnitude is zero, then so is the phase. -- -polar :: (RealFloat a, Elt (Complex a)) => Exp (Complex a) -> Exp (a,a) +polar :: RealFloat a => Exp (Complex a) -> Exp (a,a) polar z = T2 (magnitude z) (phase z) -- | Form a complex number from polar components of magnitude and phase. -- #if __GLASGOW_HASKELL__ <= 708 -mkPolar :: forall a. (RealFloat a, Elt (Complex a)) => Exp a -> Exp a -> Exp (Complex a) +mkPolar :: forall a. RealFloat a => Exp a -> Exp a -> Exp (Complex a) #else -mkPolar :: forall a. (Floating a, Elt (Complex a)) => Exp a -> Exp a -> Exp (Complex a) +mkPolar :: forall a. Floating a => Exp a -> Exp a -> Exp (Complex a) #endif mkPolar = lift2 (C.mkPolar :: Exp a -> Exp a -> Complex (Exp a)) @@ -278,26 +308,26 @@ mkPolar = lift2 (C.mkPolar :: Exp a -> Exp a -> Complex (Exp a)) -- @2*'pi'@). -- #if __GLASGOW_HASKELL__ <= 708 -cis :: forall a. (RealFloat a, Elt (Complex a)) => Exp a -> Exp (Complex a) +cis :: forall a. RealFloat a => Exp a -> Exp (Complex a) #else -cis :: forall a. (Floating a, Elt (Complex a)) => Exp a -> Exp (Complex a) +cis :: forall a. Floating a => Exp a -> Exp (Complex a) #endif cis = lift1 (C.cis :: Exp a -> Complex (Exp a)) -- | Return the real part of a complex number -- -real :: (Elt a, Elt (Complex a)) => Exp (Complex a) -> Exp a +real :: Elt a => Exp (Complex a) -> Exp a real (r ::+ _) = r -- | Return the imaginary part of a complex number -- -imag :: (Elt a, Elt (Complex a)) => Exp (Complex a) -> Exp a +imag :: Elt a => Exp (Complex a) -> Exp a imag (_ ::+ i) = i -- | Return the complex conjugate of a complex number, defined as -- -- > conjugate(Z) = X - iY -- -conjugate :: (Num a, Elt (Complex a)) => Exp (Complex a) -> Exp (Complex a) +conjugate :: Num a => Exp (Complex a) -> Exp (Complex a) conjugate z = real z ::+ (- imag z) diff --git a/src/Data/Array/Accelerate/Data/Either.hs b/src/Data/Array/Accelerate/Data/Either.hs index b0ebcd389..a427ac858 100644 --- a/src/Data/Array/Accelerate/Data/Either.hs +++ b/src/Data/Array/Accelerate/Data/Either.hs @@ -34,7 +34,7 @@ import Data.Array.Accelerate.Array.Sugar hiding ( (!) import Data.Array.Accelerate.Language hiding ( chr ) import Data.Array.Accelerate.Prelude hiding ( filter ) import Data.Array.Accelerate.Interpreter -import Data.Array.Accelerate.Product +import Data.Array.Accelerate.Pattern import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Type @@ -89,14 +89,16 @@ isRight x = tag x == 1 -- instead. -- fromLeft :: (Elt a, Elt b) => Exp (Either a b) -> Exp a -fromLeft x = Exp $ SuccTupIdx ZeroTupIdx `Prj` x +fromLeft x = a + where T3 _ a _ = asTuple x -- | The 'fromRight' function extracts the element out of the 'Right' -- constructor. If the argument was actually 'Left', you will get an undefined -- value instead. -- fromRight :: (Elt a, Elt b) => Exp (Either a b) -> Exp b -fromRight x = Exp $ ZeroTupIdx `Prj` x +fromRight x = b + where T3 _ _ b = asTuple x -- | The 'either' function performs case analysis on the 'Either' type. If the -- value is @'Left' a@, apply the first function to @a@; if it is @'Right' b@, @@ -145,31 +147,24 @@ instance (Elt a, Elt b) => Semigroup (Exp (Either a b)) where #endif tag :: (Elt a, Elt b) => Exp (Either a b) -> Exp Word8 -tag x = Exp $ SuccTupIdx (SuccTupIdx ZeroTupIdx) `Prj` x +tag x = t + where T3 t _ _ = asTuple x instance (Elt a, Elt b) => Elt (Either a b) where - type EltRepr (Either a b) = TupleRepr (Word8, EltRepr a, EltRepr b) + type EltRepr (Either a b) = Tup3 Word8 (EltRepr a) (EltRepr b) {-# INLINE eltType #-} {-# INLINE [1] toElt #-} {-# INLINE [1] fromElt #-} eltType = eltType @(Word8,a,b) toElt ((((),0),a),_) = Left (toElt a) toElt (_ ,b) = Right (toElt b) - fromElt (Left a) = ((((),0), fromElt a), fromElt (evalUndef @b)) - fromElt (Right b) = ((((),1), fromElt (evalUndef @a)), fromElt b) - -instance (Elt a, Elt b) => IsProduct Elt (Either a b) where - type ProdRepr (Either a b) = ProdRepr (Word8, a, b) - toProd ((((),0),a),_) = Left a - toProd (_ ,b) = Right b - fromProd (Left a) = ((((), 0), a), evalUndef @b) - fromProd (Right b) = ((((), 1), evalUndef @a), b) - prod = prod @Elt @(Word8,a,b) + fromElt (Left a) = ((((),0), fromElt a ), evalUndef $ eltType @b) + fromElt (Right b) = ((((),1), evalUndef $ eltType @a), fromElt b) instance (Lift Exp a, Lift Exp b, Elt (Plain a), Elt (Plain b)) => Lift Exp (Either a b) where type Plain (Either a b) = Either (Plain a) (Plain b) - lift (Left a) = Exp . Tuple $ NilTup `SnocTup` constant 0 `SnocTup` lift a `SnocTup` undef - lift (Right b) = Exp . Tuple $ NilTup `SnocTup` constant 1 `SnocTup` undef `SnocTup` lift b + lift (Left a) = toEither $ T3 (constant 0) (lift a) undef + lift (Right b) = toEither $ T3 (constant 1) undef (lift b) -- Utilities @@ -209,3 +204,9 @@ filter' keep arr emptyArray :: (Shape sh, Elt e) => Acc (Array sh e) emptyArray = fill (constant empty) undef +asTuple :: Exp (Either a b) -> Exp (Word8, a, b) +asTuple (Exp e) = Exp e + +toEither :: Exp (Word8, a, b) -> Exp (Either a b) +toEither (Exp e) = Exp e + diff --git a/src/Data/Array/Accelerate/Data/Maybe.hs b/src/Data/Array/Accelerate/Data/Maybe.hs index 8e5918edf..8ebf345c9 100644 --- a/src/Data/Array/Accelerate/Data/Maybe.hs +++ b/src/Data/Array/Accelerate/Data/Maybe.hs @@ -34,7 +34,6 @@ import Data.Array.Accelerate.Array.Sugar hiding ( (!) import Data.Array.Accelerate.Language hiding ( chr ) import Data.Array.Accelerate.Prelude hiding ( filter ) import Data.Array.Accelerate.Interpreter -import Data.Array.Accelerate.Product import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Type @@ -49,7 +48,7 @@ import Data.Array.Accelerate.Data.Semigroup #endif import Data.Maybe ( Maybe(..) ) -import Prelude ( (.), ($), const, otherwise ) +import Prelude ( ($), const, otherwise ) pattern Nothing_ :: Elt a => Exp (Maybe a) @@ -101,7 +100,7 @@ fromMaybe d x = cond (isNothing x) d (fromJust x) -- instead. -- fromJust :: Elt a => Exp (Maybe a) -> Exp a -fromJust x = Exp $ ZeroTupIdx `Prj` x +fromJust (Exp x) = Exp $ SmartExp $ PairIdxRight `Prj` x -- | The 'maybe' function takes a default value, a function, and a 'Maybe' -- value. If the 'Maybe' value is nothing, the default value is returned; @@ -151,32 +150,28 @@ instance (Semigroup (Exp a), Elt a) => Semigroup (Exp (Maybe a)) where tag :: Elt a => Exp (Maybe a) -> Exp Word8 -tag x = Exp $ SuccTupIdx ZeroTupIdx `Prj` x +tag (Exp x) = Exp $ SmartExp $ Prj PairIdxRight $ SmartExp $ Prj PairIdxLeft x instance Elt a => Elt (Maybe a) where - type EltRepr (Maybe a) = TupleRepr (Word8, EltRepr a) + type EltRepr (Maybe a) = Tup2 Word8 (EltRepr a) {-# INLINE eltType #-} {-# INLINE [1] toElt #-} {-# INLINE [1] fromElt #-} eltType = eltType @(Word8,a) toElt (((),0),_) = Nothing toElt (_ ,x) = Just (toElt x) - fromElt Nothing = (((),0), fromElt (evalUndef @a)) + fromElt Nothing = (((),0), evalUndef $ eltType @a) fromElt (Just a) = (((),1), fromElt a) -instance Elt a => IsProduct Elt (Maybe a) where - type ProdRepr (Maybe a) = ProdRepr (Word8, a) - toProd (((),0),_) = Nothing - toProd (_, x) = Just x - fromProd Nothing = (((), 0), evalUndef @a) - fromProd (Just a) = (((), 1), a) - prod = prod @Elt @(Word8,a) - instance (Lift Exp a, Elt (Plain a)) => Lift Exp (Maybe a) where type Plain (Maybe a) = Maybe (Plain a) - lift Nothing = Exp . Tuple $ NilTup `SnocTup` constant 0 `SnocTup` undef - lift (Just x) = Exp . Tuple $ NilTup `SnocTup` constant 1 `SnocTup` lift x + lift Nothing = Exp $ SmartExp $ Pair t $ unExp $ undef @(Plain a) + where + t = SmartExp $ Pair (SmartExp Nil) $ SmartExp $ Const scalarTypeWord8 0 + lift (Just x) = Exp $ SmartExp $ Pair t $ unExp $ lift x + where + t = SmartExp $ Pair (SmartExp Nil) $ SmartExp $ Const scalarTypeWord8 1 -- Utilities diff --git a/src/Data/Array/Accelerate/Data/Monoid.hs b/src/Data/Array/Accelerate/Data/Monoid.hs index 7c217a161..354e4bb60 100644 --- a/src/Data/Array/Accelerate/Data/Monoid.hs +++ b/src/Data/Array/Accelerate/Data/Monoid.hs @@ -43,7 +43,6 @@ import Data.Array.Accelerate.Classes.Ord import Data.Array.Accelerate.Language import Data.Array.Accelerate.Lift import Data.Array.Accelerate.Pattern -import Data.Array.Accelerate.Product import Data.Array.Accelerate.Type #if __GLASGOW_HASKELL__ >= 800 import Data.Array.Accelerate.Data.Semigroup () @@ -67,7 +66,6 @@ pattern Sum_ x = Pattern x {-# COMPLETE Sum_ #-} instance Elt a => Elt (Sum a) -instance Elt a => IsProduct Elt (Sum a) instance (Lift Exp a, Elt (Plain a)) => Lift Exp (Sum a) where type Plain (Sum a) = Sum (Plain a) @@ -127,7 +125,6 @@ pattern Product_ x = Pattern x {-# COMPLETE Product_ #-} instance Elt a => Elt (Product a) -instance Elt a => IsProduct Elt (Product a) instance (Lift Exp a, Elt (Plain a)) => Lift Exp (Product a) where type Plain (Product a) = Product (Plain a) diff --git a/src/Data/Array/Accelerate/Data/Ratio.hs b/src/Data/Array/Accelerate/Data/Ratio.hs index 10b8441a3..a7fbe7e88 100644 --- a/src/Data/Array/Accelerate/Data/Ratio.hs +++ b/src/Data/Array/Accelerate/Data/Ratio.hs @@ -34,7 +34,6 @@ import Data.Array.Accelerate.Language import Data.Array.Accelerate.Orphans () import Data.Array.Accelerate.Pattern import Data.Array.Accelerate.Prelude -import Data.Array.Accelerate.Product import Data.Array.Accelerate.Type import Data.Array.Accelerate.Classes.Enum @@ -55,7 +54,6 @@ import qualified Prelude as P instance Elt a => Elt (Ratio a) -instance Elt a => IsProduct Elt (Ratio a) pattern (:%) :: Elt a => Exp a -> Exp a -> Exp (Ratio a) pattern (:%) { numerator, denominator } = Pattern (numerator, denominator) diff --git a/src/Data/Array/Accelerate/Data/Semigroup.hs b/src/Data/Array/Accelerate/Data/Semigroup.hs index 718527738..c603f5e87 100644 --- a/src/Data/Array/Accelerate/Data/Semigroup.hs +++ b/src/Data/Array/Accelerate/Data/Semigroup.hs @@ -43,7 +43,6 @@ import Data.Array.Accelerate.Classes.Num import Data.Array.Accelerate.Classes.Ord import Data.Array.Accelerate.Lift import Data.Array.Accelerate.Pattern -import Data.Array.Accelerate.Product import Data.Array.Accelerate.Smart import Data.Function @@ -57,7 +56,6 @@ pattern Min_ x = Pattern x {-# COMPLETE Min_ #-} instance Elt a => Elt (Min a) -instance Elt a => IsProduct Elt (Min a) instance (Lift Exp a, Elt (Plain a)) => Lift Exp (Min a) where type Plain (Min a) = Min (Plain a) @@ -105,7 +103,6 @@ pattern Max_ x = Pattern x {-# COMPLETE Max_ #-} instance Elt a => Elt (Max a) -instance Elt a => IsProduct Elt (Max a) instance (Lift Exp a, Elt (Plain a)) => Lift Exp (Max a) where type Plain (Max a) = Max (Plain a) diff --git a/src/Data/Array/Accelerate/Interpreter.hs b/src/Data/Array/Accelerate/Interpreter.hs index e75d0c357..cdb7d6de6 100644 --- a/src/Data/Array/Accelerate/Interpreter.hs +++ b/src/Data/Array/Accelerate/Interpreter.hs @@ -42,15 +42,14 @@ module Data.Array.Accelerate.Interpreter ( - Smart.Acc, Arrays, + Smart.Acc, Sugar.Arrays, Afunction, AfunctionR, -- * Interpret an array expression run, run1, runN, -- Internal (hidden) - evalPrj, - evalPrim, evalPrimConst, evalUndef, evalCoerce, + evalPrim, evalPrimConst, evalUndef, evalUndefScalar, evalCoerceScalar, ) where @@ -63,25 +62,21 @@ import Data.Bits import Data.Char ( chr, ord ) import Data.Primitive.ByteArray import Data.Primitive.Types -import Data.Typeable import System.IO.Unsafe ( unsafePerformIO ) import Text.Printf ( printf ) import Unsafe.Coerce import Prelude hiding ( (!!), sum ) -- friends -import Data.Array.Accelerate.AST hiding ( Boundary, PreBoundary(..) ) -import Data.Array.Accelerate.Analysis.Match -import Data.Array.Accelerate.Analysis.Type ( sizeOfScalarType, sizeOfSingleType ) +import Data.Array.Accelerate.AST hiding ( Boundary(..) ) +import Data.Array.Accelerate.Analysis.Type ( sizeOfSingleType ) import Data.Array.Accelerate.Array.Data -import Data.Array.Accelerate.Array.Representation ( SliceIndex(..) ) -import Data.Array.Accelerate.Array.Sugar +import Data.Array.Accelerate.Array.Representation +import qualified Data.Array.Accelerate.Array.Sugar as Sugar import Data.Array.Accelerate.Error -import Data.Array.Accelerate.Product import Data.Array.Accelerate.Trafo hiding ( Delayed ) import Data.Array.Accelerate.Type import qualified Data.Array.Accelerate.AST as AST -import qualified Data.Array.Accelerate.Array.Representation as R import qualified Data.Array.Accelerate.Smart as Smart import qualified Data.Array.Accelerate.Trafo as AST @@ -93,7 +88,7 @@ import qualified Data.Array.Accelerate.Debug as D -- | Run a complete embedded array program using the reference interpreter. -- -run :: Arrays a => Smart.Acc a -> a +run :: Sugar.Arrays a => Smart.Acc a -> a run a = unsafePerformIO execute where !acc = convertAcc a @@ -101,11 +96,11 @@ run a = unsafePerformIO execute D.dumpGraph $!! acc D.dumpSimplStats res <- phase "execute" D.elapsed $ evaluate $ evalOpenAcc acc Empty - return $ toArr res + return $ Sugar.toArr $ snd res -- | This is 'runN' specialised to an array program of one argument. -- -run1 :: (Arrays a, Arrays b) => (Smart.Acc a -> Smart.Acc b) -> a -> b +run1 :: (Sugar.Arrays a, Sugar.Arrays b) => (Smart.Acc a -> Smart.Acc b) -> a -> b run1 = runN -- | Prepare and execute an embedded array program. @@ -121,8 +116,8 @@ runN f = go !go = eval (afunctionRepr @f) afun Empty -- eval :: AfunctionRepr g (AfunctionR g) (AreprFunctionR g) -> DelayedOpenAfun aenv (AreprFunctionR g) -> Val aenv -> AfunctionR g - eval (AfunctionReprLam reprF) (Alam lhs f) aenv = \a -> eval reprF f $ aenv `push` (lhs, fromArr a) - eval AfunctionReprBody (Abody b) aenv = unsafePerformIO $ phase "execute" D.elapsed (toArr <$> evaluate (evalOpenAcc b aenv)) + eval (AfunctionReprLam reprF) (Alam lhs f) aenv = \a -> eval reprF f $ aenv `push` (lhs, Sugar.fromArr a) + eval AfunctionReprBody (Abody b) aenv = unsafePerformIO $ phase "execute" D.elapsed (Sugar.toArr . snd <$> evaluate (evalOpenAcc b aenv)) eval _ _aenv _ = error "Two men say they're Jesus; one of them must be wrong" -- -- | Stream a lazily read list of input arrays through the given program, @@ -148,7 +143,8 @@ phase n fmt go = D.timed D.dump_phases (\wall cpu -> printf "phase %s: %s" n (fm -- not require an optional Manifest|Delayed data type to evaluate the program. -- data Delayed a where - Delayed :: sh + Delayed :: ArrayR (Array sh e) + -> sh -> (sh -> e) -> (Int -> e) -> Delayed (Array sh e) @@ -157,13 +153,16 @@ data Delayed a where -- Array expression evaluation -- --------------------------- -type EvalAcc acc = forall aenv a. acc aenv a -> Val aenv -> a +type WithReprs acc = (ArraysR acc, acc) + +fromFunction' :: ArrayR (Array sh e) -> sh -> (sh -> e) -> WithReprs (Array sh e) +fromFunction' repr sh f = (TupRsingle repr, fromFunction repr sh f) -- Evaluate an open array function -- evalOpenAfun :: DelayedOpenAfun aenv f -> Val aenv -> f evalOpenAfun (Alam lhs f) aenv = \a -> evalOpenAfun f $ aenv `push` (lhs, a) -evalOpenAfun (Abody b) aenv = evalOpenAcc b aenv +evalOpenAfun (Abody b) aenv = snd $ evalOpenAcc b aenv -- The core interpreter for optimised array programs @@ -172,61 +171,67 @@ evalOpenAcc :: forall aenv a. DelayedOpenAcc aenv a -> Val aenv - -> a + -> WithReprs a evalOpenAcc AST.Delayed{} _ = $internalError "evalOpenAcc" "expected manifest array" evalOpenAcc (AST.Manifest pacc) aenv = let - manifest :: forall a'. DelayedOpenAcc aenv a' -> a' + manifest :: forall a'. DelayedOpenAcc aenv a' -> WithReprs a' manifest acc = - let a' = evalOpenAcc acc aenv - repr = arraysRepr acc - in rnfArrays repr a' `seq` a' + let (repr, a') = evalOpenAcc acc aenv + in rnfArrays repr a' `seq` (repr, a') - delayed :: (Shape sh, Elt e) => DelayedOpenAcc aenv (Array sh e) -> Delayed (Array sh e) - delayed AST.Delayed{..} = Delayed (evalE extentD) (evalF indexD) (evalF linearIndexD) - delayed (manifest -> a) = Delayed (shape a) (a!) (a!!) + delayed :: DelayedOpenAcc aenv (Array sh e) -> Delayed (Array sh e) + delayed AST.Delayed{..} = Delayed reprD (evalE extentD) (evalF indexD) (evalF linearIndexD) + delayed a' = Delayed repr (shape a) ((repr, a) !) ((arrayRtype repr, a) !!) + where + (TupRsingle repr, a) = manifest a' - evalE :: DelayedExp aenv t -> t - evalE exp = evalPreExp evalOpenAcc exp aenv + evalE :: Exp aenv t -> t + evalE exp = evalExp exp aenv - evalF :: DelayedFun aenv f -> f - evalF fun = evalPreFun evalOpenAcc fun aenv + evalF :: Fun aenv f -> f + evalF fun = evalFun fun aenv - evalB :: AST.PreBoundary DelayedOpenAcc aenv t -> Boundary t - evalB bnd = evalPreBoundary evalOpenAcc bnd aenv + evalB :: AST.Boundary aenv t -> Boundary t + evalB bnd = evalBoundary bnd aenv in case pacc of - Avar (ArrayVar ix) -> prj ix aenv - Alet lhs acc1 acc2 -> evalOpenAcc acc2 $ aenv `push` (lhs, manifest acc1) - Apair acc1 acc2 -> (manifest acc1, manifest acc2) - Anil -> () - Apply afun acc -> evalOpenAfun afun aenv $ manifest acc - Aforeign _ afun acc -> evalOpenAfun afun Empty $ manifest acc + Avar (Var repr ix) -> (TupRsingle repr, prj ix aenv) + Alet lhs acc1 acc2 -> evalOpenAcc acc2 $ aenv `push` (lhs, snd $ manifest acc1) + Apair acc1 acc2 -> let + (r1, a1) = manifest acc1 + (r2, a2) = manifest acc2 + in + (TupRpair r1 r2, (a1, a2)) + Anil -> (TupRunit, ()) + Apply repr afun acc -> (repr, evalOpenAfun afun aenv $ snd $ manifest acc) + Aforeign repr _ afun acc -> (repr, evalOpenAfun afun Empty $ snd $ manifest acc) Acond p acc1 acc2 | evalE p -> manifest acc1 | otherwise -> manifest acc2 - Awhile cond body acc -> go (manifest acc) + Awhile cond body acc -> (repr, go initial) where + (repr, initial) = manifest acc p = evalOpenAfun cond aenv f = evalOpenAfun body aenv go !x - | p x ! Z = go (f x) - | otherwise = x + | (ArrayR ShapeRz (TupRsingle scalarTypeBool), p x) ! () = go (f x) + | otherwise = x - Use arr -> arr - Unit e -> unitOp (evalE e) + Use repr arr -> (TupRsingle repr, arr) + Unit tp e -> unitOp tp (evalE e) -- Collect s -> evalSeq defaultSeqConfig s aenv -- Producers -- --------- - Map f acc -> mapOp (evalF f) (delayed acc) - Generate sh f -> generateOp (evalE sh) (evalF f) - Transform sh p f acc -> transformOp (evalE sh) (evalF p) (evalF f) (delayed acc) - Backpermute sh p acc -> backpermuteOp (evalE sh) (evalF p) (delayed acc) - Reshape sh acc -> reshapeOp (evalE sh) (manifest acc) + Map tp f acc -> mapOp tp (evalF f) (delayed acc) + Generate repr sh f -> generateOp repr (evalE sh) (evalF f) + Transform repr sh p f acc -> transformOp repr (evalE sh) (evalF p) (evalF f) (delayed acc) + Backpermute shr sh p acc -> backpermuteOp shr (evalE sh) (evalF p) (delayed acc) + Reshape shr sh acc -> reshapeOp shr (evalE sh) (manifest acc) - ZipWith f acc1 acc2 -> zipWithOp (evalF f) (delayed acc1) (delayed acc2) + ZipWith tp f acc1 acc2 -> zipWithOp tp (evalF f) (delayed acc1) (delayed acc2) Replicate slice slix acc -> replicateOp slice (evalE slix) (manifest acc) Slice slice acc slix -> sliceOp slice (manifest acc) (evalE slix) @@ -234,8 +239,8 @@ evalOpenAcc (AST.Manifest pacc) aenv = -- --------- Fold f z acc -> foldOp (evalF f) (evalE z) (delayed acc) Fold1 f acc -> fold1Op (evalF f) (delayed acc) - FoldSeg f z acc seg -> foldSegOp (evalF f) (evalE z) (delayed acc) (delayed seg) - Fold1Seg f acc seg -> fold1SegOp (evalF f) (delayed acc) (delayed seg) + FoldSeg i f z acc seg -> foldSegOp i (evalF f) (evalE z) (delayed acc) (delayed seg) + Fold1Seg i f acc seg -> fold1SegOp i (evalF f) (delayed acc) (delayed seg) Scanl f z acc -> scanlOp (evalF f) (evalE z) (delayed acc) Scanl' f z acc -> scanl'Op (evalF f) (evalE z) (delayed acc) Scanl1 f acc -> scanl1Op (evalF f) (delayed acc) @@ -243,55 +248,58 @@ evalOpenAcc (AST.Manifest pacc) aenv = Scanr' f z acc -> scanr'Op (evalF f) (evalE z) (delayed acc) Scanr1 f acc -> scanr1Op (evalF f) (delayed acc) Permute f def p acc -> permuteOp (evalF f) (manifest def) (evalF p) (delayed acc) - Stencil sten b acc -> stencilOp (evalF sten) (evalB b) (delayed acc) - Stencil2 sten b1 a1 b2 a2 -> stencil2Op (evalF sten) (evalB b1) (delayed a1) (evalB b2) (delayed a2) + Stencil s tp sten b acc -> stencilOp s tp (evalF sten) (evalB b) (delayed acc) + Stencil2 s1 s2 tp sten b1 a1 b2 a2 + -> stencil2Op s1 s2 tp (evalF sten) (evalB b1) (delayed a1) (evalB b2) (delayed a2) -- Array primitives -- ---------------- -unitOp :: Elt e => e -> Scalar e -unitOp e = fromFunction Z (const e) +unitOp :: TupleType e -> e -> WithReprs (Scalar e) +unitOp tp e = fromFunction' (ArrayR ShapeRz tp) () (const e) generateOp - :: (Shape sh, Elt e) - => sh + :: ArrayR (Array sh e) + -> sh -> (sh -> e) - -> Array sh e -generateOp = fromFunction + -> WithReprs (Array sh e) +generateOp = fromFunction' transformOp - :: (Shape sh', Elt b) - => sh' + :: ArrayR (Array sh' b) + -> sh' -> (sh' -> sh) -> (a -> b) -> Delayed (Array sh a) - -> Array sh' b -transformOp sh' p f (Delayed _ xs _) - = fromFunction sh' (\ix -> f (xs $ p ix)) + -> WithReprs (Array sh' b) +transformOp repr sh' p f (Delayed _ _ xs _) + = fromFunction' repr sh' (\ix -> f (xs $ p ix)) reshapeOp - :: (Shape sh, Shape sh') - => sh - -> Array sh' e - -> Array sh e -reshapeOp newShape arr@(Array _ adata) - = $boundsCheck "reshape" "shape mismatch" (size newShape == size (shape arr)) - $ Array (fromElt newShape) adata + :: ShapeR sh + -> sh + -> WithReprs (Array sh' e) + -> WithReprs (Array sh e) +reshapeOp newShapeR newShape (TupRsingle (ArrayR shr tp), (Array sh adata)) + = $boundsCheck "reshape" "shape mismatch" (size newShapeR newShape == size shr sh) + ( TupRsingle (ArrayR newShapeR tp) + , Array newShape adata + ) replicateOp - :: (Shape sh, Shape sl, Elt slix, Elt e) - => SliceIndex (EltRepr slix) (EltRepr sl) co (EltRepr sh) + :: SliceIndex slix sl co sh -> slix - -> Array sl e - -> Array sh e -replicateOp slice slix arr - = fromFunction (toElt sh) (\ix -> arr ! liftToElt pf ix) + -> WithReprs (Array sl e) + -> WithReprs (Array sh e) +replicateOp slice slix (TupRsingle repr@(ArrayR _ tp), arr) + = fromFunction' repr' sh (\ix -> (repr, arr) ! pf ix) where - (sh, pf) = extend slice (fromElt slix) (fromElt (shape arr)) + repr' = ArrayR (sliceDomainR slice) tp + (sh, pf) = extend slice slix (shape arr) extend :: SliceIndex slix sl co dim -> slix @@ -308,15 +316,15 @@ replicateOp slice slix arr sliceOp - :: (Shape sh, Shape sl, Elt slix, Elt e) - => SliceIndex (EltRepr slix) (EltRepr sl) co (EltRepr sh) - -> Array sh e + :: SliceIndex slix sl co sh + -> WithReprs (Array sh e) -> slix - -> Array sl e -sliceOp slice arr slix - = fromFunction (toElt sh') (\ix -> arr ! liftToElt pf ix) + -> WithReprs (Array sl e) +sliceOp slice (TupRsingle repr@(ArrayR _ tp), arr) slix + = fromFunction' repr' sh' (\ix -> (repr, arr) ! pf ix) where - (sh', pf) = restrict slice (fromElt slix) (fromElt (shape arr)) + repr' = ArrayR (sliceShapeR slice) tp + (sh', pf) = restrict slice slix (shape arr) restrict :: SliceIndex slix sl co sh -> slix @@ -332,22 +340,22 @@ sliceOp slice arr slix in $indexCheck "slice" i sz $ (sl', \ix -> (f' ix, i)) -mapOp :: (Shape sh, Elt b) - => (a -> b) - -> Delayed (Array sh a) - -> Array sh b -mapOp f (Delayed sh xs _) - = fromFunction sh (\ix -> f (xs ix)) +mapOp :: TupleType b + -> (a -> b) + -> Delayed (Array sh a) + -> WithReprs (Array sh b) +mapOp tp f (Delayed (ArrayR shr _) sh xs _) + = fromFunction' (ArrayR shr tp) sh (\ix -> f (xs ix)) zipWithOp - :: (Shape sh, Elt c) - => (a -> b -> c) - -> Delayed (Array sh a) - -> Delayed (Array sh b) - -> Array sh c -zipWithOp f (Delayed shx xs _) (Delayed shy ys _) - = fromFunction (shx `intersect` shy) (\ix -> f (xs ix) (ys ix)) + :: TupleType c + -> (a -> b -> c) + -> Delayed (Array sh a) + -> Delayed (Array sh b) + -> WithReprs (Array sh c) +zipWithOp tp f (Delayed (ArrayR shr _) shx xs _) (Delayed _ shy ys _) + = fromFunction' (ArrayR shr tp) (intersect shr shx shy) (\ix -> f (xs ix) (ys ix)) -- zipWith'Op -- :: (Shape sh, Elt a) @@ -356,7 +364,7 @@ zipWithOp f (Delayed shx xs _) (Delayed shy ys _) -- -> Delayed (Array sh a) -- -> Array sh a -- zipWith'Op f (Delayed shx xs _) (Delayed shy ys _) --- = fromFunction (shx `union` shy) (\ix -> if ix `outside` shx +-- = fromFunction' (shx `union` shy) (\ix -> if ix `outside` shx -- then ys ix -- else if ix `outside` shy -- then xs ix @@ -366,502 +374,470 @@ zipWithOp f (Delayed shx xs _) (Delayed shy ys _) foldOp - :: (Shape sh, Elt e) - => (e -> e -> e) + :: (e -> e -> e) -> e - -> Delayed (Array (sh :. Int) e) - -> Array sh e -foldOp f z (Delayed (sh :. n) arr _) - = fromFunction sh (\ix -> iter (Z:.n) (\(Z:.i) -> arr (ix :. i)) f z) + -> Delayed (Array (sh, Int) e) + -> WithReprs (Array sh e) +foldOp f z (Delayed (ArrayR (ShapeRsnoc shr) tp) (sh, n) arr _) + = fromFunction' (ArrayR shr tp) sh (\ix -> iter (ShapeRsnoc ShapeRz) ((), n) (\((), i) -> arr (ix, i)) f z) fold1Op - :: (Shape sh, Elt e) - => (e -> e -> e) - -> Delayed (Array (sh :. Int) e) - -> Array sh e -fold1Op f (Delayed (sh :. n) arr _) + :: (e -> e -> e) + -> Delayed (Array (sh, Int) e) + -> WithReprs (Array sh e) +fold1Op f (Delayed (ArrayR (ShapeRsnoc shr) tp) (sh, n) arr _) = $boundsCheck "fold1" "empty array" (n > 0) - $ fromFunction sh (\ix -> iter1 (Z:.n) (\(Z:.i) -> arr (ix :. i)) f) + $ fromFunction' (ArrayR shr tp) sh (\ix -> iter1 (ShapeRsnoc ShapeRz) ((), n) (\((), i) -> arr (ix, i)) f) foldSegOp - :: forall sh e i. (Shape sh, Elt e, Elt i, IsIntegral i) - => (e -> e -> e) + :: IntegralType i + -> (e -> e -> e) -> e - -> Delayed (Array (sh :. Int) e) + -> Delayed (Array (sh, Int) e) -> Delayed (Segments i) - -> Array (sh :. Int) e -foldSegOp f z (Delayed (sh :. _) arr _) (Delayed (Z :. n) _ seg) - | IntegralDict <- integralDict (integralType :: IntegralType i) + -> WithReprs (Array (sh, Int) e) +foldSegOp itp f z (Delayed repr (sh, _) arr _) (Delayed _ ((), n) _ seg) + | IntegralDict <- integralDict itp = $boundsCheck "foldSeg" "empty segment descriptor" (n > 0) - $ fromFunction (sh :. n-1) - $ \(sz :. ix) -> let start = fromIntegral $ seg ix + $ fromFunction' repr (sh, n-1) + $ \(sz, ix) -> let start = fromIntegral $ seg ix end = fromIntegral $ seg (ix+1) in $boundsCheck "foldSeg" "empty segment" (end >= start) - $ iter (Z :. end-start) (\(Z:.i) -> arr (sz :. start+i)) f z + $ iter (ShapeRsnoc ShapeRz) ((), end-start) (\((), i) -> arr (sz, start+i)) f z fold1SegOp - :: forall sh e i. (Shape sh, Elt e, Elt i, IsIntegral i) - => (e -> e -> e) - -> Delayed (Array (sh :. Int) e) + :: IntegralType i + -> (e -> e -> e) + -> Delayed (Array (sh, Int) e) -> Delayed (Segments i) - -> Array (sh :. Int) e -fold1SegOp f (Delayed (sh :. _) arr _) (Delayed (Z :. n) _ seg) - | IntegralDict <- integralDict (integralType :: IntegralType i) + -> WithReprs (Array (sh, Int) e) +fold1SegOp itp f (Delayed repr (sh, _) arr _) (Delayed _ ((), n) _ seg) + | IntegralDict <- integralDict itp = $boundsCheck "foldSeg" "empty segment descriptor" (n > 0) - $ fromFunction (sh :. n-1) - $ \(sz :. ix) -> let start = fromIntegral $ seg ix + $ fromFunction' repr (sh, n-1) + $ \(sz, ix) -> let start = fromIntegral $ seg ix end = fromIntegral $ seg (ix+1) in $boundsCheck "fold1Seg" "empty segment" (end > start) - $ iter1 (Z :. end-start) (\(Z:.i) -> arr (sz :. start+i)) f + $ iter1 (ShapeRsnoc ShapeRz) ((), end-start) (\((), i) -> arr (sz, start+i)) f scanl1Op - :: (Shape sh, Elt e) - => (e -> e -> e) - -> Delayed (Array (sh:.Int) e) - -> Array (sh:.Int) e -scanl1Op f (Delayed sh@(_ :. n) ain _) + :: forall sh e. + (e -> e -> e) + -> Delayed (Array (sh, Int) e) + -> WithReprs (Array (sh, Int) e) +scanl1Op f (Delayed (ArrayR shr tp) sh@(_, n) ain _) = $boundsCheck "scanl1" "empty array" (n > 0) - $ adata `seq` Array (fromElt sh) adata + ( TupRsingle $ ArrayR shr tp + , adata `seq` Array sh adata + ) where - f' = sinkFromElt2 f -- - (adata, _) = runArrayData $ do - aout <- newArrayData (size sh) + (adata, _) = runArrayData @e $ do + aout <- newArrayData tp (size shr sh) - let write (sz:.0) = unsafeWriteArrayData aout (toIndex sh (sz:.0)) (fromElt (ain (sz:.0))) - write (sz:.i) = do - x <- unsafeReadArrayData aout (toIndex sh (sz:.i-1)) - y <- return $ fromElt (ain (sz:.i)) - unsafeWriteArrayData aout (toIndex sh (sz:.i)) (f' x y) + let write (sz, 0) = unsafeWriteArrayData tp aout (toIndex shr sh (sz, 0)) (ain (sz, 0)) + write (sz, i) = do + x <- unsafeReadArrayData tp aout (toIndex shr sh (sz, i-1)) + let y = ain (sz, i) + unsafeWriteArrayData tp aout (toIndex shr sh (sz, i)) (f x y) - iter sh write (>>) (return ()) + iter shr sh write (>>) (return ()) return (aout, undefined) scanlOp - :: (Shape sh, Elt e) - => (e -> e -> e) + :: forall sh e. + (e -> e -> e) -> e - -> Delayed (Array (sh:.Int) e) - -> Array (sh:.Int) e -scanlOp f z (Delayed (sh :. n) ain _) - = adata `seq` Array (fromElt sh') adata + -> Delayed (Array (sh, Int) e) + -> WithReprs (Array (sh, Int) e) +scanlOp f z (Delayed (ArrayR shr tp) (sh, n) ain _) + = ( TupRsingle $ ArrayR shr tp + , adata `seq` Array sh' adata + ) where - sh' = sh :. n+1 - f' = sinkFromElt2 f + sh' = (sh, n+1) -- - (adata, _) = runArrayData $ do - aout <- newArrayData (size sh') + (adata, _) = runArrayData @e $ do + aout <- newArrayData tp (size shr sh') - let write (sz:.0) = unsafeWriteArrayData aout (toIndex sh' (sz:.0)) (fromElt z) - write (sz:.i) = do - x <- unsafeReadArrayData aout (toIndex sh' (sz:.i-1)) - y <- return $ fromElt (ain (sz:.i-1)) - unsafeWriteArrayData aout (toIndex sh' (sz:.i)) (f' x y) + let write (sz, 0) = unsafeWriteArrayData tp aout (toIndex shr sh' (sz, 0)) z + write (sz, i) = do + x <- unsafeReadArrayData tp aout (toIndex shr sh' (sz, i-1)) + let y = ain (sz, i-1) + unsafeWriteArrayData tp aout (toIndex shr sh' (sz, i)) (f x y) - iter sh' write (>>) (return ()) + iter shr sh' write (>>) (return ()) return (aout, undefined) scanl'Op - :: (Shape sh, Elt e) - => (e -> e -> e) + :: forall sh e. + (e -> e -> e) -> e - -> Delayed (Array (sh:.Int) e) - -> ArrRepr (Array (sh:.Int) e, Array sh e) -scanl'Op f z (Delayed (sh :. n) ain _) - = aout `seq` asum `seq` ( ( (), Array (fromElt (sh:.n)) aout ) - , Array (fromElt sh) asum ) + -> Delayed (Array (sh, Int) e) + -> WithReprs (Array (sh, Int) e, Array sh e) +scanl'Op f z (Delayed (ArrayR shr@(ShapeRsnoc shr') tp) (sh, n) ain _) + = ( TupRsingle (ArrayR shr tp) `TupRpair` TupRsingle (ArrayR shr' tp) + , aout `seq` asum `seq` ( Array (sh, n) aout, Array sh asum ) + ) where - f' = sinkFromElt2 f - -- - (AD_Pair aout asum, _) = runArrayData $ do - aout <- newArrayData (size (sh:.n)) - asum <- newArrayData (size sh) - - let write (sz:.0) - | n == 0 = unsafeWriteArrayData asum (toIndex sh sz) (fromElt z) - | otherwise = unsafeWriteArrayData aout (toIndex (sh:.n) (sz:.0)) (fromElt z) - write (sz:.i) = do - x <- unsafeReadArrayData aout (toIndex (sh:.n) (sz:.i-1)) - y <- return $ fromElt (ain (sz:.i-1)) + ((aout, asum), _) = runArrayData @(e, e) $ do + aout <- newArrayData tp (size shr (sh, n)) + asum <- newArrayData tp (size shr' sh) + + let write (sz, 0) + | n == 0 = unsafeWriteArrayData tp asum (toIndex shr' sh sz) z + | otherwise = unsafeWriteArrayData tp aout (toIndex shr (sh, n) (sz, 0)) z + write (sz, i) = do + x <- unsafeReadArrayData tp aout (toIndex shr (sh, n) (sz, i-1)) + let y = ain (sz, i-1) if i == n - then unsafeWriteArrayData asum (toIndex sh sz) (f' x y) - else unsafeWriteArrayData aout (toIndex (sh:.n) (sz:.i)) (f' x y) + then unsafeWriteArrayData tp asum (toIndex shr' sh sz) (f x y) + else unsafeWriteArrayData tp aout (toIndex shr (sh, n) (sz, i)) (f x y) - iter (sh:.n+1) write (>>) (return ()) - return (AD_Pair aout asum, undefined) + iter shr (sh, n+1) write (>>) (return ()) + return ((aout, asum), undefined) scanrOp - :: (Shape sh, Elt e) - => (e -> e -> e) + :: forall sh e. + (e -> e -> e) -> e - -> Delayed (Array (sh:.Int) e) - -> Array (sh:.Int) e -scanrOp f z (Delayed (sz :. n) ain _) - = adata `seq` Array (fromElt sh') adata + -> Delayed (Array (sh, Int) e) + -> WithReprs (Array (sh, Int) e) +scanrOp f z (Delayed (ArrayR shr tp) (sz, n) ain _) + = ( TupRsingle (ArrayR shr tp) + , adata `seq` Array sh' adata + ) where - sh' = sz :. n+1 - f' = sinkFromElt2 f + sh' = (sz, n+1) -- - (adata, _) = runArrayData $ do - aout <- newArrayData (size sh') + (adata, _) = runArrayData @e $ do + aout <- newArrayData tp (size shr sh') - let write (sz:.0) = unsafeWriteArrayData aout (toIndex sh' (sz:.n)) (fromElt z) - write (sz:.i) = do - x <- return $ fromElt (ain (sz:.n-i)) - y <- unsafeReadArrayData aout (toIndex sh' (sz:.n-i+1)) - unsafeWriteArrayData aout (toIndex sh' (sz:.n-i)) (f' x y) + let write (sz, 0) = unsafeWriteArrayData tp aout (toIndex shr sh' (sz, n)) z + write (sz, i) = do + let x = ain (sz, n-i) + y <- unsafeReadArrayData tp aout (toIndex shr sh' (sz, n-i+1)) + unsafeWriteArrayData tp aout (toIndex shr sh' (sz, n-i)) (f x y) - iter sh' write (>>) (return ()) + iter shr sh' write (>>) (return ()) return (aout, undefined) scanr1Op - :: (Shape sh, Elt e) - => (e -> e -> e) - -> Delayed (Array (sh:.Int) e) - -> Array (sh:.Int) e -scanr1Op f (Delayed sh@(_ :. n) ain _) + :: forall sh e. + (e -> e -> e) + -> Delayed (Array (sh, Int) e) + -> WithReprs (Array (sh, Int) e) +scanr1Op f (Delayed (ArrayR shr tp) sh@(_, n) ain _) = $boundsCheck "scanr1" "empty array" (n > 0) - $ adata `seq` Array (fromElt sh) adata + ( TupRsingle $ ArrayR shr tp + , adata `seq` Array sh adata + ) where - f' = sinkFromElt2 f - -- - (adata, _) = runArrayData $ do - aout <- newArrayData (size sh) + (adata, _) = runArrayData @e $ do + aout <- newArrayData tp (size shr sh) - let write (sz:.0) = unsafeWriteArrayData aout (toIndex sh (sz:.n-1)) (fromElt (ain (sz:.n-1))) - write (sz:.i) = do - x <- return $ fromElt (ain (sz:.n-i-1)) - y <- unsafeReadArrayData aout (toIndex sh (sz:.n-i)) - unsafeWriteArrayData aout (toIndex sh (sz:.n-i-1)) (f' x y) + let write (sz, 0) = unsafeWriteArrayData tp aout (toIndex shr sh (sz, n-1)) (ain (sz, n-1)) + write (sz, i) = do + let x = ain (sz, n-i-1) + y <- unsafeReadArrayData tp aout (toIndex shr sh (sz, n-i)) + unsafeWriteArrayData tp aout (toIndex shr sh (sz, n-i-1)) (f x y) - iter sh write (>>) (return ()) + iter shr sh write (>>) (return ()) return (aout, undefined) scanr'Op - :: forall sh e. (Shape sh, Elt e) - => (e -> e -> e) + :: forall sh e. + (e -> e -> e) -> e - -> Delayed (Array (sh:.Int) e) - -> ArrRepr (Array (sh:.Int) e, Array sh e) -scanr'Op f z (Delayed (sh :. n) ain _) - = aout `seq` asum `seq` ( ((), Array (fromElt (sh:.n)) aout ) - , Array (fromElt sh) asum ) + -> Delayed (Array (sh, Int) e) + -> WithReprs (Array (sh, Int) e, Array sh e) +scanr'Op f z (Delayed (ArrayR shr@(ShapeRsnoc shr') tp) (sh, n) ain _) + = ( TupRsingle (ArrayR shr tp) `TupRpair` TupRsingle (ArrayR shr' tp) + , aout `seq` asum `seq` ( Array (sh, n) aout, Array sh asum ) + ) where - f' = sinkFromElt2 f - -- - (AD_Pair aout asum, _) = runArrayData $ do - aout <- newArrayData (size (sh:.n)) - asum <- newArrayData (size sh) + ((aout, asum), _) = runArrayData @(e, e) $ do + aout <- newArrayData tp (size shr (sh, n)) + asum <- newArrayData tp (size shr' sh) - let write (sz:.0) - | n == 0 = unsafeWriteArrayData asum (toIndex sh sz) (fromElt z) - | otherwise = unsafeWriteArrayData aout (toIndex (sh:.n) (sz:.n-1)) (fromElt z) + let write (sz, 0) + | n == 0 = unsafeWriteArrayData tp asum (toIndex shr' sh sz) z + | otherwise = unsafeWriteArrayData tp aout (toIndex shr (sh, n) (sz, n-1)) z - write (sz:.i) = do - x <- return $ fromElt (ain (sz:.n-i)) - y <- unsafeReadArrayData aout (toIndex (sh:.n) (sz:.n-i)) + write (sz, i) = do + let x = ain (sz, n-i) + y <- unsafeReadArrayData tp aout (toIndex shr (sh, n) (sz, n-i)) if i == n - then unsafeWriteArrayData asum (toIndex sh sz) (f' x y) - else unsafeWriteArrayData aout (toIndex (sh:.n) (sz:.n-i-1)) (f' x y) + then unsafeWriteArrayData tp asum (toIndex shr' sh sz) (f x y) + else unsafeWriteArrayData tp aout (toIndex shr (sh, n) (sz, n-i-1)) (f x y) - iter (sh:.n+1) write (>>) (return ()) - return (AD_Pair aout asum, undefined) + iter shr (sh, n+1) write (>>) (return ()) + return ((aout, asum), undefined) permuteOp - :: (Shape sh, Shape sh', Elt e) - => (e -> e -> e) - -> Array sh' e + :: forall sh sh' e. + (e -> e -> e) + -> WithReprs (Array sh' e) -> (sh -> sh') - -> Delayed (Array sh e) - -> Array sh' e -permuteOp f def@(Array _ adef) p (Delayed sh _ ain) - = adata `seq` Array (fromElt sh') adata + -> Delayed (Array sh e) + -> WithReprs (Array sh' e) +permuteOp f (TupRsingle (ArrayR shr' _), def@(Array _ adef)) p (Delayed (ArrayR shr tp) sh _ ain) + = (TupRsingle $ ArrayR shr' tp, adata `seq` Array sh' adata) where sh' = shape def - n' = size sh' - f' = sinkFromElt2 f + n' = size shr' sh' + + ignore' = ignore shr' -- - (adata, _) = runArrayData $ do - aout <- newArrayData n' + (adata, _) = runArrayData @e $ do + aout <- newArrayData tp n' let -- initialise array with default values init i | i >= n' = return () | otherwise = do - x <- unsafeReadArrayData adef i - unsafeWriteArrayData aout i x + x <- unsafeReadArrayData tp adef i + unsafeWriteArrayData tp aout i x init (i+1) -- project each element onto the destination array and update update src = let dst = p src - i = toIndex sh src - j = toIndex sh' dst + i = toIndex shr sh src + j = toIndex shr' sh' dst in - unless (fromElt dst == R.ignore) $ do - x <- return . fromElt $ ain i - y <- unsafeReadArrayData aout j - unsafeWriteArrayData aout j (f' x y) + unless (shapeEq shr' dst ignore') $ do + let x = ain i + y <- unsafeReadArrayData tp aout j + unsafeWriteArrayData tp aout j (f x y) init 0 - iter sh update (>>) (return ()) + iter shr sh update (>>) (return ()) return (aout, undefined) backpermuteOp - :: (Shape sh', Elt e) - => sh' + :: ShapeR sh' + -> sh' -> (sh' -> sh) -> Delayed (Array sh e) - -> Array sh' e -backpermuteOp sh' p (Delayed _ arr _) - = fromFunction sh' (\ix -> arr $ p ix) + -> WithReprs (Array sh' e) +backpermuteOp shr sh' p (Delayed (ArrayR _ tp) _ arr _) + = fromFunction' (ArrayR shr tp) sh' (\ix -> arr $ p ix) stencilOp - :: (Stencil sh a stencil, Elt b) - => (stencil -> b) + :: StencilR sh a stencil + -> TupleType b + -> (stencil -> b) -> Boundary (Array sh a) -> Delayed (Array sh a) - -> Array sh b -stencilOp stencil bnd arr@(Delayed sh _ _) - = fromFunction sh - $ stencil . stencilAccess (bounded bnd arr) + -> WithReprs (Array sh b) +stencilOp stencil tp f bnd arr@(Delayed _ sh _ _) + = fromFunction' (ArrayR shr tp) sh + $ f . stencilAccess stencil (bounded shr bnd arr) + where + shr = stencilShape stencil stencil2Op - :: (Stencil sh a stencil1, Stencil sh b stencil2, Elt c) - => (stencil1 -> stencil2 -> c) + :: StencilR sh a stencil1 + -> StencilR sh b stencil2 + -> TupleType c + -> (stencil1 -> stencil2 -> c) -> Boundary (Array sh a) -> Delayed (Array sh a) -> Boundary (Array sh b) -> Delayed (Array sh b) - -> Array sh c -stencil2Op stencil bnd1 arr1@(Delayed sh1 _ _) bnd2 arr2@(Delayed sh2 _ _) - = fromFunction (sh1 `intersect` sh2) f + -> WithReprs (Array sh c) +stencil2Op s1 s2 tp stencil bnd1 arr1@(Delayed _ sh1 _ _) bnd2 arr2@(Delayed _ sh2 _ _) + = fromFunction' (ArrayR shr tp) (intersect shr sh1 sh2) f where - f ix = stencil (stencilAccess (bounded bnd1 arr1) ix) - (stencilAccess (bounded bnd2 arr2) ix) + f ix = stencil (stencilAccess s1 (bounded shr bnd1 arr1) ix) + (stencilAccess s2 (bounded shr bnd2 arr2) ix) + shr = stencilShape s1 stencilAccess - :: Stencil sh e stencil - => (sh -> e) + :: StencilR sh e stencil + -> (sh -> e) -> sh -> stencil -stencilAccess = goR stencil +stencilAccess stencil = goR (stencilShape stencil) stencil where -- Base cases, nothing interesting to do here since we know the lower -- dimension is Z. -- - goR :: StencilR sh e stencil -> (sh -> e) -> sh -> stencil - goR StencilRunit3 rf ix = + goR :: ShapeR sh -> StencilR sh e stencil -> (sh -> e) -> sh -> stencil + goR _ (StencilRunit3 _) rf ix = let - z :. i = ix - rf' d = rf (z :. i+d) + (z, i) = ix + rf' d = rf (z, i+d) in - ( rf' (-1) - , rf' 0 - , rf' 1 - ) - - goR StencilRunit5 rf ix = - let z :. i = ix - rf' d = rf (z :. i+d) + ((( () + , rf' (-1)) + , rf' 0 ) + , rf' 1 ) + + goR _ (StencilRunit5 _) rf ix = + let (z, i) = ix + rf' d = rf (z, i+d) in - ( rf' (-2) - , rf' (-1) - , rf' 0 - , rf' 1 - , rf' 2 - ) - - goR StencilRunit7 rf ix = - let z :. i = ix - rf' d = rf (z :. i+d) + ((((( () + , rf' (-2)) + , rf' (-1)) + , rf' 0 ) + , rf' 1 ) + , rf' 2 ) + + goR _ (StencilRunit7 _) rf ix = + let (z, i) = ix + rf' d = rf (z, i+d) in - ( rf' (-3) - , rf' (-2) - , rf' (-1) - , rf' 0 - , rf' 1 - , rf' 2 - , rf' 3 - ) - - goR StencilRunit9 rf ix = - let z :. i = ix - rf' d = rf (z :. i+d) + ((((((( () + , rf' (-3)) + , rf' (-2)) + , rf' (-1)) + , rf' 0 ) + , rf' 1 ) + , rf' 2 ) + , rf' 3 ) + + goR _ (StencilRunit9 _) rf ix = + let (z, i) = ix + rf' d = rf (z, i+d) in - ( rf' (-4) - , rf' (-3) - , rf' (-2) - , rf' (-1) - , rf' 0 - , rf' 1 - , rf' 2 - , rf' 3 - , rf' 4 - ) + ((((((((( () + , rf' (-4)) + , rf' (-3)) + , rf' (-2)) + , rf' (-1)) + , rf' 0 ) + , rf' 1 ) + , rf' 2 ) + , rf' 3 ) + , rf' 4 ) -- Recursive cases. Note that because the stencil pattern is defined with -- cons ordering, whereas shapes (and indices) are defined as a snoc-list, -- when we recurse on the stencil structure we must manipulate the -- _left-most_ index component. -- - goR (StencilRtup3 s1 s2 s3) rf ix = - let (i, ix') = uncons ix - rf' d ds = rf (cons (i+d) ds) + goR (ShapeRsnoc shr) (StencilRtup3 s1 s2 s3) rf ix = + let (i, ix') = uncons shr ix + rf' d ds = rf (cons shr (i+d) ds) in - ( goR s1 (rf' (-1)) ix' - , goR s2 (rf' 0) ix' - , goR s3 (rf' 1) ix' - ) - - goR (StencilRtup5 s1 s2 s3 s4 s5) rf ix = - let (i, ix') = uncons ix - rf' d ds = rf (cons (i+d) ds) + ((( () + , goR shr s1 (rf' (-1)) ix') + , goR shr s2 (rf' 0) ix') + , goR shr s3 (rf' 1) ix') + + goR (ShapeRsnoc shr) (StencilRtup5 s1 s2 s3 s4 s5) rf ix = + let (i, ix') = uncons shr ix + rf' d ds = rf (cons shr (i+d) ds) in - ( goR s1 (rf' (-2)) ix' - , goR s2 (rf' (-1)) ix' - , goR s3 (rf' 0) ix' - , goR s4 (rf' 1) ix' - , goR s5 (rf' 2) ix' - ) - - goR (StencilRtup7 s1 s2 s3 s4 s5 s6 s7) rf ix = - let (i, ix') = uncons ix - rf' d ds = rf (cons (i+d) ds) + ((((( () + , goR shr s1 (rf' (-2)) ix') + , goR shr s2 (rf' (-1)) ix') + , goR shr s3 (rf' 0) ix') + , goR shr s4 (rf' 1) ix') + , goR shr s5 (rf' 2) ix') + + goR (ShapeRsnoc shr) (StencilRtup7 s1 s2 s3 s4 s5 s6 s7) rf ix = + let (i, ix') = uncons shr ix + rf' d ds = rf (cons shr (i+d) ds) in - ( goR s1 (rf' (-3)) ix' - , goR s2 (rf' (-2)) ix' - , goR s3 (rf' (-1)) ix' - , goR s4 (rf' 0) ix' - , goR s5 (rf' 1) ix' - , goR s6 (rf' 2) ix' - , goR s7 (rf' 3) ix' - ) - - goR (StencilRtup9 s1 s2 s3 s4 s5 s6 s7 s8 s9) rf ix = - let (i, ix') = uncons ix - rf' d ds = rf (cons (i+d) ds) + ((((((( () + , goR shr s1 (rf' (-3)) ix') + , goR shr s2 (rf' (-2)) ix') + , goR shr s3 (rf' (-1)) ix') + , goR shr s4 (rf' 0) ix') + , goR shr s5 (rf' 1) ix') + , goR shr s6 (rf' 2) ix') + , goR shr s7 (rf' 3) ix') + + goR (ShapeRsnoc shr) (StencilRtup9 s1 s2 s3 s4 s5 s6 s7 s8 s9) rf ix = + let (i, ix') = uncons shr ix + rf' d ds = rf (cons shr (i+d) ds) in - ( goR s1 (rf' (-4)) ix' - , goR s2 (rf' (-3)) ix' - , goR s3 (rf' (-2)) ix' - , goR s4 (rf' (-1)) ix' - , goR s5 (rf' 0) ix' - , goR s6 (rf' 1) ix' - , goR s7 (rf' 2) ix' - , goR s8 (rf' 3) ix' - , goR s9 (rf' 4) ix' - ) + ((((((((( () + , goR shr s1 (rf' (-4)) ix') + , goR shr s2 (rf' (-3)) ix') + , goR shr s3 (rf' (-2)) ix') + , goR shr s4 (rf' (-1)) ix') + , goR shr s5 (rf' 0) ix') + , goR shr s6 (rf' 1) ix') + , goR shr s7 (rf' 2) ix') + , goR shr s8 (rf' 3) ix') + , goR shr s9 (rf' 4) ix') -- Add a left-most component to an index -- - cons :: forall sh. Shape sh => Int -> sh -> (sh :. Int) - cons ix extent = toElt $ go (eltType @sh) (fromElt extent) - where - go :: TupleType t -> t -> (t, Int) - go TypeRunit () = ((), ix) - go (TypeRpair th tz) (sh, sz) - | TypeRscalar t <- tz - , Just Refl <- matchScalarType t (scalarType :: ScalarType Int) - = (go th sh, sz) - go _ _ - = $internalError "cons" "expected index with Int components" + cons :: ShapeR sh -> Int -> sh -> (sh, Int) + cons ShapeRz ix () = ((), ix) + cons (ShapeRsnoc shr) ix (sh, sz) = (cons shr ix sh, sz) -- Remove the left-most index of an index, and return the remainder -- - uncons :: forall sh. Shape sh => sh :. Int -> (Int, sh) - uncons extent = let (i,ix) = go (eltType @(sh:.Int)) (fromElt extent) - in (i, toElt ix) - where - go :: TupleType (t, Int) -> (t, Int) -> (Int, t) - go (TypeRpair TypeRunit _) ((), v) = (v, ()) - go (TypeRpair t1@(TypeRpair _ t2) _) (v1,v3) - | TypeRscalar t <- t2 - , Just Refl <- matchScalarType t (scalarType :: ScalarType Int) - = let (i, v1') = go t1 v1 - in (i, (v1', v3)) - go _ _ - = $internalError "uncons" "expected index with Int components" + uncons :: ShapeR sh -> (sh, Int) -> (Int, sh) + uncons ShapeRz ((), v) = (v, ()) + uncons (ShapeRsnoc shr) (v1, v2) = let (i, v1') = uncons shr v1 + in (i, (v1', v2)) bounded - :: (Shape sh, Elt e) - => Boundary (Array sh e) + :: ShapeR sh + -> Boundary (Array sh e) -> Delayed (Array sh e) -> sh -> e -bounded bnd (Delayed sh f _) ix = - if inside sh ix +bounded shr bnd (Delayed _ sh f _) ix = + if inside shr sh ix then f ix else case bnd of Function g -> g ix - Constant v -> toElt v - _ -> f (bound sh ix) + Constant v -> v + _ -> f (bound shr sh ix) where -- Whether the index (second argument) is inside the bounds of the given -- shape (first argument). -- - inside :: forall sh. Shape sh => sh -> sh -> Bool - inside sh1 ix1 = go (eltType @sh) (fromElt sh1) (fromElt ix1) - where - go :: TupleType t -> t -> t -> Bool - go TypeRunit () () = True - go (TypeRpair tsh ti) (sh, sz) (ih,iz) - = if go ti sz iz - then go tsh sh ih - else False - go (TypeRscalar t) sz iz - | Just Refl <- matchScalarType t (scalarType :: ScalarType Int) - = if iz < 0 || iz >= sz - then False - else True - -- - | otherwise - = $internalError "inside" "expected index with Int components" + inside :: ShapeR sh -> sh -> sh -> Bool + inside ShapeRz () () = True + inside (ShapeRsnoc shr) (sh, sz) (ih, iz) = iz >= 0 && iz < sz && inside shr sh ih -- Return the index (second argument), updated to obey the given boundary -- conditions when outside the bounds of the given shape (first argument) -- - bound :: forall sh. Shape sh => sh -> sh -> sh - bound sh1 ix1 = toElt $ go (eltType @sh) (fromElt sh1) (fromElt ix1) + bound :: ShapeR sh -> sh -> sh -> sh + bound ShapeRz () () = () + bound (ShapeRsnoc shr) (sh, sz) (ih, iz) = (bound shr sh ih, ih') where - go :: TupleType t -> t -> t -> t - go TypeRunit () () = () - go (TypeRpair tsh ti) (sh, sz) (ih, iz) = (go tsh sh ih, go ti sz iz) - go (TypeRscalar t) sz iz - | Just Refl <- matchScalarType t (scalarType :: ScalarType Int) - = let i | iz < 0 = case bnd of - Clamp -> 0 - Mirror -> -iz - Wrap -> sz + iz - _ -> $internalError "bound" "unexpected boundary condition" - | iz >= sz = case bnd of - Clamp -> sz - 1 - Mirror -> sz - (iz - sz + 2) - Wrap -> iz - sz - _ -> $internalError "bound" "unexpected boundary condition" - | otherwise = iz - in i - | otherwise - = $internalError "bound" "expected index with Int components" - + ih' + | iz < 0 = case bnd of + Clamp -> 0 + Mirror -> -iz + Wrap -> sz + iz + _ -> $internalError "bound" "unexpected boundary condition" + | iz >= sz = case bnd of + Clamp -> sz - 1 + Mirror -> sz - (iz - sz + 2) + Wrap -> iz - sz + _ -> $internalError "bound" "unexpected boundary condition" + | otherwise = iz -- toSeqOp :: forall slix sl dim co e proxy. (Elt slix, Shape sl, Shape dim, Elt e) -- => SliceIndex (EltRepr slix) @@ -882,18 +858,18 @@ data Boundary t where Clamp :: Boundary t Mirror :: Boundary t Wrap :: Boundary t - Constant :: Elt t => EltRepr t -> Boundary (Array sh t) - Function :: (Shape sh, Elt e) => (sh -> e) -> Boundary (Array sh e) + Constant :: t -> Boundary (Array sh t) + Function :: (sh -> e) -> Boundary (Array sh e) -evalPreBoundary :: EvalAcc acc -> AST.PreBoundary acc aenv t -> Val aenv -> Boundary t -evalPreBoundary evalAcc bnd aenv = +evalBoundary :: AST.Boundary aenv t -> Val aenv -> Boundary t +evalBoundary bnd aenv = case bnd of AST.Clamp -> Clamp AST.Mirror -> Mirror AST.Wrap -> Wrap AST.Constant v -> Constant v - AST.Function f -> Function (evalPreFun evalAcc f aenv) + AST.Function f -> Function (evalFun f aenv) -- Scalar expression evaluation @@ -901,20 +877,20 @@ evalPreBoundary evalAcc bnd aenv = -- Evaluate a closed scalar expression -- -evalPreExp :: EvalAcc acc -> PreExp acc aenv t -> Val aenv -> t -evalPreExp evalAcc e aenv = evalPreOpenExp evalAcc e EmptyElt aenv +evalExp :: Exp aenv t -> Val aenv -> t +evalExp e aenv = evalOpenExp e Empty aenv -- Evaluate a closed scalar function -- -evalPreFun :: EvalAcc acc -> PreFun acc aenv t -> Val aenv -> t -evalPreFun evalAcc f aenv = evalPreOpenFun evalAcc f EmptyElt aenv +evalFun :: Fun aenv t -> Val aenv -> t +evalFun f aenv = evalOpenFun f Empty aenv -- Evaluate an open scalar function -- -evalPreOpenFun :: EvalAcc acc -> PreOpenFun acc env aenv t -> ValElt env -> Val aenv -> t -evalPreOpenFun evalAcc (Body e) env aenv = evalPreOpenExp evalAcc e env aenv -evalPreOpenFun evalAcc (Lam f) env aenv = - \x -> evalPreOpenFun evalAcc f (env `PushElt` fromElt x) aenv +evalOpenFun :: OpenFun env aenv t -> Val env -> Val aenv -> t +evalOpenFun (Body e) env aenv = evalOpenExp e env aenv +evalOpenFun (Lam lhs f) env aenv = + \x -> evalOpenFun f (env `push` (lhs, x)) aenv -- Evaluate an open scalar expression @@ -925,42 +901,40 @@ evalPreOpenFun evalAcc (Lam f) env aenv = -- mapped over an array, the array argument would be evaluated many times -- leading to a large amount of wasteful recomputation. -- -evalPreOpenExp - :: forall acc env aenv t. - EvalAcc acc - -> PreOpenExp acc env aenv t - -> ValElt env +evalOpenExp + :: forall env aenv t. + OpenExp env aenv t + -> Val env -> Val aenv -> t -evalPreOpenExp evalAcc pexp env aenv = +evalOpenExp pexp env aenv = let - evalE :: PreOpenExp acc env aenv t' -> t' - evalE e = evalPreOpenExp evalAcc e env aenv + evalE :: OpenExp env aenv t' -> t' + evalE e = evalOpenExp e env aenv - evalF :: PreOpenFun acc env aenv f' -> f' - evalF f = evalPreOpenFun evalAcc f env aenv + evalF :: OpenFun env aenv f' -> f' + evalF f = evalOpenFun f env aenv - evalA :: acc aenv a -> a - evalA a = evalAcc a aenv + evalA :: ArrayVar aenv a -> WithReprs a + evalA (Var repr ix) = (TupRsingle repr, prj ix aenv) in case pexp of - Let exp1 exp2 -> let !v1 = evalE exp1 - env' = env `PushElt` fromElt v1 - in evalPreOpenExp evalAcc exp2 env' aenv - Var ix -> prjElt ix env - Const c -> toElt c - Undef -> evalUndef + Let lhs exp1 exp2 -> let !v1 = evalE exp1 + env' = env `push` (lhs, v1) + in evalOpenExp exp2 env' aenv + Evar (Var _ ix) -> prj ix env + Const _ c -> c + Undef tp -> evalUndefScalar tp PrimConst c -> evalPrimConst c PrimApp f x -> evalPrim f (evalE x) - Tuple tup -> toTuple $ evalTuple evalAcc tup env aenv - Prj ix tup -> evalPrj ix . fromTuple $ evalE tup - IndexNil -> Z - IndexAny -> Any - IndexCons sh sz -> evalE sh :. evalE sz - IndexHead sh -> let _ :. ix = evalE sh in ix - IndexTail sh -> let ix :. _ = evalE sh in ix - IndexSlice slice slix sh -> toElt $ restrict slice (fromElt (evalE slix)) - (fromElt (evalE sh)) + Nil -> () + Pair e1 e2 -> let !x1 = evalE e1 + !x2 = evalE e2 + in (x1, x2) + VecPack vecR e -> vecPack vecR $! evalE e + VecUnpack vecR e -> vecUnpack vecR $! evalE e + IndexSlice slice slix sh -> restrict slice (evalE slix) + (evalE sh) where restrict :: SliceIndex slix sl co sh -> slix -> sh -> sl restrict SliceNil () () = () @@ -970,8 +944,8 @@ evalPreOpenExp evalAcc pexp env aenv = restrict (SliceFixed sliceIdx) (slx, _i) (sl, _sz) = restrict sliceIdx slx sl - IndexFull slice slix sh -> toElt $ extend slice (fromElt (evalE slix)) - (fromElt (evalE sh)) + IndexFull slice slix sh -> extend slice (evalE slix) + (evalE sh) where extend :: SliceIndex slix sl co sh -> slix -> sl -> sh extend SliceNil () () = () @@ -982,8 +956,8 @@ evalPreOpenExp evalAcc pexp env aenv = let sh' = extend sliceIdx slx sl in (sh', sz) - ToIndex sh ix -> toIndex (evalE sh) (evalE ix) - FromIndex sh ix -> fromIndex (evalE sh) (evalE ix) + ToIndex shr sh ix -> toIndex shr (evalE sh) (evalE ix) + FromIndex shr sh ix -> fromIndex shr (evalE sh) (evalE ix) Cond c t e | evalE c -> evalE t | otherwise -> evalE e @@ -996,29 +970,28 @@ evalPreOpenExp evalAcc pexp env aenv = | p x = go (f x) | otherwise = x - Index acc ix -> evalA acc ! evalE ix - LinearIndex acc i -> let a = evalA acc - ix = fromIndex (shape a) (evalE i) - in a ! ix - Shape acc -> shape (evalA acc) - ShapeSize sh -> size (evalE sh) - Intersect sh1 sh2 -> intersect (evalE sh1) (evalE sh2) - Union sh1 sh2 -> union (evalE sh1) (evalE sh2) - Foreign _ f e -> evalPreOpenFun evalAcc f EmptyElt Empty $ evalE e - Coerce e -> evalCoerce (evalE e) + Index acc ix -> let (TupRsingle repr, a) = evalA acc + in (repr, a) ! evalE ix + LinearIndex acc i -> let (TupRsingle repr, a) = evalA acc + ix = fromIndex (arrayRshape repr) (shape a) (evalE i) + in (repr, a) ! ix + Shape acc -> shape $ snd $ evalA acc + ShapeSize shr sh -> size shr (evalE sh) + Foreign _ _ f e -> evalOpenFun f Empty Empty $ evalE e + Coerce t1 t2 e -> evalCoerceScalar t1 t2 (evalE e) -- Constant values -- --------------- -evalUndef :: forall a. Elt a => a -evalUndef = toElt (undef (eltType @a)) - where - undef :: TupleType t -> t - undef TypeRunit = () - undef (TypeRpair a b) = (undef a, undef b) - undef (TypeRscalar t) = scalar t +evalUndef :: TupleType a -> a +evalUndef TupRunit = () +evalUndef (TupRsingle tp) = evalUndefScalar tp +evalUndef (TupRpair t1 t2) = (evalUndef t1, evalUndef t2) +evalUndefScalar :: ScalarType a -> a +evalUndefScalar = scalar + where scalar :: ScalarType t -> t scalar (SingleScalarType t) = single t scalar (VectorScalarType t) = vector t @@ -1048,28 +1021,6 @@ evalUndef = toElt (undef (eltType @a)) -- Coercions -- --------- -evalCoerce :: forall a b. (Elt a, Elt b) => a -> b -evalCoerce = toElt . go (eltType @a) (eltType @b) . fromElt - where - go :: TupleType s -> TupleType t -> s -> t - go TypeRunit TypeRunit () = () - go (TypeRpair s1 s2) (TypeRpair t1 t2) (x,y) = (go s1 t1 x, go s2 t2 y) - go (TypeRscalar s) (TypeRscalar t) x - = $internalCheck "evalCoerce" "sizes not equal" (sizeOfScalarType s == sizeOfScalarType t) - $ evalCoerceScalar s t x - -- - -- newtype wrappers are typically declared similarly to `EltRepr (T a) = ((), EltRepr a)' - -- so add some special cases for dealing with redundant parentheses. - -- - go (TypeRpair TypeRunit s) t@TypeRscalar{} ((), x) = go s t x - go s@TypeRscalar{} (TypeRpair TypeRunit t) x = ((), go s t x) - -- - go _ _ _ - = error $ printf "could not coerce type `%s' to `%s'" - (show (typeOf (undefined::a))) - (show (typeOf (undefined::b))) - - -- Coercion between two scalar types. We require that the size of the source and -- destination values are equal (this is not checked at this point). -- @@ -1236,22 +1187,6 @@ evalPrim (PrimFromIntegral ta tb) = evalFromIntegral ta tb evalPrim (PrimToFloating ta tb) = evalToFloating ta tb --- Tuple construction and projection --- --------------------------------- - -evalTuple :: EvalAcc acc -> Tuple (PreOpenExp acc env aenv) t -> ValElt env -> Val aenv -> t -evalTuple _ NilTup _env _aenv = () -evalTuple evalAcc (tup `SnocTup` e) env aenv = - (evalTuple evalAcc tup env aenv, evalPreOpenExp evalAcc e env aenv) - -evalPrj :: TupleIdx t e -> t -> e -evalPrj ZeroTupIdx (!_, v) = v -evalPrj (SuccTupIdx idx) (tup, !_) = evalPrj idx tup - -- FIXME: Strictly speaking, we ought to force all components of a tuples; - -- not only those that we happen to encounter during the recursive - -- walk. - - -- Implementation of scalar primitives -- ----------------------------------- @@ -1882,10 +1817,10 @@ evalSeq conf s aenv = evalSeq' s evalAF f = evalOpenAfun f aenv evalE :: DelayedExp aenv t -> t - evalE exp = evalPreExp evalOpenAcc exp aenv + evalE exp = evalExp exp aenv evalF :: DelayedFun aenv f -> f - evalF fun = evalPreFun evalOpenAcc fun aenv + evalF fun = evalFun fun aenv initProducer :: forall a senv. Producer DelayedOpenAcc aenv senv a @@ -1931,9 +1866,9 @@ evalSeq conf s aenv = evalSeq' s delayed :: DelayedOpenAcc aenv (Array sh e) -> Delayed (Array sh e) delayed AST.Manifest{} = $internalError "evalOpenAcc" "expected delayed array" - delayed AST.Delayed{..} = Delayed (evalPreExp evalOpenAcc extentD aenv) - (evalPreFun evalOpenAcc indexD aenv) - (evalPreFun evalOpenAcc linearIndexD aenv) + delayed AST.Delayed{..} = Delayed (evalExp extentD aenv) + (evalFun indexD aenv) + (evalFun linearIndexD aenv) produce :: Arrays a => ExecP senv a -> Val' senv -> (Chunk a, Maybe (ExecP senv a)) produce p senv = diff --git a/src/Data/Array/Accelerate/Language.hs b/src/Data/Array/Accelerate/Language.hs index b4e3d6c95..c6bef7c64 100644 --- a/src/Data/Array/Accelerate/Language.hs +++ b/src/Data/Array/Accelerate/Language.hs @@ -1,3 +1,5 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE CPP #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} @@ -113,6 +115,8 @@ import Data.Array.Accelerate.Pattern import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Type import qualified Data.Array.Accelerate.Array.Sugar as Sugar +import qualified Data.Array.Accelerate.Array.Representation as Repr +import Data.Array.Accelerate.AST ( PrimFun(..) ) import Data.Array.Accelerate.Classes.Eq import Data.Array.Accelerate.Classes.Fractional @@ -122,7 +126,6 @@ import Data.Array.Accelerate.Classes.Ord -- standard libraries import Prelude ( ($), (.) ) -import Data.Typeable -- $setup -- >>> :seti -XFlexibleContexts @@ -165,27 +168,18 @@ import Data.Typeable -- >>> let tup = use (vec, mat) :: Acc (Vector Int, Matrix Int) -- use :: forall arrays. Arrays arrays => arrays -> Acc arrays -use arrs = Acc acc +use = Acc . use' (arrays @arrays) . fromArr where - HasTypeable acc = use' (arrays @arrays) $ fromArr arrs - - use' :: ArraysR a -> a -> HasTypeable a - use' ArraysRunit () = HasTypeable $ SmartAcc $ Anil - use' ArraysRarray a = HasTypeable $ SmartAcc $ Use a - use' (ArraysRpair r1 r2) (a1, a2) - | HasTypeable acc1 <- use' r1 a1 - , HasTypeable acc2 <- use' r2 a2 = HasTypeable $ SmartAcc $ acc1 `Apair` acc2 - --- Internal data type for 'use' to capture the 'Typeable' type class -data HasTypeable a where - HasTypeable :: Typeable a => SmartAcc a -> HasTypeable a - + use' :: ArraysR a -> a -> SmartAcc a + use' TupRunit () = SmartAcc $ Anil + use' (TupRsingle repr@ArrayR{}) a = SmartAcc $ Use repr a + use' (TupRpair r1 r2) (a1, a2) = SmartAcc $ use' r1 a1 `Apair` use' r2 a2 -- | Construct a singleton (one element) array from a scalar value (or tuple of -- scalar values). -- -unit :: Elt e => Exp e -> Acc (Scalar e) -unit = Acc . SmartAcc . Unit +unit :: forall e. Elt e => Exp e -> Acc (Scalar e) +unit (Exp e) = Acc $ SmartAcc $ Unit (eltType @e) e -- | Replicate an array across one or more dimensions as specified by the -- /generalised/ array index provided as the first argument. @@ -269,11 +263,12 @@ unit = Acc . SmartAcc . Unit -- 0, 1, 2, 3, 4, 5, 6, 7, 8, 9] -- replicate - :: (Slice slix, Elt e) + :: forall slix e. + (Slice slix, Elt e) => Exp slix -> Acc (Array (SliceShape slix) e) -> Acc (Array (FullShape slix) e) -replicate = Acc $$ applyAcc Replicate +replicate = Acc $$ applyAcc (Replicate $ sliceIndex @slix) -- | Construct a new array by applying a function to each index. -- @@ -305,11 +300,12 @@ replicate = Acc $$ applyAcc Replicate -- @.\/Data\/Array\/Accelerate\/Trafo\/Sharing.hs:447 (convertSharingExp): inconsistent valuation \@ shared \'Exp\' tree ...@. -- generate - :: (Shape sh, Elt a) + :: forall sh a. + (Shape sh, Elt a) => Exp sh -> (Exp sh -> Exp a) -> Acc (Array sh a) -generate = Acc $$ applyAcc Generate +generate = Acc $$ applyAcc (Generate $ arrayR @sh @a) -- Shape manipulation -- ------------------ @@ -324,11 +320,12 @@ generate = Acc $$ applyAcc Generate -- an index transformation in the fused code. -- reshape - :: (Shape sh, Shape sh', Elt e) + :: forall sh sh' e. + (Shape sh, Shape sh', Elt e) => Exp sh -> Acc (Array sh' e) -> Acc (Array sh e) -reshape = Acc $$ applyAcc Reshape +reshape = Acc $$ applyAcc (Reshape $ shapeR @sh) -- Extraction of sub-arrays -- ------------------------ @@ -398,11 +395,12 @@ reshape = Acc $$ applyAcc Reshape -- 30, 31, 32, 33, 34, -- 50, 51, 52, 53, 54] -- -slice :: (Slice slix, Elt e) +slice :: forall slix e. + (Slice slix, Elt e) => Acc (Array (FullShape slix) e) -> Exp slix -> Acc (Array (SliceShape slix) e) -slice = Acc $$ applyAcc Slice +slice = Acc $$ applyAcc (Slice $ sliceIndex @slix) -- Map-like functions -- ------------------ @@ -418,11 +416,12 @@ slice = Acc $$ applyAcc Slice -- >>> run $ map (+1) (use xs) -- Vector (Z :. 10) [1,2,3,4,5,6,7,8,9,10] -- -map :: (Shape sh, Elt a, Elt b) +map :: forall sh a b. + (Shape sh, Elt a, Elt b) => (Exp a -> Exp b) -> Acc (Array sh a) -> Acc (Array sh b) -map = Acc $$ applyAcc Map +map = Acc $$ applyAcc (Map (eltType @a) (eltType @b)) -- | Apply the given binary function element-wise to the two arrays. The extent -- of the resulting array is the intersection of the extents of the two source @@ -450,12 +449,13 @@ map = Acc $$ applyAcc Map -- 16, 18, 20, 22, 24, -- 31, 33, 35, 37, 39] -- -zipWith :: (Shape sh, Elt a, Elt b, Elt c) +zipWith :: forall sh a b c. + (Shape sh, Elt a, Elt b, Elt c) => (Exp a -> Exp b -> Exp c) -> Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c) -zipWith = Acc $$$ applyAcc ZipWith +zipWith = Acc $$$ applyAcc (ZipWith (eltType @a) (eltType @b) (eltType @c)) -- Reductions -- ---------- @@ -521,12 +521,13 @@ zipWith = Acc $$$ applyAcc ZipWith -- See also 'Data.Array.Accelerate.Data.Fold.Fold', which can be a useful way to -- compute multiple results from a single reduction. -- -fold :: (Shape sh, Elt a) +fold :: forall sh a. + (Shape sh, Elt a) => (Exp a -> Exp a -> Exp a) -> Exp a -> Acc (Array (sh:.Int) a) -> Acc (Array sh a) -fold = Acc $$$ applyAcc Fold +fold = Acc $$$ applyAcc (Fold $ eltType @a) -- | Variant of 'fold' that requires the innermost dimension of the array to be -- non-empty and doesn't need an default value. @@ -538,11 +539,12 @@ fold = Acc $$$ applyAcc Fold -- The first argument needs to be an /associative/ function to enable an -- efficient parallel implementation, but does not need to be commutative. -- -fold1 :: (Shape sh, Elt a) +fold1 :: forall sh a. + (Shape sh, Elt a) => (Exp a -> Exp a -> Exp a) -> Acc (Array (sh:.Int) a) -> Acc (Array sh a) -fold1 = Acc $$ applyAcc Fold1 +fold1 = Acc $$ applyAcc (Fold1 $ eltType @a) -- | Segmented reduction along the innermost dimension of an array. The -- segment descriptor specifies the starting index (offset) along the @@ -558,13 +560,14 @@ fold1 = Acc $$ applyAcc Fold1 -- @since 1.3.0.0 -- foldSeg' - :: (Shape sh, Elt a, Elt i, IsIntegral i) + :: forall sh a i. + (Shape sh, Elt a, Elt i, IsIntegral i, i ~ EltRepr i) => (Exp a -> Exp a -> Exp a) -> Exp a -> Acc (Array (sh:.Int) a) -> Acc (Segments i) -> Acc (Array (sh:.Int) a) -foldSeg' = Acc $$$$ applyAcc FoldSeg +foldSeg' = Acc $$$$ applyAcc (FoldSeg (integralType @i) (eltType @a)) -- | Variant of 'foldSeg'' that requires /all/ segments of the reduced -- array to be non-empty, and doesn't need a default value. The segment @@ -574,12 +577,13 @@ foldSeg' = Acc $$$$ applyAcc FoldSeg -- @since 1.3.0.0 -- fold1Seg' - :: (Shape sh, Elt a, Elt i, IsIntegral i) + :: forall sh a i. + (Shape sh, Elt a, Elt i, IsIntegral i, i ~ EltRepr i) => (Exp a -> Exp a -> Exp a) -> Acc (Array (sh:.Int) a) -> Acc (Segments i) -> Acc (Array (sh:.Int) a) -fold1Seg' = Acc $$$ applyAcc Fold1Seg +fold1Seg' = Acc $$$ applyAcc (Fold1Seg (integralType @i) (eltType @a)) -- Scan functions -- -------------- @@ -601,12 +605,13 @@ fold1Seg' = Acc $$$ applyAcc Fold1Seg -- 0, 20, 41, 63, 86, 110, 135, 161, 188, 216, 245, -- 0, 30, 61, 93, 126, 160, 195, 231, 268, 306, 345] -- -scanl :: (Shape sh, Elt a) +scanl :: forall sh a. + (Shape sh, Elt a) => (Exp a -> Exp a -> Exp a) -> Exp a -> Acc (Array (sh:.Int) a) -> Acc (Array (sh:.Int) a) -scanl = Acc $$$ applyAcc Scanl +scanl = Acc $$$ applyAcc (Scanl $ eltType @a) -- | Variant of 'scanl', where the last element (final reduction result) along -- each dimension is returned separately. Denotationally we have: @@ -634,12 +639,13 @@ scanl = Acc $$$ applyAcc Scanl -- >>> sums -- Vector (Z :. 4) [45,145,245,345] -- -scanl' :: (Shape sh, Elt a) +scanl' :: forall sh a. + (Shape sh, Elt a) => (Exp a -> Exp a -> Exp a) -> Exp a -> Acc (Array (sh:.Int) a) -> Acc (Array (sh:.Int) a, Array sh a) -scanl' = Acc $$$ applyAcc Scanl' +scanl' = Acc . mkPairToTuple $$$ applyAcc (Scanl' $ eltType @a) -- | Data.List style left-to-right scan along the innermost dimension without an -- initial value (aka inclusive scan). The innermost dimension of the array must @@ -653,37 +659,41 @@ scanl' = Acc $$$ applyAcc Scanl' -- 20, 41, 63, 86, 110, 135, 161, 188, 216, 245, -- 30, 61, 93, 126, 160, 195, 231, 268, 306, 345] -- -scanl1 :: (Shape sh, Elt a) +scanl1 :: forall sh a. + (Shape sh, Elt a) => (Exp a -> Exp a -> Exp a) -> Acc (Array (sh:.Int) a) -> Acc (Array (sh:.Int) a) -scanl1 = Acc $$ applyAcc Scanl1 +scanl1 = Acc $$ applyAcc (Scanl1 $ eltType @a) -- | Right-to-left variant of 'scanl'. -- -scanr :: (Shape sh, Elt a) +scanr :: forall sh a. + (Shape sh, Elt a) => (Exp a -> Exp a -> Exp a) -> Exp a -> Acc (Array (sh:.Int) a) -> Acc (Array (sh:.Int) a) -scanr = Acc $$$ applyAcc Scanr +scanr = Acc $$$ applyAcc (Scanr $ eltType @a) -- | Right-to-left variant of 'scanl''. -- -scanr' :: (Shape sh, Elt a) +scanr' :: forall sh a. + (Shape sh, Elt a) => (Exp a -> Exp a -> Exp a) -> Exp a -> Acc (Array (sh:.Int) a) -> Acc (Array (sh:.Int) a, Array sh a) -scanr' = Acc $$$ applyAcc Scanr' +scanr' = Acc . mkPairToTuple $$$ applyAcc (Scanr' $ eltType @a) -- | Right-to-left variant of 'scanl1'. -- -scanr1 :: (Shape sh, Elt a) +scanr1 :: forall sh a. + (Shape sh, Elt a) => (Exp a -> Exp a -> Exp a) -> Acc (Array (sh:.Int) a) -> Acc (Array (sh:.Int) a) -scanr1 = Acc $$ applyAcc Scanr1 +scanr1 = Acc $$ applyAcc (Scanr1 $ eltType @a) -- Permutations -- ------------ @@ -781,13 +791,13 @@ scanr1 = Acc $$ applyAcc Scanr1 -- @-fno-fast-permute-const@. -- permute - :: (Shape sh, Shape sh', Elt a) + :: forall sh sh' a. (Shape sh, Shape sh', Elt a) => (Exp a -> Exp a -> Exp a) -- ^ combination function -> Acc (Array sh' a) -- ^ array of default values -> (Exp sh -> Exp sh') -- ^ index permutation function -> Acc (Array sh a) -- ^ array of source values to be permuted -> Acc (Array sh' a) -permute = Acc $$$$ applyAcc Permute +permute = Acc $$$$ applyAcc (Permute $ arrayR @sh @a) -- | Generalised backward permutation operation (array gather). -- @@ -833,13 +843,12 @@ permute = Acc $$$$ applyAcc Permute -- 9, 19, 29, 39, 49] -- backpermute - :: (Shape sh, Shape sh', Elt a) + :: forall sh sh' a. (Shape sh, Shape sh', Elt a) => Exp sh' -- ^ shape of the result array -> (Exp sh' -> Exp sh) -- ^ index permutation function -> Acc (Array sh a) -- ^ source array -> Acc (Array sh' a) -backpermute = Acc $$$ applyAcc Backpermute - +backpermute = Acc $$$ applyAcc (Backpermute $ shapeR @sh') -- Stencil operations -- ------------------ @@ -946,20 +955,27 @@ type Stencil5x5x5 a = (Stencil5x5 a, Stencil5x5 a, Stencil5x5 a, Stencil5x5 a, S -- which approach is best for your application. -- stencil - :: (Stencil sh a stencil, Elt b) + :: forall sh stencil a b. + (Stencil sh a stencil, Elt b) => (stencil -> Exp b) -- ^ stencil function -> Boundary (Array sh a) -- ^ boundary condition -> Acc (Array sh a) -- ^ source array -> Acc (Array sh b) -- ^ destination array stencil f (Boundary b) (Acc a) - = Acc $ SmartAcc $ Stencil f b a + = Acc $ SmartAcc $ Stencil + (stencilR @sh @a @stencil) + (eltType @b) + (unExp . f . stencilPrj @sh @a @stencil) + b + a -- | Map a binary stencil of an array. The extent of the resulting array is the -- intersection of the extents of the two source arrays. This is the stencil -- equivalent of 'zipWith'. -- stencil2 - :: (Stencil sh a stencil1, Stencil sh b stencil2, Elt c) + :: forall sh stencil1 stencil2 a b c. + (Stencil sh a stencil1, Stencil sh b stencil2, Elt c) => (stencil1 -> stencil2 -> Exp c) -- ^ binary stencil function -> Boundary (Array sh a) -- ^ boundary condition #1 -> Acc (Array sh a) -- ^ source array #1 @@ -967,7 +983,15 @@ stencil2 -> Acc (Array sh b) -- ^ source array #2 -> Acc (Array sh c) -- ^ destination array stencil2 f (Boundary b1) (Acc a1) (Boundary b2) (Acc a2) - = Acc $ SmartAcc $ Stencil2 f b1 a1 b2 a2 + = Acc $ SmartAcc $ Stencil2 + (stencilR @sh @a @stencil1) + (stencilR @sh @b @stencil2) + (eltType @c) + (\x y -> unExp $ f (stencilPrj @sh @a @stencil1 x) (stencilPrj @sh @b @stencil2 y)) + b1 + a1 + b2 + a2 -- | Boundary condition where elements of the stencil which would be -- out-of-bounds are instead clamped to the edges of the array. @@ -1041,10 +1065,13 @@ wrap = Boundary Wrap -- > Z :. height :. width = unlift (shape xs) -- function - :: (Shape sh, Elt e) + :: forall sh e. (Shape sh, Elt e) => (Exp sh -> Exp e) -> Boundary (Array sh e) -function = Boundary . Function +function f = Boundary $ Function (f') + where + f' :: SmartExp (EltRepr sh) -> SmartExp (EltRepr e) + f' = unExp . f . Exp {-- @@ -1171,11 +1198,11 @@ collect = Acc . Collect -- foreignAcc :: forall as bs asm. (Arrays as, Arrays bs, Foreign asm) - => asm (as -> bs) + => asm (ArrRepr as -> ArrRepr bs) -> (Acc as -> Acc bs) -> Acc as -> Acc bs -foreignAcc asm f (Acc as) = Acc $ SmartAcc $ Aforeign asm f as +foreignAcc asm f (Acc as) = Acc $ SmartAcc $ Aforeign (arrays @bs) asm (unAccFunction f) as -- | Call a foreign scalar expression. -- @@ -1188,12 +1215,12 @@ foreignAcc asm f (Acc as) = Acc $ SmartAcc $ Aforeign asm f as -- purely in Accelerate. -- foreignExp - :: (Elt x, Elt y, Foreign asm) - => asm (x -> y) + :: forall x y asm. (Elt x, Elt y, Foreign asm) + => asm (EltRepr x -> EltRepr y) -> (Exp x -> Exp y) -> Exp x -> Exp y -foreignExp = Exp $$$ Foreign +foreignExp a f (Exp x) = exp $ Foreign (eltType @y) a (unExpFunction f) x -- Composition of array computations @@ -1213,7 +1240,7 @@ foreignExp = Exp $$$ Foreign -- infixl 1 >-> (>->) :: forall a b c. (Arrays a, Arrays b, Arrays c) => (Acc a -> Acc b) -> (Acc b -> Acc c) -> (Acc a -> Acc c) -(>->) = Acc $$$ applyAcc $ Pipe (arrays @a) (arrays @b) +(>->) = Acc $$$ applyAcc $ Pipe (arrays @a) (arrays @b) (arrays @c) -- Flow control constructs @@ -1250,26 +1277,41 @@ awhile = Acc $$$ applyAcc $ Awhile $ arrays @a -- array. -- toIndex - :: Shape sh + :: forall sh. Shape sh => Exp sh -- ^ extent of the array -> Exp sh -- ^ index to remap -> Exp Int -toIndex = Exp $$ ToIndex +toIndex (Exp sh) (Exp ix) = exp $ ToIndex (shapeR @sh) sh ix -- | Inverse of 'toIndex' -- -fromIndex :: Shape sh => Exp sh -> Exp Int -> Exp sh -fromIndex = Exp $$ FromIndex +fromIndex :: forall sh. Shape sh => Exp sh -> Exp Int -> Exp sh +fromIndex (Exp sh) (Exp e) = exp $ FromIndex (shapeR @sh) sh e -- | Intersection of two shapes -- -intersect :: Shape sh => Exp sh -> Exp sh -> Exp sh -intersect = Exp $$ Intersect +intersect :: forall sh. Shape sh => Exp sh -> Exp sh -> Exp sh +intersect (Exp x) (Exp y) = Exp $ intersect' (shapeR @sh) x y + +intersect' :: Repr.ShapeR sh -> SmartExp sh -> SmartExp sh -> SmartExp sh +intersect' Repr.ShapeRz _ _ = SmartExp Nil +intersect' (Repr.ShapeRsnoc shr) (unPair -> (xs, x)) (unPair -> (ys, y)) + = SmartExp + $ intersect' shr xs ys `Pair` + SmartExp (PrimApp (PrimMin singleType) $ SmartExp $ Pair x y) + -- | Union of two shapes -- -union :: Shape sh => Exp sh -> Exp sh -> Exp sh -union = Exp $$ Union +union :: forall sh. Shape sh => Exp sh -> Exp sh -> Exp sh +union (Exp x) (Exp y) = Exp $ union' (shapeR @sh) x y + +union' :: Repr.ShapeR sh -> SmartExp sh -> SmartExp sh -> SmartExp sh +union' Repr.ShapeRz _ _ = SmartExp Nil +union' (Repr.ShapeRsnoc shr) (unPair -> (xs, x)) (unPair -> (ys, y)) + = SmartExp + $ union' shr xs ys `Pair` + SmartExp (PrimApp (PrimMax singleType) $ SmartExp $ Pair x y) -- Flow-control @@ -1285,18 +1327,30 @@ cond :: Elt t -> Exp t -- ^ then-expression -> Exp t -- ^ else-expression -> Exp t -cond = Exp $$$ Cond +cond (Exp c) (Exp x) (Exp y) = exp $ Cond c x y -- | While construct. Continue to apply the given function, starting with the -- initial value, until the test function evaluates to 'False'. -- -while :: Elt e +while :: forall e. Elt e => (Exp e -> Exp Bool) -- ^ keep evaluating while this returns 'True' -> (Exp e -> Exp e) -- ^ function to apply -> Exp e -- ^ initial value -> Exp e -while = Exp $$$ While +#if __GLASGOW_HASKELL__ < 804 +while c f (Exp e) = exp $ While @SmartAcc @SmartExp @(EltRepr e) (eltType @e) (unExp . c . Exp) (unExp . f . Exp) e +#else +while c f (Exp e) = exp $ While @(EltRepr e) (eltType @e) (unExp . c . Exp) (unExp . f . Exp) e +#endif + +{- + While :: TupleType t + -> (SmartExp t -> exp Bool) + -> (SmartExp t -> exp t) + -> exp t + -> PreSmartExp acc exp t + -} -- Array operations with a scalar result -- ------------------------------------- @@ -1317,8 +1371,8 @@ while = Exp $$$ While -- 12 -- infixl 9 ! -(!) :: (Shape sh, Elt e) => Acc (Array sh e) -> Exp sh -> Exp e -Acc a ! ix = Exp $ Index a ix +(!) :: forall sh e. (Shape sh, Elt e) => Acc (Array sh e) -> Exp sh -> Exp e +Acc a ! Exp ix = exp $ Index (eltType @e) a ix -- | Extract the value from an array at the specified linear index. -- Multidimensional arrays in Accelerate are stored in row-major order with @@ -1337,13 +1391,13 @@ Acc a ! ix = Exp $ Index a ix -- 12 -- infixl 9 !! -(!!) :: (Shape sh, Elt e) => Acc (Array sh e) -> Exp Int -> Exp e -Acc a !! ix = Exp $ LinearIndex a ix +(!!) :: forall sh e. (Shape sh, Elt e) => Acc (Array sh e) -> Exp Int -> Exp e +Acc a !! Exp ix = exp $ LinearIndex (eltType @e) a ix -- | Extract the shape (extent) of an array. -- -shape :: (Shape sh, Elt e) => Acc (Array sh e) -> Exp sh -shape = Exp . Shape . unAcc +shape :: forall sh e. (Shape sh, Elt e) => Acc (Array sh e) -> Exp sh +shape = exp . Shape (shapeR @sh) . unAcc -- | The number of elements in the array -- @@ -1352,8 +1406,8 @@ size = shapeSize . shape -- | The number of elements that would be held by an array of the given shape. -- -shapeSize :: Shape sh => Exp sh -> Exp Int -shapeSize = Exp . ShapeSize +shapeSize :: forall sh. Shape sh => Exp sh -> Exp Int +shapeSize (Exp sh) = exp $ ShapeSize (shapeR @sh) sh -- Numeric functions diff --git a/src/Data/Array/Accelerate/Lift.hs b/src/Data/Array/Accelerate/Lift.hs index d2371150e..0485bcb3b 100644 --- a/src/Data/Array/Accelerate/Lift.hs +++ b/src/Data/Array/Accelerate/Lift.hs @@ -3,6 +3,9 @@ {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} @@ -144,153 +147,157 @@ instance Unlift Acc (Acc a) where instance Lift Exp () where type Plain () = () - lift _ = Exp $ Tuple NilTup + lift _ = Exp $ SmartExp Nil instance Unlift Exp () where unlift _ = () instance Lift Exp Z where type Plain Z = Z - lift _ = Exp $ IndexNil + lift _ = Exp $ SmartExp Nil instance Unlift Exp Z where unlift _ = Z instance (Elt (Plain ix), Lift Exp ix) => Lift Exp (ix :. Int) where type Plain (ix :. Int) = Plain ix :. Int - lift (ix:.i) = Exp $ IndexCons (lift ix) (Exp $ Const i) + lift (ix:.i) = Exp $ SmartExp $ Pair (unExp $ lift ix) (unExp $ expConst i) instance (Elt (Plain ix), Lift Exp ix) => Lift Exp (ix :. All) where type Plain (ix :. All) = Plain ix :. All - lift (ix:.i) = Exp $ IndexCons (lift ix) (Exp $ Const i) + lift (ix:.i) = Exp $ SmartExp $ Pair (unExp $ lift ix) (unExp $ constant i) instance (Elt e, Elt (Plain ix), Lift Exp ix) => Lift Exp (ix :. Exp e) where type Plain (ix :. Exp e) = Plain ix :. e - lift (ix:.i) = Exp $ IndexCons (lift ix) i + lift (ix :. Exp i) = Exp $ SmartExp $ Pair (unExp $ lift ix) i instance {-# OVERLAPPABLE #-} (Elt e, Elt (Plain ix), Unlift Exp ix) => Unlift Exp (ix :. Exp e) where - unlift e = unlift (Exp $ IndexTail e) :. Exp (IndexHead e) + unlift (Exp e) = unlift (Exp $ SmartExp $ Prj PairIdxLeft e) :. Exp (SmartExp $ Prj PairIdxRight e) instance {-# OVERLAPPABLE #-} (Elt e, Elt ix) => Unlift Exp (Exp ix :. Exp e) where - unlift e = (Exp $ IndexTail e) :. Exp (IndexHead e) + unlift (Exp e) = (Exp $ SmartExp $ Prj PairIdxLeft e) :. Exp (SmartExp $ Prj PairIdxRight e) -instance Shape sh => Lift Exp (Any sh) where - type Plain (Any sh) = Any sh - lift Any = Exp $ IndexAny +instance (Shape sh, Elt (Any sh)) => Lift Exp (Any sh) where + type Plain (Any sh) = Any sh + lift Any = constant Any -- instances for numeric types +{-# INLINE expConst #-} +expConst :: forall e. Elt e => IsScalar (EltRepr e) => e -> Exp e +expConst = Exp . SmartExp . Const (scalarType @(EltRepr e)) . fromElt + instance Lift Exp Int where type Plain Int = Int - lift = Exp . Const + lift = expConst instance Lift Exp Int8 where type Plain Int8 = Int8 - lift = Exp . Const + lift = expConst instance Lift Exp Int16 where type Plain Int16 = Int16 - lift = Exp . Const + lift = expConst instance Lift Exp Int32 where type Plain Int32 = Int32 - lift = Exp . Const + lift = expConst instance Lift Exp Int64 where type Plain Int64 = Int64 - lift = Exp . Const + lift = expConst instance Lift Exp Word where type Plain Word = Word - lift = Exp . Const + lift = expConst instance Lift Exp Word8 where type Plain Word8 = Word8 - lift = Exp . Const + lift = expConst instance Lift Exp Word16 where type Plain Word16 = Word16 - lift = Exp . Const + lift = expConst instance Lift Exp Word32 where type Plain Word32 = Word32 - lift = Exp . Const + lift = expConst instance Lift Exp Word64 where type Plain Word64 = Word64 - lift = Exp . Const + lift = expConst instance Lift Exp CShort where type Plain CShort = CShort - lift = Exp . Const + lift = expConst instance Lift Exp CUShort where type Plain CUShort = CUShort - lift = Exp . Const + lift = expConst instance Lift Exp CInt where type Plain CInt = CInt - lift = Exp . Const + lift = expConst instance Lift Exp CUInt where type Plain CUInt = CUInt - lift = Exp . Const + lift = expConst instance Lift Exp CLong where type Plain CLong = CLong - lift = Exp . Const + lift = expConst instance Lift Exp CULong where type Plain CULong = CULong - lift = Exp . Const + lift = expConst instance Lift Exp CLLong where type Plain CLLong = CLLong - lift = Exp . Const + lift = expConst instance Lift Exp CULLong where type Plain CULLong = CULLong - lift = Exp . Const + lift = expConst instance Lift Exp Half where type Plain Half = Half - lift = Exp . Const + lift = expConst instance Lift Exp Float where type Plain Float = Float - lift = Exp . Const + lift = expConst instance Lift Exp Double where type Plain Double = Double - lift = Exp . Const + lift = expConst instance Lift Exp CFloat where type Plain CFloat = CFloat - lift = Exp . Const + lift = expConst instance Lift Exp CDouble where type Plain CDouble = CDouble - lift = Exp . Const + lift = expConst instance Lift Exp Bool where type Plain Bool = Bool - lift = Exp . Const + lift = expConst instance Lift Exp Char where type Plain Char = Char - lift = Exp . Const + lift = expConst instance Lift Exp CChar where type Plain CChar = CChar - lift = Exp . Const + lift = expConst instance Lift Exp CSChar where type Plain CSChar = CSChar - lift = Exp . Const + lift = expConst instance Lift Exp CUChar where type Plain CUChar = CUChar - lift = Exp . Const + lift = expConst -- Instances for tuples @@ -482,7 +489,7 @@ instance Lift Acc () where instance (Shape sh, Elt e) => Lift Acc (Array sh e) where type Plain (Array sh e) = Array sh e - lift = Acc . SmartAcc . Use + lift (Array arr) = Acc $ SmartAcc $ Use (arrayR @sh @e) arr -- Lift and Unlift instances for tuples -- diff --git a/src/Data/Array/Accelerate/Pattern.hs b/src/Data/Array/Accelerate/Pattern.hs index 17ccc6621..b2aec3891 100644 --- a/src/Data/Array/Accelerate/Pattern.hs +++ b/src/Data/Array/Accelerate/Pattern.hs @@ -1,204 +1,274 @@ -{-# LANGUAGE CPP #-} -{-# LANGUAGE ConstraintKinds #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE MultiParamTypeClasses #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TemplateHaskell #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE ViewPatterns #-} -#if __GLASGOW_HASKELL__ <= 800 -{-# OPTIONS_GHC -fno-warn-unrecognised-pragmas #-} -#endif --- | --- Module : Data.Array.Accelerate.Pattern --- Copyright : [2018..2019] The Accelerate Team --- License : BSD3 --- --- Maintainer : Trevor L. McDonell --- Stability : experimental --- Portability : non-portable (GHC extensions) --- - -module Data.Array.Accelerate.Pattern ( - - pattern Pattern, - pattern T2, pattern T3, pattern T4, pattern T5, pattern T6, - pattern T7, pattern T8, pattern T9, pattern T10, pattern T11, - pattern T12, pattern T13, pattern T14, pattern T15, pattern T16, - - pattern Z_, pattern Ix, pattern (::.), - pattern I0, pattern I1, pattern I2, pattern I3, pattern I4, - pattern I5, pattern I6, pattern I7, pattern I8, pattern I9, - -) where - -import Data.Array.Accelerate.Array.Sugar -import Data.Array.Accelerate.Product -import Data.Array.Accelerate.Smart - -import Language.Haskell.TH hiding ( Exp ) -import Language.Haskell.TH.Extra - - --- | A pattern synonym for working with (product) data types. You can declare --- your own pattern synonyms based off of this. --- -pattern Pattern :: forall b a context. IsPattern context a b => b -> context a -pattern Pattern vars <- (destruct @context -> vars) - where Pattern = construct @context - -class IsPattern con a t where - construct :: t -> con a - destruct :: con a -> t - - --- | Pattern synonyms for indices, which may be more convenient to use than --- 'Data.Array.Accelerate.Lift.lift' and --- 'Data.Array.Accelerate.Lift.unlift'. --- -pattern Z_ :: Exp DIM0 -pattern Z_ = Pattern Z -{-# COMPLETE Z_ #-} - -infixl 3 ::. -pattern (::.) :: (Elt a, Elt b) => Exp a -> Exp b -> Exp (a :. b) -pattern a ::. b = Pattern (a :. b) -{-# COMPLETE (::.) #-} - -pattern Ix :: (Elt a, Elt b) => Exp a -> Exp b -> Exp (a :. b) -pattern a `Ix` b = a ::. b -{-# COMPLETE Ix #-} - --- IsPattern instances for Shape nil and cons --- -instance IsPattern Exp Z Z where - construct _ = Exp IndexNil - destruct _ = Z - -instance (Elt a, Elt b) => IsPattern Exp (a :. b) (Exp a :. Exp b) where - construct (a :. b) = Exp (a `IndexCons` b) - destruct t = Exp (IndexTail t) :. Exp (IndexHead t) - --- IsPattern instances for up to 16-tuples (Acc and Exp). TH takes care of the --- (unremarkable) boilerplate for us, but since the implementation is a little --- tricky it is debatable whether or not this is a good idea... --- -$(runQ $ do - let - mkIsPattern' :: Name -> TypeQ -> ExpQ -> ExpQ -> ExpQ -> ExpQ -> Int -> Q [Dec] - mkIsPattern' con cst tup prj nil snoc n = - let - xs = [ mkName ('x' : show i) | i <- [0 .. n-1]] - b = foldl (\ts t -> appT ts (appT (conT con) (varT t))) (tupleT n) xs - repr = foldl (\ts t -> [t| ($ts, $(varT t)) |]) [t| () |] xs - context = foldl (\ts t -> appT ts (appT cst (varT t))) (tupleT n) xs - -- - tix 0 = [| ZeroTupIdx |] - tix i = [| SuccTupIdx $(tix (i-1)) |] - get x i = [| $(conE con) ($prj $(tix i) $x) |] - in - [d| instance - ( IsProduct $cst a - , ProdRepr a ~ $repr - , $cst a - , $context - ) => IsPattern $(conT con) a $b where - construct $(tupP (map varP xs)) = $(conE con) ($tup $(foldl (\vs v -> appE (appE snoc vs) (varE v)) nil xs)) - destruct _x = $(tupE (map (get [|_x|]) [(n-1), (n-2) .. 0])) - |] - - mkIsPattern :: Name -> TypeQ -> TypeQ -> ExpQ -> ExpQ -> ExpQ -> ExpQ -> Int -> Q [Dec] - mkIsPattern con cst repr smart prj nil pair n = do - let - xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] - ts = map varT xs - a = tupT ts - b = tupT (map (conT con `appT`) ts) - context = tupT (map (cst `appT`) ts) - equiv = case n of - 1 -> [t| ((), $repr $a) |] - _ -> [t| $repr $a |] - -- - get x 0 = [| $(conE con) ($smart ($prj PairIdxRight $x)) |] - get x i = get [| $smart ($prj PairIdxLeft $x) |] (i-1) - -- - _x <- newName "_x" - [d| instance ($repr a ~ $equiv, $context) => IsPattern $(conT con) a $b where - construct $(tupP (map (conP con . return . varP) xs)) = - $(conE con) $(foldl (\vs v -> appE smart (appE (appE pair vs) (varE v))) (appE smart nil) xs) - destruct $(conP con [varP _x]) = - $(tupE (map (get (varE _x)) [(n-1), (n-2) .. 0])) - |] - - mkExpPattern = mkIsPattern' (mkName "Exp") [t| Elt |] [| Tuple |] [| Prj |] [| NilTup |] [| SnocTup |] - mkAccPattern = mkIsPattern (mkName "Acc") [t| Arrays |] [t| ArrRepr |] [| SmartAcc |] [| Aprj |] [| Anil |] [| Apair |] - -- - es <- mapM mkExpPattern [0..16] - as <- mapM mkAccPattern [0..16] - return $ concat (es ++ as) - ) - --- | Specialised pattern synonyms for tuples, which may be more convenient to --- use than 'Data.Array.Accelerate.Lift.lift' and --- 'Data.Array.Accelerate.Lift.unlift'. For example, to construct a pair: --- --- > let a = 4 :: Exp Int --- > let b = 2 :: Exp Float --- > let c = T2 a b -- :: Exp (Int, Float); equivalent to 'lift (a,b)' --- --- Similarly they can be used to destruct values: --- --- > let T2 x y = c -- x :: Exp Int, y :: Exp Float; equivalent to 'let (x,y) = unlift c' --- --- These pattern synonyms can be used for both 'Exp' and 'Acc' terms. --- --- Similarly, we have patterns for constructing and destructing indices of --- a given dimensionality: --- --- > let ix = Ix 2 3 -- :: Exp DIM2 --- > let I2 y x = ix -- y :: Exp Int, x :: Exp Int --- -$(runQ $ do - let - mkT :: Int -> Q [Dec] - mkT n = - let xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] - ts = map varT xs - name = mkName ('T':show n) - con = varT (mkName "con") - ty1 = tupT ts - ty2 = tupT (map (con `appT`) ts) - sig = foldr (\t r -> [t| $con $t -> $r |]) (appT con ty1) ts - in - sequence - [ patSynSigD name [t| IsPattern $con $ty1 $ty2 => $sig |] - , patSynD name (prefixPatSyn xs) implBidir [p| Pattern $(tupP (map varP xs)) |] - , pragCompleteD [name] (Just ''Acc) - , pragCompleteD [name] (Just ''Exp) - ] - - mkI :: Int -> Q [Dec] - mkI n = - let xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] - ts = map varT xs - name = mkName ('I':show n) - ix = mkName "Ix" - cst = tupT (map (\t -> [t| Elt $t |]) ts) - dim = foldl (\h t -> [t| $h :. $t |]) [t| Z |] ts - sig = foldr (\t r -> [t| Exp $t -> $r |]) [t| Exp $dim |] ts - in - sequence - [ patSynSigD name [t| $cst => $sig |] - , patSynD name (prefixPatSyn xs) implBidir (foldl (\ps p -> infixP ps ix (varP p)) [p| Z_ |] xs) - , pragCompleteD [name] Nothing - ] - -- - ts <- mapM mkT [2..16] - is <- mapM mkI [0..9] - return $ concat (ts ++ is) - ) - +{-# LANGUAGE CPP #-} +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +#if __GLASGOW_HASKELL__ <= 800 +{-# OPTIONS_GHC -fno-warn-unrecognised-pragmas #-} +#endif +-- | +-- Module : Data.Array.Accelerate.Pattern +-- Copyright : [2018..2019] The Accelerate Team +-- License : BSD3 +-- +-- Maintainer : Trevor L. McDonell +-- Stability : experimental +-- Portability : non-portable (GHC extensions) +-- + +module Data.Array.Accelerate.Pattern ( + + pattern Pattern, + pattern T2, pattern T3, pattern T4, pattern T5, pattern T6, + pattern T7, pattern T8, pattern T9, pattern T10, pattern T11, + pattern T12, pattern T13, pattern T14, pattern T15, pattern T16, + + pattern Z_, pattern Ix, pattern (::.), + pattern I0, pattern I1, pattern I2, pattern I3, pattern I4, + pattern I5, pattern I6, pattern I7, pattern I8, pattern I9, + + pattern V2_, pattern V3_, pattern V4_, pattern V8_, pattern V16_, + +) where + +import Data.Array.Accelerate.Array.Sugar +import Data.Array.Accelerate.Smart +import Data.Array.Accelerate.Type + +import Language.Haskell.TH hiding ( Exp ) +import Language.Haskell.TH.Extra + + +-- | A pattern synonym for working with (product) data types. You can declare +-- your own pattern synonyms based off of this. +-- +pattern Pattern :: forall b a context. IsPattern context a b => b -> context a +pattern Pattern vars <- (destruct @context -> vars) + where Pattern = construct @context + +class IsPattern con a t where + construct :: t -> con a + destruct :: con a -> t + + +-- | Pattern synonyms for indices, which may be more convenient to use than +-- 'Data.Array.Accelerate.Lift.lift' and +-- 'Data.Array.Accelerate.Lift.unlift'. +-- +pattern Z_ :: Exp DIM0 +pattern Z_ = Pattern Z +{-# COMPLETE Z_ #-} + +infixl 3 ::. +pattern (::.) :: (Elt a, Elt b) => Exp a -> Exp b -> Exp (a :. b) +pattern a ::. b = Pattern (a :. b) +{-# COMPLETE (::.) #-} + +pattern Ix :: (Elt a, Elt b) => Exp a -> Exp b -> Exp (a :. b) +pattern a `Ix` b = a ::. b +{-# COMPLETE Ix #-} + +-- IsPattern instances for Shape nil and cons +-- +instance IsPattern Exp Z Z where + construct _ = constant Z + destruct _ = Z + +instance (Elt a, Elt b) => IsPattern Exp (a :. b) (Exp a :. Exp b) where + construct (Exp a :. Exp b) = Exp $ SmartExp $ Pair a b + destruct (Exp t) = Exp (SmartExp $ Prj PairIdxLeft t) :. Exp (SmartExp $ Prj PairIdxRight t) + +-- IsPattern instances for up to 16-tuples (Acc and Exp). TH takes care of the +-- (unremarkable) boilerplate for us, but since the implementation is a little +-- tricky it is debatable whether or not this is a good idea... +-- +$(runQ $ do + let + -- Generate instance declarations for IsPattern of the form: + -- instance (Elt x, EltRepr x ~ (((), EltRepr a), EltRepr b), Elt a, Elt b,) => IsPattern Exp x (Exp a, Exp b) + mkIsPattern :: Name -> TypeQ -> TypeQ -> ExpQ -> ExpQ -> ExpQ -> ExpQ -> Int -> Q [Dec] + mkIsPattern con cst repr smart prj nil pair n = do + a <- newName "a" + let + -- Type variables for the elements + xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] + -- Last argument to `IsPattern`, eg (Exp, a, Exp b) in the example + b = foldl (\ts t -> appT ts (appT (conT con) (varT t))) (tupleT n) xs + -- Representation as snoc-list of pairs, eg (((), EltRepr a), EltRepr b) + snoc = foldl (\sn t -> [t| ($sn, $(appT repr $ varT t)) |]) [t| () |] xs + -- Constraints for the type class, consisting of Elt constraints on all type variables, + -- and an equality constraint on the representation type of `a` and the snoc representation `snoc`. + contexts = appT cst [t| $(varT a) |] + : [t| $repr $(varT a) ~ $snoc |] + : map (\t -> appT cst (varT t)) xs + -- Store all constraints in a tuple + context = foldl (\ts t -> appT ts t) (tupleT $ length contexts) contexts + -- + get x 0 = [| $(conE con) ($smart ($prj PairIdxRight $x)) |] + get x i = get [| $smart ($prj PairIdxLeft $x) |] (i-1) + -- + _x <- newName "_x" + [d| instance $context => IsPattern $(conT con) $(varT a) $b where + construct $(tupP (map (conP con . return . varP) xs)) = + $(conE con) $(foldl (\vs v -> appE smart (appE (appE pair vs) (varE v))) (appE smart nil) xs) + destruct $(conP con [varP _x]) = + $(tupE (map (get (varE _x)) [(n-1), (n-2) .. 0])) + |] + + mkExpPattern = mkIsPattern (mkName "Exp") [t| Elt |] [t| EltRepr |] [| SmartExp |] [| Prj |] [| Nil |] [| Pair |] + mkAccPattern = mkIsPattern (mkName "Acc") [t| Arrays |] [t| ArrRepr |] [| SmartAcc |] [| Aprj |] [| Anil |] [| Apair |] + -- + es <- mapM mkExpPattern [0..16] + as <- mapM mkAccPattern [0..16] + return $ concat (es ++ as) + ) + +-- | Specialised pattern synonyms for tuples, which may be more convenient to +-- use than 'Data.Array.Accelerate.Lift.lift' and +-- 'Data.Array.Accelerate.Lift.unlift'. For example, to construct a pair: +-- +-- > let a = 4 :: Exp Int +-- > let b = 2 :: Exp Float +-- > let c = T2 a b -- :: Exp (Int, Float); equivalent to 'lift (a,b)' +-- +-- Similarly they can be used to destruct values: +-- +-- > let T2 x y = c -- x :: Exp Int, y :: Exp Float; equivalent to 'let (x,y) = unlift c' +-- +-- These pattern synonyms can be used for both 'Exp' and 'Acc' terms. +-- +-- Similarly, we have patterns for constructing and destructing indices of +-- a given dimensionality: +-- +-- > let ix = Ix 2 3 -- :: Exp DIM2 +-- > let I2 y x = ix -- y :: Exp Int, x :: Exp Int +-- +$(runQ $ do + let + mkT :: Int -> Q [Dec] + mkT n = + let xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] + ts = map varT xs + name = mkName ('T':show n) + con = varT (mkName "con") + ty1 = tupT ts + ty2 = tupT (map (con `appT`) ts) + sig = foldr (\t r -> [t| $con $t -> $r |]) (appT con ty1) ts + in + sequence + [ patSynSigD name [t| IsPattern $con $ty1 $ty2 => $sig |] + , patSynD name (prefixPatSyn xs) implBidir [p| Pattern $(tupP (map varP xs)) |] + , pragCompleteD [name] (Just ''Acc) + , pragCompleteD [name] (Just ''Exp) + ] + + mkI :: Int -> Q [Dec] + mkI n = + let xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] + ts = map varT xs + name = mkName ('I':show n) + ix = mkName "Ix" + cst = tupT (map (\t -> [t| Elt $t |]) ts) + dim = foldl (\h t -> [t| $h :. $t |]) [t| Z |] ts + sig = foldr (\t r -> [t| Exp $t -> $r |]) [t| Exp $dim |] ts + in + sequence + [ patSynSigD name [t| $cst => $sig |] + , patSynD name (prefixPatSyn xs) implBidir (foldl (\ps p -> infixP ps ix (varP p)) [p| Z_ |] xs) + , pragCompleteD [name] Nothing + ] + -- + ts <- mapM mkT [2..16] + is <- mapM mkI [0..9] + return $ concat (ts ++ is) + ) + +-- Newtype to make difference between T and P instances clear +newtype VecPattern a = VecPattern a + +instance VecElt a => IsPattern Exp (Vec 2 a) (VecPattern (Exp a, Exp a)) where + construct (VecPattern as) = Exp $ SmartExp $ VecPack r tup + where + r = vecR2 $ singleType @(EltRepr a) + Exp tup = construct as :: Exp (a, a) + destruct e = VecPattern $ destruct e' + where + e' :: Exp (a, a) + e' = Exp $ SmartExp $ VecUnpack r $ unExp e + r = vecR2 $ singleType @(EltRepr a) + +instance VecElt a => IsPattern Exp (Vec 3 a) (VecPattern (Exp a, Exp a, Exp a)) where + construct (VecPattern as) = Exp $ SmartExp $ VecPack r tup + where + r = vecR3 $ singleType @(EltRepr a) + Exp tup = construct as :: Exp (a, a, a) + destruct e = VecPattern $ destruct e' + where + e' :: Exp (a, a, a) + e' = Exp $ SmartExp $ VecUnpack r $ unExp e + r = vecR3 $ singleType @(EltRepr a) + +instance VecElt a => IsPattern Exp (Vec 4 a) (VecPattern (Exp a, Exp a, Exp a, Exp a)) where + construct (VecPattern as) = Exp $ SmartExp $ VecPack r tup + where + r = vecR4 $ singleType @(EltRepr a) + Exp tup = construct as :: Exp (a, a, a, a) + destruct e = VecPattern $ destruct e' + where + e' :: Exp (a, a, a, a) + e' = Exp $ SmartExp $ VecUnpack r $ unExp e + r = vecR4 $ singleType @(EltRepr a) + +instance VecElt a => IsPattern Exp (Vec 8 a) (VecPattern (Exp a, Exp a, Exp a, Exp a, Exp a, Exp a, Exp a, Exp a)) where + construct (VecPattern as) = Exp $ SmartExp $ VecPack r tup + where + r = vecR8 $ singleType @(EltRepr a) + Exp tup = construct as :: Exp (a, a, a, a, a, a, a, a) + destruct e = VecPattern $ destruct e' + where + e' :: Exp (a, a, a, a, a, a, a, a) + e' = Exp $ SmartExp $ VecUnpack r $ unExp e + r = vecR8 $ singleType @(EltRepr a) + +instance VecElt a => IsPattern Exp (Vec 16 a) (VecPattern (Exp a, Exp a, Exp a, Exp a, Exp a, Exp a, Exp a, Exp a, Exp a, Exp a, Exp a, Exp a, Exp a, Exp a, Exp a, Exp a)) where + construct (VecPattern as) = Exp $ SmartExp $ VecPack r tup + where + r = vecR16 $ singleType @(EltRepr a) + Exp tup = construct as :: Exp (a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a) + destruct e = VecPattern $ destruct e' + where + e' :: Exp (a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a) + e' = Exp $ SmartExp $ VecUnpack r $ unExp e + r = vecR16 $ singleType @(EltRepr a) + +pattern V2_ :: VecElt a => Exp a -> Exp a -> Exp (Vec 2 a) +pattern V2_ a b = Pattern (VecPattern (a, b)) +{-# COMPLETE V2_ #-} + +pattern V3_ :: VecElt a => Exp a -> Exp a -> Exp a -> Exp (Vec 3 a) +pattern V3_ a b c = Pattern (VecPattern (a, b, c)) +{-# COMPLETE V3_ #-} + +pattern V4_ :: VecElt a => Exp a -> Exp a -> Exp a -> Exp a -> Exp (Vec 4 a) +pattern V4_ a b c d = Pattern (VecPattern (a, b, c, d)) +{-# COMPLETE V4_ #-} + +pattern V8_ :: VecElt a => Exp a -> Exp a -> Exp a -> Exp a -> Exp a -> Exp a -> Exp a -> Exp a -> Exp (Vec 8 a) +pattern V8_ a b c d e f g h = Pattern (VecPattern (a, b, c, d, e, f, g, h)) +{-# COMPLETE V8_ #-} + +pattern V16_ :: VecElt a + => Exp a -> Exp a -> Exp a -> Exp a -> Exp a -> Exp a -> Exp a -> Exp a -> + Exp a -> Exp a -> Exp a -> Exp a -> Exp a -> Exp a -> Exp a -> Exp a -> Exp (Vec 16 a) +pattern V16_ a b c d e f g h + i j k l m n o p = Pattern (VecPattern (a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p)) +{-# COMPLETE V16_ #-} diff --git a/src/Data/Array/Accelerate/Prelude.hs b/src/Data/Array/Accelerate/Prelude.hs index dbc39e2c8..5cb57836e 100644 --- a/src/Data/Array/Accelerate/Prelude.hs +++ b/src/Data/Array/Accelerate/Prelude.hs @@ -707,7 +707,7 @@ fold1All f arr = fold1 f (flatten arr) -- 40, 170, 0, 138] -- foldSeg - :: forall sh e i. (Shape sh, Elt e, Elt i, IsIntegral i) + :: forall sh e i. (Shape sh, Elt e, Elt i, i ~ EltRepr i, IsIntegral i) => (Exp e -> Exp e -> Exp e) -> Exp e -> Acc (Array (sh:.Int) e) @@ -734,15 +734,17 @@ foldSeg f z arr seg = foldSeg' f z arr (scanl plus zero seg) -- descriptor species the length of each of the logical sub-arrays. -- fold1Seg - :: forall sh e i. (Shape sh, Elt e, Elt i, IsIntegral i) + :: forall sh e i. (Shape sh, Elt e, Elt i, i ~ EltRepr i, IsIntegral i) => (Exp e -> Exp e -> Exp e) -> Acc (Array (sh:.Int) e) -> Acc (Segments i) -> Acc (Array (sh:.Int) e) fold1Seg f arr seg = fold1Seg' f arr (scanl plus zero seg) where + plus :: Exp i -> Exp i -> Exp i + zero :: Exp i (plus, zero) = - case integralType @i of + case integralType @(EltRepr i) of TypeInt{} -> ((+), 0) TypeInt8{} -> ((+), 0) TypeInt16{} -> ((+), 0) diff --git a/src/Data/Array/Accelerate/Pretty.hs b/src/Data/Array/Accelerate/Pretty.hs index fc549c4c0..b5b32a97b 100644 --- a/src/Data/Array/Accelerate/Pretty.hs +++ b/src/Data/Array/Accelerate/Pretty.hs @@ -24,8 +24,8 @@ module Data.Array.Accelerate.Pretty ( PrettyAcc, ExtractAcc, prettyPreOpenAcc, prettyPreOpenAfun, - prettyPreOpenExp, - prettyPreOpenFun, + prettyOpenExp, + prettyOpenFun, -- ** Graphviz Graph, @@ -101,17 +101,11 @@ instance PrettyEnv aenv => Show (DelayedOpenAcc aenv a) where instance PrettyEnv aenv => Show (DelayedOpenAfun aenv f) where show = renderForTerminal . prettyPreOpenAfun prettyDelayedOpenAcc (prettyEnv (pretty 'a')) -instance (PrettyEnv env, PrettyEnv aenv) => Show (PreOpenExp OpenAcc env aenv e) where - show = renderForTerminal . prettyPreOpenExp context0 prettyOpenAcc extractOpenAcc (prettyEnv (pretty 'x')) (prettyEnv (pretty 'a')) +instance (PrettyEnv env, PrettyEnv aenv) => Show (OpenExp env aenv e) where + show = renderForTerminal . prettyOpenExp context0 (prettyEnv (pretty 'x')) (prettyEnv (pretty 'a')) -instance (PrettyEnv env, PrettyEnv aenv) => Show (PreOpenExp DelayedOpenAcc env aenv e) where - show = renderForTerminal . prettyPreOpenExp context0 prettyDelayedOpenAcc extractDelayedOpenAcc (prettyEnv (pretty 'x')) (prettyEnv (pretty 'a')) - -instance (PrettyEnv env, PrettyEnv aenv) => Show (PreOpenFun OpenAcc env aenv e) where - show = renderForTerminal . prettyPreOpenFun prettyOpenAcc extractOpenAcc (prettyEnv (pretty 'x')) (prettyEnv (pretty 'a')) - -instance (PrettyEnv env, PrettyEnv aenv) => Show (PreOpenFun DelayedOpenAcc env aenv e) where - show = renderForTerminal . prettyPreOpenFun prettyDelayedOpenAcc extractDelayedOpenAcc (prettyEnv (pretty 'x')) (prettyEnv (pretty 'a')) +instance (PrettyEnv env, PrettyEnv aenv) => Show (OpenFun env aenv e) where + show = renderForTerminal . prettyOpenFun (prettyEnv (pretty 'x')) (prettyEnv (pretty 'a')) -- Internals @@ -158,12 +152,12 @@ extractOpenAcc (OpenAcc pacc) = pacc prettyDelayedOpenAcc :: PrettyAcc DelayedOpenAcc prettyDelayedOpenAcc context aenv (Manifest pacc) = prettyPreOpenAcc context prettyDelayedOpenAcc extractDelayedOpenAcc aenv pacc -prettyDelayedOpenAcc _ aenv (Delayed sh f _) +prettyDelayedOpenAcc _ aenv (Delayed _ sh f _) = parens $ nest shiftwidth $ sep [ delayed "delayed" - , prettyPreOpenExp app prettyDelayedOpenAcc extractDelayedOpenAcc Empty aenv sh - , parens $ prettyPreOpenFun prettyDelayedOpenAcc extractDelayedOpenAcc Empty aenv f + , prettyOpenExp app Empty aenv sh + , parens $ prettyOpenFun Empty aenv f ] extractDelayedOpenAcc :: DelayedOpenAcc aenv a -> PreOpenAcc DelayedOpenAcc aenv a diff --git a/src/Data/Array/Accelerate/Pretty/Graphviz.hs b/src/Data/Array/Accelerate/Pretty/Graphviz.hs index 8a1123651..3533a6472 100644 --- a/src/Data/Array/Accelerate/Pretty/Graphviz.hs +++ b/src/Data/Array/Accelerate/Pretty/Graphviz.hs @@ -44,8 +44,9 @@ import qualified Data.HashSet as Set import qualified Data.Sequence as Seq -- friends -import Data.Array.Accelerate.AST ( PreOpenAcc(..), PreOpenAfun(..), PreOpenFun(..), PreOpenExp(..), PreBoundary(..), LeftHandSide(..), ArrayVar(..), Idx(..) ) -import Data.Array.Accelerate.Array.Sugar ( Array, Elt, Tuple(..), ArraysR(..), toElt, strForeign ) +import Data.Array.Accelerate.AST hiding ( Val(..), prj ) +import Data.Array.Accelerate.Array.Representation +import Data.Array.Accelerate.Array.Sugar ( strForeign ) import Data.Array.Accelerate.Error import Data.Array.Accelerate.Pretty.Graphviz.Monad import Data.Array.Accelerate.Pretty.Graphviz.Type @@ -195,7 +196,7 @@ prettyDelayedOpenAcc detail ctx aenv atop@(Manifest pacc) = Avar ix -> pnode (avar ix) Alet lhs bnd body -> do bnd'@(PNode ident _ _) <- prettyDelayedOpenAcc detail context0 aenv bnd - (aenv1, a) <- prettyLetLeftHandSide ident aenv lhs + (aenv1, a) <- prettyLetALeftHandSide ident aenv lhs _ <- mkNode bnd' (Just a) body' <- prettyDelayedOpenAcc detail context0 aenv1 body return body' @@ -210,7 +211,7 @@ prettyDelayedOpenAcc detail ctx aenv atop@(Manifest pacc) = deps = (vt, Just "T") : (ve, Just "F") : map (,port) vs return $ PNode ident doc deps - Apply afun acc -> apply <$> prettyDelayedAfun detail aenv afun + Apply _ afun acc -> apply <$> prettyDelayedAfun detail aenv afun <*> prettyDelayedOpenAcc detail ctx aenv acc Awhile p f x -> do @@ -227,19 +228,19 @@ prettyDelayedOpenAcc detail ctx aenv atop@(Manifest pacc) = Anil -> "()" .$ [] - Use arr -> "use" .$ [ return $ PDoc (prettyArray arr) [] ] - Unit e -> "unit" .$ [ ppE e ] - Generate sh f -> "generate" .$ [ ppE sh, ppF f ] - Transform sh ix f xs -> "transform" .$ [ ppE sh, ppF ix, ppF f, ppA xs ] - Reshape sh xs -> "reshape" .$ [ ppE sh, ppA xs ] + Use repr arr -> "use" .$ [ return $ PDoc (prettyArray repr arr) [] ] + Unit _ e -> "unit" .$ [ ppE e ] + Generate _ sh f -> "generate" .$ [ ppE sh, ppF f ] + Transform _ sh ix f xs -> "transform" .$ [ ppE sh, ppF ix, ppF f, ppA xs ] + Reshape _ sh xs -> "reshape" .$ [ ppE sh, ppA xs ] Replicate _ty ix xs -> "replicate" .$ [ ppE ix, ppA xs ] Slice _ty xs ix -> "slice" .$ [ ppA xs, ppE ix ] - Map f xs -> "map" .$ [ ppF f, ppA xs ] - ZipWith f xs ys -> "zipWith" .$ [ ppF f, ppA xs, ppA ys ] + Map _ f xs -> "map" .$ [ ppF f, ppA xs ] + ZipWith _ f xs ys -> "zipWith" .$ [ ppF f, ppA xs, ppA ys ] Fold f e xs -> "fold" .$ [ ppF f, ppE e, ppA xs ] Fold1 f xs -> "fold1" .$ [ ppF f, ppA xs ] - FoldSeg f e xs ys -> "foldSeg" .$ [ ppF f, ppE e, ppA xs, ppA ys ] - Fold1Seg f xs ys -> "fold1Seg" .$ [ ppF f, ppA xs, ppA ys ] + FoldSeg _ f e xs ys -> "foldSeg" .$ [ ppF f, ppE e, ppA xs, ppA ys ] + Fold1Seg _ f xs ys -> "fold1Seg" .$ [ ppF f, ppA xs, ppA ys ] Scanl f e xs -> "scanl" .$ [ ppF f, ppE e, ppA xs ] Scanl' f e xs -> "scanl'" .$ [ ppF f, ppE e, ppA xs ] Scanl1 f xs -> "scanl1" .$ [ ppF f, ppA xs ] @@ -247,11 +248,12 @@ prettyDelayedOpenAcc detail ctx aenv atop@(Manifest pacc) = Scanr' f e xs -> "scanr'" .$ [ ppF f, ppE e, ppA xs ] Scanr1 f xs -> "scanr1" .$ [ ppF f, ppA xs ] Permute f dfts p xs -> "permute" .$ [ ppF f, ppA dfts, ppF p, ppA xs ] - Backpermute sh p xs -> "backpermute" .$ [ ppE sh, ppF p, ppA xs ] - Stencil sten bndy xs -> "stencil" .$ [ ppF sten, ppB bndy, ppA xs ] - Stencil2 sten bndy1 acc1 bndy2 acc2 - -> "stencil2" .$ [ ppF sten, ppB bndy1, ppA acc1, ppB bndy2, ppA acc2 ] - Aforeign ff _afun xs -> "aforeign" .$ [ return (PDoc (pretty (strForeign ff)) []), {- ppAf afun, -} ppA xs ] + Backpermute _ sh p xs -> "backpermute" .$ [ ppE sh, ppF p, ppA xs ] + Stencil s _ sten bndy xs + -> "stencil" .$ [ ppF sten, ppB (stencilElt s) bndy, ppA xs ] + Stencil2 s1 s2 _ sten bndy1 acc1 bndy2 acc2 + -> "stencil2" .$ [ ppF sten, ppB (stencilElt s1) bndy1, ppA acc1, ppB (stencilElt s2) bndy2, ppA acc2 ] + Aforeign _ ff _afun xs -> "aforeign" .$ [ return (PDoc (pretty (strForeign ff)) []), {- ppAf afun, -} ppA xs ] -- Collect{} -> error "Collect" where @@ -278,21 +280,17 @@ prettyDelayedOpenAcc detail ctx aenv atop@(Manifest pacc) = -- Free variables -- - fvA :: FVAcc DelayedOpenAcc - fvA env (Manifest (Avar (ArrayVar ix))) = [ Vertex (fst $ aprj ix env) Nothing ] - fvA _ _ = $internalError "graphviz" "expected array variable" + fvF :: Fun aenv t -> [Vertex] + fvF = fvOpenFun Empty aenv - fvF :: DelayedFun aenv t -> [Vertex] - fvF = fvPreOpenFun fvA Empty aenv - - fvE :: DelayedExp aenv t -> [Vertex] - fvE = fvPreOpenExp fvA Empty aenv + fvE :: Exp aenv t -> [Vertex] + fvE = fvOpenExp Empty aenv -- Pretty-printing -- avar :: ArrayVar aenv t -> PDoc - avar (ArrayVar ix) = let (ident, v) = aprj ix aenv - in PDoc (pretty v) [Vertex ident Nothing] + avar (Var _ ix) = let (ident, v) = aprj ix aenv + in PDoc (pretty v) [Vertex ident Nothing] aenv' :: Val aenv aenv' = avalToVal aenv @@ -306,33 +304,35 @@ prettyDelayedOpenAcc detail ctx aenv atop@(Manifest pacc) = v <- mkLabel ident <- mkNode acc' (Just v) return $ PDoc (pretty v) [Vertex ident Nothing] - ppA (Delayed sh f _) - | Shape a <- sh -- identical shape - , Just Refl <- match f (Lam (Body (Index a (Var ZeroIdx)))) -- identity function - = ppA a - ppA (Delayed sh f _) = do + ppA (Delayed _ sh f _) + | Shape a <- sh -- identical shape + , Just b <- isIdentityIndexing f -- function is `\ix -> b ! ix` + , Just Refl <- match a b -- function thus is `\ix -> a ! ix` + = ppA $ Manifest $ Avar a + ppA (Delayed _ sh f _) = do PDoc d v <- "Delayed" `fmt` [ ppE sh, ppF f ] return $ PDoc (parens d) v - ppB :: forall sh e. Elt e - => PreBoundary DelayedOpenAcc aenv (Array sh e) + ppB :: forall sh e. + TupleType e + -> Boundary aenv (Array sh e) -> Dot PDoc - ppB Clamp = return (PDoc "clamp" []) - ppB Mirror = return (PDoc "mirror" []) - ppB Wrap = return (PDoc "wrap" []) - ppB (Constant e) = return (PDoc (prettyConst (toElt e :: e)) []) - ppB (Function f) = ppF f + ppB _ Clamp = return (PDoc "clamp" []) + ppB _ Mirror = return (PDoc "mirror" []) + ppB _ Wrap = return (PDoc "wrap" []) + ppB tp (Constant e) = return (PDoc (prettyConst tp e) []) + ppB _ (Function f) = ppF f - ppF :: DelayedFun aenv t -> Dot PDoc - ppF = return . uncurry PDoc . (parens . prettyDelayedFun aenv' &&& fvF) + ppF :: Fun aenv t -> Dot PDoc + ppF = return . uncurry PDoc . (parens . prettyFun aenv' &&& fvF) - ppE :: DelayedExp aenv t -> Dot PDoc - ppE = return . uncurry PDoc . (prettyDelayedExp aenv' &&& fvE) + ppE :: Exp aenv t -> Dot PDoc + ppE = return . uncurry PDoc . (prettyExp aenv' &&& fvE) lift :: DelayedOpenAcc aenv a -> Dot Vertex - lift Delayed{} = $internalError "prettyDelayedOpenAcc" "expected manifest array" - lift (Manifest (Avar (ArrayVar ix))) = return $ Vertex (fst (aprj ix aenv)) Nothing - lift acc = do + lift Delayed{} = $internalError "prettyDelayedOpenAcc" "expected manifest array" + lift (Manifest (Avar (Var _ ix))) = return $ Vertex (fst (aprj ix aenv)) Nothing + lift acc = do acc' <- prettyDelayedOpenAcc detail context0 aenv acc ident <- mkNode acc' Nothing return $ Vertex ident Nothing @@ -381,46 +381,46 @@ prettyDelayedAfun detail aenv afun = do go :: Aval aenv' -> DelayedOpenAfun aenv' a' -> Dot Graph go aenv' (Abody b) = graphDelayedOpenAcc detail aenv' b go aenv' (Alam lhs f) = do - aenv'' <- prettyLambdaLeftHandSide aenv' lhs + aenv'' <- prettyLambdaALeftHandSide aenv' lhs go aenv'' f collect :: Aval aenv' -> HashSet NodeId collect Aempty = Set.empty collect (Apush a i _) = Set.insert i (collect a) -prettyLetLeftHandSide +prettyLetALeftHandSide :: forall repr aenv aenv'. NodeId -> Aval aenv - -> LeftHandSide repr aenv aenv' + -> ALeftHandSide repr aenv aenv' -> Dot (Aval aenv', Label) -prettyLetLeftHandSide _ aenv (LeftHandSideWildcard repr) = return (aenv, doc) +prettyLetALeftHandSide _ aenv (LeftHandSideWildcard repr) = return (aenv, doc) where doc = case repr of - ArraysRunit -> "()" - _ -> "_" -prettyLetLeftHandSide ident aenv LeftHandSideArray = do + TupRunit -> "()" + _ -> "_" +prettyLetALeftHandSide ident aenv (LeftHandSideSingle _) = do a <- mkLabel return (Apush aenv ident a, a) -prettyLetLeftHandSide ident aenv (LeftHandSidePair lhs1 lhs2) = do - (aenv1, d1) <- prettyLetLeftHandSide ident aenv lhs1 - (aenv2, d2) <- prettyLetLeftHandSide ident aenv1 lhs2 +prettyLetALeftHandSide ident aenv (LeftHandSidePair lhs1 lhs2) = do + (aenv1, d1) <- prettyLetALeftHandSide ident aenv lhs1 + (aenv2, d2) <- prettyLetALeftHandSide ident aenv1 lhs2 return (aenv2, "(" <> d1 <> ", " <> d2 <> ")") -prettyLambdaLeftHandSide +prettyLambdaALeftHandSide :: forall repr aenv aenv'. Aval aenv - -> LeftHandSide repr aenv aenv' + -> ALeftHandSide repr aenv aenv' -> Dot (Aval aenv') -prettyLambdaLeftHandSide aenv (LeftHandSideWildcard _) = return aenv -prettyLambdaLeftHandSide aenv lhs@LeftHandSideArray = do +prettyLambdaALeftHandSide aenv (LeftHandSideWildcard _) = return aenv +prettyLambdaALeftHandSide aenv lhs@(LeftHandSideSingle _) = do a <- mkLabel ident <- mkNodeId lhs _ <- mkNode (PNode ident (Leaf (Nothing, pretty a)) []) Nothing return $ Apush aenv ident a -prettyLambdaLeftHandSide aenv (LeftHandSidePair lhs1 lhs2) = do - aenv1 <- prettyLambdaLeftHandSide aenv lhs1 - prettyLambdaLeftHandSide aenv1 lhs2 +prettyLambdaALeftHandSide aenv (LeftHandSidePair lhs1 lhs2) = do + aenv1 <- prettyLambdaALeftHandSide aenv lhs1 + prettyLambdaALeftHandSide aenv1 lhs2 -- Display array tuples. This is a little tricky... -- @@ -479,49 +479,6 @@ replant pnode@(PNode ident tree _) = -- nodes. -- -prettyDelayedFun :: Val aenv -> DelayedFun aenv f -> Adoc -prettyDelayedFun = prettyDelayedOpenFun Empty - -prettyDelayedExp :: Val aenv -> DelayedExp aenv t -> Adoc -prettyDelayedExp = prettyDelayedOpenExp context0 Empty - - -prettyDelayedOpenFun - :: forall env aenv f. - Val env - -> Val aenv - -> DelayedOpenFun env aenv f - -> Adoc -prettyDelayedOpenFun env0 aenv = next "\\\\" env0 - where - -- graphviz will silently not print a label containing the string "->", - -- so instead we use the special token "&rarr" for a short right arrow. - -- - next :: Adoc -> Val env' -> PreOpenFun DelayedOpenAcc env' aenv f' -> Adoc - next vs env (Body body) = - nest shiftwidth (sep [ vs <> "→" - , prettyDelayedOpenExp context0 env aenv body ]) - next vs env (Lam lam) = - let x = pretty 'x' <> pretty (sizeEnv env) - in next (vs <> x <> space) (env `Push` x) lam - -prettyDelayedOpenExp - :: Context - -> Val env - -> Val aenv - -> DelayedOpenExp env aenv t - -> Adoc -prettyDelayedOpenExp context = prettyPreOpenExp context pp ex - where - pp :: PrettyAcc DelayedOpenAcc - pp _ aenv (Manifest (Avar (ArrayVar ix))) = prj ix aenv - pp _ _ _ = $internalError "prettyDelayedOpenExp" "expected array variable" - - ex :: ExtractAcc DelayedOpenAcc - ex (Manifest pacc) = pacc - ex Delayed{} = $internalError "prettyDelayedOpenExp" "expected manifest array" - - -- Data dependencies -- ----------------- -- @@ -530,61 +487,55 @@ prettyDelayedOpenExp context = prettyPreOpenExp context pp ex -- nodes (vertices) into the current term. -- -type FVAcc acc = forall aenv a. Aval aenv -> acc aenv a -> [Vertex] +fvAvar :: Aval aenv -> ArrayVar aenv a -> [Vertex] +fvAvar env (Var _ ix) = [ Vertex (fst $ aprj ix env) Nothing ] -fvPreOpenFun - :: forall acc env aenv fun. - FVAcc acc - -> Val env +fvOpenFun + :: forall env aenv fun. + Val env -> Aval aenv - -> PreOpenFun acc env aenv fun + -> OpenFun env aenv fun -> [Vertex] -fvPreOpenFun fvA env aenv (Body b) = fvPreOpenExp fvA env aenv b -fvPreOpenFun fvA env aenv (Lam f) = fvPreOpenFun fvA (env `Push` (pretty 'x' <> pretty (sizeEnv env))) aenv f +fvOpenFun env aenv (Body b) = fvOpenExp env aenv b +fvOpenFun env aenv (Lam lhs f) = fvOpenFun env' aenv f + where + (env', _) = prettyELhs True env lhs -fvPreOpenExp - :: forall acc env aenv exp. - FVAcc acc - -> Val env +fvOpenExp + :: forall env aenv exp. + Val env -> Aval aenv - -> PreOpenExp acc env aenv exp + -> OpenExp env aenv exp -> [Vertex] -fvPreOpenExp fvA env aenv = fv +fvOpenExp env aenv = fv where - fvT :: Tuple (PreOpenExp acc env aenv) t -> [Vertex] - fvT NilTup = [] - fvT (SnocTup tup e) = concat [ fv e, fvT tup ] - - fvF :: PreOpenFun acc env aenv f -> [Vertex] - fvF = fvPreOpenFun fvA env aenv + fvF :: OpenFun env aenv f -> [Vertex] + fvF = fvOpenFun env aenv - fv :: PreOpenExp acc env aenv e -> [Vertex] - fv (Shape acc) = if cfgIncludeShape then fvA aenv acc else [] - fv (Index acc i) = concat [ fvA aenv acc, fv i ] - fv (LinearIndex acc i) = concat [ fvA aenv acc, fv i ] + fv :: OpenExp env aenv e -> [Vertex] + fv (Shape acc) = if cfgIncludeShape then fvAvar aenv acc else [] + fv (Index acc i) = concat [ fvAvar aenv acc, fv i ] + fv (LinearIndex acc i) = concat [ fvAvar aenv acc, fv i ] -- - fv (Let e1 e2) = concat [ fv e1, fvPreOpenExp fvA (env `Push` (pretty 'x' <> pretty (sizeEnv env))) aenv e2 ] - fv Var{} = [] - fv Undef = [] + fv (Let lhs e1 e2) = concat [ fv e1, fvOpenExp env' aenv e2 ] + where + (env', _) = prettyELhs False env lhs + fv Evar{} = [] + fv Undef{} = [] fv Const{} = [] fv PrimConst{} = [] fv (PrimApp _ x) = fv x - fv (Tuple tup) = fvT tup - fv (Prj _ e) = fv e - fv IndexNil = [] - fv IndexAny = [] - fv (IndexHead sh) = fv sh - fv (IndexTail sh) = fv sh - fv (IndexCons t h) = concat [ fv t, fv h ] + fv (Pair e1 e2) = concat [ fv e1, fv e2] + fv Nil = [] + fv (VecPack _ e) = fv e + fv (VecUnpack _ e) = fv e fv (IndexSlice _ slix sh) = concat [ fv slix, fv sh ] fv (IndexFull _ slix sh) = concat [ fv slix, fv sh ] - fv (ToIndex sh ix) = concat [ fv sh, fv ix ] - fv (FromIndex sh ix) = concat [ fv sh, fv ix ] - fv (Union sh1 sh2) = concat [ fv sh1, fv sh2 ] - fv (Intersect sh1 sh2) = concat [ fv sh1, fv sh2 ] - fv (ShapeSize sh) = fv sh + fv (ToIndex _ sh ix) = concat [ fv sh, fv ix ] + fv (FromIndex _ sh ix) = concat [ fv sh, fv ix ] + fv (ShapeSize _ sh) = fv sh fv Foreign{} = [] fv (Cond p t e) = concat [ fv p, fv t, fv e ] fv (While p f x) = concat [ fvF p, fvF f, fv x ] - fv (Coerce e) = fv e + fv (Coerce _ _ e) = fv e diff --git a/src/Data/Array/Accelerate/Pretty/Print.hs b/src/Data/Array/Accelerate/Pretty/Print.hs index 0aff9340d..f64bb347b 100644 --- a/src/Data/Array/Accelerate/Pretty/Print.hs +++ b/src/Data/Array/Accelerate/Pretty/Print.hs @@ -25,10 +25,12 @@ module Data.Array.Accelerate.Pretty.Print ( PrettyAcc, ExtractAcc, prettyPreOpenAcc, prettyPreOpenAfun, - prettyPreOpenExp, - prettyPreOpenFun, + prettyOpenExp, prettyExp, + prettyOpenFun, prettyFun, prettyArray, prettyConst, + prettyELhs, + prettyALhs, -- ** Internals Adoc, @@ -53,11 +55,11 @@ import Data.Char import Data.String import Data.Text.Prettyprint.Doc import Data.Text.Prettyprint.Doc.Render.Terminal -import Data.Typeable ( Typeable, typeOf, showsTypeRep ) import Prelude hiding ( exp ) import Data.Array.Accelerate.AST hiding ( Val(..), prj ) -import Data.Array.Accelerate.Array.Sugar +import Data.Array.Accelerate.Array.Sugar ( strForeign ) +import Data.Array.Accelerate.Array.Representation import Data.Array.Accelerate.Type @@ -112,8 +114,8 @@ prettyPreOpenAfun prettyAcc aenv0 = next (pretty '\\') aenv0 next :: Adoc -> Val aenv' -> PreOpenAfun acc aenv' f' -> Adoc next vs aenv (Abody body) = hang shiftwidth (sep [vs <> "->", prettyAcc context0 aenv body]) next vs aenv (Alam lhs lam) = - let (aenv', lhs') = prettyLHS aenv lhs - in next (vs <> lhs' <> space) aenv' lam + let (aenv', lhs') = prettyALhs True aenv lhs + in next (vs <> lhs' <> space) aenv' lam prettyPreOpenAcc :: forall acc aenv arrs. @@ -125,11 +127,11 @@ prettyPreOpenAcc -> Adoc prettyPreOpenAcc ctx prettyAcc extractAcc aenv pacc = case pacc of - Avar (ArrayVar idx) -> prj idx aenv + Avar (Var _ idx) -> prj idx aenv Alet{} -> prettyAlet ctx prettyAcc extractAcc aenv pacc Apair{} -> prettyAtuple prettyAcc extractAcc aenv pacc Anil -> "()" - Apply f a -> apply + Apply _ f a -> apply where op = Operator ">->" Infix L 1 apply = sep [ ppAF f, group (sep [opName op, ppA a]) ] @@ -147,21 +149,21 @@ prettyPreOpenAcc ctx prettyAcc extractAcc aenv pacc = , hang shiftwidth (sep [ then_, t' ]) , hang shiftwidth (sep [ else_, e' ]) ] - Aforeign ff _f a -> "aforeign" .$ [ pretty (strForeign ff), ppA a ] + Aforeign _ ff _ a -> "aforeign" .$ [ pretty (strForeign ff), ppA a ] Awhile p f a -> "awhile" .$ [ ppAF p, ppAF f, ppA a ] - Use arr -> "use" .$ [ prettyArray arr ] - Unit e -> "unit" .$ [ ppE e ] - Reshape sh a -> "reshape" .$ [ ppE sh, ppA a ] - Generate sh f -> "generate" .$ [ ppE sh, ppF f ] - Transform sh p f a -> "transform" .$ [ ppE sh, ppF p, ppF f, ppA a ] + Use repr arr -> "use" .$ [ prettyArray repr arr ] + Unit _ e -> "unit" .$ [ ppE e ] + Reshape _ sh a -> "reshape" .$ [ ppE sh, ppA a ] + Generate _ sh f -> "generate" .$ [ ppE sh, ppF f ] + Transform _ sh p f a -> "transform" .$ [ ppE sh, ppF p, ppF f, ppA a ] Replicate _ ix a -> "replicate" .$ [ ppE ix, ppA a ] Slice _ a ix -> "slice" .$ [ ppE ix, ppA a ] - Map f a -> "map" .$ [ ppF f, ppA a ] - ZipWith f a b -> "zipWith" .$ [ ppF f, ppA a, ppA b ] + Map _ f a -> "map" .$ [ ppF f, ppA a ] + ZipWith _ f a b -> "zipWith" .$ [ ppF f, ppA a, ppA b ] Fold f z a -> "fold" .$ [ ppF f, ppE z, ppA a ] Fold1 f a -> "fold1" .$ [ ppF f, ppA a ] - FoldSeg f z a s -> "foldSeg" .$ [ ppF f, ppE z, ppA a, ppA s ] - Fold1Seg f a s -> "fold1Seg" .$ [ ppF f, ppA a, ppA s ] + FoldSeg _ f z a s -> "foldSeg" .$ [ ppF f, ppE z, ppA a, ppA s ] + Fold1Seg _ f a s -> "fold1Seg" .$ [ ppF f, ppA a, ppA s ] Scanl f z a -> "scanl" .$ [ ppF f, ppE z, ppA a ] Scanl' f z a -> "scanl'" .$ [ ppF f, ppE z, ppA a ] Scanl1 f a -> "scanl1" .$ [ ppF f, ppA a ] @@ -169,9 +171,10 @@ prettyPreOpenAcc ctx prettyAcc extractAcc aenv pacc = Scanr' f z a -> "scanr'" .$ [ ppF f, ppE z, ppA a ] Scanr1 f a -> "scanr1" .$ [ ppF f, ppA a ] Permute f d p s -> "permute" .$ [ ppF f, ppA d, ppF p, ppA s ] - Backpermute sh f a -> "backpermute" .$ [ ppE sh, ppF f, ppA a ] - Stencil f b a -> "stencil" .$ [ ppF f, ppB b, ppA a ] - Stencil2 f b1 a1 b2 a2 -> "stencil2" .$ [ ppF f, ppB b1, ppA a1, ppB b2, ppA a2 ] + Backpermute _ sh f a -> "backpermute" .$ [ ppE sh, ppF f, ppA a ] + Stencil s _ f b a -> "stencil" .$ [ ppF f, ppB (stencilElt s) b, ppA a ] + Stencil2 s1 s2 _ f b1 a1 b2 a2 + -> "stencil2" .$ [ ppF f, ppB (stencilElt s1) b1, ppA a1, ppB (stencilElt s2) b2, ppA a2 ] where infixr 0 .$ f .$ xs @@ -184,20 +187,21 @@ prettyPreOpenAcc ctx prettyAcc extractAcc aenv pacc = ppAF :: PreOpenAfun acc aenv f -> Adoc ppAF = parens . prettyPreOpenAfun prettyAcc aenv - ppE :: PreExp acc aenv t -> Adoc - ppE = prettyPreOpenExp app prettyAcc extractAcc Empty aenv + ppE :: Exp aenv t -> Adoc + ppE = prettyOpenExp app Empty aenv - ppF :: PreFun acc aenv t -> Adoc - ppF = parens . prettyPreOpenFun prettyAcc extractAcc Empty aenv + ppF :: Fun aenv t -> Adoc + ppF = parens . prettyOpenFun Empty aenv - ppB :: forall sh e. Elt e - => PreBoundary acc aenv (Array sh e) + ppB :: forall sh e. + TupleType e + -> Boundary aenv (Array sh e) -> Adoc - ppB Clamp = "clamp" - ppB Mirror = "mirror" - ppB Wrap = "wrap" - ppB (Constant e) = prettyConst (toElt e :: e) - ppB (Function f) = ppF f + ppB _ Clamp = "clamp" + ppB _ Mirror = "mirror" + ppB _ Wrap = "wrap" + ppB tp (Constant e) = prettyConst tp e + ppB _ (Function f) = ppF f prettyAlet @@ -216,7 +220,7 @@ prettyAlet ctx prettyAcc extractAcc aenv0 collect aenv = \case Alet lhs a1 a2 -> - let (aenv', v) = prettyLHS aenv lhs + let (aenv', v) = prettyALhs False aenv lhs a1' = ppA aenv a1 bnd | isAlet a1 = nest shiftwidth (vsep [v <+> equals, a1']) | otherwise = v <+> align (equals <+> a1') @@ -250,54 +254,72 @@ prettyAtuple -> Val aenv -> PreOpenAcc acc aenv arrs -> Adoc -prettyAtuple prettyAcc extractAcc aenv0 - = align . wrap . collect aenv0 +prettyAtuple prettyAcc extractAcc aenv0 acc = case collect acc of + Just tup -> align $ "T" <> pretty (length tup) <+> sep tup + Nothing -> align $ ppPair acc where - wrap [x] = x - wrap xs = tupled xs - - collect :: Val aenv' -> PreOpenAcc acc aenv' a -> [Adoc] - collect aenv = - \case - Anil -> [] - Apair a1 a2 -> collect aenv (extractAcc a1) ++ [prettyAcc context0 aenv a2] - next -> [prettyPreOpenAcc context0 prettyAcc extractAcc aenv next] - -prettyLHS :: Val aenv -> LeftHandSide arrs aenv aenv' -> (Val aenv', Adoc) -prettyLHS aenv0 = fmap wrap . go aenv0 + ppPair :: PreOpenAcc acc aenv arrs' -> Adoc + ppPair (Apair a1 a2) = "(" <> ppPair (extractAcc a1) <> "," <+> prettyAcc context0 aenv0 a2 <> ")" + ppPair a = prettyPreOpenAcc context0 prettyAcc extractAcc aenv0 a + + collect :: PreOpenAcc acc aenv arrs' -> Maybe [Adoc] + collect Anil = Just [] + collect (Apair a1 a2) + | Just tup <- collect $ extractAcc a1 + = Just $ tup ++ [prettyAcc app aenv0 a2] + collect _ = Nothing + +-- TODO: Should we also print the types of the declared variables? And the types of wildcards? +prettyALhs :: Bool -> Val env -> LeftHandSide s arrs env env' -> (Val env', Adoc) +prettyALhs requiresParens = prettyLhs requiresParens 'a' + +prettyELhs :: Bool -> Val env -> LeftHandSide s arrs env env' -> (Val env', Adoc) +prettyELhs requiresParens = prettyLhs requiresParens 'x' + +prettyLhs :: forall s env env' arrs. Bool -> Char -> Val env -> LeftHandSide s arrs env env' -> (Val env', Adoc) +prettyLhs requiresParens x env0 lhs = case collect lhs of + Just (env1, tup) -> (env1, parensIf requiresParens (pretty 'T' <> pretty (length tup) <+> sep tup)) + Nothing -> ppPair lhs where - wrap [x] = x - wrap xs = tupled xs - - go :: Val aenv -> LeftHandSide arrs aenv aenv' -> (Val aenv', [Adoc]) - go aenv (LeftHandSideWildcard ArraysRunit) = (aenv, []) - go aenv (LeftHandSideWildcard _) = (aenv, ["_"]) - go aenv LeftHandSideArray = (aenv `Push` v, [v]) + ppPair :: LeftHandSide s arrs' env env'' -> (Val env'', Adoc) + ppPair (LeftHandSideWildcard TupRunit) = (env0, "()") + ppPair (LeftHandSideWildcard _) = (env0, "_") + ppPair (LeftHandSideSingle _) = (env0 `Push` v, v) where - v = pretty 'a' <> pretty (sizeEnv aenv) - go aenv (LeftHandSidePair a b) = (aenv2, doc1 ++ [doc2]) + v = pretty x <> pretty (sizeEnv env0) + ppPair (LeftHandSidePair a b) = (env2, tupled [doc1, doc2]) where - (aenv1, doc1) = go aenv a - (aenv2, doc2) = prettyLHS aenv1 b + (env1, doc1) = ppPair a + (env2, doc2) = prettyLhs False x env1 b + + collect :: LeftHandSide s arrs' env env'' -> Maybe (Val env'', [Adoc]) + collect (LeftHandSidePair l1 l2) + | Just (env1, tup ) <- collect l1 + , (env2, doc2) <- prettyLhs True x env1 l2 = Just (env2, tup ++ [doc2]) + collect (LeftHandSideWildcard TupRunit) = Just (env0, []) + collect _ = Nothing -prettyArray :: (Shape sh, Elt e) => Array sh e -> Adoc -prettyArray = parens . viaShow +prettyArray :: ArrayR (Array sh e) -> Array sh e -> Adoc +prettyArray repr = parens . fromString . showArray repr -- Scalar expressions -- ------------------ +prettyFun :: Val aenv -> Fun aenv f -> Adoc +prettyFun = prettyOpenFun Empty -prettyPreOpenFun - :: forall acc env aenv f. - PrettyAcc acc - -> ExtractAcc acc - -> Val env +prettyExp :: Val aenv -> Exp aenv t -> Adoc +prettyExp = prettyOpenExp context0 Empty + +prettyOpenFun + :: forall env aenv f. + Val env -> Val aenv - -> PreOpenFun acc env aenv f + -> OpenFun env aenv f -> Adoc -prettyPreOpenFun prettyAcc extractAcc env0 aenv = next (pretty '\\') env0 +prettyOpenFun env0 aenv = next (pretty '\\') env0 where - next :: Adoc -> Val env' -> PreOpenFun acc env' aenv f' -> Adoc + next :: Adoc -> Val env' -> OpenFun env' aenv f' -> Adoc next vs env (Body body) -- PrimApp f x <- body -- , op <- primOperator f @@ -308,35 +330,35 @@ prettyPreOpenFun prettyAcc extractAcc env0 aenv = next (pretty '\\') env0 -- = opName op -- surrounding context will add parens -- = hang shiftwidth (sep [ vs <> "->" - , prettyPreOpenExp context0 prettyAcc extractAcc env aenv body]) - next vs env (Lam lam) = - let x = pretty 'x' <> pretty (sizeEnv env) - in next (vs <> x <> space) (env `Push` x) lam + , prettyOpenExp context0 env aenv body]) + next vs env (Lam lhs lam) = + let (env', lhs') = prettyELhs True env lhs + in next (vs <> lhs' <> space) env' lam -prettyPreOpenExp - :: forall acc env aenv t. +prettyOpenExp + :: forall env aenv t. Context - -> PrettyAcc acc - -> ExtractAcc acc -> Val env -> Val aenv - -> PreOpenExp acc env aenv t + -> OpenExp env aenv t -> Adoc -prettyPreOpenExp ctx prettyAcc extractAcc env aenv exp = +prettyOpenExp ctx env aenv exp = case exp of - Var idx -> prj idx env - Let{} -> prettyLet ctx prettyAcc extractAcc env aenv exp + Evar (Var _ idx) -> prj idx env + Let{} -> prettyLet ctx env aenv exp PrimApp f x - | Tuple (NilTup `SnocTup` a `SnocTup` b) <- x -> ppF2 op (ppE a) (ppE b) - | otherwise -> ppF1 op' (ppE x) + | a `Pair` b <- x -> ppF2 op (ppE a) (ppE b) + | otherwise -> ppF1 op' (ppE x) where op = primOperator f op' = isInfix op ? (Operator (parens (opName op)) App L 10, op) -- PrimConst c -> prettyPrimConst c - Const c -> prettyConst (toElt c :: t) - Tuple t -> prettyTuple (eltType @t) prettyAcc extractAcc env aenv t - Prj tix e -> ppF2 (Operator "#" Infix L 8) (ppE e) (\_ -> pretty (tupleIdxToInt tix)) + Const tp c -> prettyConst (TupRsingle tp) c + Pair{} -> prettyTuple ctx env aenv exp + Nil -> "()" + VecPack _ e -> ppF1 "vecPack" (ppE e) + VecUnpack _ e -> ppF1 "vecUnpack" (ppE e) Cond p t e -> flatAlt multi single where p' = ppE p context0 @@ -350,35 +372,28 @@ prettyPreOpenExp ctx prettyAcc extractAcc env aenv exp = , hang shiftwidth (sep [ then_, t' ]) , hang shiftwidth (sep [ else_, e' ]) ] -- - IndexAny -> "Any" - IndexNil -> pretty 'Z' - IndexCons sh sz -> ppF2 (Operator ":." Infix L 3) (ppE sh) (ppE sz) - IndexHead sh -> ppF1 "indexHead" (ppE sh) - IndexTail sh -> ppF1 "indexTail" (ppE sh) IndexSlice _ slix sh -> ppF2 "indexSlice" (ppE slix) (ppE sh) IndexFull _ slix sl -> ppF2 "indexFull" (ppE slix) (ppE sl) - ToIndex sh ix -> ppF2 "toIndex" (ppE sh) (ppE ix) - FromIndex sh ix -> ppF2 "fromIndex" (ppE sh) (ppE ix) + ToIndex _ sh ix -> ppF2 "toIndex" (ppE sh) (ppE ix) + FromIndex _ sh ix -> ppF2 "fromIndex" (ppE sh) (ppE ix) While p f x -> ppF3 "while" (ppF p) (ppF f) (ppE x) - Foreign ff _f e -> ppF2 "foreign" (\_ -> pretty (strForeign ff)) (ppE e) + Foreign _ ff _ e -> ppF2 "foreign" (\_ -> pretty (strForeign ff)) (ppE e) Shape arr -> ppF1 "shape" (ppA arr) - ShapeSize sh -> ppF1 "shapeSize" (ppE sh) - Intersect sh1 sh2 -> ppF2 "intersect" (ppE sh1) (ppE sh2) - Union sh1 sh2 -> ppF2 "union" (ppE sh1) (ppE sh2) + ShapeSize _ sh -> ppF1 "shapeSize" (ppE sh) Index arr ix -> ppF2 (Operator (pretty '!') Infix L 9) (ppA arr) (ppE ix) LinearIndex arr ix -> ppF2 (Operator "!!" Infix L 9) (ppA arr) (ppE ix) - Coerce x -> ppF1 (Operator (withTypeRep "coerce") App L 10) (ppE x) - Undef -> withTypeRep "undef" + Coerce _ tp x -> ppF1 (Operator (withTypeRep tp "coerce") App L 10) (ppE x) + Undef tp -> withTypeRep tp "undef" where - ppE :: PreOpenExp acc env aenv e -> Context -> Adoc - ppE e c = prettyPreOpenExp c prettyAcc extractAcc env aenv e + ppE :: OpenExp env aenv e -> Context -> Adoc + ppE e c = prettyOpenExp c env aenv e - ppA :: acc aenv a -> Context -> Adoc - ppA acc _ = prettyAcc app aenv acc + ppA :: ArrayVar aenv a -> Context -> Adoc + ppA acc _ = prettyArrayVar aenv acc - ppF :: PreOpenFun acc env aenv f -> Context -> Adoc - ppF f _ = parens $ prettyPreOpenFun prettyAcc extractAcc env aenv f + ppF :: OpenFun env aenv f -> Context -> Adoc + ppF f _ = parens $ prettyOpenFun env aenv f ppF1 :: Operator -> (Context -> Adoc) -> Adoc ppF1 op x @@ -401,29 +416,32 @@ prettyPreOpenExp ctx prettyAcc extractAcc env aenv exp = $ hang 2 $ sep [ opName op, x app, y app, z app ] - withTypeRep :: Typeable t => Adoc -> Adoc - withTypeRep op = op <> enclose langle rangle (pretty (showsTypeRep (typeOf (undefined::t)) "")) + withTypeRep :: ScalarType t -> Adoc -> Adoc + withTypeRep tp op = op <> enclose langle rangle (pretty (showScalarType tp)) +prettyArrayVar + :: forall aenv a. + Val aenv + -> ArrayVar aenv a + -> Adoc +prettyArrayVar aenv (Var _ idx) = prj idx aenv prettyLet - :: forall acc env aenv t. + :: forall env aenv t. Context - -> PrettyAcc acc - -> ExtractAcc acc -> Val env -> Val aenv - -> PreOpenExp acc env aenv t + -> OpenExp env aenv t -> Adoc -prettyLet ctx prettyAcc extractAcc env0 aenv +prettyLet ctx env0 aenv = parensIf (needsParens ctx "let") . align . wrap . collect env0 where - collect :: Val env' -> PreOpenExp acc env' aenv e -> ([Adoc], Adoc) + collect :: Val env' -> OpenExp env' aenv e -> ([Adoc], Adoc) collect env = \case - Let e1 e2 -> - let env' = env `Push` v - v = pretty 'x' <> pretty (sizeEnv env) + Let lhs e1 e2 -> + let (env', v) = prettyELhs False env lhs e1' = ppE env e1 bnd | isLet e1 = nest shiftwidth (vsep [v <+> equals, e1']) | otherwise = v <+> align (equals <+> e1') @@ -433,12 +451,12 @@ prettyLet ctx prettyAcc extractAcc env0 aenv -- next -> ([], ppE env next) - isLet :: PreOpenExp acc env' aenv t' -> Bool + isLet :: OpenExp env' aenv t' -> Bool isLet Let{} = True isLet _ = False - ppE :: Val env' -> PreOpenExp acc env' aenv t' -> Adoc - ppE env = prettyPreOpenExp context0 prettyAcc extractAcc env aenv + ppE :: Val env' -> OpenExp env' aenv t' -> Adoc + ppE env = prettyOpenExp context0 env aenv wrap :: ([Adoc], Adoc) -> Adoc wrap ([], body) = body -- shouldn't happen! @@ -451,30 +469,55 @@ prettyLet ctx prettyAcc extractAcc env0 aenv ] prettyTuple - :: forall acc env aenv t p. - TupleType t - -> PrettyAcc acc - -> ExtractAcc acc + :: forall env aenv t. + Context -> Val env -> Val aenv - -> Tuple (PreOpenExp acc env aenv) p + -> OpenExp env aenv t -> Adoc -prettyTuple tt prettyAcc extractAcc env aenv = wrap . collect [] +prettyTuple ctx env aenv exp = case collect exp of + Just tup -> align $ parensIf (ctxPrecedence ctx > 0) ("T" <> pretty (length tup) <+> sep tup) + Nothing -> align $ ppPair exp where - collect :: [Adoc] -> Tuple (PreOpenExp acc env aenv) s -> [Adoc] - collect acc = - \case - NilTup -> acc - SnocTup tup e -> collect (align (prettyPreOpenExp context0 prettyAcc extractAcc env aenv e) : acc) tup - -- - wrap - | TypeRscalar VectorScalarType{} <- tt = group . encloseSep (flatAlt "< " "<") (flatAlt " >" ">") ", " - | otherwise = tupled -- as above, with parenthesis + ppPair :: OpenExp env aenv t' -> Adoc + ppPair (Pair e1 e2) = "(" <> ppPair e1 <> "," <+> prettyOpenExp context0 env aenv e2 <> ")" + ppPair e = prettyOpenExp context0 env aenv e + + collect :: OpenExp env aenv t' -> Maybe [Adoc] + collect Nil = Just [] + collect (Pair e1 e2) + | Just tup <- collect e1 + = Just $ tup ++ [prettyOpenExp app env aenv e2] + collect _ = Nothing +{- -prettyConst :: Elt e => e -> Adoc -prettyConst x = - let y = show x +prettyAtuple + :: forall acc aenv arrs. + PrettyAcc acc + -> ExtractAcc acc + -> Val aenv + -> PreOpenAcc acc aenv arrs + -> Adoc +prettyAtuple prettyAcc extractAcc aenv0 acc = case collect acc of + Just tup -> align $ "T" <> pretty (length tup) <+> sep tup + Nothing -> align $ ppPair acc + where + ppPair :: PreOpenAcc acc aenv arrs' -> Adoc + ppPair (Apair a1 a2) = "(" <> ppPair (extractAcc a1) <> "," <+> prettyAcc context0 aenv0 a2 <> ")" + ppPair a = prettyPreOpenAcc context0 prettyAcc extractAcc aenv0 a + + collect :: PreOpenAcc acc aenv arrs' -> Maybe [Adoc] + collect Anil = Just [] + collect (Apair a1 a2) + | Just tup <- collect $ extractAcc a1 + = Just $ tup ++ [prettyAcc app aenv0 a2] + collect _ = Nothing +-} + +prettyConst :: TupleType e -> e -> Adoc +prettyConst tp x = + let y = showElement tp x in parensIf (any isSpace y) (pretty y) prettyPrimConst :: PrimConst a -> Adoc diff --git a/src/Data/Array/Accelerate/Smart.hs b/src/Data/Array/Accelerate/Smart.hs index 16e99fa26..d48603093 100644 --- a/src/Data/Array/Accelerate/Smart.hs +++ b/src/Data/Array/Accelerate/Smart.hs @@ -1,5 +1,6 @@ {-# LANGUAGE AllowAmbiguousTypes #-} -{-# LANGUAGE DeriveDataTypeable #-} +{-# LANGUAGE CPP #-} +{-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} @@ -27,8 +28,8 @@ module Data.Array.Accelerate.Smart ( -- * HOAS AST - Acc(..), SmartAcc(..), PreSmartAcc(..), PairIdx(..), Exp(..), PreExp(..), - Boundary(..), PreBoundary(..), Stencil(..), Level, + Acc(..), SmartAcc(..), PreSmartAcc(..), PairIdx(..), Exp(..), SmartExp(..), PreSmartExp(..), + Boundary(..), PreBoundary(..), Stencil(..), Level, unExp, -- * Smart constructors for literals constant, undef, @@ -58,10 +59,11 @@ module Data.Array.Accelerate.Smart ( mkLAnd, mkLOr, mkLNot, mkIsNaN, mkIsInfinite, -- * Smart constructors for type coercion functions - mkOrd, mkChr, mkBoolToInt, mkFromIntegral, mkToFloating, mkBitcast, mkUnsafeCoerce, + mkOrd, mkChr, mkBoolToInt, mkFromIntegral, mkToFloating, mkBitcast, mkCoerce, Coerce, -- * Auxiliary functions - ($$), ($$$), ($$$$), ($$$$$), unAcc, unAccFunction, ApplyAcc(..), + ($$), ($$$), ($$$$), ($$$$$), unAcc, unAccFunction, ApplyAcc(..), exp, unPair, mkPairToTuple, HasExpType(..), HasArraysRepr(..), + vecR2, vecR3, vecR4, vecR5, vecR6, vecR7, vecR8, vecR9, vecR16, unExpFunction, -- Debugging showPreAccOp, showPreExpOp, @@ -69,19 +71,19 @@ module Data.Array.Accelerate.Smart ( ) where -- standard library -import Prelude hiding ( exp ) +import Prelude hiding ( exp ) import Data.Kind -import Data.Typeable -- friends import Data.Array.Accelerate.Type -import Data.Array.Accelerate.Array.Sugar -import Data.Array.Accelerate.Product -import Data.Array.Accelerate.AST hiding ( PreOpenAcc(..), OpenAcc(..), Acc - , PreOpenExp(..), OpenExp, PreExp, Exp - , Stencil(..), PreBoundary(..), Boundary - , showPreAccOp, showPreExpOp ) -import qualified Data.Array.Accelerate.AST as AST +import Data.Array.Accelerate.Array.Sugar (Elt, Arrays, EltRepr, ArrRepr, (:.), Foreign, eltType, fromElt, DIM1) +import qualified Data.Array.Accelerate.Array.Sugar as Sugar +import Data.Array.Accelerate.Array.Representation hiding (DIM1) +import Data.Array.Accelerate.AST hiding ( PreOpenAcc(..), OpenAcc(..), Acc + , OpenExp(..), Exp + , Boundary(..), HasArraysRepr(..), arrayRepr, expType + , showPreAccOp, showPreExpOp ) +import GHC.TypeNats -- Array computations -- ------------------ @@ -271,8 +273,7 @@ import qualified Data.Array.Accelerate.AST as AST -- newtype Acc a = Acc (SmartAcc (ArrRepr a)) -newtype SmartAcc a = SmartAcc (PreSmartAcc SmartAcc Exp a) -deriving instance Typeable Acc +newtype SmartAcc a = SmartAcc (PreSmartAcc SmartAcc SmartExp a) -- The level of lambda-bound variables. The root has level 0; then it increases with each bound @@ -284,32 +285,31 @@ type Level = Int -- data PreSmartAcc acc exp as where -- Needed for conversion to de Bruijn form - Atag :: Typeable as - => Level -- environment size at defining occurrence + Atag :: ArraysR as + -> Level -- environment size at defining occurrence -> PreSmartAcc acc exp as - Pipe :: (Typeable as, Typeable bs, Typeable cs) - => ArraysR as + Pipe :: ArraysR as -> ArraysR bs + -> ArraysR cs -> (SmartAcc as -> acc bs) -> (SmartAcc bs -> acc cs) -> acc as -> PreSmartAcc acc exp cs - Aforeign :: (Typeable (ArrRepr as), Typeable (ArrRepr bs), Arrays as, Arrays bs, Foreign asm) - => asm (as -> bs) - -> (Acc as -> Acc bs) - -> acc (ArrRepr as) - -> PreSmartAcc acc exp (ArrRepr bs) + Aforeign :: Foreign asm + => ArraysR bs + -> asm (as -> bs) + -> (SmartAcc as -> SmartAcc bs) + -> acc as + -> PreSmartAcc acc exp bs - Acond :: Typeable as - => exp Bool + Acond :: exp Bool -> acc as -> acc as -> PreSmartAcc acc exp as - Awhile :: Typeable arrs - => ArraysR arrs + Awhile :: ArraysR arrs -> (SmartAcc arrs -> acc (Scalar Bool)) -> (SmartAcc arrs -> acc arrs) -> acc arrs @@ -317,134 +317,140 @@ data PreSmartAcc acc exp as where Anil :: PreSmartAcc acc exp () - Apair :: (Typeable arrs1, Typeable arrs2) - => acc arrs1 + Apair :: acc arrs1 -> acc arrs2 -> PreSmartAcc acc exp (arrs1, arrs2) - Aprj :: (Typeable arrs1, Typeable arrs2) - => PairIdx (arrs1, arrs2) arrs + Aprj :: PairIdx (arrs1, arrs2) arrs -> acc (arrs1, arrs2) -> PreSmartAcc acc exp arrs - Use :: (Shape sh, Elt e) - => Array sh e + Use :: ArrayR (Array sh e) + -> Array sh e -> PreSmartAcc acc exp (Array sh e) - Unit :: Elt e - => exp e + Unit :: TupleType e + -> exp e -> PreSmartAcc acc exp (Scalar e) - Generate :: (Shape sh, Elt e) - => exp sh - -> (Exp sh -> exp e) + Generate :: ArrayR (Array sh e) + -> exp sh + -> (SmartExp sh -> exp e) -> PreSmartAcc acc exp (Array sh e) - Reshape :: (Shape sh, Shape sh', Elt e) - => exp sh + Reshape :: ShapeR sh + -> exp sh -> acc (Array sh' e) -> PreSmartAcc acc exp (Array sh e) - Replicate :: (Slice slix, Elt e) - => exp slix - -> acc (Array (SliceShape slix) e) - -> PreSmartAcc acc exp (Array (FullShape slix) e) + Replicate :: SliceIndex slix sl co sh + -> exp slix + -> acc (Array sl e) + -> PreSmartAcc acc exp (Array sh e) - Slice :: (Slice slix, Elt e) - => acc (Array (FullShape slix) e) + Slice :: SliceIndex slix sl co sh + -> acc (Array sh e) -> exp slix - -> PreSmartAcc acc exp (Array (SliceShape slix) e) + -> PreSmartAcc acc exp (Array sl e) - Map :: (Shape sh, Elt e, Elt e') - => (Exp e -> exp e') + Map :: TupleType e + -> TupleType e' + -> (SmartExp e -> exp e') -> acc (Array sh e) -> PreSmartAcc acc exp (Array sh e') - ZipWith :: (Shape sh, Elt e1, Elt e2, Elt e3) - => (Exp e1 -> Exp e2 -> exp e3) + ZipWith :: TupleType e1 + -> TupleType e2 + -> TupleType e3 + -> (SmartExp e1 -> SmartExp e2 -> exp e3) -> acc (Array sh e1) -> acc (Array sh e2) -> PreSmartAcc acc exp (Array sh e3) - Fold :: (Shape sh, Elt e) - => (Exp e -> Exp e -> exp e) + Fold :: TupleType e + -> (SmartExp e -> SmartExp e -> exp e) -> exp e - -> acc (Array (sh:.Int) e) + -> acc (Array (sh, Int) e) -> PreSmartAcc acc exp (Array sh e) - Fold1 :: (Shape sh, Elt e) - => (Exp e -> Exp e -> exp e) - -> acc (Array (sh:.Int) e) + Fold1 :: TupleType e + -> (SmartExp e -> SmartExp e -> exp e) + -> acc (Array (sh, Int) e) -> PreSmartAcc acc exp (Array sh e) - FoldSeg :: (Shape sh, Elt e, Elt i, IsIntegral i) - => (Exp e -> Exp e -> exp e) + FoldSeg :: IntegralType i + -> TupleType e + -> (SmartExp e -> SmartExp e -> exp e) -> exp e - -> acc (Array (sh:.Int) e) + -> acc (Array (sh, Int) e) -> acc (Segments i) - -> PreSmartAcc acc exp (Array (sh:.Int) e) + -> PreSmartAcc acc exp (Array (sh, Int) e) - Fold1Seg :: (Shape sh, Elt e, Elt i, IsIntegral i) - => (Exp e -> Exp e -> exp e) - -> acc (Array (sh:.Int) e) + Fold1Seg :: IntegralType i + -> TupleType e + -> (SmartExp e -> SmartExp e -> exp e) + -> acc (Array (sh, Int) e) -> acc (Segments i) - -> PreSmartAcc acc exp (Array (sh:.Int) e) + -> PreSmartAcc acc exp (Array (sh, Int) e) - Scanl :: (Shape sh, Elt e) - => (Exp e -> Exp e -> exp e) + Scanl :: TupleType e + -> (SmartExp e -> SmartExp e -> exp e) -> exp e - -> acc (Array (sh :. Int) e) - -> PreSmartAcc acc exp (Array (sh :. Int) e) + -> acc (Array (sh, Int) e) + -> PreSmartAcc acc exp (Array (sh, Int) e) - Scanl' :: (Shape sh, Elt e) - => (Exp e -> Exp e -> exp e) + Scanl' :: TupleType e + -> (SmartExp e -> SmartExp e -> exp e) -> exp e - -> acc (Array (sh :. Int) e) - -> PreSmartAcc acc exp (ArrRepr (Array (sh :. Int) e, Array sh e)) + -> acc (Array (sh, Int) e) + -> PreSmartAcc acc exp (Array (sh, Int) e, Array sh e) - Scanl1 :: (Shape sh, Elt e) - => (Exp e -> Exp e -> exp e) - -> acc (Array (sh :. Int) e) - -> PreSmartAcc acc exp (Array (sh :. Int) e) + Scanl1 :: TupleType e + -> (SmartExp e -> SmartExp e -> exp e) + -> acc (Array (sh, Int) e) + -> PreSmartAcc acc exp (Array (sh, Int) e) - Scanr :: (Shape sh, Elt e) - => (Exp e -> Exp e -> exp e) + Scanr :: TupleType e + -> (SmartExp e -> SmartExp e -> exp e) -> exp e - -> acc (Array (sh :. Int) e) - -> PreSmartAcc acc exp (Array (sh :. Int) e) + -> acc (Array (sh, Int) e) + -> PreSmartAcc acc exp (Array (sh, Int) e) - Scanr' :: (Shape sh, Elt e) - => (Exp e -> Exp e -> exp e) + Scanr' :: TupleType e + -> (SmartExp e -> SmartExp e -> exp e) -> exp e - -> acc (Array (sh :. Int) e) - -> PreSmartAcc acc exp (ArrRepr (Array (sh :. Int) e, Array sh e)) + -> acc (Array (sh, Int) e) + -> PreSmartAcc acc exp (Array (sh, Int) e, Array sh e) - Scanr1 :: (Shape sh, Elt e) - => (Exp e -> Exp e -> exp e) - -> acc (Array (sh :. Int) e) - -> PreSmartAcc acc exp (Array (sh :. Int) e) + Scanr1 :: TupleType e + -> (SmartExp e -> SmartExp e -> exp e) + -> acc (Array (sh, Int) e) + -> PreSmartAcc acc exp (Array (sh, Int) e) - Permute :: (Shape sh, Shape sh', Elt e) - => (Exp e -> Exp e -> exp e) + Permute :: ArrayR (Array sh e) + -> (SmartExp e -> SmartExp e -> exp e) -> acc (Array sh' e) - -> (Exp sh -> exp sh') + -> (SmartExp sh -> exp sh') -> acc (Array sh e) -> PreSmartAcc acc exp (Array sh' e) - Backpermute :: (Shape sh, Shape sh', Elt e) - => exp sh' - -> (Exp sh' -> exp sh) + Backpermute :: ShapeR sh' + -> exp sh' + -> (SmartExp sh' -> exp sh) -> acc (Array sh e) -> PreSmartAcc acc exp (Array sh' e) - Stencil :: (Shape sh, Elt a, Elt b, Stencil sh a stencil) - => (stencil -> exp b) + Stencil :: StencilR sh a stencil + -> TupleType b + -> (SmartExp stencil -> exp b) -> PreBoundary acc exp (Array sh a) -> acc (Array sh a) -> PreSmartAcc acc exp (Array sh b) - Stencil2 :: (Shape sh, Elt a, Elt b, Elt c, Stencil sh a stencil1, Stencil sh b stencil2) - => (stencil1 -> stencil2 -> exp c) + Stencil2 :: StencilR sh a stencil1 + -> StencilR sh b stencil2 + -> TupleType c + -> (SmartExp stencil1 -> SmartExp stencil2 -> exp c) -> PreBoundary acc exp (Array sh a) -> acc (Array sh a) -> PreBoundary acc exp (Array sh b) @@ -455,9 +461,63 @@ data PreSmartAcc acc exp as where -- => seq arrs -- -> PreSmartAcc acc seq exp arrs -data PairIdx p a where - PairIdxLeft :: PairIdx (a, b) a - PairIdxRight :: PairIdx (a, b) b +class HasArraysRepr f where + arraysRepr :: f a -> ArraysR a + +arrayRepr :: HasArraysRepr f => f (Array sh e) -> ArrayR (Array sh e) +arrayRepr acc = case arraysRepr acc of + TupRsingle repr -> repr + +instance HasArraysRepr acc => HasArraysRepr (PreSmartAcc acc exp) where + arraysRepr acc = case acc of + Atag repr _ -> repr + Pipe _ _ repr _ _ _ -> repr + Aforeign repr _ _ _ -> repr + Acond _ a _ -> arraysRepr a + Awhile _ _ _ a -> arraysRepr a + Anil -> TupRunit + Apair a1 a2 -> arraysRepr a1 `TupRpair` arraysRepr a2 + Aprj idx a | TupRpair t1 t2 <- arraysRepr a + -> case idx of + PairIdxLeft -> t1 + PairIdxRight -> t2 + Aprj _ _ -> error "Ejector seat? You're joking!" + Use repr _ -> TupRsingle repr + Unit tp _ -> TupRsingle $ ArrayR ShapeRz $ tp + Generate repr _ _ -> TupRsingle repr + Reshape shr _ a -> let ArrayR _ tp = arrayRepr a + in TupRsingle $ ArrayR shr tp + Replicate si _ a -> let ArrayR _ tp = arrayRepr a + in TupRsingle $ ArrayR (sliceDomainR si) tp + Slice si a _ -> let ArrayR _ tp = arrayRepr a + in TupRsingle $ ArrayR (sliceShapeR si) tp + Map _ tp _ a -> let ArrayR shr _ = arrayRepr a + in TupRsingle $ ArrayR shr tp + ZipWith _ _ tp _ a _ -> let ArrayR shr _ = arrayRepr a + in TupRsingle $ ArrayR shr tp + Fold _ _ _ a -> let ArrayR (ShapeRsnoc shr) tp = arrayRepr a + in TupRsingle (ArrayR shr tp) + Fold1 _ _ a -> let ArrayR (ShapeRsnoc shr) tp = arrayRepr a + in TupRsingle (ArrayR shr tp) + FoldSeg _ _ _ _ a _ -> arraysRepr a + Fold1Seg _ _ _ a _ -> arraysRepr a + Scanl _ _ _ a -> arraysRepr a + Scanl' _ _ _ a -> let repr@(ArrayR (ShapeRsnoc shr) tp) = arrayRepr a + in TupRsingle repr `TupRpair` TupRsingle (ArrayR shr tp) + Scanl1 _ _ a -> arraysRepr a + Scanr _ _ _ a -> arraysRepr a + Scanr' _ _ _ a -> let repr@(ArrayR (ShapeRsnoc shr) tp) = arrayRepr a + in TupRsingle repr `TupRpair` TupRsingle (ArrayR shr tp) + Scanr1 _ _ a -> arraysRepr a + Permute _ _ a _ _ -> arraysRepr a + Backpermute shr _ _ a -> let ArrayR _ tp = arrayRepr a + in TupRsingle (ArrayR shr tp) + Stencil s tp _ _ _ -> TupRsingle $ ArrayR (stencilShape s) tp + Stencil2 s _ tp _ _ _ _ _ -> TupRsingle $ ArrayR (stencilShape s) tp + +instance HasArraysRepr SmartAcc where + arraysRepr (SmartAcc e) = arraysRepr e + {-- data PreSeq acc seq exp arrs where @@ -574,9 +634,7 @@ deriving instance Typeable Seq -- -------------------------------------------- -- HOAS expressions mirror the constructors of 'AST.OpenExp', but with the 'Tag' --- constructor instead of variables in the form of de Bruijn indices. Moreover, --- HOAS expression use n-tuples and the type class 'Elt' to constrain element --- types, whereas 'AST.OpenExp' uses nested pairs and the GADT 'TupleType'. +-- constructor instead of variables in the form of de Bruijn indices. -- -- | The type 'Exp' represents embedded scalar expressions. The collective @@ -588,142 +646,157 @@ deriving instance Typeable Seq -- efficiently on constrained hardware such as GPUs, and is thus currently -- unsupported. -- -newtype Exp t = Exp (PreExp SmartAcc Exp t) - -deriving instance Typeable Exp +newtype Exp t = Exp (SmartExp (EltRepr t)) +newtype SmartExp t = SmartExp (PreSmartExp SmartAcc SmartExp t) -- | Scalar expressions to parametrise collective array operations, themselves parameterised over -- the type of collective array operations. -- -data PreExp acc exp t where +data PreSmartExp acc exp t where -- Needed for conversion to de Bruijn form - Tag :: Elt t - => Level -- environment size at defining occurrence - -> PreExp acc exp t - - -- All the same constructors as 'AST.Exp' - Const :: Elt t - => t - -> PreExp acc exp t - - Tuple :: (Elt t, IsTuple t) - => Tuple exp (TupleRepr t) - -> PreExp acc exp t + Tag :: TupleType t + -> Level -- environment size at defining occurrence + -> PreSmartExp acc exp t - Prj :: (Elt t, IsTuple t, Elt e) - => TupleIdx (TupleRepr t) e - -> exp t - -> PreExp acc exp e + -- All the same constructors as 'AST.Exp', plus projection + Const :: ScalarType t + -> t + -> PreSmartExp acc exp t - IndexNil :: PreExp acc exp Z + Nil :: PreSmartExp acc exp () - IndexCons :: (Elt sl, Elt a) - => exp sl - -> exp a - -> PreExp acc exp (sl:.a) + Pair :: exp t1 + -> exp t2 + -> PreSmartExp acc exp (t1, t2) - IndexHead :: (Elt sl, Elt a) - => exp (sl:.a) - -> PreExp acc exp a + Prj :: PairIdx (t1, t2) t + -> exp (t1, t2) + -> PreSmartExp acc exp t - IndexTail :: (Elt sl, Elt a) - => exp (sl:.a) - -> PreExp acc exp sl + -- SIMD vectors + VecPack :: KnownNat n + => VecR n s tup + -> exp tup + -> PreSmartExp acc exp (Vec n s) - IndexAny :: Shape sh - => PreExp acc exp (Any sh) + VecUnpack :: KnownNat n + => VecR n s tup + -> exp (Vec n s) + -> PreSmartExp acc exp tup - ToIndex :: Shape sh - => exp sh + ToIndex :: ShapeR sh + -> exp sh -> exp sh - -> PreExp acc exp Int + -> PreSmartExp acc exp Int - FromIndex :: Shape sh - => exp sh + FromIndex :: ShapeR sh + -> exp sh -> exp Int - -> PreExp acc exp sh + -> PreSmartExp acc exp sh - Cond :: Elt t - => exp Bool + Cond :: exp Bool -> exp t -> exp t - -> PreExp acc exp t + -> PreSmartExp acc exp t - While :: Elt t - => (Exp t -> exp Bool) - -> (Exp t -> exp t) + While :: TupleType t + -> (SmartExp t -> exp Bool) + -> (SmartExp t -> exp t) -> exp t - -> PreExp acc exp t + -> PreSmartExp acc exp t - PrimConst :: Elt t - => PrimConst t - -> PreExp acc exp t + PrimConst :: PrimConst t + -> PreSmartExp acc exp t - PrimApp :: (Elt a, Elt r) - => PrimFun (a -> r) + PrimApp :: PrimFun (a -> r) -> exp a - -> PreExp acc exp r + -> PreSmartExp acc exp r - Index :: (Shape sh, Elt t) - => acc (Array sh t) + Index :: TupleType t + -> acc (Array sh t) -> exp sh - -> PreExp acc exp t + -> PreSmartExp acc exp t - LinearIndex :: (Shape sh, Elt t) - => acc (Array sh t) + LinearIndex :: TupleType t + -> acc (Array sh t) -> exp Int - -> PreExp acc exp t - - Shape :: (Shape sh, Elt e) - => acc (Array sh e) - -> PreExp acc exp sh + -> PreSmartExp acc exp t - ShapeSize :: Shape sh - => exp sh - -> PreExp acc exp Int - - Intersect :: Shape sh - => exp sh - -> exp sh - -> PreExp acc exp sh + Shape :: ShapeR sh + -> acc (Array sh e) + -> PreSmartExp acc exp sh - Union :: Shape sh - => exp sh + ShapeSize :: ShapeR sh -> exp sh - -> PreExp acc exp sh + -> PreSmartExp acc exp Int - Foreign :: (Elt x, Elt y, Foreign asm) - => asm (x -> y) - -> (Exp x -> Exp y) -- RCE: Using Exp instead of exp to aid in sharing recovery. + Foreign :: Foreign asm + => TupleType y + -> asm (x -> y) + -> (SmartExp x -> SmartExp y) -- RCE: Using SmartExp instead of exp to aid in sharing recovery. -> exp x - -> PreExp acc exp y + -> PreSmartExp acc exp y - Undef :: Elt t - => PreExp acc exp t - - Coerce :: (Elt a, Elt b) - => exp a - -> PreExp acc exp b + Undef :: ScalarType t + -> PreSmartExp acc exp t + Coerce :: BitSizeEq a b + => ScalarType a + -> ScalarType b + -> exp a + -> PreSmartExp acc exp b + +class HasExpType f where + expType :: f t -> TupleType t + +instance HasExpType exp => HasExpType (PreSmartExp acc exp) where + expType expr = case expr of + Tag tp _ -> tp + Const tp _ -> TupRsingle tp + Nil -> TupRunit + Pair e1 e2 -> expType e1 `TupRpair` expType e2 + Prj idx e | TupRpair t1 t2 <- expType e + -> case idx of + PairIdxLeft -> t1 + PairIdxRight -> t2 + Prj _ _ -> error "I never joke about my work" + VecPack vecR _ -> TupRsingle $ VectorScalarType $ vecRvector vecR + VecUnpack vecR _ -> vecRtuple vecR + ToIndex _ _ _ -> TupRsingle $ scalarTypeInt + FromIndex shr _ _ -> shapeType shr + Cond _ e _ -> expType e + While t _ _ _ -> t + PrimConst c -> TupRsingle $ SingleScalarType $ primConstType c + PrimApp f _ -> snd $ primFunType f + Index tp _ _ -> tp + LinearIndex tp _ _ -> tp + Shape shr _ -> shapeType shr + ShapeSize _ _ -> TupRsingle $ scalarTypeInt + Foreign tp _ _ _ -> tp + Undef tp -> TupRsingle tp + Coerce _ tp _ -> TupRsingle tp + +instance HasExpType SmartExp where + expType (SmartExp e) = expType e -- Smart constructors for stencils -- ------------------------------- -- | Boundary condition specification for stencil operations -- -newtype Boundary t = Boundary (PreBoundary SmartAcc Exp t) +data Boundary t where + Boundary :: !(PreBoundary SmartAcc SmartExp (Array (EltRepr sh) (EltRepr e))) + -> Boundary (Sugar.Array sh e) data PreBoundary acc exp t where Clamp :: PreBoundary acc exp t Mirror :: PreBoundary acc exp t Wrap :: PreBoundary acc exp t - Constant :: Elt e - => e + Constant :: e -> PreBoundary acc exp (Array sh e) - Function :: (Shape sh, Elt e) - => (Exp sh -> exp e) + Function :: (SmartExp sh -> exp e) -> PreBoundary acc exp (Array sh e) @@ -737,170 +810,189 @@ data PreBoundary acc exp t where -- Bruijn index). The various positions in the stencil are accessed via -- tuple indices (i.e., projections). -- -class (Elt (StencilRepr sh stencil), AST.Stencil sh a (StencilRepr sh stencil)) => Stencil sh a stencil where +class Stencil sh e stencil where type StencilRepr sh stencil :: Type - stencilPrj :: Exp (StencilRepr sh stencil) -> stencil + + stencilR :: StencilR (EltRepr sh) (EltRepr e) (StencilRepr sh stencil) + stencilPrj :: SmartExp (StencilRepr sh stencil) -> stencil -- DIM1 instance Elt e => Stencil DIM1 e (Exp e, Exp e, Exp e) where type StencilRepr DIM1 (Exp e, Exp e, Exp e) - = (e, e, e) - stencilPrj s = (Exp $ Prj tix2 s, - Exp $ Prj tix1 s, - Exp $ Prj tix0 s) + = EltRepr (e, e, e) + stencilR = StencilRunit3 @(EltRepr e) $ eltType @e + stencilPrj s = (Exp $ prj2 s, + Exp $ prj1 s, + Exp $ prj0 s) instance Elt e => Stencil DIM1 e (Exp e, Exp e, Exp e, Exp e, Exp e) where type StencilRepr DIM1 (Exp e, Exp e, Exp e, Exp e, Exp e) - = (e, e, e, e, e) - stencilPrj s = (Exp $ Prj tix4 s, - Exp $ Prj tix3 s, - Exp $ Prj tix2 s, - Exp $ Prj tix1 s, - Exp $ Prj tix0 s) + = EltRepr (e, e, e, e, e) + stencilR = StencilRunit5 $ eltType @e + stencilPrj s = (Exp $ prj4 s, + Exp $ prj3 s, + Exp $ prj2 s, + Exp $ prj1 s, + Exp $ prj0 s) instance Elt e => Stencil DIM1 e (Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e) where type StencilRepr DIM1 (Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e) - = (e, e, e, e, e, e, e) - stencilPrj s = (Exp $ Prj tix6 s, - Exp $ Prj tix5 s, - Exp $ Prj tix4 s, - Exp $ Prj tix3 s, - Exp $ Prj tix2 s, - Exp $ Prj tix1 s, - Exp $ Prj tix0 s) + = EltRepr (e, e, e, e, e, e, e) + stencilR = StencilRunit7 $ eltType @e + stencilPrj s = (Exp $ prj6 s, + Exp $ prj5 s, + Exp $ prj4 s, + Exp $ prj3 s, + Exp $ prj2 s, + Exp $ prj1 s, + Exp $ prj0 s) instance Elt e => Stencil DIM1 e (Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e) where type StencilRepr DIM1 (Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e) - = (e, e, e, e, e, e, e, e, e) - stencilPrj s = (Exp $ Prj tix8 s, - Exp $ Prj tix7 s, - Exp $ Prj tix6 s, - Exp $ Prj tix5 s, - Exp $ Prj tix4 s, - Exp $ Prj tix3 s, - Exp $ Prj tix2 s, - Exp $ Prj tix1 s, - Exp $ Prj tix0 s) + = EltRepr (e, e, e, e, e, e, e, e, e) + stencilR = StencilRunit9 $ eltType @e + stencilPrj s = (Exp $ prj8 s, + Exp $ prj7 s, + Exp $ prj6 s, + Exp $ prj5 s, + Exp $ prj4 s, + Exp $ prj3 s, + Exp $ prj2 s, + Exp $ prj1 s, + Exp $ prj0 s) -- DIM(n+1) instance (Stencil (sh:.Int) a row2, Stencil (sh:.Int) a row1, Stencil (sh:.Int) a row0) => Stencil (sh:.Int:.Int) a (row2, row1, row0) where type StencilRepr (sh:.Int:.Int) (row2, row1, row0) - = (StencilRepr (sh:.Int) row2, StencilRepr (sh:.Int) row1, StencilRepr (sh:.Int) row0) - stencilPrj s = (stencilPrj @(sh:.Int) @a (Exp $ Prj tix2 s), - stencilPrj @(sh:.Int) @a (Exp $ Prj tix1 s), - stencilPrj @(sh:.Int) @a (Exp $ Prj tix0 s)) + = Tup3 (StencilRepr (sh:.Int) row2) (StencilRepr (sh:.Int) row1) (StencilRepr (sh:.Int) row0) + stencilR = StencilRtup3 (stencilR @(sh:.Int) @a @row2) (stencilR @(sh:.Int) @a @row1) (stencilR @(sh:.Int) @a @row0) + stencilPrj s = (stencilPrj @(sh:.Int) @a $ prj2 s, + stencilPrj @(sh:.Int) @a $ prj1 s, + stencilPrj @(sh:.Int) @a $ prj0 s) -instance (Stencil (sh:.Int) a row1, - Stencil (sh:.Int) a row2, +instance (Stencil (sh:.Int) a row4, Stencil (sh:.Int) a row3, - Stencil (sh:.Int) a row4, - Stencil (sh:.Int) a row5) => Stencil (sh:.Int:.Int) a (row1, row2, row3, row4, row5) where - type StencilRepr (sh:.Int:.Int) (row1, row2, row3, row4, row5) - = (StencilRepr (sh:.Int) row1, StencilRepr (sh:.Int) row2, StencilRepr (sh:.Int) row3, - StencilRepr (sh:.Int) row4, StencilRepr (sh:.Int) row5) - stencilPrj s = (stencilPrj @(sh:.Int) @a (Exp $ Prj tix4 s), - stencilPrj @(sh:.Int) @a (Exp $ Prj tix3 s), - stencilPrj @(sh:.Int) @a (Exp $ Prj tix2 s), - stencilPrj @(sh:.Int) @a (Exp $ Prj tix1 s), - stencilPrj @(sh:.Int) @a (Exp $ Prj tix0 s)) - -instance (Stencil (sh:.Int) a row1, Stencil (sh:.Int) a row2, - Stencil (sh:.Int) a row3, - Stencil (sh:.Int) a row4, + Stencil (sh:.Int) a row1, + Stencil (sh:.Int) a row0) => Stencil (sh:.Int:.Int) a (row4, row3, row2, row1, row0) where + type StencilRepr (sh:.Int:.Int) (row4, row3, row2, row1, row0) + = Tup5 (StencilRepr (sh:.Int) row4) (StencilRepr (sh:.Int) row3) (StencilRepr (sh:.Int) row2) + (StencilRepr (sh:.Int) row1) (StencilRepr (sh:.Int) row0) + stencilR = StencilRtup5 (stencilR @(sh:.Int) @a @row4) (stencilR @(sh:.Int) @a @row3) + (stencilR @(sh:.Int) @a @row2) (stencilR @(sh:.Int) @a @row1) (stencilR @(sh:.Int) @a @row0) + stencilPrj s = (stencilPrj @(sh:.Int) @a $ prj4 s, + stencilPrj @(sh:.Int) @a $ prj3 s, + stencilPrj @(sh:.Int) @a $ prj2 s, + stencilPrj @(sh:.Int) @a $ prj1 s, + stencilPrj @(sh:.Int) @a $ prj0 s) + +instance (Stencil (sh:.Int) a row6, Stencil (sh:.Int) a row5, - Stencil (sh:.Int) a row6, - Stencil (sh:.Int) a row7) - => Stencil (sh:.Int:.Int) a (row1, row2, row3, row4, row5, row6, row7) where - type StencilRepr (sh:.Int:.Int) (row1, row2, row3, row4, row5, row6, row7) - = (StencilRepr (sh:.Int) row1, StencilRepr (sh:.Int) row2, StencilRepr (sh:.Int) row3, - StencilRepr (sh:.Int) row4, StencilRepr (sh:.Int) row5, StencilRepr (sh:.Int) row6, - StencilRepr (sh:.Int) row7) - stencilPrj s = (stencilPrj @(sh:.Int) @a (Exp $ Prj tix6 s), - stencilPrj @(sh:.Int) @a (Exp $ Prj tix5 s), - stencilPrj @(sh:.Int) @a (Exp $ Prj tix4 s), - stencilPrj @(sh:.Int) @a (Exp $ Prj tix3 s), - stencilPrj @(sh:.Int) @a (Exp $ Prj tix2 s), - stencilPrj @(sh:.Int) @a (Exp $ Prj tix1 s), - stencilPrj @(sh:.Int) @a (Exp $ Prj tix0 s)) - -instance (Stencil (sh:.Int) a row1, - Stencil (sh:.Int) a row2, - Stencil (sh:.Int) a row3, Stencil (sh:.Int) a row4, - Stencil (sh:.Int) a row5, - Stencil (sh:.Int) a row6, + Stencil (sh:.Int) a row3, + Stencil (sh:.Int) a row2, + Stencil (sh:.Int) a row1, + Stencil (sh:.Int) a row0) + => Stencil (sh:.Int:.Int) a (row6, row5, row4, row3, row2, row1, row0) where + type StencilRepr (sh:.Int:.Int) (row6, row5, row4, row3, row2, row1, row0) + = Tup7 (StencilRepr (sh:.Int) row6) (StencilRepr (sh:.Int) row5) (StencilRepr (sh:.Int) row4) + (StencilRepr (sh:.Int) row3) (StencilRepr (sh:.Int) row2) (StencilRepr (sh:.Int) row1) + (StencilRepr (sh:.Int) row0) + stencilR = StencilRtup7 (stencilR @(sh:.Int) @a @row6) + (stencilR @(sh:.Int) @a @row5) (stencilR @(sh:.Int) @a @row4) (stencilR @(sh:.Int) @a @row3) + (stencilR @(sh:.Int) @a @row2) (stencilR @(sh:.Int) @a @row1) (stencilR @(sh:.Int) @a @row0) + stencilPrj s = (stencilPrj @(sh:.Int) @a $ prj6 s, + stencilPrj @(sh:.Int) @a $ prj5 s, + stencilPrj @(sh:.Int) @a $ prj4 s, + stencilPrj @(sh:.Int) @a $ prj3 s, + stencilPrj @(sh:.Int) @a $ prj2 s, + stencilPrj @(sh:.Int) @a $ prj1 s, + stencilPrj @(sh:.Int) @a $ prj0 s) + +instance (Stencil (sh:.Int) a row8, Stencil (sh:.Int) a row7, - Stencil (sh:.Int) a row8, - Stencil (sh:.Int) a row9) - => Stencil (sh:.Int:.Int) a (row1, row2, row3, row4, row5, row6, row7, row8, row9) where - type StencilRepr (sh:.Int:.Int) (row1, row2, row3, row4, row5, row6, row7, row8, row9) - = (StencilRepr (sh:.Int) row1, StencilRepr (sh:.Int) row2, StencilRepr (sh:.Int) row3, - StencilRepr (sh:.Int) row4, StencilRepr (sh:.Int) row5, StencilRepr (sh:.Int) row6, - StencilRepr (sh:.Int) row7, StencilRepr (sh:.Int) row8, StencilRepr (sh:.Int) row9) - stencilPrj s = (stencilPrj @(sh:.Int) @a (Exp $ Prj tix8 s), - stencilPrj @(sh:.Int) @a (Exp $ Prj tix7 s), - stencilPrj @(sh:.Int) @a (Exp $ Prj tix6 s), - stencilPrj @(sh:.Int) @a (Exp $ Prj tix5 s), - stencilPrj @(sh:.Int) @a (Exp $ Prj tix4 s), - stencilPrj @(sh:.Int) @a (Exp $ Prj tix3 s), - stencilPrj @(sh:.Int) @a (Exp $ Prj tix2 s), - stencilPrj @(sh:.Int) @a (Exp $ Prj tix1 s), - stencilPrj @(sh:.Int) @a (Exp $ Prj tix0 s)) + Stencil (sh:.Int) a row6, + Stencil (sh:.Int) a row5, + Stencil (sh:.Int) a row4, + Stencil (sh:.Int) a row3, + Stencil (sh:.Int) a row2, + Stencil (sh:.Int) a row1, + Stencil (sh:.Int) a row0) + => Stencil (sh:.Int:.Int) a (row8, row7, row6, row5, row4, row3, row2, row1, row0) where + type StencilRepr (sh:.Int:.Int) (row8, row7, row6, row5, row4, row3, row2, row1, row0) + = Tup9 (StencilRepr (sh:.Int) row8) (StencilRepr (sh:.Int) row7) (StencilRepr (sh:.Int) row6) + (StencilRepr (sh:.Int) row5) (StencilRepr (sh:.Int) row4) (StencilRepr (sh:.Int) row3) + (StencilRepr (sh:.Int) row2) (StencilRepr (sh:.Int) row1) (StencilRepr (sh:.Int) row0) + stencilR = StencilRtup9 + (stencilR @(sh:.Int) @a @row8) (stencilR @(sh:.Int) @a @row7) (stencilR @(sh:.Int) @a @row6) + (stencilR @(sh:.Int) @a @row5) (stencilR @(sh:.Int) @a @row4) (stencilR @(sh:.Int) @a @row3) + (stencilR @(sh:.Int) @a @row2) (stencilR @(sh:.Int) @a @row1) (stencilR @(sh:.Int) @a @row0) + stencilPrj s = (stencilPrj @(sh:.Int) @a $ prj8 s, + stencilPrj @(sh:.Int) @a $ prj7 s, + stencilPrj @(sh:.Int) @a $ prj6 s, + stencilPrj @(sh:.Int) @a $ prj5 s, + stencilPrj @(sh:.Int) @a $ prj4 s, + stencilPrj @(sh:.Int) @a $ prj3 s, + stencilPrj @(sh:.Int) @a $ prj2 s, + stencilPrj @(sh:.Int) @a $ prj1 s, + stencilPrj @(sh:.Int) @a $ prj0 s) -- Auxiliary tuple index constants -- -tix0 :: TupleIdx (t, s0) s0 -tix0 = ZeroTupIdx -tix1 :: TupleIdx ((t, s1), s0) s1 -tix1 = SuccTupIdx tix0 +prjTail :: SmartExp (t, a) -> SmartExp t +prjTail = SmartExp . Prj PairIdxLeft + +prj0 :: SmartExp (t, a) -> SmartExp a +prj0 = SmartExp . Prj PairIdxRight -tix2 :: TupleIdx (((t, s2), s1), s0) s2 -tix2 = SuccTupIdx tix1 +prj1 :: SmartExp ((t, a), s0) -> SmartExp a +prj1 = prj0 . prjTail -tix3 :: TupleIdx ((((t, s3), s2), s1), s0) s3 -tix3 = SuccTupIdx tix2 +prj2 :: SmartExp (((t, a), s1), s0) -> SmartExp a +prj2 = prj1 . prjTail -tix4 :: TupleIdx (((((t, s4), s3), s2), s1), s0) s4 -tix4 = SuccTupIdx tix3 +prj3 :: SmartExp ((((t, a), s2), s1), s0) -> SmartExp a +prj3 = prj2 . prjTail -tix5 :: TupleIdx ((((((t, s5), s4), s3), s2), s1), s0) s5 -tix5 = SuccTupIdx tix4 +prj4 :: SmartExp (((((t, a), s3), s2), s1), s0) -> SmartExp a +prj4 = prj3 . prjTail -tix6 :: TupleIdx (((((((t, s6), s5), s4), s3), s2), s1), s0) s6 -tix6 = SuccTupIdx tix5 +prj5 :: SmartExp ((((((t, a), s4), s3), s2), s1), s0) -> SmartExp a +prj5 = prj4 . prjTail -tix7 :: TupleIdx ((((((((t, s7), s6), s5), s4), s3), s2), s1), s0) s7 -tix7 = SuccTupIdx tix6 +prj6 :: SmartExp (((((((t, a), s5), s4), s3), s2), s1), s0) -> SmartExp a +prj6 = prj5 . prjTail -tix8 :: TupleIdx (((((((((t, s8), s7), s6), s5), s4), s3), s2), s1), s0) s8 -tix8 = SuccTupIdx tix7 +prj7 :: SmartExp ((((((((t, a), s6), s5), s4), s3), s2), s1), s0) -> SmartExp a +prj7 = prj6 . prjTail -tix9 :: TupleIdx ((((((((((t, s9), s8), s7), s6), s5), s4), s3), s2), s1), s0) s9 -tix9 = SuccTupIdx tix8 +prj8 :: SmartExp (((((((((t, a), s7), s6), s5), s4), s3), s2), s1), s0) -> SmartExp a +prj8 = prj7 . prjTail -tix10 :: TupleIdx (((((((((((t, s10), s9), s8), s7), s6), s5), s4), s3), s2), s1), s0) s10 -tix10 = SuccTupIdx tix9 +prj9 :: SmartExp ((((((((((t, a), s8), s7), s6), s5), s4), s3), s2), s1), s0) -> SmartExp a +prj9 = prj8 . prjTail -tix11 :: TupleIdx ((((((((((((t, s11), s10), s9), s8), s7), s6), s5), s4), s3), s2), s1), s0) s11 -tix11 = SuccTupIdx tix10 +prj10 :: SmartExp (((((((((((t, a), s9), s8), s7), s6), s5), s4), s3), s2), s1), s0) -> SmartExp a +prj10 = prj9 . prjTail -tix12 :: TupleIdx (((((((((((((t, s12), s11), s10), s9), s8), s7), s6), s5), s4), s3), s2), s1), s0) s12 -tix12 = SuccTupIdx tix11 +prj11 :: SmartExp ((((((((((((t, a), s10), s9), s8), s7), s6), s5), s4), s3), s2), s1), s0) -> SmartExp a +prj11 = prj10 . prjTail -tix13 :: TupleIdx ((((((((((((((t, s13), s12), s11), s10), s9), s8), s7), s6), s5), s4), s3), s2), s1), s0) s13 -tix13 = SuccTupIdx tix12 +prj12 :: SmartExp (((((((((((((t, a), s11), s10), s9), s8), s7), s6), s5), s4), s3), s2), s1), s0) -> SmartExp a +prj12 = prj11 . prjTail -tix14 :: TupleIdx (((((((((((((((t, s14), s13), s12), s11), s10), s9), s8), s7), s6), s5), s4), s3), s2), s1), s0) s14 -tix14 = SuccTupIdx tix13 +prj13 :: SmartExp ((((((((((((((t, a), s12), s11), s10), s9), s8), s7), s6), s5), s4), s3), s2), s1), s0) -> SmartExp a +prj13 = prj12 . prjTail -tix15 :: TupleIdx ((((((((((((((((t, s15), s14), s13), s12), s11), s10), s9), s8), s7), s6), s5), s4), s3), s2), s1), s0) s15 -tix15 = SuccTupIdx tix14 +prj14 :: SmartExp (((((((((((((((t, a), s13), s12), s11), s10), s9), s8), s7), s6), s5), s4), s3), s2), s1), s0) -> SmartExp a +prj14 = prj13 . prjTail +prj15 :: SmartExp ((((((((((((((((t, a), s14), s13), s12), s11), s10), s9), s8), s7), s6), s5), s4), s3), s2), s1), s0) -> SmartExp a +prj15 = prj14 . prjTail -- Smart constructor for literals -- @@ -917,8 +1009,13 @@ tix15 = SuccTupIdx tix14 -- they can be passed as an input to the computation and thus the value can -- change without the need to generate fresh code. -- -constant :: Elt t => t -> Exp t -constant = Exp . Const +constant :: forall e. Elt e => e -> Exp e +constant = Exp . go (eltType @e) . fromElt + where + go :: TupleType t -> t -> SmartExp t + go TupRunit () = SmartExp $ Nil + go (TupRsingle tp) c = SmartExp $ Const tp c + go (TupRpair t1 t2) (c1, c2) = SmartExp $ go t1 c1 `Pair` go t2 c2 -- | 'undef' can be used anywhere a constant is expected, and indicates that the -- consumer of the value can receive an unspecified bit pattern. @@ -943,8 +1040,13 @@ constant = Exp . Const -- -- @since 1.2.0.0 -- -undef :: Elt t => Exp t -undef = Exp Undef +undef :: forall e. Elt e => Exp e +undef = Exp $ go $ eltType @e + where + go :: TupleType t -> SmartExp t + go TupRunit = SmartExp $ Nil + go (TupRsingle t) = SmartExp $ Undef t + go (TupRpair t1 t2) = SmartExp $ go t1 `Pair` go t2 -- | Get the innermost dimension of a shape. -- @@ -957,471 +1059,462 @@ undef = Exp Undef -- innermost nested loop. -- indexHead :: (Elt sh, Elt a) => Exp (sh :. a) -> Exp a -indexHead = Exp . IndexHead +indexHead (Exp x) = exp $ Prj PairIdxRight x -- | Get all but the innermost element of a shape -- indexTail :: (Elt sh, Elt a) => Exp (sh :. a) -> Exp sh -indexTail = Exp . IndexTail +indexTail (Exp x) = exp $ Prj PairIdxLeft x -- Smart constructor and destructors for scalar tuples -- +nilTup :: SmartExp () +nilTup = SmartExp Nil + +snocTup :: Elt b => SmartExp a -> Exp b -> SmartExp (a, EltRepr b) +snocTup a (Exp b) = SmartExp $ Pair a b + tup2 :: (Elt a, Elt b) => (Exp a, Exp b) -> Exp (a, b) tup2 (a, b) = Exp - $ Tuple - $ NilTup `SnocTup` a - `SnocTup` b + $ nilTup `snocTup` a + `snocTup` b tup3 :: (Elt a, Elt b, Elt c) => (Exp a, Exp b, Exp c) -> Exp (a, b, c) tup3 (a, b, c) = Exp - $ Tuple - $ NilTup `SnocTup` a - `SnocTup` b - `SnocTup` c + $ nilTup `snocTup` a + `snocTup` b + `snocTup` c tup4 :: (Elt a, Elt b, Elt c, Elt d) => (Exp a, Exp b, Exp c, Exp d) -> Exp (a, b, c, d) tup4 (a, b, c, d) = Exp - $ Tuple - $ NilTup `SnocTup` a - `SnocTup` b - `SnocTup` c - `SnocTup` d + $ nilTup `snocTup` a + `snocTup` b + `snocTup` c + `snocTup` d tup5 :: (Elt a, Elt b, Elt c, Elt d, Elt e) => (Exp a, Exp b, Exp c, Exp d, Exp e) -> Exp (a, b, c, d, e) tup5 (a, b, c, d, e) = Exp - $ Tuple - $ NilTup `SnocTup` a - `SnocTup` b - `SnocTup` c - `SnocTup` d - `SnocTup` e + $ nilTup `snocTup` a + `snocTup` b + `snocTup` c + `snocTup` d + `snocTup` e tup6 :: (Elt a, Elt b, Elt c, Elt d, Elt e, Elt f) => (Exp a, Exp b, Exp c, Exp d, Exp e, Exp f) -> Exp (a, b, c, d, e, f) tup6 (a, b, c, d, e, f) = Exp - $ Tuple - $ NilTup `SnocTup` a - `SnocTup` b - `SnocTup` c - `SnocTup` d - `SnocTup` e - `SnocTup` f + $ nilTup `snocTup` a + `snocTup` b + `snocTup` c + `snocTup` d + `snocTup` e + `snocTup` f tup7 :: (Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g) => (Exp a, Exp b, Exp c, Exp d, Exp e, Exp f, Exp g) -> Exp (a, b, c, d, e, f, g) tup7 (a, b, c, d, e, f, g) = Exp - $ Tuple - $ NilTup `SnocTup` a - `SnocTup` b - `SnocTup` c - `SnocTup` d - `SnocTup` e - `SnocTup` f - `SnocTup` g + $ nilTup `snocTup` a + `snocTup` b + `snocTup` c + `snocTup` d + `snocTup` e + `snocTup` f + `snocTup` g tup8 :: (Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h) => (Exp a, Exp b, Exp c, Exp d, Exp e, Exp f, Exp g, Exp h) -> Exp (a, b, c, d, e, f, g, h) tup8 (a, b, c, d, e, f, g, h) = Exp - $ Tuple - $ NilTup `SnocTup` a - `SnocTup` b - `SnocTup` c - `SnocTup` d - `SnocTup` e - `SnocTup` f - `SnocTup` g - `SnocTup` h + $ nilTup `snocTup` a + `snocTup` b + `snocTup` c + `snocTup` d + `snocTup` e + `snocTup` f + `snocTup` g + `snocTup` h tup9 :: (Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h, Elt i) => (Exp a, Exp b, Exp c, Exp d, Exp e, Exp f, Exp g, Exp h, Exp i) -> Exp (a, b, c, d, e, f, g, h, i) tup9 (a, b, c, d, e, f, g, h, i) = Exp - $ Tuple - $ NilTup `SnocTup` a - `SnocTup` b - `SnocTup` c - `SnocTup` d - `SnocTup` e - `SnocTup` f - `SnocTup` g - `SnocTup` h - `SnocTup` i + $ nilTup `snocTup` a + `snocTup` b + `snocTup` c + `snocTup` d + `snocTup` e + `snocTup` f + `snocTup` g + `snocTup` h + `snocTup` i tup10 :: (Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h, Elt i, Elt j) => (Exp a, Exp b, Exp c, Exp d, Exp e, Exp f, Exp g, Exp h, Exp i, Exp j) -> Exp (a, b, c, d, e, f, g, h, i, j) tup10 (a, b, c, d, e, f, g, h, i, j) = Exp - $ Tuple - $ NilTup `SnocTup` a - `SnocTup` b - `SnocTup` c - `SnocTup` d - `SnocTup` e - `SnocTup` f - `SnocTup` g - `SnocTup` h - `SnocTup` i - `SnocTup` j + $ nilTup `snocTup` a + `snocTup` b + `snocTup` c + `snocTup` d + `snocTup` e + `snocTup` f + `snocTup` g + `snocTup` h + `snocTup` i + `snocTup` j tup11 :: (Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h, Elt i, Elt j, Elt k) => (Exp a, Exp b, Exp c, Exp d, Exp e, Exp f, Exp g, Exp h, Exp i, Exp j, Exp k) -> Exp (a, b, c, d, e, f, g, h, i, j, k) tup11 (a, b, c, d, e, f, g, h, i, j, k) = Exp - $ Tuple - $ NilTup `SnocTup` a - `SnocTup` b - `SnocTup` c - `SnocTup` d - `SnocTup` e - `SnocTup` f - `SnocTup` g - `SnocTup` h - `SnocTup` i - `SnocTup` j - `SnocTup` k + $ nilTup `snocTup` a + `snocTup` b + `snocTup` c + `snocTup` d + `snocTup` e + `snocTup` f + `snocTup` g + `snocTup` h + `snocTup` i + `snocTup` j + `snocTup` k tup12 :: (Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h, Elt i, Elt j, Elt k, Elt l) => (Exp a, Exp b, Exp c, Exp d, Exp e, Exp f, Exp g, Exp h, Exp i, Exp j, Exp k, Exp l) -> Exp (a, b, c, d, e, f, g, h, i, j, k, l) tup12 (a, b, c, d, e, f, g, h, i, j, k, l) = Exp - $ Tuple - $ NilTup `SnocTup` a - `SnocTup` b - `SnocTup` c - `SnocTup` d - `SnocTup` e - `SnocTup` f - `SnocTup` g - `SnocTup` h - `SnocTup` i - `SnocTup` j - `SnocTup` k - `SnocTup` l + $ nilTup `snocTup` a + `snocTup` b + `snocTup` c + `snocTup` d + `snocTup` e + `snocTup` f + `snocTup` g + `snocTup` h + `snocTup` i + `snocTup` j + `snocTup` k + `snocTup` l tup13 :: (Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h, Elt i, Elt j, Elt k, Elt l, Elt m) => (Exp a, Exp b, Exp c, Exp d, Exp e, Exp f, Exp g, Exp h, Exp i, Exp j, Exp k, Exp l, Exp m) -> Exp (a, b, c, d, e, f, g, h, i, j, k, l, m) tup13 (a, b, c, d, e, f, g, h, i, j, k, l, m) = Exp - $ Tuple - $ NilTup `SnocTup` a - `SnocTup` b - `SnocTup` c - `SnocTup` d - `SnocTup` e - `SnocTup` f - `SnocTup` g - `SnocTup` h - `SnocTup` i - `SnocTup` j - `SnocTup` k - `SnocTup` l - `SnocTup` m + $ nilTup `snocTup` a + `snocTup` b + `snocTup` c + `snocTup` d + `snocTup` e + `snocTup` f + `snocTup` g + `snocTup` h + `snocTup` i + `snocTup` j + `snocTup` k + `snocTup` l + `snocTup` m tup14 :: (Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h, Elt i, Elt j, Elt k, Elt l, Elt m, Elt n) => (Exp a, Exp b, Exp c, Exp d, Exp e, Exp f, Exp g, Exp h, Exp i, Exp j, Exp k, Exp l, Exp m, Exp n) -> Exp (a, b, c, d, e, f, g, h, i, j, k, l, m, n) tup14 (a, b, c, d, e, f, g, h, i, j, k, l, m, n) = Exp - $ Tuple - $ NilTup `SnocTup` a - `SnocTup` b - `SnocTup` c - `SnocTup` d - `SnocTup` e - `SnocTup` f - `SnocTup` g - `SnocTup` h - `SnocTup` i - `SnocTup` j - `SnocTup` k - `SnocTup` l - `SnocTup` m - `SnocTup` n + $ nilTup `snocTup` a + `snocTup` b + `snocTup` c + `snocTup` d + `snocTup` e + `snocTup` f + `snocTup` g + `snocTup` h + `snocTup` i + `snocTup` j + `snocTup` k + `snocTup` l + `snocTup` m + `snocTup` n tup15 :: (Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h, Elt i, Elt j, Elt k, Elt l, Elt m, Elt n, Elt o) => (Exp a, Exp b, Exp c, Exp d, Exp e, Exp f, Exp g, Exp h, Exp i, Exp j, Exp k, Exp l, Exp m, Exp n, Exp o) -> Exp (a, b, c, d, e, f, g, h, i, j, k, l, m, n, o) tup15 (a, b, c, d, e, f, g, h, i, j, k, l, m, n, o) = Exp - $ Tuple - $ NilTup `SnocTup` a - `SnocTup` b - `SnocTup` c - `SnocTup` d - `SnocTup` e - `SnocTup` f - `SnocTup` g - `SnocTup` h - `SnocTup` i - `SnocTup` j - `SnocTup` k - `SnocTup` l - `SnocTup` m - `SnocTup` n - `SnocTup` o + $ nilTup `snocTup` a + `snocTup` b + `snocTup` c + `snocTup` d + `snocTup` e + `snocTup` f + `snocTup` g + `snocTup` h + `snocTup` i + `snocTup` j + `snocTup` k + `snocTup` l + `snocTup` m + `snocTup` n + `snocTup` o tup16 :: (Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h, Elt i, Elt j, Elt k, Elt l, Elt m, Elt n, Elt o, Elt p) => (Exp a, Exp b, Exp c, Exp d, Exp e, Exp f, Exp g, Exp h, Exp i, Exp j, Exp k, Exp l, Exp m, Exp n, Exp o, Exp p) -> Exp (a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p) tup16 (a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p) = Exp - $ Tuple - $ NilTup `SnocTup` a - `SnocTup` b - `SnocTup` c - `SnocTup` d - `SnocTup` e - `SnocTup` f - `SnocTup` g - `SnocTup` h - `SnocTup` i - `SnocTup` j - `SnocTup` k - `SnocTup` l - `SnocTup` m - `SnocTup` n - `SnocTup` o - `SnocTup` p + $ nilTup `snocTup` a + `snocTup` b + `snocTup` c + `snocTup` d + `snocTup` e + `snocTup` f + `snocTup` g + `snocTup` h + `snocTup` i + `snocTup` j + `snocTup` k + `snocTup` l + `snocTup` m + `snocTup` n + `snocTup` o + `snocTup` p untup2 :: (Elt a, Elt b) => Exp (a, b) -> (Exp a, Exp b) -untup2 e = - ( Exp $ tix1 `Prj` e - , Exp $ tix0 `Prj` e ) +untup2 (Exp e) = + ( Exp $ prj1 e + , Exp $ prj0 e ) untup3 :: (Elt a, Elt b, Elt c) => Exp (a, b, c) -> (Exp a, Exp b, Exp c) -untup3 e = - ( Exp $ tix2 `Prj` e - , Exp $ tix1 `Prj` e - , Exp $ tix0 `Prj` e ) +untup3 (Exp e) = + ( Exp $ prj2 e + , Exp $ prj1 e + , Exp $ prj0 e ) untup4 :: (Elt a, Elt b, Elt c, Elt d) => Exp (a, b, c, d) -> (Exp a, Exp b, Exp c, Exp d) -untup4 e = - ( Exp $ tix3 `Prj` e - , Exp $ tix2 `Prj` e - , Exp $ tix1 `Prj` e - , Exp $ tix0 `Prj` e ) +untup4 (Exp e) = + ( Exp $ prj3 e + , Exp $ prj2 e + , Exp $ prj1 e + , Exp $ prj0 e ) untup5 :: (Elt a, Elt b, Elt c, Elt d, Elt e) => Exp (a, b, c, d, e) -> (Exp a, Exp b, Exp c, Exp d, Exp e) -untup5 e = - ( Exp $ tix4 `Prj` e - , Exp $ tix3 `Prj` e - , Exp $ tix2 `Prj` e - , Exp $ tix1 `Prj` e - , Exp $ tix0 `Prj` e ) +untup5 (Exp e) = + ( Exp $ prj4 e + , Exp $ prj3 e + , Exp $ prj2 e + , Exp $ prj1 e + , Exp $ prj0 e ) untup6 :: (Elt a, Elt b, Elt c, Elt d, Elt e, Elt f) => Exp (a, b, c, d, e, f) -> (Exp a, Exp b, Exp c, Exp d, Exp e, Exp f) -untup6 e = - ( Exp $ tix5 `Prj` e - , Exp $ tix4 `Prj` e - , Exp $ tix3 `Prj` e - , Exp $ tix2 `Prj` e - , Exp $ tix1 `Prj` e - , Exp $ tix0 `Prj` e ) +untup6 (Exp e) = + ( Exp $ prj5 e + , Exp $ prj4 e + , Exp $ prj3 e + , Exp $ prj2 e + , Exp $ prj1 e + , Exp $ prj0 e ) untup7 :: (Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g) => Exp (a, b, c, d, e, f, g) -> (Exp a, Exp b, Exp c, Exp d, Exp e, Exp f, Exp g) -untup7 e = - ( Exp $ tix6 `Prj` e - , Exp $ tix5 `Prj` e - , Exp $ tix4 `Prj` e - , Exp $ tix3 `Prj` e - , Exp $ tix2 `Prj` e - , Exp $ tix1 `Prj` e - , Exp $ tix0 `Prj` e ) +untup7 (Exp e) = + ( Exp $ prj6 e + , Exp $ prj5 e + , Exp $ prj4 e + , Exp $ prj3 e + , Exp $ prj2 e + , Exp $ prj1 e + , Exp $ prj0 e ) untup8 :: (Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h) => Exp (a, b, c, d, e, f, g, h) -> (Exp a, Exp b, Exp c, Exp d, Exp e, Exp f, Exp g, Exp h) -untup8 e = - ( Exp $ tix7 `Prj` e - , Exp $ tix6 `Prj` e - , Exp $ tix5 `Prj` e - , Exp $ tix4 `Prj` e - , Exp $ tix3 `Prj` e - , Exp $ tix2 `Prj` e - , Exp $ tix1 `Prj` e - , Exp $ tix0 `Prj` e ) +untup8 (Exp e) = + ( Exp $ prj7 e + , Exp $ prj6 e + , Exp $ prj5 e + , Exp $ prj4 e + , Exp $ prj3 e + , Exp $ prj2 e + , Exp $ prj1 e + , Exp $ prj0 e ) untup9 :: (Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h, Elt i) => Exp (a, b, c, d, e, f, g, h, i) -> (Exp a, Exp b, Exp c, Exp d, Exp e, Exp f, Exp g, Exp h, Exp i) -untup9 e = - ( Exp $ tix8 `Prj` e - , Exp $ tix7 `Prj` e - , Exp $ tix6 `Prj` e - , Exp $ tix5 `Prj` e - , Exp $ tix4 `Prj` e - , Exp $ tix3 `Prj` e - , Exp $ tix2 `Prj` e - , Exp $ tix1 `Prj` e - , Exp $ tix0 `Prj` e ) +untup9 (Exp e) = + ( Exp $ prj8 e + , Exp $ prj7 e + , Exp $ prj6 e + , Exp $ prj5 e + , Exp $ prj4 e + , Exp $ prj3 e + , Exp $ prj2 e + , Exp $ prj1 e + , Exp $ prj0 e ) untup10 :: (Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h, Elt i, Elt j) => Exp (a, b, c, d, e, f, g, h, i, j) -> (Exp a, Exp b, Exp c, Exp d, Exp e, Exp f, Exp g, Exp h, Exp i, Exp j) -untup10 e = - ( Exp $ tix9 `Prj` e - , Exp $ tix8 `Prj` e - , Exp $ tix7 `Prj` e - , Exp $ tix6 `Prj` e - , Exp $ tix5 `Prj` e - , Exp $ tix4 `Prj` e - , Exp $ tix3 `Prj` e - , Exp $ tix2 `Prj` e - , Exp $ tix1 `Prj` e - , Exp $ tix0 `Prj` e ) +untup10 (Exp e) = + ( Exp $ prj9 e + , Exp $ prj8 e + , Exp $ prj7 e + , Exp $ prj6 e + , Exp $ prj5 e + , Exp $ prj4 e + , Exp $ prj3 e + , Exp $ prj2 e + , Exp $ prj1 e + , Exp $ prj0 e ) untup11 :: (Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h, Elt i, Elt j, Elt k) => Exp (a, b, c, d, e, f, g, h, i, j, k) -> (Exp a, Exp b, Exp c, Exp d, Exp e, Exp f, Exp g, Exp h, Exp i, Exp j, Exp k) -untup11 e = - ( Exp $ tix10 `Prj` e - , Exp $ tix9 `Prj` e - , Exp $ tix8 `Prj` e - , Exp $ tix7 `Prj` e - , Exp $ tix6 `Prj` e - , Exp $ tix5 `Prj` e - , Exp $ tix4 `Prj` e - , Exp $ tix3 `Prj` e - , Exp $ tix2 `Prj` e - , Exp $ tix1 `Prj` e - , Exp $ tix0 `Prj` e ) +untup11 (Exp e) = + ( Exp $ prj10 e + , Exp $ prj9 e + , Exp $ prj8 e + , Exp $ prj7 e + , Exp $ prj6 e + , Exp $ prj5 e + , Exp $ prj4 e + , Exp $ prj3 e + , Exp $ prj2 e + , Exp $ prj1 e + , Exp $ prj0 e ) untup12 :: (Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h, Elt i, Elt j, Elt k, Elt l) => Exp (a, b, c, d, e, f, g, h, i, j, k, l) -> (Exp a, Exp b, Exp c, Exp d, Exp e, Exp f, Exp g, Exp h, Exp i, Exp j, Exp k, Exp l) -untup12 e = - ( Exp $ tix11 `Prj` e - , Exp $ tix10 `Prj` e - , Exp $ tix9 `Prj` e - , Exp $ tix8 `Prj` e - , Exp $ tix7 `Prj` e - , Exp $ tix6 `Prj` e - , Exp $ tix5 `Prj` e - , Exp $ tix4 `Prj` e - , Exp $ tix3 `Prj` e - , Exp $ tix2 `Prj` e - , Exp $ tix1 `Prj` e - , Exp $ tix0 `Prj` e ) +untup12 (Exp e) = + ( Exp $ prj11 e + , Exp $ prj10 e + , Exp $ prj9 e + , Exp $ prj8 e + , Exp $ prj7 e + , Exp $ prj6 e + , Exp $ prj5 e + , Exp $ prj4 e + , Exp $ prj3 e + , Exp $ prj2 e + , Exp $ prj1 e + , Exp $ prj0 e ) untup13 :: (Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h, Elt i, Elt j, Elt k, Elt l, Elt m) => Exp (a, b, c, d, e, f, g, h, i, j, k, l, m) -> (Exp a, Exp b, Exp c, Exp d, Exp e, Exp f, Exp g, Exp h, Exp i, Exp j, Exp k, Exp l, Exp m) -untup13 e = - ( Exp $ tix12 `Prj` e - , Exp $ tix11 `Prj` e - , Exp $ tix10 `Prj` e - , Exp $ tix9 `Prj` e - , Exp $ tix8 `Prj` e - , Exp $ tix7 `Prj` e - , Exp $ tix6 `Prj` e - , Exp $ tix5 `Prj` e - , Exp $ tix4 `Prj` e - , Exp $ tix3 `Prj` e - , Exp $ tix2 `Prj` e - , Exp $ tix1 `Prj` e - , Exp $ tix0 `Prj` e ) +untup13 (Exp e) = + ( Exp $ prj12 e + , Exp $ prj11 e + , Exp $ prj10 e + , Exp $ prj9 e + , Exp $ prj8 e + , Exp $ prj7 e + , Exp $ prj6 e + , Exp $ prj5 e + , Exp $ prj4 e + , Exp $ prj3 e + , Exp $ prj2 e + , Exp $ prj1 e + , Exp $ prj0 e ) untup14 :: (Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h, Elt i, Elt j, Elt k, Elt l, Elt m, Elt n) => Exp (a, b, c, d, e, f, g, h, i, j, k, l, m, n) -> (Exp a, Exp b, Exp c, Exp d, Exp e, Exp f, Exp g, Exp h, Exp i, Exp j, Exp k, Exp l, Exp m, Exp n) -untup14 e = - ( Exp $ tix13 `Prj` e - , Exp $ tix12 `Prj` e - , Exp $ tix11 `Prj` e - , Exp $ tix10 `Prj` e - , Exp $ tix9 `Prj` e - , Exp $ tix8 `Prj` e - , Exp $ tix7 `Prj` e - , Exp $ tix6 `Prj` e - , Exp $ tix5 `Prj` e - , Exp $ tix4 `Prj` e - , Exp $ tix3 `Prj` e - , Exp $ tix2 `Prj` e - , Exp $ tix1 `Prj` e - , Exp $ tix0 `Prj` e ) +untup14 (Exp e) = + ( Exp $ prj13 e + , Exp $ prj12 e + , Exp $ prj11 e + , Exp $ prj10 e + , Exp $ prj9 e + , Exp $ prj8 e + , Exp $ prj7 e + , Exp $ prj6 e + , Exp $ prj5 e + , Exp $ prj4 e + , Exp $ prj3 e + , Exp $ prj2 e + , Exp $ prj1 e + , Exp $ prj0 e ) untup15 :: (Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h, Elt i, Elt j, Elt k, Elt l, Elt m, Elt n, Elt o) => Exp (a, b, c, d, e, f, g, h, i, j, k, l, m, n, o) -> (Exp a, Exp b, Exp c, Exp d, Exp e, Exp f, Exp g, Exp h, Exp i, Exp j, Exp k, Exp l, Exp m, Exp n, Exp o) -untup15 e = - ( Exp $ tix14 `Prj` e - , Exp $ tix13 `Prj` e - , Exp $ tix12 `Prj` e - , Exp $ tix11 `Prj` e - , Exp $ tix10 `Prj` e - , Exp $ tix9 `Prj` e - , Exp $ tix8 `Prj` e - , Exp $ tix7 `Prj` e - , Exp $ tix6 `Prj` e - , Exp $ tix5 `Prj` e - , Exp $ tix4 `Prj` e - , Exp $ tix3 `Prj` e - , Exp $ tix2 `Prj` e - , Exp $ tix1 `Prj` e - , Exp $ tix0 `Prj` e ) +untup15 (Exp e) = + ( Exp $ prj14 e + , Exp $ prj13 e + , Exp $ prj12 e + , Exp $ prj11 e + , Exp $ prj10 e + , Exp $ prj9 e + , Exp $ prj8 e + , Exp $ prj7 e + , Exp $ prj6 e + , Exp $ prj5 e + , Exp $ prj4 e + , Exp $ prj3 e + , Exp $ prj2 e + , Exp $ prj1 e + , Exp $ prj0 e ) untup16 :: (Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h, Elt i, Elt j, Elt k, Elt l, Elt m, Elt n, Elt o, Elt p) => Exp (a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p) -> (Exp a, Exp b, Exp c, Exp d, Exp e, Exp f, Exp g, Exp h, Exp i, Exp j, Exp k, Exp l, Exp m, Exp n, Exp o, Exp p) -untup16 e = - ( Exp $ tix15 `Prj` e - , Exp $ tix14 `Prj` e - , Exp $ tix13 `Prj` e - , Exp $ tix12 `Prj` e - , Exp $ tix11 `Prj` e - , Exp $ tix10 `Prj` e - , Exp $ tix9 `Prj` e - , Exp $ tix8 `Prj` e - , Exp $ tix7 `Prj` e - , Exp $ tix6 `Prj` e - , Exp $ tix5 `Prj` e - , Exp $ tix4 `Prj` e - , Exp $ tix3 `Prj` e - , Exp $ tix2 `Prj` e - , Exp $ tix1 `Prj` e - , Exp $ tix0 `Prj` e ) +untup16 (Exp e) = + ( Exp $ prj15 e + , Exp $ prj14 e + , Exp $ prj13 e + , Exp $ prj12 e + , Exp $ prj11 e + , Exp $ prj10 e + , Exp $ prj9 e + , Exp $ prj8 e + , Exp $ prj7 e + , Exp $ prj6 e + , Exp $ prj5 e + , Exp $ prj4 e + , Exp $ prj3 e + , Exp $ prj2 e + , Exp $ prj1 e + , Exp $ prj0 e ) -- Smart constructor for constants -- -mkMinBound :: (Elt t, IsBounded t) => Exp t -mkMinBound = Exp $ PrimConst (PrimMinBound boundedType) +mkMinBound :: (Elt t, IsBounded (EltRepr t)) => Exp t +mkMinBound = exp $ PrimConst (PrimMinBound boundedType) -mkMaxBound :: (Elt t, IsBounded t) => Exp t -mkMaxBound = Exp $ PrimConst (PrimMaxBound boundedType) +mkMaxBound :: (Elt t, IsBounded (EltRepr t)) => Exp t +mkMaxBound = exp $ PrimConst (PrimMaxBound boundedType) -mkPi :: (Elt r, IsFloating r) => Exp r -mkPi = Exp $ PrimConst (PrimPi floatingType) +mkPi :: (Elt r, IsFloating (EltRepr r)) => Exp r +mkPi = exp $ PrimConst (PrimPi floatingType) -- Smart constructors for primitive applications @@ -1429,239 +1522,264 @@ mkPi = Exp $ PrimConst (PrimPi floatingType) -- Operators from Floating -mkSin :: (Elt t, IsFloating t) => Exp t -> Exp t -mkSin x = Exp $ PrimSin floatingType `PrimApp` x +mkSin :: (Elt t, IsFloating (EltRepr t)) => Exp t -> Exp t +mkSin = mkPrimUnary $ PrimSin floatingType -mkCos :: (Elt t, IsFloating t) => Exp t -> Exp t -mkCos x = Exp $ PrimCos floatingType `PrimApp` x +mkCos :: (Elt t, IsFloating (EltRepr t)) => Exp t -> Exp t +mkCos = mkPrimUnary $ PrimCos floatingType -mkTan :: (Elt t, IsFloating t) => Exp t -> Exp t -mkTan x = Exp $ PrimTan floatingType `PrimApp` x +mkTan :: (Elt t, IsFloating (EltRepr t)) => Exp t -> Exp t +mkTan = mkPrimUnary $ PrimTan floatingType -mkAsin :: (Elt t, IsFloating t) => Exp t -> Exp t -mkAsin x = Exp $ PrimAsin floatingType `PrimApp` x +mkAsin :: (Elt t, IsFloating (EltRepr t)) => Exp t -> Exp t +mkAsin = mkPrimUnary $ PrimAsin floatingType -mkAcos :: (Elt t, IsFloating t) => Exp t -> Exp t -mkAcos x = Exp $ PrimAcos floatingType `PrimApp` x +mkAcos :: (Elt t, IsFloating (EltRepr t)) => Exp t -> Exp t +mkAcos = mkPrimUnary $ PrimAcos floatingType -mkAtan :: (Elt t, IsFloating t) => Exp t -> Exp t -mkAtan x = Exp $ PrimAtan floatingType `PrimApp` x +mkAtan :: (Elt t, IsFloating (EltRepr t)) => Exp t -> Exp t +mkAtan = mkPrimUnary $ PrimAtan floatingType -mkSinh :: (Elt t, IsFloating t) => Exp t -> Exp t -mkSinh x = Exp $ PrimSinh floatingType `PrimApp` x +mkSinh :: (Elt t, IsFloating (EltRepr t)) => Exp t -> Exp t +mkSinh = mkPrimUnary $ PrimSinh floatingType -mkCosh :: (Elt t, IsFloating t) => Exp t -> Exp t -mkCosh x = Exp $ PrimCosh floatingType `PrimApp` x +mkCosh :: (Elt t, IsFloating (EltRepr t)) => Exp t -> Exp t +mkCosh = mkPrimUnary $ PrimCosh floatingType -mkTanh :: (Elt t, IsFloating t) => Exp t -> Exp t -mkTanh x = Exp $ PrimTanh floatingType `PrimApp` x +mkTanh :: (Elt t, IsFloating (EltRepr t)) => Exp t -> Exp t +mkTanh = mkPrimUnary $ PrimTanh floatingType -mkAsinh :: (Elt t, IsFloating t) => Exp t -> Exp t -mkAsinh x = Exp $ PrimAsinh floatingType `PrimApp` x +mkAsinh :: (Elt t, IsFloating (EltRepr t)) => Exp t -> Exp t +mkAsinh = mkPrimUnary $ PrimAsinh floatingType -mkAcosh :: (Elt t, IsFloating t) => Exp t -> Exp t -mkAcosh x = Exp $ PrimAcosh floatingType `PrimApp` x +mkAcosh :: (Elt t, IsFloating (EltRepr t)) => Exp t -> Exp t +mkAcosh = mkPrimUnary $ PrimAcosh floatingType -mkAtanh :: (Elt t, IsFloating t) => Exp t -> Exp t -mkAtanh x = Exp $ PrimAtanh floatingType `PrimApp` x +mkAtanh :: (Elt t, IsFloating (EltRepr t)) => Exp t -> Exp t +mkAtanh = mkPrimUnary $ PrimAtanh floatingType -mkExpFloating :: (Elt t, IsFloating t) => Exp t -> Exp t -mkExpFloating x = Exp $ PrimExpFloating floatingType `PrimApp` x +mkExpFloating :: (Elt t, IsFloating (EltRepr t)) => Exp t -> Exp t +mkExpFloating = mkPrimUnary $ PrimExpFloating floatingType -mkSqrt :: (Elt t, IsFloating t) => Exp t -> Exp t -mkSqrt x = Exp $ PrimSqrt floatingType `PrimApp` x +mkSqrt :: (Elt t, IsFloating (EltRepr t)) => Exp t -> Exp t +mkSqrt = mkPrimUnary $ PrimSqrt floatingType -mkLog :: (Elt t, IsFloating t) => Exp t -> Exp t -mkLog x = Exp $ PrimLog floatingType `PrimApp` x +mkLog :: (Elt t, IsFloating (EltRepr t)) => Exp t -> Exp t +mkLog = mkPrimUnary $ PrimLog floatingType -mkFPow :: (Elt t, IsFloating t) => Exp t -> Exp t -> Exp t -mkFPow x y = Exp $ PrimFPow floatingType `PrimApp` tup2 (x, y) +mkFPow :: (Elt t, IsFloating (EltRepr t)) => Exp t -> Exp t -> Exp t +mkFPow = mkPrimBinary $ PrimFPow floatingType -mkLogBase :: (Elt t, IsFloating t) => Exp t -> Exp t -> Exp t -mkLogBase x y = Exp $ PrimLogBase floatingType `PrimApp` tup2 (x, y) +mkLogBase :: (Elt t, IsFloating (EltRepr t)) => Exp t -> Exp t -> Exp t +mkLogBase = mkPrimBinary $ PrimLogBase floatingType -- Operators from Num -mkAdd :: (Elt t, IsNum t) => Exp t -> Exp t -> Exp t -mkAdd x y = Exp $ PrimAdd numType `PrimApp` tup2 (x, y) +mkAdd :: (Elt t, IsNum (EltRepr t)) => Exp t -> Exp t -> Exp t +mkAdd = mkPrimBinary $ PrimAdd numType -mkSub :: (Elt t, IsNum t) => Exp t -> Exp t -> Exp t -mkSub x y = Exp $ PrimSub numType `PrimApp` tup2 (x, y) +mkSub :: (Elt t, IsNum (EltRepr t)) => Exp t -> Exp t -> Exp t +mkSub = mkPrimBinary $ PrimSub numType -mkMul :: (Elt t, IsNum t) => Exp t -> Exp t -> Exp t -mkMul x y = Exp $ PrimMul numType `PrimApp` tup2 (x, y) +mkMul :: (Elt t, IsNum (EltRepr t)) => Exp t -> Exp t -> Exp t +mkMul = mkPrimBinary $ PrimMul numType -mkNeg :: (Elt t, IsNum t) => Exp t -> Exp t -mkNeg x = Exp $ PrimNeg numType `PrimApp` x +mkNeg :: (Elt t, IsNum (EltRepr t)) => Exp t -> Exp t +mkNeg = mkPrimUnary $ PrimNeg numType -mkAbs :: (Elt t, IsNum t) => Exp t -> Exp t -mkAbs x = Exp $ PrimAbs numType `PrimApp` x +mkAbs :: (Elt t, IsNum (EltRepr t)) => Exp t -> Exp t +mkAbs = mkPrimUnary $ PrimAbs numType -mkSig :: (Elt t, IsNum t) => Exp t -> Exp t -mkSig x = Exp $ PrimSig numType `PrimApp` x +mkSig :: (Elt t, IsNum (EltRepr t)) => Exp t -> Exp t +mkSig = mkPrimUnary $ PrimSig numType -- Operators from Integral -mkQuot :: (Elt t, IsIntegral t) => Exp t -> Exp t -> Exp t -mkQuot x y = Exp $ PrimQuot integralType `PrimApp` tup2 (x, y) - -mkRem :: (Elt t, IsIntegral t) => Exp t -> Exp t -> Exp t -mkRem x y = Exp $ PrimRem integralType `PrimApp` tup2 (x, y) +mkQuot :: (Elt t, IsIntegral (EltRepr t)) => Exp t -> Exp t -> Exp t +mkQuot = mkPrimBinary $ PrimQuot integralType -mkQuotRem :: (Elt t, IsIntegral t) => Exp t -> Exp t -> (Exp t, Exp t) -mkQuotRem x y = untup2 $ Exp $ PrimQuotRem integralType `PrimApp` tup2 (x ,y) +mkRem :: (Elt t, IsIntegral (EltRepr t)) => Exp t -> Exp t -> Exp t +mkRem = mkPrimBinary $ PrimRem integralType -mkIDiv :: (Elt t, IsIntegral t) => Exp t -> Exp t -> Exp t -mkIDiv x y = Exp $ PrimIDiv integralType `PrimApp` tup2 (x, y) +mkQuotRem :: (Elt t, IsIntegral (EltRepr t)) => Exp t -> Exp t -> (Exp t, Exp t) +mkQuotRem (Exp x) (Exp y) = + let pair = SmartExp $ PrimQuotRem integralType `PrimApp` (SmartExp $ Pair x y) + in (exp $ Prj PairIdxLeft pair, exp $ Prj PairIdxRight pair) -mkMod :: (Elt t, IsIntegral t) => Exp t -> Exp t -> Exp t -mkMod x y = Exp $ PrimMod integralType `PrimApp` tup2 (x, y) +mkIDiv :: (Elt t, IsIntegral (EltRepr t)) => Exp t -> Exp t -> Exp t +mkIDiv = mkPrimBinary $ PrimIDiv integralType -mkDivMod :: (Elt t, IsIntegral t) => Exp t -> Exp t -> (Exp t, Exp t) -mkDivMod x y = untup2 $ Exp $ PrimDivMod integralType `PrimApp` tup2 (x ,y) +mkMod :: (Elt t, IsIntegral (EltRepr t)) => Exp t -> Exp t -> Exp t +mkMod = mkPrimBinary $ PrimMod integralType +mkDivMod :: (Elt t, IsIntegral (EltRepr t)) => Exp t -> Exp t -> (Exp t, Exp t) +mkDivMod (Exp x) (Exp y) = + let pair = SmartExp $ PrimDivMod integralType `PrimApp` (SmartExp $ Pair x y) + in (exp $ Prj PairIdxLeft pair, exp $ Prj PairIdxRight pair) -- Operators from Bits and FiniteBits -mkBAnd :: (Elt t, IsIntegral t) => Exp t -> Exp t -> Exp t -mkBAnd x y = Exp $ PrimBAnd integralType `PrimApp` tup2 (x, y) +mkBAnd :: (Elt t, IsIntegral (EltRepr t)) => Exp t -> Exp t -> Exp t +mkBAnd = mkPrimBinary $ PrimBAnd integralType -mkBOr :: (Elt t, IsIntegral t) => Exp t -> Exp t -> Exp t -mkBOr x y = Exp $ PrimBOr integralType `PrimApp` tup2 (x, y) +mkBOr :: (Elt t, IsIntegral (EltRepr t)) => Exp t -> Exp t -> Exp t +mkBOr = mkPrimBinary $ PrimBOr integralType -mkBXor :: (Elt t, IsIntegral t) => Exp t -> Exp t -> Exp t -mkBXor x y = Exp $ PrimBXor integralType `PrimApp` tup2 (x, y) +mkBXor :: (Elt t, IsIntegral (EltRepr t)) => Exp t -> Exp t -> Exp t +mkBXor = mkPrimBinary $ PrimBXor integralType -mkBNot :: (Elt t, IsIntegral t) => Exp t -> Exp t -mkBNot x = Exp $ PrimBNot integralType `PrimApp` x +mkBNot :: (Elt t, IsIntegral (EltRepr t)) => Exp t -> Exp t +mkBNot = mkPrimUnary $ PrimBNot integralType -mkBShiftL :: (Elt t, IsIntegral t) => Exp t -> Exp Int -> Exp t -mkBShiftL x i = Exp $ PrimBShiftL integralType `PrimApp` tup2 (x, i) +mkBShiftL :: (Elt t, IsIntegral (EltRepr t)) => Exp t -> Exp Int -> Exp t +mkBShiftL = mkPrimBinary $ PrimBShiftL integralType -mkBShiftR :: (Elt t, IsIntegral t) => Exp t -> Exp Int -> Exp t -mkBShiftR x i = Exp $ PrimBShiftR integralType `PrimApp` tup2 (x, i) +mkBShiftR :: (Elt t, IsIntegral (EltRepr t)) => Exp t -> Exp Int -> Exp t +mkBShiftR = mkPrimBinary $ PrimBShiftR integralType -mkBRotateL :: (Elt t, IsIntegral t) => Exp t -> Exp Int -> Exp t -mkBRotateL x i = Exp $ PrimBRotateL integralType `PrimApp` tup2 (x, i) +mkBRotateL :: (Elt t, IsIntegral (EltRepr t)) => Exp t -> Exp Int -> Exp t +mkBRotateL = mkPrimBinary $ PrimBRotateL integralType -mkBRotateR :: (Elt t, IsIntegral t) => Exp t -> Exp Int -> Exp t -mkBRotateR x i = Exp $ PrimBRotateR integralType `PrimApp` tup2 (x, i) +mkBRotateR :: (Elt t, IsIntegral (EltRepr t)) => Exp t -> Exp Int -> Exp t +mkBRotateR = mkPrimBinary $ PrimBRotateR integralType -mkPopCount :: (Elt t, IsIntegral t) => Exp t -> Exp Int -mkPopCount x = Exp $ PrimPopCount integralType `PrimApp` x +mkPopCount :: (Elt t, IsIntegral (EltRepr t)) => Exp t -> Exp Int +mkPopCount = mkPrimUnary $ PrimPopCount integralType -mkCountLeadingZeros :: (Elt t, IsIntegral t) => Exp t -> Exp Int -mkCountLeadingZeros x = Exp $ PrimCountLeadingZeros integralType `PrimApp` x +mkCountLeadingZeros :: (Elt t, IsIntegral (EltRepr t)) => Exp t -> Exp Int +mkCountLeadingZeros = mkPrimUnary $ PrimCountLeadingZeros integralType -mkCountTrailingZeros :: (Elt t, IsIntegral t) => Exp t -> Exp Int -mkCountTrailingZeros x = Exp $ PrimCountTrailingZeros integralType `PrimApp` x +mkCountTrailingZeros :: (Elt t, IsIntegral (EltRepr t)) => Exp t -> Exp Int +mkCountTrailingZeros = mkPrimUnary $ PrimCountTrailingZeros integralType -- Operators from Fractional -mkFDiv :: (Elt t, IsFloating t) => Exp t -> Exp t -> Exp t -mkFDiv x y = Exp $ PrimFDiv floatingType `PrimApp` tup2 (x, y) +mkFDiv :: (Elt t, IsFloating (EltRepr t)) => Exp t -> Exp t -> Exp t +mkFDiv = mkPrimBinary $ PrimFDiv floatingType -mkRecip :: (Elt t, IsFloating t) => Exp t -> Exp t -mkRecip x = Exp $ PrimRecip floatingType `PrimApp` x +mkRecip :: (Elt t, IsFloating (EltRepr t)) => Exp t -> Exp t +mkRecip = mkPrimUnary $ PrimRecip floatingType -- Operators from RealFrac -mkTruncate :: (Elt a, Elt b, IsFloating a, IsIntegral b) => Exp a -> Exp b -mkTruncate x = Exp $ PrimTruncate floatingType integralType `PrimApp` x +mkTruncate :: (Elt a, Elt b, IsFloating (EltRepr a), IsIntegral (EltRepr b)) => Exp a -> Exp b +mkTruncate = mkPrimUnary $ PrimTruncate floatingType integralType -mkRound :: (Elt a, Elt b, IsFloating a, IsIntegral b) => Exp a -> Exp b -mkRound x = Exp $ PrimRound floatingType integralType `PrimApp` x +mkRound :: (Elt a, Elt b, IsFloating (EltRepr a), IsIntegral (EltRepr b)) => Exp a -> Exp b +mkRound = mkPrimUnary $ PrimRound floatingType integralType -mkFloor :: (Elt a, Elt b, IsFloating a, IsIntegral b) => Exp a -> Exp b -mkFloor x = Exp $ PrimFloor floatingType integralType `PrimApp` x +mkFloor :: (Elt a, Elt b, IsFloating (EltRepr a), IsIntegral (EltRepr b)) => Exp a -> Exp b +mkFloor = mkPrimUnary $ PrimFloor floatingType integralType -mkCeiling :: (Elt a, Elt b, IsFloating a, IsIntegral b) => Exp a -> Exp b -mkCeiling x = Exp $ PrimCeiling floatingType integralType `PrimApp` x +mkCeiling :: (Elt a, Elt b, IsFloating (EltRepr a), IsIntegral (EltRepr b)) => Exp a -> Exp b +mkCeiling = mkPrimUnary $ PrimCeiling floatingType integralType -- Operators from RealFloat -mkAtan2 :: (Elt t, IsFloating t) => Exp t -> Exp t -> Exp t -mkAtan2 x y = Exp $ PrimAtan2 floatingType `PrimApp` tup2 (x, y) +mkAtan2 :: (Elt t, IsFloating (EltRepr t)) => Exp t -> Exp t -> Exp t +mkAtan2 = mkPrimBinary $ PrimAtan2 floatingType -mkIsNaN :: (Elt t, IsFloating t) => Exp t -> Exp Bool -mkIsNaN x = Exp $ PrimIsNaN floatingType `PrimApp` x +mkIsNaN :: (Elt t, IsFloating (EltRepr t)) => Exp t -> Exp Bool +mkIsNaN = mkPrimUnary $ PrimIsNaN floatingType -mkIsInfinite :: (Elt t, IsFloating t) => Exp t -> Exp Bool -mkIsInfinite x = Exp $ PrimIsInfinite floatingType `PrimApp` x +mkIsInfinite :: (Elt t, IsFloating (EltRepr t)) => Exp t -> Exp Bool +mkIsInfinite = mkPrimUnary $ PrimIsInfinite floatingType -- FIXME: add missing operations from Floating, RealFrac & RealFloat -- Relational and equality operators -mkLt :: (Elt t, IsSingle t) => Exp t -> Exp t -> Exp Bool -mkLt x y = Exp $ PrimLt singleType `PrimApp` tup2 (x, y) +mkLt :: (Elt t, IsSingle (EltRepr t)) => Exp t -> Exp t -> Exp Bool +mkLt = mkPrimBinary $ PrimLt singleType -mkGt :: (Elt t, IsSingle t) => Exp t -> Exp t -> Exp Bool -mkGt x y = Exp $ PrimGt singleType `PrimApp` tup2 (x, y) +mkGt :: (Elt t, IsSingle (EltRepr t)) => Exp t -> Exp t -> Exp Bool +mkGt = mkPrimBinary $ PrimGt singleType -mkLtEq :: (Elt t, IsSingle t) => Exp t -> Exp t -> Exp Bool -mkLtEq x y = Exp $ PrimLtEq singleType `PrimApp` tup2 (x, y) +mkLtEq :: (Elt t, IsSingle (EltRepr t)) => Exp t -> Exp t -> Exp Bool +mkLtEq = mkPrimBinary $ PrimLtEq singleType -mkGtEq :: (Elt t, IsSingle t) => Exp t -> Exp t -> Exp Bool -mkGtEq x y = Exp $ PrimGtEq singleType `PrimApp` tup2 (x, y) +mkGtEq :: (Elt t, IsSingle (EltRepr t)) => Exp t -> Exp t -> Exp Bool +mkGtEq = mkPrimBinary $ PrimGtEq singleType -mkEq :: (Elt t, IsSingle t) => Exp t -> Exp t -> Exp Bool -mkEq x y = Exp $ PrimEq singleType `PrimApp` tup2 (x, y) +mkEq :: (Elt t, IsSingle (EltRepr t)) => Exp t -> Exp t -> Exp Bool +mkEq = mkPrimBinary $ PrimEq singleType -mkNEq :: (Elt t, IsSingle t) => Exp t -> Exp t -> Exp Bool -mkNEq x y = Exp $ PrimNEq singleType `PrimApp` tup2 (x, y) +mkNEq :: (Elt t, IsSingle (EltRepr t)) => Exp t -> Exp t -> Exp Bool +mkNEq = mkPrimBinary $ PrimNEq singleType -mkMax :: (Elt t, IsSingle t) => Exp t -> Exp t -> Exp t -mkMax x y = Exp $ PrimMax singleType `PrimApp` tup2 (x, y) +mkMax :: (Elt t, IsSingle (EltRepr t)) => Exp t -> Exp t -> Exp t +mkMax = mkPrimBinary $ PrimMax singleType -mkMin :: (Elt t, IsSingle t) => Exp t -> Exp t -> Exp t -mkMin x y = Exp $ PrimMin singleType `PrimApp` tup2 (x, y) +mkMin :: (Elt t, IsSingle (EltRepr t)) => Exp t -> Exp t -> Exp t +mkMin = mkPrimBinary $ PrimMin singleType -- Logical operators mkLAnd :: Exp Bool -> Exp Bool -> Exp Bool -mkLAnd x y = Exp $ PrimLAnd `PrimApp` tup2 (x, y) +mkLAnd = mkPrimBinary PrimLAnd mkLOr :: Exp Bool -> Exp Bool -> Exp Bool -mkLOr x y = Exp $ PrimLOr `PrimApp` tup2 (x, y) +mkLOr = mkPrimBinary PrimLOr mkLNot :: Exp Bool -> Exp Bool -mkLNot x = Exp $ PrimLNot `PrimApp` x +mkLNot = mkPrimUnary PrimLNot -- Character conversions mkOrd :: Exp Char -> Exp Int -mkOrd x = Exp $ PrimOrd `PrimApp` x +mkOrd = mkPrimUnary PrimOrd mkChr :: Exp Int -> Exp Char -mkChr x = Exp $ PrimChr `PrimApp` x +mkChr = mkPrimUnary PrimChr -- Numeric conversions -mkFromIntegral :: (Elt a, Elt b, IsIntegral a, IsNum b) => Exp a -> Exp b -mkFromIntegral x = Exp $ PrimFromIntegral integralType numType `PrimApp` x +mkFromIntegral :: (Elt a, Elt b, IsIntegral (EltRepr a), IsNum (EltRepr b)) => Exp a -> Exp b +mkFromIntegral = mkPrimUnary $ PrimFromIntegral integralType numType -mkToFloating :: (Elt a, Elt b, IsNum a, IsFloating b) => Exp a -> Exp b -mkToFloating x = Exp $ PrimToFloating numType floatingType `PrimApp` x +mkToFloating :: (Elt a, Elt b, IsNum (EltRepr a), IsFloating (EltRepr b)) => Exp a -> Exp b +mkToFloating = mkPrimUnary $ PrimToFloating numType floatingType -- Other conversions mkBoolToInt :: Exp Bool -> Exp Int -mkBoolToInt b = Exp $ PrimBoolToInt `PrimApp` b +mkBoolToInt (Exp b) = exp $ PrimBoolToInt `PrimApp` b -- NOTE: Restricted to scalar types with a type-level BitSizeEq constraint to -- make this version "safe" mkBitcast :: forall b a. (Elt a, Elt b, IsScalar (EltRepr a), IsScalar (EltRepr b), BitSizeEq (EltRepr a) (EltRepr b)) => Exp a -> Exp b -mkBitcast = mkUnsafeCoerce +mkBitcast (Exp a) = exp $ Coerce (scalarType @(EltRepr a)) (scalarType @(EltRepr b)) a + +mkCoerce :: Coerce (EltRepr a) (EltRepr b) => Exp a -> Exp b +mkCoerce (Exp a) = Exp $ mkCoerce' a + +class Coerce a b where + mkCoerce' :: SmartExp a -> SmartExp b + +instance (IsScalar a, IsScalar b, BitSizeEq a b) => Coerce a b where + mkCoerce' = SmartExp . Coerce (scalarType @a) (scalarType @b) + +instance (Coerce a1 b1, Coerce a2 b2) => Coerce (a1, a2) (b1, b2) where + mkCoerce' a = SmartExp $ Pair (mkCoerce' $ SmartExp $ Prj PairIdxLeft a) (mkCoerce' $ SmartExp $ Prj PairIdxRight a) + +instance Coerce () () where + mkCoerce' _ = SmartExp $ Nil + +instance Coerce ((), a) a where + mkCoerce' a = SmartExp $ Prj PairIdxRight a + +instance Coerce a ((), a) where + mkCoerce' = SmartExp . Pair (SmartExp $ Nil) -mkUnsafeCoerce :: forall b a. (Elt a, Elt b) => Exp a -> Exp b -mkUnsafeCoerce = Exp . Coerce -- Auxiliary functions -- -------------------- +exp :: PreSmartExp SmartAcc SmartExp (EltRepr t) -> Exp t +exp = Exp . SmartExp + infixr 0 $$ ($$) :: (b -> a) -> (c -> d -> b) -> c -> d -> a (f $$ g) x y = f (g x y) @@ -1684,13 +1802,38 @@ unAcc (Acc a) = a unAccFunction :: (Arrays a, Arrays b) => (Acc a -> Acc b) -> SmartAcc (ArrRepr a) -> SmartAcc (ArrRepr b) unAccFunction f = unAcc . f . Acc +unExp :: Elt e => Exp e -> SmartExp (EltRepr e) +unExp (Exp e) = e + +unExpFunction :: (Elt a, Elt b) => (Exp a -> Exp b) -> SmartExp (EltRepr a) -> SmartExp (EltRepr b) +unExpFunction f = unExp . f . Exp + +unExpBinaryFunction :: (Elt a, Elt b, Elt c) => (Exp a -> Exp b -> Exp c) -> SmartExp (EltRepr a) -> SmartExp (EltRepr b) -> SmartExp (EltRepr c) +unExpBinaryFunction f a b = unExp $ f (Exp a) (Exp b) + +mkPrimUnary :: (Elt a, Elt b) => PrimFun (EltRepr a -> EltRepr b) -> Exp a -> Exp b +mkPrimUnary prim (Exp a) = exp $ PrimApp prim a + +mkPrimBinary :: (Elt a, Elt b, Elt c) => PrimFun ((EltRepr a, EltRepr b) -> EltRepr c) -> Exp a -> Exp b -> Exp c +mkPrimBinary prim (Exp a) (Exp b) = exp $ PrimApp prim (SmartExp $ Pair a b) + +unPair :: SmartExp (a, b) -> (SmartExp a, SmartExp b) +unPair e = (SmartExp $ Prj PairIdxLeft e, SmartExp $ Prj PairIdxRight e) + +mkPairToTuple :: SmartAcc (a, b) -> SmartAcc (((), a), b) +mkPairToTuple e = SmartAcc Anil `pair` a `pair` b + where + a = SmartAcc $ Aprj PairIdxLeft e + b = SmartAcc $ Aprj PairIdxRight e + pair x y = SmartAcc $ Apair x y + class ApplyAcc a where type FromApplyAcc a applyAcc :: FromApplyAcc a -> a instance ApplyAcc (SmartAcc a) where - type FromApplyAcc (SmartAcc a) = PreSmartAcc SmartAcc Exp a + type FromApplyAcc (SmartAcc a) = PreSmartAcc SmartAcc SmartExp a applyAcc = SmartAcc @@ -1699,26 +1842,32 @@ instance (Arrays a, ApplyAcc t) => ApplyAcc (Acc a -> t) where applyAcc f a = applyAcc $ f (unAcc a) -instance ApplyAcc t => ApplyAcc (Exp a -> t) where - type FromApplyAcc (Exp a -> t) = Exp a -> FromApplyAcc t +instance (Elt a, ApplyAcc t) => ApplyAcc (Exp a -> t) where + type FromApplyAcc (Exp a -> t) = SmartExp (EltRepr a) -> FromApplyAcc t + + applyAcc f a = applyAcc $ f (unExp a) - applyAcc f a = applyAcc $ f a +instance (Elt a, Elt b, ApplyAcc t) => ApplyAcc ((Exp a -> Exp b) -> t) where + type FromApplyAcc ((Exp a -> Exp b) -> t) = (SmartExp (EltRepr a) -> SmartExp (EltRepr b)) -> FromApplyAcc t -instance ApplyAcc t => ApplyAcc ((Exp a -> b) -> t) where - type FromApplyAcc ((Exp a -> b) -> t) = (Exp a -> b) -> FromApplyAcc t + applyAcc f a = applyAcc $ f (unExpFunction a) - applyAcc f a = applyAcc $ f a +instance (Elt a, Elt b, Elt c, ApplyAcc t) => ApplyAcc ((Exp a -> Exp b -> Exp c) -> t) where + type FromApplyAcc ((Exp a -> Exp b -> Exp c) -> t) = (SmartExp (EltRepr a) -> SmartExp (EltRepr b) -> SmartExp (EltRepr c)) -> FromApplyAcc t + + applyAcc f a = applyAcc $ f (unExpBinaryFunction a) instance (Arrays a, Arrays b, ApplyAcc t) => ApplyAcc ((Acc a -> Acc b) -> t) where type FromApplyAcc ((Acc a -> Acc b) -> t) = (SmartAcc (ArrRepr a) -> SmartAcc (ArrRepr b)) -> FromApplyAcc t applyAcc f a = applyAcc $ f (unAccFunction a) + -- Debugging -- --------- showPreAccOp :: forall acc exp arrs. PreSmartAcc acc exp arrs -> String -showPreAccOp (Atag i) = "Atag " ++ show i -showPreAccOp (Use a) = "Use " ++ showShortendArr a +showPreAccOp (Atag _ i) = "Atag " ++ show i +showPreAccOp (Use repr a) = "Use " ++ showShortendArr repr a showPreAccOp Pipe{} = "Pipe" showPreAccOp Acond{} = "Acond" showPreAccOp Awhile{} = "Awhile" @@ -1762,25 +1911,15 @@ showPreSeqOp (Stuple{}) = "Stuple" --} -showShortendArr :: (Shape sh, Elt e) => Array sh e -> String -showShortendArr arr - = show (take cutoff l) ++ if length l > cutoff then ".." else "" - where - l = toList arr - cutoff = 5 - - -showPreExpOp :: PreExp acc exp t -> String -showPreExpOp (Tag i) = "Tag" ++ show i -showPreExpOp (Const c) = "Const " ++ show c -showPreExpOp Undef = "Undef" -showPreExpOp Tuple{} = "Tuple" +showPreExpOp :: PreSmartExp acc exp t -> String +showPreExpOp (Tag _ i) = "Tag" ++ show i +showPreExpOp (Const tp c) = "Const " ++ showElement (TupRsingle tp) c +showPreExpOp (Undef _) = "Undef" +showPreExpOp Nil{} = "Nil" +showPreExpOp Pair{} = "Pair" showPreExpOp Prj{} = "Prj" -showPreExpOp IndexNil = "IndexNil" -showPreExpOp IndexCons{} = "IndexCons" -showPreExpOp IndexHead{} = "IndexHead" -showPreExpOp IndexTail{} = "IndexTail" -showPreExpOp IndexAny = "IndexAny" +showPreExpOp VecPack{} = "VecPack" +showPreExpOp VecUnpack{} = "VecUnpack" showPreExpOp ToIndex{} = "ToIndex" showPreExpOp FromIndex{} = "FromIndex" showPreExpOp Cond{} = "Cond" @@ -1791,8 +1930,33 @@ showPreExpOp Index{} = "Index" showPreExpOp LinearIndex{} = "LinearIndex" showPreExpOp Shape{} = "Shape" showPreExpOp ShapeSize{} = "ShapeSize" -showPreExpOp Intersect{} = "Intersect" -showPreExpOp Union{} = "Union" showPreExpOp Foreign{} = "Foreign" showPreExpOp Coerce{} = "Coerce" +vecR2 :: SingleType s -> VecR 2 s (Tup2 s s) +vecR2 s = VecRsucc $ VecRsucc $ VecRnil s + +vecR3 :: SingleType s -> VecR 3 s (Tup3 s s s) +vecR3 = VecRsucc . vecR2 + +vecR4 :: SingleType s -> VecR 4 s (Tup4 s s s s) +vecR4 = VecRsucc . vecR3 + +vecR5 :: SingleType s -> VecR 5 s (Tup5 s s s s s) +vecR5 = VecRsucc . vecR4 + +vecR6 :: SingleType s -> VecR 6 s (Tup6 s s s s s s) +vecR6 = VecRsucc . vecR5 + +vecR7 :: SingleType s -> VecR 7 s (Tup7 s s s s s s s) +vecR7 = VecRsucc . vecR6 + +vecR8 :: SingleType s -> VecR 8 s (Tup8 s s s s s s s s) +vecR8 = VecRsucc . vecR7 + +vecR9 :: SingleType s -> VecR 9 s (Tup9 s s s s s s s s s) +vecR9 = VecRsucc . vecR8 + +vecR16 :: SingleType s -> VecR 16 s (Tup16 s s s s s s s s s s s s s s s s) +vecR16 = VecRsucc . VecRsucc . VecRsucc . VecRsucc . VecRsucc . VecRsucc . VecRsucc . vecR9 + diff --git a/src/Data/Array/Accelerate/Test/NoFib/Imaginary/DotP.hs b/src/Data/Array/Accelerate/Test/NoFib/Imaginary/DotP.hs index 196541362..0c1bbd475 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Imaginary/DotP.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Imaginary/DotP.hs @@ -20,7 +20,6 @@ module Data.Array.Accelerate.Test.NoFib.Imaginary.DotP ( ) where -import Data.Typeable import Prelude as P import Data.Array.Accelerate as A @@ -57,7 +56,7 @@ test_dotp runN = => Gen a -> TestTree testElt e = - testProperty (show (typeOf (undefined :: a))) $ test_dotp' runN e + testProperty (show (eltType @a)) $ test_dotp' runN e test_dotp' diff --git a/src/Data/Array/Accelerate/Test/NoFib/Imaginary/SASUM.hs b/src/Data/Array/Accelerate/Test/NoFib/Imaginary/SASUM.hs index f87b791be..b7b41eb61 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Imaginary/SASUM.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Imaginary/SASUM.hs @@ -20,7 +20,6 @@ module Data.Array.Accelerate.Test.NoFib.Imaginary.SASUM ( ) where -import Data.Typeable import Prelude as P import Data.Array.Accelerate as A @@ -57,7 +56,7 @@ test_sasum runN = => Gen a -> TestTree testElt e = - testProperty (show (typeOf (undefined :: a))) $ test_sasum' runN e + testProperty (show (eltType @a)) $ test_sasum' runN e test_sasum' diff --git a/src/Data/Array/Accelerate/Test/NoFib/Imaginary/SAXPY.hs b/src/Data/Array/Accelerate/Test/NoFib/Imaginary/SAXPY.hs index e0dd2cadc..c0ba90e49 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Imaginary/SAXPY.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Imaginary/SAXPY.hs @@ -20,7 +20,6 @@ module Data.Array.Accelerate.Test.NoFib.Imaginary.SAXPY ( ) where -import Data.Typeable import Prelude as P import Data.Array.Accelerate as A @@ -57,7 +56,7 @@ test_saxpy runN = => Gen a -> TestTree testElt e = - testProperty (show (typeOf (undefined :: a))) $ test_saxpy' runN e + testProperty (show (eltType @a)) $ test_saxpy' runN e test_saxpy' diff --git a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue264.hs b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue264.hs index b396f3a9e..267fc89ad 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue264.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue264.hs @@ -23,7 +23,6 @@ module Data.Array.Accelerate.Test.NoFib.Issues.Issue264 ( ) where -import Data.Typeable import Prelude as P import Data.Array.Accelerate as A @@ -60,7 +59,7 @@ test_issue264 runN = => Gen a -> TestTree testElt e = - testGroup (show (typeOf (undefined :: a))) + testGroup (show (eltType @a)) [ testProperty "neg.neg" $ test_neg_neg runN e ] diff --git a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue364.hs b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue364.hs index e2481c69f..9c4042865 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue364.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue364.hs @@ -24,7 +24,6 @@ module Data.Array.Accelerate.Test.NoFib.Issues.Issue364 ( ) where -import Data.Typeable import Prelude ( fromInteger, show ) import qualified Prelude as P #if __GLASGOW_HASKELL__ == 800 @@ -58,7 +57,7 @@ test_issue364 runN = => Gen e -> TestTree testElt _ = - testGroup (show (typeOf (undefined :: e))) + testGroup (show (eltType @e)) [ testCase "A" $ expectedArray @_ @e Z 64 @=? runN (scanl iappend one) (intervalArray Z 64) , testCase "B" $ expectedArray @_ @e Z 65 @=? runN (scanl iappend one) (intervalArray Z 65) -- failed for integral types ] diff --git a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue407.hs b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue407.hs index e431981f1..d07bf63a8 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue407.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue407.hs @@ -24,10 +24,10 @@ module Data.Array.Accelerate.Test.NoFib.Issues.Issue407 ( ) where -import Data.Typeable import Prelude as P import Data.Array.Accelerate as A +import Data.Array.Accelerate.Array.Sugar as A import Data.Array.Accelerate.Test.NoFib.Base import Test.Tasty @@ -45,7 +45,7 @@ test_issue407 runN = :: forall a. (P.Fractional a, A.RealFloat a) => TestTree testElt = - testGroup (show (typeOf (undefined :: a))) + testGroup (show (A.eltType @a)) [ testCase "isNaN" $ eNaN @=? runN (A.map A.isNaN) xs , testCase "isInfinite" $ eInf @=? runN (A.map A.isInfinite) xs ] diff --git a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue409.hs b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue409.hs index 385e126fa..7b3c9b122 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue409.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue409.hs @@ -22,10 +22,10 @@ module Data.Array.Accelerate.Test.NoFib.Issues.Issue409 ( ) where -import Data.Typeable import Prelude as P import Data.Array.Accelerate as A +import Data.Array.Accelerate.Array.Sugar as A import Data.Array.Accelerate.Test.NoFib.Base import Test.Tasty @@ -43,7 +43,7 @@ test_issue409 runN = :: forall a. (P.Floating a, P.Eq a, A.Floating a) => TestTree testElt = - testGroup (show (typeOf (undefined :: a))) + testGroup (show (A.eltType @a)) [ testCase "A" $ e1 @=? indexArray (runN (A.map f) t1) Z ] where diff --git a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Backpermute.hs b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Backpermute.hs index 561d6553d..e7909667e 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Backpermute.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Backpermute.hs @@ -21,7 +21,6 @@ module Data.Array.Accelerate.Test.NoFib.Prelude.Backpermute ( ) where -import Data.Typeable import Prelude as P import Data.Array.Accelerate as A @@ -59,7 +58,7 @@ test_backpermute runN = => Gen a -> TestTree testElt e = - testGroup (show (typeOf (undefined :: a))) + testGroup (show (eltType @a)) [ testDim dim1 , testDim dim2 , testDim dim3 diff --git a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Filter.hs b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Filter.hs index 81c2fdc99..deabe823a 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Filter.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Filter.hs @@ -23,7 +23,6 @@ module Data.Array.Accelerate.Test.NoFib.Prelude.Filter ( ) where -import Data.Typeable import Prelude as P import Data.Array.Accelerate as A @@ -59,7 +58,7 @@ test_filter runN = => Gen a -> TestTree testIntegralElt e = - testGroup (show (typeOf (undefined :: a))) + testGroup (show (eltType @a)) [ testDim dim1 , testDim dim2 , testDim dim3 @@ -79,7 +78,7 @@ test_filter runN = => Gen a -> TestTree testFloatingElt e = - testGroup (show (typeOf (undefined :: a))) + testGroup (show (eltType @a)) [ testDim dim1 , testDim dim2 , testDim dim3 diff --git a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Fold.hs b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Fold.hs index d44c840cb..285e88e66 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Fold.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Fold.hs @@ -22,7 +22,6 @@ module Data.Array.Accelerate.Test.NoFib.Prelude.Fold ( ) where -import Data.Typeable import Prelude as P import Data.Array.Accelerate as A @@ -61,7 +60,7 @@ test_fold runN = -> Gen a -> TestTree testElt e small = - testGroup (show (typeOf (undefined :: a))) + testGroup (show (eltType @a)) [ testDim dim1 , testDim dim2 , testDim dim3 @@ -101,7 +100,7 @@ test_foldSeg runN = => Gen a -> TestTree testElt e = - testGroup (show (typeOf (undefined :: a))) + testGroup (show (eltType @a)) [ testDim dim1 , testDim dim2 , testDim dim3 diff --git a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Map.hs b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Map.hs index 747ffcafc..3e1ebcd15 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Map.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Map.hs @@ -24,7 +24,6 @@ module Data.Array.Accelerate.Test.NoFib.Prelude.Map ( ) where import Data.Bits as P -import Data.Typeable import Prelude as P import Data.Array.Accelerate as A @@ -65,7 +64,7 @@ test_map runN = => Gen a -> TestTree testIntegralElt e = - testGroup (show (typeOf (undefined :: a))) + testGroup (show (eltType @a)) [ testDim dim0 , testDim dim1 , testDim dim2 @@ -97,7 +96,7 @@ test_map runN = => (Range a -> Gen a) -> TestTree testFloatingElt e = - testGroup (show (typeOf (undefined :: a))) + testGroup (show (eltType @a)) [ testDim dim0 , testDim dim1 , testDim dim2 diff --git a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Permute.hs b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Permute.hs index 31bf21931..bbbc395d5 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Permute.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Permute.hs @@ -22,13 +22,13 @@ module Data.Array.Accelerate.Test.NoFib.Prelude.Permute ( ) where import Control.Monad -import Data.Typeable import System.IO.Unsafe import Prelude as P import qualified Data.Set as Set import Data.Array.Accelerate as A import Data.Array.Accelerate.Array.Sugar as S +import qualified Data.Array.Accelerate.Array.Representation as R import Data.Array.Accelerate.Array.Data import Data.Array.Accelerate.Test.NoFib.Base import Data.Array.Accelerate.Test.NoFib.Config @@ -63,7 +63,7 @@ test_permute runN = => Gen a -> TestTree testElt e = - testGroup (show (typeOf (undefined :: a))) + testGroup (show (eltType @a)) [ testDim dim1 , testDim dim2 , testDim dim3 @@ -144,15 +144,16 @@ test_accumulate runN dim dim' e = permuteRef - :: (Shape sh, Shape sh', P.Eq sh', Elt e) + :: forall sh sh' e. (Shape sh, Shape sh', P.Eq sh', Elt e) => (e -> e -> e) -> Array sh' e -> (sh -> sh') -> Array sh e -> Array sh' e -permuteRef f def@(Array _ aold) p arr@(Array _ anew) = +permuteRef f def@(Array (R.Array _ aold)) p arr@(Array (R.Array _ anew)) = unsafePerformIO $ do let + tp = S.eltType @e sh = S.shape arr sh' = S.shape def n = S.size sh @@ -165,9 +166,9 @@ permuteRef f def@(Array _ aold) p arr@(Array _ anew) = -- unless (ix' P.== S.ignore) $ do let i' = S.toIndex sh' ix' - x <- toElt <$> unsafeReadArrayData anew i - x' <- toElt <$> unsafeReadArrayData aold i' - unsafeWriteArrayData aold i' (fromElt (f x x')) + x <- toElt <$> unsafeReadArrayData tp anew i + x' <- toElt <$> unsafeReadArrayData tp aold i' + unsafeWriteArrayData tp aold i' (fromElt (f x x')) -- go (i+1) -- diff --git a/src/Data/Array/Accelerate/Test/NoFib/Prelude/SIMD.hs b/src/Data/Array/Accelerate/Test/NoFib/Prelude/SIMD.hs index 485873ddd..ba1c3a41e 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Prelude/SIMD.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Prelude/SIMD.hs @@ -19,8 +19,6 @@ module Data.Array.Accelerate.Test.NoFib.Prelude.SIMD ( ) where -import Data.Typeable -import Data.Primitive.Types import Control.Lens ( view, _1, _2, _3, _4 ) import Prelude as P @@ -29,8 +27,6 @@ import Data.Array.Accelerate.Array.Sugar as S import Data.Array.Accelerate.Test.NoFib.Base import Data.Array.Accelerate.Test.NoFib.Config import Data.Array.Accelerate.Type -import Data.Array.Accelerate.Smart -import Data.Array.Accelerate.Product import Hedgehog import qualified Hedgehog.Gen as Gen @@ -55,16 +51,16 @@ test_simd runN = , at @TestDouble $ testElt f64 ] where - testElt :: forall e. (Prim e, P.Eq e, Elt e, Elt (V2 e), Elt (V3 e), Elt (V4 e)) + testElt :: forall e. (VecElt e, P.Eq e) => Gen e -> TestTree testElt e = - testGroup (show (typeOf (undefined::e))) + testGroup (show (eltType @e)) [ testExtract e , testInject e ] - testExtract :: forall e. (Prim e, P.Eq e, Elt e, Elt (V2 e), Elt (V3 e), Elt (V4 e)) + testExtract :: forall e. (VecElt e, P.Eq e) => Gen e -> TestTree testExtract e = @@ -74,7 +70,7 @@ test_simd runN = , testProperty "V4" $ test_extract_v4 runN dim1 e ] - testInject :: forall e. (Prim e, P.Eq e, Elt e, Elt (V2 e), Elt (V3 e), Elt (V4 e)) + testInject :: forall e. (VecElt e, P.Eq e) => Gen e -> TestTree testInject e = @@ -86,7 +82,7 @@ test_simd runN = test_extract_v2 - :: (Shape sh, Prim e, P.Eq e, P.Eq sh, Elt e, Elt (V2 e)) + :: (Shape sh, VecElt e, P.Eq e, P.Eq sh) => RunN -> Gen sh -> Gen e @@ -99,7 +95,7 @@ test_extract_v2 runN dim e = let !go = runN (A.map (view _m . unpackV2')) in go xs === mapRef (view _l . unpackV2) xs test_extract_v3 - :: (Shape sh, Prim e, P.Eq e, P.Eq sh, Elt e, Elt (V3 e)) + :: (Shape sh, VecElt e, P.Eq e, P.Eq sh) => RunN -> Gen sh -> Gen e @@ -112,7 +108,7 @@ test_extract_v3 runN dim e = let !go = runN (A.map (view _m . unpackV3')) in go xs === mapRef (view _l . unpackV3) xs test_extract_v4 - :: (Shape sh, Prim e, P.Eq e, P.Eq sh, Elt e, Elt (V4 e)) + :: (Shape sh, VecElt e, P.Eq e, P.Eq sh) => RunN -> Gen sh -> Gen e @@ -125,7 +121,7 @@ test_extract_v4 runN dim e = let !go = runN (A.map (view _m . unpackV4')) in go xs === mapRef (view _l . unpackV4) xs test_inject_v2 - :: (Shape sh, Prim e, P.Eq e, P.Eq sh, Elt e, Elt (V2 e)) + :: (Shape sh, VecElt e, P.Eq e, P.Eq sh) => RunN -> Gen sh -> Gen e @@ -136,10 +132,10 @@ test_inject_v2 runN dim e = sh2 <- forAll dim xs <- forAll (array sh1 e) ys <- forAll (array sh2 e) - let !go = runN (A.zipWith packV2') in go xs ys === zipWithRef V2 xs ys + let !go = runN (A.zipWith A.V2_) in go xs ys === zipWithRef V2 xs ys test_inject_v3 - :: (Shape sh, Prim e, P.Eq e, P.Eq sh, Elt e, Elt (V3 e)) + :: (Shape sh, VecElt e, P.Eq e, P.Eq sh) => RunN -> Gen sh -> Gen e @@ -152,10 +148,10 @@ test_inject_v3 runN dim e = xs <- forAll (array sh1 e) ys <- forAll (array sh2 e) zs <- forAll (array sh3 e) - let !go = runN (A.zipWith3 packV3') in go xs ys zs === zipWith3Ref V3 xs ys zs + let !go = runN (A.zipWith3 A.V3_) in go xs ys zs === zipWith3Ref V3 xs ys zs test_inject_v4 - :: (Shape sh, Prim e, P.Eq e, P.Eq sh, Elt e, Elt (V4 e)) + :: (Shape sh, VecElt e, P.Eq e, P.Eq sh) => RunN -> Gen sh -> Gen e @@ -170,38 +166,17 @@ test_inject_v4 runN dim e = ys <- forAll (array sh2 e) zs <- forAll (array sh3 e) ws <- forAll (array sh4 e) - let !go = runN (A.zipWith4 packV4') in go xs ys zs ws === zipWith4Ref V4 xs ys zs ws - - -unpackV2' :: (Prim e, Elt e, Elt (V2 e)) => Exp (V2 e) -> (Exp e, Exp e) -unpackV2' e = - ( Exp $ SuccTupIdx ZeroTupIdx `Prj` e - , Exp $ ZeroTupIdx `Prj` e - ) - -unpackV3' :: (Prim e, Elt e, Elt (V3 e)) => Exp (V3 e) -> (Exp e, Exp e, Exp e) -unpackV3' e = - ( Exp $ SuccTupIdx (SuccTupIdx ZeroTupIdx) `Prj` e - , Exp $ SuccTupIdx ZeroTupIdx `Prj` e - , Exp $ ZeroTupIdx `Prj` e - ) - -unpackV4' :: (Prim e, Elt e, Elt (V4 e)) => Exp (V4 e) -> (Exp e, Exp e, Exp e, Exp e) -unpackV4' e = - ( Exp $ SuccTupIdx (SuccTupIdx (SuccTupIdx ZeroTupIdx)) `Prj` e - , Exp $ SuccTupIdx (SuccTupIdx ZeroTupIdx) `Prj` e - , Exp $ SuccTupIdx ZeroTupIdx `Prj` e - , Exp $ ZeroTupIdx `Prj` e - ) - -packV2' :: (Prim e, Elt e, Elt (V2 e)) => Exp e -> Exp e -> Exp (V2 e) -packV2' x y = Exp . Tuple $ NilTup `SnocTup` x `SnocTup` y - -packV3' :: (Prim e, Elt e, Elt (V3 e)) => Exp e -> Exp e -> Exp e -> Exp (V3 e) -packV3' x y z = Exp . Tuple $ NilTup `SnocTup` x `SnocTup` y `SnocTup` z - -packV4' :: (Prim e, Elt e, Elt (V4 e)) => Exp e -> Exp e -> Exp e -> Exp e -> Exp (V4 e) -packV4' x y z w = Exp . Tuple $ NilTup `SnocTup` x `SnocTup` y `SnocTup` z `SnocTup` w + let !go = runN (A.zipWith4 A.V4_) in go xs ys zs ws === zipWith4Ref V4 xs ys zs ws + + +unpackV2' :: VecElt e => Exp (V2 e) -> (Exp e, Exp e) +unpackV2' (A.V2_ a b) = (a, b) + +unpackV3' :: VecElt e => Exp (V3 e) -> (Exp e, Exp e, Exp e) +unpackV3' (A.V3_ a b c) = (a, b, c) + +unpackV4' :: VecElt e => Exp (V4 e) -> (Exp e, Exp e, Exp e, Exp e) +unpackV4' (A.V4_ a b c d) = (a, b, c, d) -- Reference Implementation diff --git a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Scan.hs b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Scan.hs index 22c06786d..a327b050a 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Scan.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Scan.hs @@ -27,7 +27,6 @@ module Data.Array.Accelerate.Test.NoFib.Prelude.Scan ( ) where -import Data.Typeable import Prelude as P import Data.Array.Accelerate as A @@ -65,7 +64,7 @@ test_scanl runN = => Gen a -> TestTree testElt e = - testGroup (show (typeOf (undefined :: a))) + testGroup (show (eltType @a)) [ testDim dim1 , testDim dim2 , testDim dim3 @@ -101,7 +100,7 @@ test_scanl1 runN = => Gen a -> TestTree testElt e = - testGroup (show (typeOf (undefined :: a))) + testGroup (show (eltType @a)) [ testDim dim1 , testDim dim2 , testDim dim3 @@ -136,7 +135,7 @@ test_scanl' runN = => Gen a -> TestTree testElt e = - testGroup (show (typeOf (undefined :: a))) + testGroup (show (eltType @a)) [ testDim dim1 , testDim dim2 , testDim dim3 @@ -172,7 +171,7 @@ test_scanr runN = => Gen a -> TestTree testElt e = - testGroup (show (typeOf (undefined :: a))) + testGroup (show (eltType @a)) [ testDim dim1 , testDim dim2 , testDim dim3 @@ -208,7 +207,7 @@ test_scanr1 runN = => Gen a -> TestTree testElt e = - testGroup (show (typeOf (undefined :: a))) + testGroup (show (eltType @a)) [ testDim dim1 , testDim dim2 , testDim dim3 @@ -243,7 +242,7 @@ test_scanr' runN = => Gen a -> TestTree testElt e = - testGroup (show (typeOf (undefined :: a))) + testGroup (show (eltType @a)) [ testDim dim1 , testDim dim2 , testDim dim3 @@ -279,7 +278,7 @@ test_scanlSeg runN = => Gen a -> TestTree testElt e = - testGroup (show (typeOf (undefined :: a))) + testGroup (show (eltType @a)) [ testDim dim1 , testDim dim2 , testDim dim3 @@ -314,7 +313,7 @@ test_scanl1Seg runN = => Gen a -> TestTree testElt e = - testGroup (show (typeOf (undefined :: a))) + testGroup (show (eltType @a)) [ testDim dim1 , testDim dim2 , testDim dim3 @@ -348,7 +347,7 @@ test_scanl'Seg runN = => Gen a -> TestTree testElt e = - testGroup (show (typeOf (undefined :: a))) + testGroup (show (eltType @a)) [ testDim dim1 , testDim dim2 , testDim dim3 @@ -383,7 +382,7 @@ test_scanrSeg runN = => Gen a -> TestTree testElt e = - testGroup (show (typeOf (undefined :: a))) + testGroup (show (eltType @a)) [ testDim dim1 , testDim dim2 , testDim dim3 @@ -418,7 +417,7 @@ test_scanr1Seg runN = => Gen a -> TestTree testElt e = - testGroup (show (typeOf (undefined :: a))) + testGroup (show (eltType @a)) [ testDim dim1 , testDim dim2 , testDim dim3 @@ -452,7 +451,7 @@ test_scanr'Seg runN = => Gen a -> TestTree testElt e = - testGroup (show (typeOf (undefined :: a))) + testGroup (show (eltType @a)) [ testDim dim1 , testDim dim2 , testDim dim3 diff --git a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Stencil.hs b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Stencil.hs index 465fba69c..e7c4eaefb 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Stencil.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Stencil.hs @@ -64,7 +64,7 @@ test_stencil runN = => Gen a -> TestTree testElt e = - testGroup (show (typeOf (undefined :: a))) + testGroup (show (eltType @a)) [ testDim1 , testDim2 , testDim3 @@ -630,9 +630,9 @@ bound bnd sh0 ix0 = Right ix' -> Right (toElt ix') where go :: TupleType t -> t -> t -> Either e t - go TypeRunit () () = Right () - go (TypeRpair tsh tsz) (sh,sz) (ih,iz) = go tsh sh ih `addDim` go tsz sz iz - go (TypeRscalar t) sh i + go TupRunit () () = Right () + go (TupRpair tsh tsz) (sh,sz) (ih,iz) = go tsh sh ih `addDim` go tsz sz iz + go (TupRsingle t) sh i | Just Refl <- matchScalarType t (scalarType :: ScalarType Int) = if i P.< 0 then case bnd of diff --git a/src/Data/Array/Accelerate/Test/NoFib/Prelude/ZipWith.hs b/src/Data/Array/Accelerate/Test/NoFib/Prelude/ZipWith.hs index 26b2b1455..d5324b1c0 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Prelude/ZipWith.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Prelude/ZipWith.hs @@ -23,7 +23,6 @@ module Data.Array.Accelerate.Test.NoFib.Prelude.ZipWith ( ) where import Data.Bits as P -import Data.Typeable import Prelude as P import Data.Array.Accelerate as A @@ -65,7 +64,7 @@ test_zipWith runN = => Gen a -> TestTree testIntegralElt e = - testGroup (show (typeOf (undefined :: a))) + testGroup (show (eltType @a)) [ testDim dim0 , testDim dim1 , testDim dim2 @@ -115,7 +114,7 @@ test_zipWith runN = => (Range a -> Gen a) -> TestTree testFloatingElt e = - testGroup (show (typeOf (undefined :: a))) + testGroup (show (eltType @a)) [ testDim dim0 , testDim dim1 , testDim dim2 diff --git a/src/Data/Array/Accelerate/Test/NoFib/Spectral/BlackScholes.hs b/src/Data/Array/Accelerate/Test/NoFib/Spectral/BlackScholes.hs index adbfdaec5..34b4b8915 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Spectral/BlackScholes.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Spectral/BlackScholes.hs @@ -22,7 +22,6 @@ module Data.Array.Accelerate.Test.NoFib.Spectral.BlackScholes ( ) where -import Data.Typeable import Prelude as P import Data.Array.Accelerate as A @@ -52,7 +51,7 @@ test_blackscholes runN = => (Range a -> Gen a) -> TestTree testElt e = - testProperty (show (typeOf (undefined :: a))) $ test_blackscholes' runN e + testProperty (show (eltType @a)) $ test_blackscholes' runN e test_blackscholes' diff --git a/src/Data/Array/Accelerate/Test/NoFib/Spectral/RadixSort.hs b/src/Data/Array/Accelerate/Test/NoFib/Spectral/RadixSort.hs index c8be0ff1e..1ec13bac3 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Spectral/RadixSort.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Spectral/RadixSort.hs @@ -22,7 +22,6 @@ module Data.Array.Accelerate.Test.NoFib.Spectral.RadixSort ( ) where -import Data.Typeable import Data.Function import Data.List import Prelude as P @@ -30,7 +29,7 @@ import qualified Data.Bits as P import Data.Array.Accelerate as A import Data.Array.Accelerate.Data.Bits as A -import Data.Array.Accelerate.Array.Sugar as S ( shape ) +import Data.Array.Accelerate.Array.Sugar as S ( shape, eltType ) import Data.Array.Accelerate.Test.NoFib.Base import Data.Array.Accelerate.Test.NoFib.Config import Data.Array.Accelerate.Test.Similar @@ -62,7 +61,7 @@ test_radixsort runN = => Gen a -> TestTree testElt e = - testGroup (show (typeOf (undefined :: a))) + testGroup (show (eltType @a)) [ testProperty "ascending" $ test_sort_ascending runN e , testProperty "descending" $ test_sort_descending runN e , testProperty "key-value" $ test_sort_keyval runN e f32 diff --git a/src/Data/Array/Accelerate/Test/NoFib/Spectral/SMVM.hs b/src/Data/Array/Accelerate/Test/NoFib/Spectral/SMVM.hs index b78714a01..31c1a8404 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Spectral/SMVM.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Spectral/SMVM.hs @@ -21,7 +21,6 @@ module Data.Array.Accelerate.Test.NoFib.Spectral.SMVM ( ) where -import Data.Typeable import Prelude as P import Data.Array.Accelerate as A @@ -50,7 +49,7 @@ test_smvm runN = => Gen a -> TestTree testElt e = - testProperty (show (typeOf (undefined :: a))) $ test_smvm' runN e + testProperty (show (eltType @a)) $ test_smvm' runN e test_smvm' :: (A.Num e, P.Num e, Similar e) => RunN -> Gen e -> Property diff --git a/src/Data/Array/Accelerate/Trafo.hs b/src/Data/Array/Accelerate/Trafo.hs index 579c75fe0..9cadbe8c4 100644 --- a/src/Data/Array/Accelerate/Trafo.hs +++ b/src/Data/Array/Accelerate/Trafo.hs @@ -38,8 +38,6 @@ module Data.Array.Accelerate.Trafo ( -- * Fusion DelayedAcc, DelayedOpenAcc(..), DelayedAfun, DelayedOpenAfun, - DelayedExp, DelayedOpenExp, - DelayedFun, DelayedOpenFun, -- * Substitution module Data.Array.Accelerate.Trafo.Substitution, @@ -57,14 +55,15 @@ import Control.DeepSeq import Data.Typeable import Data.Array.Accelerate.Smart -import Data.Array.Accelerate.Array.Sugar ( Arrays, Elt, ArrRepr ) +import Data.Array.Accelerate.Array.Sugar ( ArrRepr, EltRepr ) import Data.Array.Accelerate.Trafo.Base ( Match(..), matchDelayedOpenAcc, encodeDelayedOpenAcc ) import Data.Array.Accelerate.Trafo.Config -import Data.Array.Accelerate.Trafo.Fusion ( DelayedAcc, DelayedOpenAcc(..), DelayedAfun, DelayedOpenAfun, DelayedExp, DelayedFun, DelayedOpenExp, DelayedOpenFun ) -import Data.Array.Accelerate.Trafo.Sharing ( Function, FunctionR, Afunction, AfunctionR, AreprFunctionR, AfunctionRepr(..), afunctionRepr ) +import Data.Array.Accelerate.Trafo.Fusion ( DelayedAcc, DelayedOpenAcc(..), DelayedAfun, DelayedOpenAfun ) +import Data.Array.Accelerate.Trafo.Sharing ( Function, FunctionR, Afunction, AfunctionR, AreprFunctionR, AfunctionRepr(..), afunctionRepr, EltReprFunctionR ) import Data.Array.Accelerate.Trafo.Substitution import qualified Data.Array.Accelerate.AST as AST import qualified Data.Array.Accelerate.Trafo.Fusion as Fusion +import qualified Data.Array.Accelerate.Trafo.LetSplit as LetSplit import qualified Data.Array.Accelerate.Trafo.Simplify as Rewrite import qualified Data.Array.Accelerate.Trafo.Sharing as Sharing -- import qualified Data.Array.Accelerate.Trafo.Vectorise as Vectorise @@ -83,15 +82,15 @@ import Data.Array.Accelerate.Debug.Timed -- | Convert a closed array expression to de Bruijn form while also -- incorporating sharing observation and array fusion. -- -convertAcc :: Arrays arrs => Acc arrs -> DelayedAcc (ArrRepr arrs) +convertAcc :: Acc arrs -> DelayedAcc (ArrRepr arrs) convertAcc = convertAccWith defaultOptions -convertAccWith :: Arrays arrs => Config -> Acc arrs -> DelayedAcc (ArrRepr arrs) -convertAccWith config acc +convertAccWith :: Config -> Acc arrs -> DelayedAcc (ArrRepr arrs) +convertAccWith config = phase "array-fusion" (Fusion.convertAccWith config) + . phase "array-split-lets" LetSplit.convertAcc -- phase "vectorise-sequences" Vectorise.vectoriseSeqAcc `when` vectoriseSequences - $ phase "sharing-recovery" (Sharing.convertAccWith config) - $ acc + . phase "sharing-recovery" (Sharing.convertAccWith config) -- | Convert a unary function over array computations, incorporating sharing @@ -101,17 +100,17 @@ convertAfun :: Afunction f => f -> DelayedAfun (AreprFunctionR f) convertAfun = convertAfunWith defaultOptions convertAfunWith :: Afunction f => Config -> f -> DelayedAfun (AreprFunctionR f) -convertAfunWith config acc +convertAfunWith config = phase "array-fusion" (Fusion.convertAfunWith config) + . phase "array-split-lets" LetSplit.convertAfun -- phase "vectorise-sequences" Vectorise.vectoriseSeqAfun `when` vectoriseSequences - $ phase "sharing-recovery" (Sharing.convertAfunWith config) - $ acc + . phase "sharing-recovery" (Sharing.convertAfunWith config) -- | Convert a closed scalar expression, incorporating sharing observation and -- optimisation. -- -convertExp :: Elt e => Exp e -> AST.Exp () e +convertExp :: Exp e -> AST.Exp () (EltRepr e) convertExp = phase "exp-simplify" Rewrite.simplify -- XXX: only if simplification is enabled . phase "sharing-recovery" Sharing.convertExp @@ -120,7 +119,7 @@ convertExp -- | Convert closed scalar functions, incorporating sharing observation and -- optimisation. -- -convertFun :: Function f => f -> AST.Fun () (FunctionR f) +convertFun :: Function f => f -> AST.Fun () (EltReprFunctionR f) convertFun = phase "exp-simplify" Rewrite.simplify . phase "sharing-recovery" Sharing.convertFun diff --git a/src/Data/Array/Accelerate/Trafo/Algebra.hs b/src/Data/Array/Accelerate/Trafo/Algebra.hs index 9ffd57c2e..775a237ec 100644 --- a/src/Data/Array/Accelerate/Trafo/Algebra.hs +++ b/src/Data/Array/Accelerate/Trafo/Algebra.hs @@ -41,9 +41,7 @@ import qualified Prelude as P -- friends import Data.Array.Accelerate.AST import Data.Array.Accelerate.Analysis.Match -import Data.Array.Accelerate.Array.Sugar hiding ( Any ) import Data.Array.Accelerate.Pretty.Print ( primOperator, isInfix, opName ) -import Data.Array.Accelerate.Product import Data.Array.Accelerate.Trafo.Base import Data.Array.Accelerate.Type @@ -54,46 +52,32 @@ import qualified Data.Array.Accelerate.Debug.Stats as Stats -- or constant let bindings. Be careful not to follow self-cycles. -- propagate - :: forall acc env aenv exp. Kit acc - => Gamma acc env env aenv - -> PreOpenExp acc env aenv exp + :: forall env aenv exp. + Gamma env env aenv + -> OpenExp env aenv exp -> Maybe exp propagate env = cvtE where - cvtE :: PreOpenExp acc env aenv e -> Maybe e + cvtE :: OpenExp env aenv e -> Maybe e cvtE exp = case exp of - Const c -> Just (toElt c) + Const _ c -> Just c PrimConst c -> Just (evalPrimConst c) - Prj ix (Var v) | Tuple t <- prjExp v env -> cvtT ix t - Prj ix e | Just c <- cvtE e -> cvtP ix (fromTuple c) - Var ix + Evar (Var _ ix) | e <- prjExp ix env , Nothing <- match exp e -> cvtE e - -- - IndexHead (cvtE -> Just (_ :. z)) -> Just z - IndexTail (cvtE -> Just (sh :. _)) -> Just sh + Nil -> Just () + Pair e1 e2 -> (,) <$> cvtE e1 <*> cvtE e2 _ -> Nothing - cvtP :: TupleIdx t e -> t -> Maybe e - cvtP ZeroTupIdx (_, v) = Just v - cvtP (SuccTupIdx idx) (tup, _) = cvtP idx tup - - cvtT :: TupleIdx t e -> Tuple (PreOpenExp acc env aenv) t -> Maybe e - cvtT ZeroTupIdx (SnocTup _ e) = cvtE e - cvtT (SuccTupIdx idx) (SnocTup tup _) = cvtT idx tup -#if __GLASGOW_HASKELL__ < 800 - cvtT _ _ = error "hey what's the head angle on that thing?" -#endif - -- Attempt to evaluate primitive function applications -- evalPrimApp - :: forall acc env aenv a r. (Kit acc, Elt a, Elt r) - => Gamma acc env env aenv + :: forall env aenv a r. + Gamma env env aenv -> PrimFun (a -> r) - -> PreOpenExp acc env aenv a - -> (Any, PreOpenExp acc env aenv r) + -> OpenExp env aenv a + -> (Any, OpenExp env aenv r) evalPrimApp env f x -- First attempt to move constant values towards the left | Just r <- commutes f x env = evalPrimApp env f r @@ -175,11 +159,11 @@ evalPrimApp env f x -- to the left of the operator. Returning Nothing indicates no change is made. -- commutes - :: forall acc env aenv a r. Kit acc - => PrimFun (a -> r) - -> PreOpenExp acc env aenv a - -> Gamma acc env env aenv - -> Maybe (PreOpenExp acc env aenv a) + :: forall env aenv a r. + PrimFun (a -> r) + -> OpenExp env aenv a + -> Gamma env env aenv + -> Maybe (OpenExp env aenv a) commutes f x env = case f of PrimAdd _ -> swizzle x PrimMul _ -> swizzle x @@ -192,12 +176,12 @@ commutes f x env = case f of PrimMin _ -> swizzle x _ -> Nothing where - swizzle :: PreOpenExp acc env aenv (b,b) -> Maybe (PreOpenExp acc env aenv (b,b)) - swizzle (Tuple (NilTup `SnocTup` a `SnocTup` b)) + swizzle :: OpenExp env aenv (b,b) -> Maybe (OpenExp env aenv (b,b)) + swizzle (Pair a b) | Nothing <- propagate env a , Just _ <- propagate env b = Stats.ruleFired (pprFun "commutes" f) - $ Just $ Tuple (NilTup `SnocTup` b `SnocTup` a) + $ Just $ Pair b a -- TLM: changing the ordering here when neither term can be reduced can be -- disadvantageous: for example in (x &&* y), the user might have put a @@ -229,8 +213,8 @@ commutes f x env = case f of associates :: (Elt a, Elt r) => PrimFun (a -> r) - -> PreOpenExp acc env aenv a - -> Maybe (PreOpenExp acc env aenv r) + -> OpenExp env aenv a + -> Maybe (OpenExp env aenv r) associates fun exp = case fun of PrimAdd _ -> swizzle fun exp [PrimAdd ty, PrimSub ty] PrimSub _ -> swizzle fun exp [PrimAdd ty, PrimSub ty] @@ -242,7 +226,7 @@ associates fun exp = case fun of ty = undefined ops = [ PrimMul ty, PrimFDiv ty, PrimAdd ty, PrimSub ty, PrimBAnd ty, PrimBOr ty, PrimBXor ty ] - swizzle :: (Elt a, Elt r) => PrimFun (a -> r) -> PreOpenExp acc env aenv a -> [PrimFun (a -> r)] -> Maybe (PreOpenExp acc env aenv r) + swizzle :: (Elt a, Elt r) => PrimFun (a -> r) -> OpenExp env aenv a -> [PrimFun (a -> r)] -> Maybe (OpenExp env aenv r) swizzle f x lvl | Just Refl <- matches f ops , Just (a,bc) <- untup2 x @@ -269,30 +253,30 @@ associates fun exp = case fun of -- Helper functions -- ---------------- -type a :-> b = forall acc env aenv. Kit acc => PreOpenExp acc env aenv a -> Gamma acc env env aenv -> Maybe (PreOpenExp acc env aenv b) +type a :-> b = forall env aenv. OpenExp env aenv a -> Gamma env env aenv -> Maybe (OpenExp env aenv b) -eval1 :: Elt b => (a -> b) -> a :-> b -eval1 f x env - | Just a <- propagate env x = Stats.substitution "constant fold" . Just $ Const (fromElt (f a)) +eval1 :: SingleType b -> (a -> b) -> a :-> b +eval1 tp f x env + | Just a <- propagate env x = Stats.substitution "constant fold" . Just $ Const (SingleScalarType tp) (f a) | otherwise = Nothing -eval2 :: Elt c => (a -> b -> c) -> (a,b) :-> c -eval2 f (untup2 -> Just (x,y)) env +eval2 :: SingleType c -> (a -> b -> c) -> (a,b) :-> c +eval2 tp f (untup2 -> Just (x,y)) env | Just a <- propagate env x , Just b <- propagate env y = Stats.substitution "constant fold" - $ Just $ Const (fromElt (f a b)) + $ Just $ Const (SingleScalarType tp) (f a b) -eval2 _ _ _ +eval2 _ _ _ _ = Nothing -tup2 :: (Elt a, Elt b) => (PreOpenExp acc env aenv a, PreOpenExp acc env aenv b) -> PreOpenExp acc env aenv (a, b) -tup2 (a,b) = Tuple (NilTup `SnocTup` a `SnocTup` b) +tup2 :: (OpenExp env aenv a, OpenExp env aenv b) -> OpenExp env aenv (a, b) +tup2 (a,b) = Pair a b -untup2 :: PreOpenExp acc env aenv (a, b) -> Maybe (PreOpenExp acc env aenv a, PreOpenExp acc env aenv b) +untup2 :: OpenExp env aenv (a, b) -> Maybe (OpenExp env aenv a, OpenExp env aenv b) untup2 exp - | Tuple (NilTup `SnocTup` a `SnocTup` b) <- exp = Just (a, b) - | otherwise = Nothing + | Pair a b <- exp = Just (a, b) + | otherwise = Nothing pprFun :: Text -> PrimFun f -> Text @@ -310,25 +294,25 @@ pprFun rule f -- Methods of Num -- -------------- -evalAdd :: Elt a => NumType a -> (a,a) :-> a -evalAdd (IntegralNumType ty) | IntegralDict <- integralDict ty = evalAdd' -evalAdd (FloatingNumType ty) | FloatingDict <- floatingDict ty = evalAdd' +evalAdd :: NumType a -> (a,a) :-> a +evalAdd ty@(IntegralNumType ty') | IntegralDict <- integralDict ty' = evalAdd' ty +evalAdd ty@(FloatingNumType ty') | FloatingDict <- floatingDict ty' = evalAdd' ty -evalAdd' :: (Elt a, Eq a, Num a) => (a,a) :-> a -evalAdd' (untup2 -> Just (x,y)) env +evalAdd' :: (Eq a, Num a) => NumType a -> (a,a) :-> a +evalAdd' _ (untup2 -> Just (x,y)) env | Just a <- propagate env x , a == 0 = Stats.ruleFired "x+0" $ Just y -evalAdd' arg env - = eval2 (+) arg env +evalAdd' ty arg env + = eval2 (NumSingleType ty) (+) arg env -evalSub :: Elt a => NumType a -> (a,a) :-> a +evalSub :: NumType a -> (a,a) :-> a evalSub ty@(IntegralNumType ty') | IntegralDict <- integralDict ty' = evalSub' ty evalSub ty@(FloatingNumType ty') | FloatingDict <- floatingDict ty' = evalSub' ty -evalSub' :: forall a. (Elt a, Eq a, Num a) => NumType a -> (a,a) :-> a +evalSub' :: forall a. (Eq a, Num a) => NumType a -> (a,a) :-> a evalSub' ty (untup2 -> Just (x,y)) env | Just b <- propagate env y , b == 0 @@ -337,22 +321,25 @@ evalSub' ty (untup2 -> Just (x,y)) env | Nothing <- propagate env x , Just b <- propagate env y = Stats.ruleFired "-y+x" - $ Just . snd $ evalPrimApp env (PrimAdd ty) (Tuple $ NilTup `SnocTup` Const (fromElt (-b)) `SnocTup` x) + $ Just . snd $ evalPrimApp env (PrimAdd ty) (Const tp (-b) `Pair` x) + -- (Tuple $ NilTup `SnocTup` Const (fromElt (-b)) `SnocTup` x) | Just Refl <- match x y = Stats.ruleFired "x-x" - $ Just $ Const (fromElt (0::a)) + $ Just $ Const tp 0 + where + tp = SingleScalarType $ NumSingleType ty -evalSub' _ arg env - = eval2 (-) arg env +evalSub' ty arg env + = eval2 (NumSingleType ty) (-) arg env -evalMul :: Elt a => NumType a -> (a,a) :-> a -evalMul (IntegralNumType ty) | IntegralDict <- integralDict ty = evalMul' -evalMul (FloatingNumType ty) | FloatingDict <- floatingDict ty = evalMul' +evalMul :: NumType a -> (a,a) :-> a +evalMul ty@(IntegralNumType ty') | IntegralDict <- integralDict ty' = evalMul' ty +evalMul ty@(FloatingNumType ty') | FloatingDict <- floatingDict ty' = evalMul' ty -evalMul' :: (Elt a, Eq a, Num a) => (a,a) :-> a -evalMul' (untup2 -> Just (x,y)) env +evalMul' :: (Eq a, Num a) => NumType a -> (a,a) :-> a +evalMul' _ (untup2 -> Just (x,y)) env | Just a <- propagate env x , Nothing <- propagate env y = case a of @@ -360,21 +347,21 @@ evalMul' (untup2 -> Just (x,y)) env 1 -> Stats.ruleFired "x*1" $ Just y _ -> Nothing -evalMul' arg env - = eval2 (*) arg env +evalMul' ty arg env + = eval2 (NumSingleType ty) (*) arg env -evalNeg :: Elt a => NumType a -> a :-> a +evalNeg :: NumType a -> a :-> a evalNeg _ x _ | PrimApp PrimNeg{} x' <- x = Stats.ruleFired "negate/negate" $ Just x' -evalNeg (IntegralNumType ty) x env | IntegralDict <- integralDict ty = eval1 negate x env -evalNeg (FloatingNumType ty) x env | FloatingDict <- floatingDict ty = eval1 negate x env +evalNeg (IntegralNumType ty) x env | IntegralDict <- integralDict ty = eval1 (NumSingleType $ IntegralNumType ty) negate x env +evalNeg (FloatingNumType ty) x env | FloatingDict <- floatingDict ty = eval1 (NumSingleType $ FloatingNumType ty) negate x env -evalAbs :: Elt a => NumType a -> a :-> a -evalAbs (IntegralNumType ty) | IntegralDict <- integralDict ty = eval1 abs -evalAbs (FloatingNumType ty) | FloatingDict <- floatingDict ty = eval1 abs +evalAbs :: NumType a -> a :-> a +evalAbs (IntegralNumType ty) | IntegralDict <- integralDict ty = eval1 (NumSingleType $ IntegralNumType ty) abs +evalAbs (FloatingNumType ty) | FloatingDict <- floatingDict ty = eval1 (NumSingleType $ FloatingNumType ty) abs -evalSig :: Elt a => NumType a -> a :-> a -evalSig (IntegralNumType ty) | IntegralDict <- integralDict ty = eval1 signum -evalSig (FloatingNumType ty) | FloatingDict <- floatingDict ty = eval1 signum +evalSig :: NumType a -> a :-> a +evalSig (IntegralNumType ty) | IntegralDict <- integralDict ty = eval1 (NumSingleType $ IntegralNumType ty) signum +evalSig (FloatingNumType ty) | FloatingDict <- floatingDict ty = eval1 (NumSingleType $ FloatingNumType ty) signum -- Methods of Integral & Bits @@ -398,17 +385,19 @@ evalRem _ _ _ evalQuotRem :: forall a. IntegralType a -> (a,a) :-> (a,a) evalQuotRem ty exp env - | IntegralDict <- integralDict ty - , Tuple (NilTup `SnocTup` x `SnocTup` y) <- exp -- TLM: untup2, but inlined to expose the Elt dictionary - , Just b <- propagate env y + | IntegralDict <- integralDict ty + , Just (x, y) <- untup2 exp + , Just b <- propagate env y = case b of 0 -> Nothing - 1 -> Stats.ruleFired "quotRem x 1" $ Just (tup2 (x, Const (fromElt (0::a)))) + 1 -> Stats.ruleFired "quotRem x 1" $ Just (tup2 (x, Const tp 0)) _ -> case propagate env x of Nothing -> Nothing Just a -> Stats.substitution "constant fold" $ Just $ let (u,v) = quotRem a b - in tup2 (Const (fromElt u), Const (fromElt v)) + in tup2 (Const tp u, Const tp v) + where + tp = SingleScalarType $ NumSingleType $ IntegralNumType ty evalQuotRem _ _ _ = Nothing @@ -431,78 +420,80 @@ evalMod _ _ _ evalDivMod :: forall a. IntegralType a -> (a,a) :-> (a,a) evalDivMod ty exp env - | IntegralDict <- integralDict ty - , Tuple (NilTup `SnocTup` x `SnocTup` y) <- exp -- TLM: untup2, but inlined to expose the Elt dictionary - , Just b <- propagate env y + | IntegralDict <- integralDict ty + , Just (x, y) <- untup2 exp + , Just b <- propagate env y = case b of 0 -> Nothing - 1 -> Stats.ruleFired "divMod x 1" $ Just (tup2 (x, Const (fromElt (0::a)))) + 1 -> Stats.ruleFired "divMod x 1" $ Just (tup2 (x, Const tp 0)) _ -> case propagate env x of Nothing -> Nothing Just a -> Stats.substitution "constant fold" $ Just $ let (u,v) = divMod a b - in tup2 (Const (fromElt u), Const (fromElt v)) + in tup2 (Const tp u, Const tp v) + where + tp = SingleScalarType $ NumSingleType $ IntegralNumType ty evalDivMod _ _ _ = Nothing -evalBAnd :: Elt a => IntegralType a -> (a,a) :-> a -evalBAnd ty | IntegralDict <- integralDict ty = eval2 (.&.) +evalBAnd :: IntegralType a -> (a,a) :-> a +evalBAnd ty | IntegralDict <- integralDict ty = eval2 (NumSingleType $ IntegralNumType ty) (.&.) -evalBOr :: Elt a => IntegralType a -> (a,a) :-> a -evalBOr ty | IntegralDict <- integralDict ty = evalBOr' +evalBOr :: IntegralType a -> (a,a) :-> a +evalBOr ty | IntegralDict <- integralDict ty = evalBOr' ty -evalBOr' :: (Elt a, Eq a, Num a, Bits a) => (a,a) :-> a -evalBOr' (untup2 -> Just (x,y)) env +evalBOr' :: (Eq a, Num a, Bits a) => IntegralType a -> (a,a) :-> a +evalBOr' _ (untup2 -> Just (x,y)) env | Just 0 <- propagate env x = Stats.ruleFired "x .|. 0" $ Just y -evalBOr' arg env - = eval2 (.|.) arg env +evalBOr' ty arg env + = eval2 (NumSingleType $ IntegralNumType ty) (.|.) arg env -evalBXor :: Elt a => IntegralType a -> (a,a) :-> a -evalBXor ty | IntegralDict <- integralDict ty = eval2 xor +evalBXor :: IntegralType a -> (a,a) :-> a +evalBXor ty | IntegralDict <- integralDict ty = eval2 (NumSingleType $ IntegralNumType ty) xor -evalBNot :: Elt a => IntegralType a -> a :-> a -evalBNot ty | IntegralDict <- integralDict ty = eval1 complement +evalBNot :: IntegralType a -> a :-> a +evalBNot ty | IntegralDict <- integralDict ty = eval1 (NumSingleType $ IntegralNumType ty) complement -evalBShiftL :: Elt a => IntegralType a -> (a,Int) :-> a +evalBShiftL :: IntegralType a -> (a,Int) :-> a evalBShiftL _ (untup2 -> Just (x,i)) env | Just 0 <- propagate env i = Stats.ruleFired "x `shiftL` 0" $ Just x evalBShiftL ty arg env - | IntegralDict <- integralDict ty = eval2 shiftL arg env + | IntegralDict <- integralDict ty = eval2 (NumSingleType $ IntegralNumType ty) shiftL arg env -evalBShiftR :: Elt a => IntegralType a -> (a,Int) :-> a +evalBShiftR :: IntegralType a -> (a,Int) :-> a evalBShiftR _ (untup2 -> Just (x,i)) env | Just 0 <- propagate env i = Stats.ruleFired "x `shiftR` 0" $ Just x evalBShiftR ty arg env - | IntegralDict <- integralDict ty = eval2 shiftR arg env + | IntegralDict <- integralDict ty = eval2 (NumSingleType $ IntegralNumType ty) shiftR arg env -evalBRotateL :: Elt a => IntegralType a -> (a,Int) :-> a +evalBRotateL :: IntegralType a -> (a,Int) :-> a evalBRotateL _ (untup2 -> Just (x,i)) env | Just 0 <- propagate env i = Stats.ruleFired "x `rotateL` 0" $ Just x evalBRotateL ty arg env - | IntegralDict <- integralDict ty = eval2 rotateL arg env + | IntegralDict <- integralDict ty = eval2 (NumSingleType $ IntegralNumType ty) rotateL arg env -evalBRotateR :: Elt a => IntegralType a -> (a,Int) :-> a +evalBRotateR :: IntegralType a -> (a,Int) :-> a evalBRotateR _ (untup2 -> Just (x,i)) env | Just 0 <- propagate env i = Stats.ruleFired "x `rotateR` 0" $ Just x evalBRotateR ty arg env - | IntegralDict <- integralDict ty = eval2 rotateR arg env + | IntegralDict <- integralDict ty = eval2 (NumSingleType $ IntegralNumType ty) rotateR arg env evalPopCount :: IntegralType a -> a :-> Int -evalPopCount ty | IntegralDict <- integralDict ty = eval1 popCount +evalPopCount ty | IntegralDict <- integralDict ty = eval1 (NumSingleType $ IntegralNumType TypeInt) popCount evalCountLeadingZeros :: IntegralType a -> a :-> Int #if __GLASGOW_HASKELL__ >= 710 -evalCountLeadingZeros ty | IntegralDict <- integralDict ty = eval1 countLeadingZeros +evalCountLeadingZeros ty | IntegralDict <- integralDict ty = eval1 (NumSingleType $ IntegralNumType TypeInt) countLeadingZeros #else -evalCountLeadingZeros ty | IntegralDict <- integralDict ty = eval1 clz +evalCountLeadingZeros ty | IntegralDict <- integralDict ty = eval1 (NumSingleType $ IntegralNumType TypeInt) clz where clz x = (w-1) - go (w-1) where @@ -514,9 +505,9 @@ evalCountLeadingZeros ty | IntegralDict <- integralDict ty = eval1 clz evalCountTrailingZeros :: IntegralType a -> a :-> Int #if __GLASGOW_HASKELL__ >= 710 -evalCountTrailingZeros ty | IntegralDict <- integralDict ty = eval1 countTrailingZeros +evalCountTrailingZeros ty | IntegralDict <- integralDict ty = eval1 (NumSingleType $ IntegralNumType TypeInt) countTrailingZeros #else -evalCountTrailingZeros ty | IntegralDict <- integralDict ty = eval1 ctz +evalCountTrailingZeros ty | IntegralDict <- integralDict ty = eval1 (NumSingleType $ IntegralNumType TypeInt) ctz where ctz x = go 0 where @@ -530,109 +521,109 @@ evalCountTrailingZeros ty | IntegralDict <- integralDict ty = eval1 ctz -- Methods of Fractional & Floating -- -------------------------------- -evalFDiv :: Elt a => FloatingType a -> (a,a) :-> a -evalFDiv ty | FloatingDict <- floatingDict ty = evalFDiv' +evalFDiv :: FloatingType a -> (a,a) :-> a +evalFDiv ty | FloatingDict <- floatingDict ty = evalFDiv' ty -evalFDiv' :: (Elt a, Fractional a, Eq a) => (a,a) :-> a -evalFDiv' (untup2 -> Just (x,y)) env +evalFDiv' :: (Fractional a, Eq a) => FloatingType a -> (a,a) :-> a +evalFDiv' _ (untup2 -> Just (x,y)) env | Just 1 <- propagate env y = Stats.ruleFired "x/1" $ Just x -evalFDiv' arg env - = eval2 (/) arg env +evalFDiv' ty arg env + = eval2 (NumSingleType $ FloatingNumType ty) (/) arg env -evalRecip :: Elt a => FloatingType a -> a :-> a -evalRecip ty | FloatingDict <- floatingDict ty = eval1 recip +evalRecip :: FloatingType a -> a :-> a +evalRecip ty | FloatingDict <- floatingDict ty = eval1 (NumSingleType $ FloatingNumType ty) recip -evalSin :: Elt a => FloatingType a -> a :-> a -evalSin ty | FloatingDict <- floatingDict ty = eval1 sin +evalSin :: FloatingType a -> a :-> a +evalSin ty | FloatingDict <- floatingDict ty = eval1 (NumSingleType $ FloatingNumType ty) sin -evalCos :: Elt a => FloatingType a -> a :-> a -evalCos ty | FloatingDict <- floatingDict ty = eval1 cos +evalCos :: FloatingType a -> a :-> a +evalCos ty | FloatingDict <- floatingDict ty = eval1 (NumSingleType $ FloatingNumType ty) cos -evalTan :: Elt a => FloatingType a -> a :-> a -evalTan ty | FloatingDict <- floatingDict ty = eval1 tan +evalTan :: FloatingType a -> a :-> a +evalTan ty | FloatingDict <- floatingDict ty = eval1 (NumSingleType $ FloatingNumType ty) tan -evalAsin :: Elt a => FloatingType a -> a :-> a -evalAsin ty | FloatingDict <- floatingDict ty = eval1 asin +evalAsin :: FloatingType a -> a :-> a +evalAsin ty | FloatingDict <- floatingDict ty = eval1 (NumSingleType $ FloatingNumType ty) asin -evalAcos :: Elt a => FloatingType a -> a :-> a -evalAcos ty | FloatingDict <- floatingDict ty = eval1 acos +evalAcos :: FloatingType a -> a :-> a +evalAcos ty | FloatingDict <- floatingDict ty = eval1 (NumSingleType $ FloatingNumType ty) acos -evalAtan :: Elt a => FloatingType a -> a :-> a -evalAtan ty | FloatingDict <- floatingDict ty = eval1 atan +evalAtan :: FloatingType a -> a :-> a +evalAtan ty | FloatingDict <- floatingDict ty = eval1 (NumSingleType $ FloatingNumType ty) atan -evalSinh :: Elt a => FloatingType a -> a :-> a -evalSinh ty | FloatingDict <- floatingDict ty = eval1 sinh +evalSinh :: FloatingType a -> a :-> a +evalSinh ty | FloatingDict <- floatingDict ty = eval1 (NumSingleType $ FloatingNumType ty) sinh -evalCosh :: Elt a => FloatingType a -> a :-> a -evalCosh ty | FloatingDict <- floatingDict ty = eval1 cosh +evalCosh :: FloatingType a -> a :-> a +evalCosh ty | FloatingDict <- floatingDict ty = eval1 (NumSingleType $ FloatingNumType ty) cosh -evalTanh :: Elt a => FloatingType a -> a :-> a -evalTanh ty | FloatingDict <- floatingDict ty = eval1 tanh +evalTanh :: FloatingType a -> a :-> a +evalTanh ty | FloatingDict <- floatingDict ty = eval1 (NumSingleType $ FloatingNumType ty) tanh -evalAsinh :: Elt a => FloatingType a -> a :-> a -evalAsinh ty | FloatingDict <- floatingDict ty = eval1 asinh +evalAsinh :: FloatingType a -> a :-> a +evalAsinh ty | FloatingDict <- floatingDict ty = eval1 (NumSingleType $ FloatingNumType ty) asinh -evalAcosh :: Elt a => FloatingType a -> a :-> a -evalAcosh ty | FloatingDict <- floatingDict ty = eval1 acosh +evalAcosh :: FloatingType a -> a :-> a +evalAcosh ty | FloatingDict <- floatingDict ty = eval1 (NumSingleType $ FloatingNumType ty) acosh -evalAtanh :: Elt a => FloatingType a -> a :-> a -evalAtanh ty | FloatingDict <- floatingDict ty = eval1 atanh +evalAtanh :: FloatingType a -> a :-> a +evalAtanh ty | FloatingDict <- floatingDict ty = eval1 (NumSingleType $ FloatingNumType ty) atanh -evalExpFloating :: Elt a => FloatingType a -> a :-> a -evalExpFloating ty | FloatingDict <- floatingDict ty = eval1 P.exp +evalExpFloating :: FloatingType a -> a :-> a +evalExpFloating ty | FloatingDict <- floatingDict ty = eval1 (NumSingleType $ FloatingNumType ty) P.exp -evalSqrt :: Elt a => FloatingType a -> a :-> a -evalSqrt ty | FloatingDict <- floatingDict ty = eval1 sqrt +evalSqrt :: FloatingType a -> a :-> a +evalSqrt ty | FloatingDict <- floatingDict ty = eval1 (NumSingleType $ FloatingNumType ty) sqrt -evalLog :: Elt a => FloatingType a -> a :-> a -evalLog ty | FloatingDict <- floatingDict ty = eval1 log +evalLog :: FloatingType a -> a :-> a +evalLog ty | FloatingDict <- floatingDict ty = eval1 (NumSingleType $ FloatingNumType ty) log -evalFPow :: Elt a => FloatingType a -> (a,a) :-> a -evalFPow ty | FloatingDict <- floatingDict ty = eval2 (**) +evalFPow :: FloatingType a -> (a,a) :-> a +evalFPow ty | FloatingDict <- floatingDict ty = eval2 (NumSingleType $ FloatingNumType ty) (**) -evalLogBase :: Elt a => FloatingType a -> (a,a) :-> a -evalLogBase ty | FloatingDict <- floatingDict ty = eval2 logBase +evalLogBase :: FloatingType a -> (a,a) :-> a +evalLogBase ty | FloatingDict <- floatingDict ty = eval2 (NumSingleType $ FloatingNumType ty) logBase -evalAtan2 :: Elt a => FloatingType a -> (a,a) :-> a -evalAtan2 ty | FloatingDict <- floatingDict ty = eval2 atan2 +evalAtan2 :: FloatingType a -> (a,a) :-> a +evalAtan2 ty | FloatingDict <- floatingDict ty = eval2 (NumSingleType $ FloatingNumType ty) atan2 -evalTruncate :: Elt b => FloatingType a -> IntegralType b -> a :-> b +evalTruncate :: FloatingType a -> IntegralType b -> a :-> b evalTruncate ta tb | FloatingDict <- floatingDict ta - , IntegralDict <- integralDict tb = eval1 truncate + , IntegralDict <- integralDict tb = eval1 (NumSingleType $ IntegralNumType tb) truncate -evalRound :: Elt b => FloatingType a -> IntegralType b -> a :-> b +evalRound :: FloatingType a -> IntegralType b -> a :-> b evalRound ta tb | FloatingDict <- floatingDict ta - , IntegralDict <- integralDict tb = eval1 round + , IntegralDict <- integralDict tb = eval1 (NumSingleType $ IntegralNumType tb) round -evalFloor :: Elt b => FloatingType a -> IntegralType b -> a :-> b +evalFloor :: FloatingType a -> IntegralType b -> a :-> b evalFloor ta tb | FloatingDict <- floatingDict ta - , IntegralDict <- integralDict tb = eval1 floor + , IntegralDict <- integralDict tb = eval1 (NumSingleType $ IntegralNumType tb) floor -evalCeiling :: Elt b => FloatingType a -> IntegralType b -> a :-> b +evalCeiling :: FloatingType a -> IntegralType b -> a :-> b evalCeiling ta tb | FloatingDict <- floatingDict ta - , IntegralDict <- integralDict tb = eval1 ceiling + , IntegralDict <- integralDict tb = eval1 (NumSingleType $ IntegralNumType tb) ceiling evalIsNaN :: FloatingType a -> a :-> Bool -evalIsNaN ty | FloatingDict <- floatingDict ty = eval1 isNaN +evalIsNaN ty | FloatingDict <- floatingDict ty = eval1 (NonNumSingleType TypeBool) isNaN evalIsInfinite :: FloatingType a -> a :-> Bool -evalIsInfinite ty | FloatingDict <- floatingDict ty = eval1 isInfinite +evalIsInfinite ty | FloatingDict <- floatingDict ty = eval1 (NonNumSingleType TypeBool) isInfinite -- Relational & Equality -- --------------------- evalLt :: SingleType a -> (a,a) :-> Bool -evalLt (NumSingleType (IntegralNumType ty)) | IntegralDict <- integralDict ty = eval2 (<) -evalLt (NumSingleType (FloatingNumType ty)) | FloatingDict <- floatingDict ty = eval2 (<) -evalLt (NonNumSingleType ty) | NonNumDict <- nonNumDict ty = eval2 (<) +evalLt (NumSingleType (IntegralNumType ty)) | IntegralDict <- integralDict ty = eval2 (NonNumSingleType TypeBool) (<) +evalLt (NumSingleType (FloatingNumType ty)) | FloatingDict <- floatingDict ty = eval2 (NonNumSingleType TypeBool) (<) +evalLt (NonNumSingleType ty) | NonNumDict <- nonNumDict ty = eval2 (NonNumSingleType TypeBool) (<) -- evalLt (SingleScalarType s) = -- case s of @@ -646,9 +637,9 @@ evalLt (NonNumSingleType ty) | NonNumDict <- nonNumDict ty = -- NonNumSingleType t | NonNumDict <- t -> eval2 (<) evalGt :: SingleType a -> (a,a) :-> Bool -evalGt (NumSingleType (IntegralNumType ty)) | IntegralDict <- integralDict ty = eval2 (>) -evalGt (NumSingleType (FloatingNumType ty)) | FloatingDict <- floatingDict ty = eval2 (>) -evalGt (NonNumSingleType ty) | NonNumDict <- nonNumDict ty = eval2 (>) +evalGt (NumSingleType (IntegralNumType ty)) | IntegralDict <- integralDict ty = eval2 (NonNumSingleType TypeBool) (>) +evalGt (NumSingleType (FloatingNumType ty)) | FloatingDict <- floatingDict ty = eval2 (NonNumSingleType TypeBool) (>) +evalGt (NonNumSingleType ty) | NonNumDict <- nonNumDict ty = eval2 (NonNumSingleType TypeBool) (>) -- evalGt (SingleScalarType s) = -- case s of @@ -657,9 +648,9 @@ evalGt (NonNumSingleType ty) | NonNumDict <- nonNumDict ty = -- NonNumSingleType t | NonNumDict <- nonNumDict t -> eval2 (>) evalLtEq :: SingleType a -> (a,a) :-> Bool -evalLtEq (NumSingleType (IntegralNumType ty)) | IntegralDict <- integralDict ty = eval2 (<=) -evalLtEq (NumSingleType (FloatingNumType ty)) | FloatingDict <- floatingDict ty = eval2 (<=) -evalLtEq (NonNumSingleType ty) | NonNumDict <- nonNumDict ty = eval2 (<=) +evalLtEq (NumSingleType (IntegralNumType ty)) | IntegralDict <- integralDict ty = eval2 (NonNumSingleType TypeBool) (<=) +evalLtEq (NumSingleType (FloatingNumType ty)) | FloatingDict <- floatingDict ty = eval2 (NonNumSingleType TypeBool) (<=) +evalLtEq (NonNumSingleType ty) | NonNumDict <- nonNumDict ty = eval2 (NonNumSingleType TypeBool) (<=) -- evalLtEq (SingleScalarType s) = -- case s of @@ -668,9 +659,9 @@ evalLtEq (NonNumSingleType ty) | NonNumDict <- nonNumDict ty -- NonNumSingleType t | NonNumDict <- nonNumDict t -> eval2 (<=) evalGtEq :: SingleType a -> (a,a) :-> Bool -evalGtEq (NumSingleType (IntegralNumType ty)) | IntegralDict <- integralDict ty = eval2 (>=) -evalGtEq (NumSingleType (FloatingNumType ty)) | FloatingDict <- floatingDict ty = eval2 (>=) -evalGtEq (NonNumSingleType ty) | NonNumDict <- nonNumDict ty = eval2 (>=) +evalGtEq (NumSingleType (IntegralNumType ty)) | IntegralDict <- integralDict ty = eval2 (NonNumSingleType TypeBool) (>=) +evalGtEq (NumSingleType (FloatingNumType ty)) | FloatingDict <- floatingDict ty = eval2 (NonNumSingleType TypeBool) (>=) +evalGtEq (NonNumSingleType ty) | NonNumDict <- nonNumDict ty = eval2 (NonNumSingleType TypeBool) (>=) -- evalGtEq (SingleScalarType s) = -- case s of @@ -679,9 +670,9 @@ evalGtEq (NonNumSingleType ty) | NonNumDict <- nonNumDict ty -- NonNumSingleType t | NonNumDict <- nonNumDict t -> eval2 (>=) evalEq :: SingleType a -> (a,a) :-> Bool -evalEq (NumSingleType (IntegralNumType ty)) | IntegralDict <- integralDict ty = eval2 (==) -evalEq (NumSingleType (FloatingNumType ty)) | FloatingDict <- floatingDict ty = eval2 (==) -evalEq (NonNumSingleType ty) | NonNumDict <- nonNumDict ty = eval2 (==) +evalEq (NumSingleType (IntegralNumType ty)) | IntegralDict <- integralDict ty = eval2 (NonNumSingleType TypeBool) (==) +evalEq (NumSingleType (FloatingNumType ty)) | FloatingDict <- floatingDict ty = eval2 (NonNumSingleType TypeBool) (==) +evalEq (NonNumSingleType ty) | NonNumDict <- nonNumDict ty = eval2 (NonNumSingleType TypeBool) (==) -- evalEq (SingleScalarType s) = -- case s of @@ -690,9 +681,9 @@ evalEq (NonNumSingleType ty) | NonNumDict <- nonNumDict ty = -- NonNumSingleType t | NonNumDict <- nonNumDict t -> eval2 (==) evalNEq :: SingleType a -> (a,a) :-> Bool -evalNEq (NumSingleType (IntegralNumType ty)) | IntegralDict <- integralDict ty = eval2 (/=) -evalNEq (NumSingleType (FloatingNumType ty)) | FloatingDict <- floatingDict ty = eval2 (/=) -evalNEq (NonNumSingleType ty) | NonNumDict <- nonNumDict ty = eval2 (/=) +evalNEq (NumSingleType (IntegralNumType ty)) | IntegralDict <- integralDict ty = eval2 (NonNumSingleType TypeBool) (/=) +evalNEq (NumSingleType (FloatingNumType ty)) | FloatingDict <- floatingDict ty = eval2 (NonNumSingleType TypeBool) (/=) +evalNEq (NonNumSingleType ty) | NonNumDict <- nonNumDict ty = eval2 (NonNumSingleType TypeBool) (/=) -- evalNEq (SingleScalarType s) = -- case s of @@ -700,10 +691,10 @@ evalNEq (NonNumSingleType ty) | NonNumDict <- nonNumDict ty = -- NumSingleType (FloatingNumType t) | FloatingDict <- floatingDict t -> eval2 (/=) -- NonNumSingleType t | NonNumDict <- nonNumDict t -> eval2 (/=) -evalMax :: Elt a => SingleType a -> (a,a) :-> a -evalMax (NumSingleType (IntegralNumType ty)) | IntegralDict <- integralDict ty = eval2 max -evalMax (NumSingleType (FloatingNumType ty)) | FloatingDict <- floatingDict ty = eval2 max -evalMax (NonNumSingleType ty) | NonNumDict <- nonNumDict ty = eval2 max +evalMax :: SingleType a -> (a,a) :-> a +evalMax ty@(NumSingleType (IntegralNumType ty')) | IntegralDict <- integralDict ty' = eval2 ty max +evalMax ty@(NumSingleType (FloatingNumType ty')) | FloatingDict <- floatingDict ty' = eval2 ty max +evalMax ty@(NonNumSingleType ty') | NonNumDict <- nonNumDict ty' = eval2 ty max -- evalMax (SingleScalarType s) = -- case s of @@ -711,10 +702,10 @@ evalMax (NonNumSingleType ty) | NonNumDict <- nonNumDict ty = -- NumSingleType (FloatingNumType t) | FloatingDict <- floatingDict t -> eval2 max -- NonNumSingleType t | NonNumDict <- nonNumDict t -> eval2 max -evalMin :: Elt a => SingleType a -> (a,a) :-> a -evalMin (NumSingleType (IntegralNumType ty)) | IntegralDict <- integralDict ty = eval2 min -evalMin (NumSingleType (FloatingNumType ty)) | FloatingDict <- floatingDict ty = eval2 min -evalMin (NonNumSingleType ty) | NonNumDict <- nonNumDict ty = eval2 min +evalMin :: SingleType a -> (a,a) :-> a +evalMin ty@(NumSingleType (IntegralNumType ty')) | IntegralDict <- integralDict ty' = eval2 ty min +evalMin ty@(NumSingleType (FloatingNumType ty')) | FloatingDict <- floatingDict ty' = eval2 ty min +evalMin ty@(NonNumSingleType ty') | NonNumDict <- nonNumDict ty' = eval2 ty min -- evalMin (SingleScalarType s) = -- case s of @@ -730,11 +721,11 @@ evalLAnd :: (Bool,Bool) :-> Bool evalLAnd (untup2 -> Just (x,y)) env | Just a <- propagate env x = Just $ if a then Stats.ruleFired "True &&" y - else Stats.ruleFired "False &&" $ Const (fromElt False) + else Stats.ruleFired "False &&" $ Const scalarTypeBool False | Just b <- propagate env y = Just $ if b then Stats.ruleFired "True &&" x - else Stats.ruleFired "False &&" $ Const (fromElt False) + else Stats.ruleFired "False &&" $ Const scalarTypeBool False evalLAnd _ _ = Nothing @@ -742,11 +733,11 @@ evalLAnd _ _ evalLOr :: (Bool,Bool) :-> Bool evalLOr (untup2 -> Just (x,y)) env | Just a <- propagate env x - = Just $ if a then Stats.ruleFired "True ||" $ Const (fromElt True) + = Just $ if a then Stats.ruleFired "True ||" $ Const scalarTypeBool True else Stats.ruleFired "False ||" y | Just b <- propagate env y - = Just $ if b then Stats.ruleFired "True ||" $ Const (fromElt True) + = Just $ if b then Stats.ruleFired "True ||" $ Const scalarTypeBool True else Stats.ruleFired "False ||" x evalLOr _ _ @@ -754,49 +745,49 @@ evalLOr _ _ evalLNot :: Bool :-> Bool evalLNot x _ | PrimApp PrimLNot x' <- x = Stats.ruleFired "not/not" $ Just x' -evalLNot x env = eval1 not x env +evalLNot x env = eval1 (NonNumSingleType TypeBool) not x env evalOrd :: Char :-> Int -evalOrd = eval1 ord +evalOrd = eval1 (NumSingleType $ IntegralNumType $ TypeInt) ord evalChr :: Int :-> Char -evalChr = eval1 chr +evalChr = eval1 (NonNumSingleType $ TypeChar) chr evalBoolToInt :: Bool :-> Int -evalBoolToInt = eval1 fromEnum +evalBoolToInt = eval1 (NumSingleType $ IntegralNumType $ TypeInt) fromEnum -evalFromIntegral :: Elt b => IntegralType a -> NumType b -> a :-> b +evalFromIntegral :: IntegralType a -> NumType b -> a :-> b evalFromIntegral ta (IntegralNumType tb) | IntegralDict <- integralDict ta - , IntegralDict <- integralDict tb = eval1 fromIntegral + , IntegralDict <- integralDict tb = eval1 (NumSingleType $ IntegralNumType tb) fromIntegral evalFromIntegral ta (FloatingNumType tb) | IntegralDict <- integralDict ta - , FloatingDict <- floatingDict tb = eval1 fromIntegral + , FloatingDict <- floatingDict tb = eval1 (NumSingleType $ FloatingNumType tb) fromIntegral -evalToFloating :: Elt b => NumType a -> FloatingType b -> a :-> b +evalToFloating :: NumType a -> FloatingType b -> a :-> b evalToFloating (IntegralNumType ta) tb x env | IntegralDict <- integralDict ta - , FloatingDict <- floatingDict tb = eval1 realToFrac x env + , FloatingDict <- floatingDict tb = eval1 (NumSingleType $ FloatingNumType tb) realToFrac x env evalToFloating (FloatingNumType ta) tb x env - | TypeHalf FloatingDict <- ta - , TypeHalf FloatingDict <- tb = Just x + | TypeHalf <- ta + , TypeHalf <- tb = Just x - | TypeFloat FloatingDict <- ta - , TypeFloat FloatingDict <- tb = Just x + | TypeFloat <- ta + , TypeFloat <- tb = Just x - | TypeDouble FloatingDict <- ta - , TypeDouble FloatingDict <- tb = Just x + | TypeDouble <- ta + , TypeDouble <- tb = Just x - | TypeFloat FloatingDict <- ta - , TypeDouble FloatingDict <- tb = eval1 float2Double x env + | TypeFloat <- ta + , TypeDouble <- tb = eval1 (NumSingleType $ FloatingNumType tb) float2Double x env - | TypeDouble FloatingDict <- ta - , TypeFloat FloatingDict <- tb = eval1 double2Float x env + | TypeDouble <- ta + , TypeFloat <- tb = eval1 (NumSingleType $ FloatingNumType tb) double2Float x env | FloatingDict <- floatingDict ta - , FloatingDict <- floatingDict tb = eval1 realToFrac x env + , FloatingDict <- floatingDict tb = eval1 (NumSingleType $ FloatingNumType tb) realToFrac x env -- Scalar primitives diff --git a/src/Data/Array/Accelerate/Trafo/Base.hs b/src/Data/Array/Accelerate/Trafo/Base.hs index 036744f69..cfdaef3a5 100644 --- a/src/Data/Array/Accelerate/Trafo/Base.hs +++ b/src/Data/Array/Accelerate/Trafo/Base.hs @@ -37,23 +37,23 @@ module Data.Array.Accelerate.Trafo.Base ( -- Delayed Arrays DelayedAcc, DelayedOpenAcc(..), DelayedAfun, DelayedOpenAfun, - DelayedExp, DelayedOpenExp, - DelayedFun, DelayedOpenFun, matchDelayedOpenAcc, encodeDelayedOpenAcc, -- Environments Gamma(..), incExp, prjExp, pushExp, Extend(..), pushArrayEnv, append, bind, - Sink(..), sink, sink1, - Supplement(..), bindExps, - - leftHandSideChangeEnv, + Sink(..), SinkExp(..), sinkA, sink1, + OpenExp', bindExps, -- Adding new variables to the environment - declareArrays, DeclareArrays(..), + declareVars, DeclareVars(..), + + -- Checks + isIdentity, isIdentityIndexing, - aletBodyIsTrivial, + -- Utilities + mkIntersect, mkUnion, ) where -- standard library @@ -68,9 +68,10 @@ import Prelude hiding ( until ) -- friends import Data.Array.Accelerate.AST hiding ( Val(..) ) +import Data.Array.Accelerate.Type import Data.Array.Accelerate.Analysis.Hash import Data.Array.Accelerate.Analysis.Match -import Data.Array.Accelerate.Array.Sugar ( Array, Arrays, ArraysR(..), Shape, Elt ) +import Data.Array.Accelerate.Array.Representation import Data.Array.Accelerate.Error import Data.Array.Accelerate.Trafo.Substitution @@ -102,15 +103,15 @@ encodeOpenAcc :: EncodeAcc OpenAcc encodeOpenAcc options (OpenAcc pacc) = encodePreOpenAcc options encodeAcc pacc matchOpenAcc :: MatchAcc OpenAcc -matchOpenAcc (OpenAcc pacc1) (OpenAcc pacc2) = matchPreOpenAcc matchAcc encodeAcc pacc1 pacc2 +matchOpenAcc (OpenAcc pacc1) (OpenAcc pacc2) = matchPreOpenAcc matchAcc pacc1 pacc2 avarIn :: forall acc aenv a. Kit acc => ArrayVar aenv a -> acc aenv a -avarIn v@ArrayVar{} = inject $ Avar v +avarIn v@(Var ArrayR{} _) = inject $ Avar v avarsIn :: forall acc aenv arrs. Kit acc => ArrayVars aenv arrs -> acc aenv arrs -avarsIn ArrayVarsNil = inject Anil -avarsIn (ArrayVarsArray v) = avarIn v -avarsIn (ArrayVarsPair a b) = inject $ avarsIn a `Apair` avarsIn b +avarsIn VarsNil = inject Anil +avarsIn (VarsSingle v) = avarIn v +avarsIn (VarsPair a b) = inject $ avarsIn a `Apair` avarsIn b kmap :: Kit acc => (PreOpenAcc acc aenv a -> PreOpenAcc acc aenv b) -> acc aenv a -> acc aenv b kmap f = inject . f . fromJust . extract @@ -118,31 +119,30 @@ kmap f = inject . f . fromJust . extract extractArrayVars :: Kit acc => acc aenv a -> Maybe (ArrayVars aenv a) extractArrayVars (extract -> Just acc) = case acc of Apair (extractArrayVars -> Just a) (extractArrayVars -> Just b) - -> Just $ ArrayVarsPair a b + -> Just $ VarsPair a b Anil - -> Just ArrayVarsNil + -> Just VarsNil Avar v - -> Just $ ArrayVarsArray v + -> Just $ VarsSingle v _ -> Nothing extractArrayVars _ = Nothing -data DeclareArrays arrs aenv where - DeclareArrays - :: LeftHandSide arrs aenv aenv' - -> (aenv :> aenv') - -> (forall aenv''. aenv' :> aenv'' -> ArrayVars aenv'' arrs) - -> DeclareArrays arrs aenv - -declareArrays :: ArraysR arrs -> DeclareArrays arrs aenv -declareArrays ArraysRarray - = DeclareArrays LeftHandSideArray SuccIdx $ \k -> ArrayVarsArray $ ArrayVar $ k ZeroIdx -declareArrays ArraysRunit - = DeclareArrays (LeftHandSideWildcard ArraysRunit) id $ const $ ArrayVarsNil -declareArrays (ArraysRpair r1 r2) = case declareArrays r1 of - DeclareArrays lhs1 subst1 a1 -> case declareArrays r2 of - DeclareArrays lhs2 subst2 a2 -> - DeclareArrays (LeftHandSidePair lhs1 lhs2) (subst2 . subst1) $ \k -> a1 (k . subst2) `ArrayVarsPair` a2 k +data DeclareVars s t aenv where + DeclareVars + :: LeftHandSide s t env env' + -> (env :> env') + -> (forall env''. env' :> env'' -> Vars s env'' t) + -> DeclareVars s t env +declareVars :: TupR s t -> DeclareVars s t env +declareVars (TupRsingle s) + = DeclareVars (LeftHandSideSingle s) (weakenSucc weakenId) $ \k -> VarsSingle $ Var s $ k >:> ZeroIdx +declareVars TupRunit + = DeclareVars (LeftHandSideWildcard TupRunit) weakenId $ const $ VarsNil +declareVars (TupRpair r1 r2) + | DeclareVars lhs1 subst1 a1 <- declareVars r1 + , DeclareVars lhs2 subst2 a2 <- declareVars r2 + = DeclareVars (LeftHandSidePair lhs1 lhs2) (subst2 .> subst1) $ \k -> a1 (k .> subst2) `VarsPair` a2 k -- fromOpenAfun :: Kit acc => OpenAfun aenv f -> PreOpenAfun acc aenv f @@ -159,23 +159,32 @@ instance Match (Idx env) where {-# INLINEABLE match #-} match = matchIdx -instance Match (ArrayVar env) where +instance Match (Var s env) where {-# INLINEABLE match #-} - match (ArrayVar a) (ArrayVar b) + match (Var _ a) (Var _ b) | Just Refl <- match a b = Just Refl | otherwise = Nothing -instance Kit acc => Match (PreOpenExp acc env aenv) where +instance Match ScalarType where + match = matchScalarType + +instance Match ArrayR where + match = matchArrayR + +instance Match a => Match (TupR a) where + match = matchTupR match + +instance Match (OpenExp env aenv) where {-# INLINEABLE match #-} - match = matchPreOpenExp matchAcc encodeAcc + match = matchOpenExp -instance Kit acc => Match (PreOpenFun acc env aenv) where +instance Match (OpenFun env aenv) where {-# INLINEABLE match #-} - match = matchPreOpenFun matchAcc encodeAcc + match = matchOpenFun instance Kit acc => Match (PreOpenAcc acc aenv) where {-# INLINEABLE match #-} - match = matchPreOpenAcc matchAcc encodeAcc + match = matchPreOpenAcc matchAcc instance {-# INCOHERENT #-} Kit acc => Match (acc aenv) where {-# INLINEABLE match #-} @@ -192,40 +201,37 @@ instance {-# INCOHERENT #-} Kit acc => Match (acc aenv) where type DelayedAcc = DelayedOpenAcc () type DelayedAfun = PreOpenAfun DelayedOpenAcc () -type DelayedExp = DelayedOpenExp () -type DelayedFun = DelayedOpenFun () - -- data DelayedSeq t where -- DelayedSeq :: Extend DelayedOpenAcc () aenv -- -> DelayedOpenSeq aenv () t -- -> DelayedSeq t type DelayedOpenAfun = PreOpenAfun DelayedOpenAcc -type DelayedOpenExp = PreOpenExp DelayedOpenAcc -type DelayedOpenFun = PreOpenFun DelayedOpenAcc -- type DelayedOpenSeq = PreOpenSeq DelayedOpenAcc data DelayedOpenAcc aenv a where Manifest :: PreOpenAcc DelayedOpenAcc aenv a -> DelayedOpenAcc aenv a - Delayed :: (Shape sh, Elt e) => - { extentD :: PreExp DelayedOpenAcc aenv sh - , indexD :: PreFun DelayedOpenAcc aenv (sh -> e) - , linearIndexD :: PreFun DelayedOpenAcc aenv (Int -> e) + Delayed :: + { reprD :: ArrayR (Array sh e) + , extentD :: Exp aenv sh + , indexD :: Fun aenv (sh -> e) + , linearIndexD :: Fun aenv (Int -> e) } -> DelayedOpenAcc aenv (Array sh e) instance HasArraysRepr DelayedOpenAcc where arraysRepr (Manifest a) = arraysRepr a - arraysRepr Delayed{} = ArraysRarray + arraysRepr Delayed{..} = TupRsingle reprD instance Rebuildable DelayedOpenAcc where type AccClo DelayedOpenAcc = DelayedOpenAcc {-# INLINEABLE rebuildPartial #-} rebuildPartial v acc = case acc of Manifest pacc -> Manifest <$> rebuildPartial v pacc - Delayed{..} -> Delayed <$> rebuildPartial v extentD - <*> rebuildPartial v indexD - <*> rebuildPartial v linearIndexD + Delayed{..} -> (\e i l -> Delayed reprD (unOpenAccExp e) (unOpenAccFun i) (unOpenAccFun l)) + <$> rebuildPartial v (OpenAccExp extentD) + <*> rebuildPartial v (OpenAccFun indexD) + <*> rebuildPartial v (OpenAccFun linearIndexD) instance Sink DelayedOpenAcc where weaken k = Stats.substitution "weaken" . rebuildA (rebuildWeakenVar k) @@ -252,11 +258,11 @@ instance NFData (DelayedOpenAcc aenv t) where encodeDelayedOpenAcc :: EncodeAcc DelayedOpenAcc encodeDelayedOpenAcc options acc = let - travE :: DelayedExp aenv sh -> Builder - travE = encodePreOpenExp options encodeDelayedOpenAcc + travE :: Exp aenv sh -> Builder + travE = encodeOpenExp - travF :: DelayedFun aenv f -> Builder - travF = encodePreOpenFun options encodeDelayedOpenAcc + travF :: Fun aenv f -> Builder + travF = encodeOpenFun travA :: PreOpenAcc DelayedOpenAcc aenv a -> Builder travA = encodePreOpenAcc options encodeDelayedOpenAcc @@ -266,28 +272,29 @@ encodeDelayedOpenAcc options acc = | otherwise = encodeArraysType . arraysRepr in case acc of - Manifest pacc -> intHost $(hashQ ("Manifest" :: String)) <> deepA pacc - Delayed sh f g -> intHost $(hashQ ("Delayed" :: String)) <> travE sh <> travF f <> travF g + Manifest pacc -> intHost $(hashQ ("Manifest" :: String)) <> deepA pacc + Delayed _ sh f g -> intHost $(hashQ ("Delayed" :: String)) <> travE sh <> travF f <> travF g {-# INLINEABLE matchDelayedOpenAcc #-} matchDelayedOpenAcc :: MatchAcc DelayedOpenAcc matchDelayedOpenAcc (Manifest pacc1) (Manifest pacc2) - = matchPreOpenAcc matchDelayedOpenAcc encodeDelayedOpenAcc pacc1 pacc2 + = matchPreOpenAcc matchDelayedOpenAcc pacc1 pacc2 -matchDelayedOpenAcc (Delayed sh1 ix1 lx1) (Delayed sh2 ix2 lx2) - | Just Refl <- matchPreOpenExp matchDelayedOpenAcc encodeDelayedOpenAcc sh1 sh2 - , Just Refl <- matchPreOpenFun matchDelayedOpenAcc encodeDelayedOpenAcc ix1 ix2 - , Just Refl <- matchPreOpenFun matchDelayedOpenAcc encodeDelayedOpenAcc lx1 lx2 +matchDelayedOpenAcc (Delayed _ sh1 ix1 lx1) (Delayed _ sh2 ix2 lx2) + | Just Refl <- matchOpenExp sh1 sh2 + , Just Refl <- matchOpenFun ix1 ix2 + , Just Refl <- matchOpenFun lx1 lx2 = Just Refl matchDelayedOpenAcc _ _ = Nothing rnfDelayedOpenAcc :: DelayedOpenAcc aenv t -> () -rnfDelayedOpenAcc (Manifest pacc) = rnfPreOpenAcc rnfDelayedOpenAcc pacc -rnfDelayedOpenAcc (Delayed sh ix lx) = rnfPreOpenExp rnfDelayedOpenAcc sh - `seq` rnfPreOpenFun rnfDelayedOpenAcc ix - `seq` rnfPreOpenFun rnfDelayedOpenAcc lx +rnfDelayedOpenAcc (Manifest pacc) = rnfPreOpenAcc rnfDelayedOpenAcc pacc +rnfDelayedOpenAcc (Delayed repr sh ix lx) = rnfArrayR repr + `seq` rnfOpenExp sh + `seq` rnfOpenFun ix + `seq` rnfOpenFun lx {-- rnfDelayedSeq :: DelayedSeq t -> () @@ -307,19 +314,18 @@ rnfExtend rnfA (PushEnv env a) = rnfExtend rnfA env `seq` rnfA a -- environment variable env' is used to project out the corresponding -- index when looking up in the environment congruent expressions. -- -data Gamma acc env env' aenv where - EmptyExp :: Gamma acc env env' aenv +data Gamma env env' aenv where + EmptyExp :: Gamma env env' aenv - PushExp :: Elt t - => Gamma acc env env' aenv - -> WeakPreOpenExp acc env aenv t - -> Gamma acc env (env', t) aenv + PushExp :: Gamma env env' aenv + -> WeakOpenExp env aenv t + -> Gamma env (env', t) aenv -data WeakPreOpenExp acc env aenv t where +data WeakOpenExp env aenv t where Subst :: env :> env' - -> PreOpenExp acc env aenv t - -> PreOpenExp acc env' aenv t {- LAZY -} - -> WeakPreOpenExp acc env' aenv t + -> OpenExp env aenv t + -> OpenExp env' aenv t {- LAZY -} + -> WeakOpenExp env' aenv t -- XXX: The simplifier calls this function every time it moves under a let -- binding. This means we have a number of calls to 'weakenE' exponential in the @@ -333,31 +339,26 @@ data WeakPreOpenExp acc env aenv t where -- -- incExp - :: Kit acc - => Gamma acc env env' aenv - -> Gamma acc (env,s) env' aenv + :: Gamma env env' aenv + -> Gamma (env,s) env' aenv incExp EmptyExp = EmptyExp incExp (PushExp env w) = incExp env `PushExp` subs w where - subs :: forall acc env aenv s t. Kit acc => WeakPreOpenExp acc env aenv t -> WeakPreOpenExp acc (env,s) aenv t - subs (Subst k (e :: PreOpenExp acc env_ aenv t) _) = Subst k' e (weakenE k' e) - where - k' :: env_ :> (env,s) - k' = SuccIdx . k + subs :: forall env aenv s t. WeakOpenExp env aenv t -> WeakOpenExp (env,s) aenv t + subs (Subst k (e :: OpenExp env_ aenv t) _) = Subst (weakenSucc' k) e (weakenE (weakenSucc' k) e) -prjExp :: Idx env' t -> Gamma acc env env' aenv -> PreOpenExp acc env aenv t +prjExp :: Idx env' t -> Gamma env env' aenv -> OpenExp env aenv t prjExp ZeroIdx (PushExp _ (Subst _ _ e)) = e prjExp (SuccIdx ix) (PushExp env _) = prjExp ix env prjExp _ _ = $internalError "prjExp" "inconsistent valuation" -pushExp :: Elt t => Gamma acc env env' aenv -> PreOpenExp acc env aenv t -> Gamma acc env (env',t) aenv -pushExp env e = env `PushExp` Subst id e e +pushExp :: Gamma env env' aenv -> OpenExp env aenv t -> Gamma env (env',t) aenv +pushExp env e = env `PushExp` Subst weakenId e e {-- lookupExp - :: Kit acc - => Gamma acc env env' aenv - -> PreOpenExp acc env aenv t + :: Gamma env env' aenv + -> OpenExp env aenv t -> Maybe (Idx env' t) lookupExp EmptyExp _ = Nothing lookupExp (PushExp env e) x @@ -365,19 +366,18 @@ lookupExp (PushExp env e) x | otherwise = SuccIdx `fmap` lookupExp env x weakenGamma1 - :: Kit acc - => Gamma acc env env' aenv - -> Gamma acc env env' (aenv,t) + :: Gamma env env' aenv + -> Gamma env env' (aenv,t) weakenGamma1 EmptyExp = EmptyExp weakenGamma1 (PushExp env e) = PushExp (weakenGamma1 env) (weaken SuccIdx e) sinkGamma :: Kit acc => Extend acc aenv aenv' - -> Gamma acc env env' aenv - -> Gamma acc env env' aenv' + -> Gamma env env' aenv + -> Gamma env env' aenv' sinkGamma _ EmptyExp = EmptyExp -sinkGamma ext (PushExp env e) = PushExp (sinkGamma ext env) (sink ext e) +sinkGamma ext (PushExp env e) = PushExp (sinkGamma ext env) (sinkA ext e) --} -- As part of various transformations we often need to lift out array valued @@ -386,90 +386,90 @@ sinkGamma ext (PushExp env e) = PushExp (sinkGamma ext env) (sink ext e) -- The Extend type is a heterogeneous snoc-list of array terms that witnesses -- how the array environment is extended by binding these additional terms. -- -data Extend acc aenv aenv' where - BaseEnv :: Extend acc aenv aenv +data Extend s f env env' where + BaseEnv :: Extend s f env env - PushEnv :: Extend acc aenv aenv' - -> LeftHandSide arrs aenv' aenv'' - -> acc aenv' arrs - -> Extend acc aenv aenv'' + PushEnv :: Extend s f env env' + -> LeftHandSide s t env' env'' + -> f env' t + -> Extend s f env env'' + +pushArrayEnv :: HasArraysRepr acc => Extend ArrayR acc aenv aenv' -> acc aenv' (Array sh e) -> Extend ArrayR acc aenv (aenv', Array sh e) +pushArrayEnv env a = PushEnv env (LeftHandSideSingle $ arrayRepr a) a -pushArrayEnv :: (Shape sh, Elt e) => Extend acc aenv aenv' -> acc aenv' (Array sh e) -> Extend acc aenv (aenv', Array sh e) -pushArrayEnv env a = PushEnv env LeftHandSideArray a -- Append two environment witnesses -- -append :: Extend acc env env' -> Extend acc env' env'' -> Extend acc env env'' +append :: Extend s acc env env' -> Extend s acc env' env'' -> Extend s acc env env'' append x BaseEnv = x append x (PushEnv e lhs a) = PushEnv (append x e) lhs a -- Bring into scope all of the array terms in the Extend environment list. This -- converts a term in the inner environment (aenv') into the outer (aenv). -- -bind :: (Kit acc, Arrays a) - => Extend acc aenv aenv' +bind :: Kit acc + => Extend ArrayR acc aenv aenv' -> PreOpenAcc acc aenv' a -> PreOpenAcc acc aenv a -bind BaseEnv = id -bind (PushEnv env lhs a) = bind env . Alet lhs a . inject +bind BaseEnv = id +bind (PushEnv g lhs a) = bind g . Alet lhs a . inject -- Sink a term from one array environment into another, where additional -- bindings have come into scope according to the witness and no old things have -- vanished. -- -sink :: Sink f => Extend acc env env' -> f env t -> f env' t -sink env = weaken (k env) - where - k :: Extend acc env env' -> Idx env t -> Idx env' t - k BaseEnv = Stats.substitution "sink" id - k (PushEnv e (LeftHandSideWildcard _) _) = k e - k (PushEnv e (LeftHandSideArray) _) = SuccIdx . k e - k (PushEnv e (LeftHandSidePair l1 l2) _) = k (PushEnv (PushEnv e l1 undefined) l2 undefined) - -sink1 :: Sink f => Extend acc env env' -> f (env,s) t -> f (env',s) t -sink1 env = weaken (k env) +sinkA :: Sink f => Extend s acc env env' -> f env t -> f env' t +sinkA env = weaken (sinkWeaken env) -- TODO: Fix Stats sinkA vs sink1 + +sinkWeaken :: Extend s acc env env' -> env :> env' +sinkWeaken BaseEnv = Stats.substitution "sink" weakenId +sinkWeaken (PushEnv e (LeftHandSideWildcard _) _) = sinkWeaken e +sinkWeaken (PushEnv e (LeftHandSideSingle _) _) = weakenSucc' $ sinkWeaken e +sinkWeaken (PushEnv e (LeftHandSidePair l1 l2) _) = sinkWeaken (PushEnv (PushEnv e l1 undefined) l2 undefined) + +sink1 :: Sink f => Extend s acc env env' -> f (env,t') t -> f (env',t') t +sink1 env = weaken $ sink $ sinkWeaken env + +-- Wrapper around OpenExp, with the order of type arguments env and aenv flipped +newtype OpenExp' aenv env e = OpenExp' (OpenExp env aenv e) + +bindExps :: Extend ScalarType (OpenExp' aenv) env env' + -> OpenExp env' aenv e + -> OpenExp env aenv e +bindExps BaseEnv = id +bindExps (PushEnv g lhs (OpenExp' b)) = bindExps g . Let lhs b + + +-- Utilities for working with shapes +mkShapeBinary :: (forall env'. OpenExp env' aenv Int -> OpenExp env' aenv Int -> OpenExp env' aenv Int) + -> ShapeR sh + -> OpenExp env aenv sh + -> OpenExp env aenv sh + -> OpenExp env aenv sh +mkShapeBinary _ ShapeRz _ _ = Nil +mkShapeBinary f (ShapeRsnoc shr) (Pair as a) (Pair bs b) = mkShapeBinary f shr as bs `Pair` f a b +mkShapeBinary f shr (Let lhs bnd a) b = Let lhs bnd $ mkShapeBinary f shr a (weakenE (weakenWithLHS lhs) b) +mkShapeBinary f shr a (Let lhs bnd b) = Let lhs bnd $ mkShapeBinary f shr (weakenE (weakenWithLHS lhs) a) b +mkShapeBinary f shr a b@Pair{} -- `a` is not Pair + | DeclareVars lhs k value <- declareVars $ shapeType shr + = Let lhs a $ mkShapeBinary f shr (evars $ value weakenId) (weakenE k b) +mkShapeBinary f shr a b -- `b` is not a Pair + | DeclareVars lhs k value <- declareVars $ shapeType shr + = Let lhs b $ mkShapeBinary f shr (weakenE k a) (evars $ value weakenId) + +mkIntersect :: ShapeR sh + -> OpenExp env aenv sh + -> OpenExp env aenv sh + -> OpenExp env aenv sh +mkIntersect = mkShapeBinary f where - k :: Extend acc env env' -> Idx (env,s) t -> Idx (env',s) t - k BaseEnv = Stats.substitution "sink1" id - k (PushEnv e (LeftHandSideWildcard _) _) = k e - k (PushEnv e (LeftHandSideArray) _) = split . k e - k (PushEnv e (LeftHandSidePair l1 l2) _) = k (PushEnv (PushEnv e l1 undefined) l2 undefined) - - split :: Idx (env,s) t -> Idx ((env,u),s) t - split ZeroIdx = ZeroIdx - split (SuccIdx ix) = SuccIdx (SuccIdx ix) + f a b = PrimApp (PrimMin singleType) $ Pair a b +mkUnion :: ShapeR sh + -> OpenExp env aenv sh + -> OpenExp env aenv sh + -> OpenExp env aenv sh +mkUnion = mkShapeBinary f + where + f a b = PrimApp (PrimMax singleType) $ Pair a b --- This is the same as Extend, but for the scalar environment. --- -data Supplement acc env env' aenv where - BaseSup :: Supplement acc env env aenv - - PushSup :: Elt e - => Supplement acc env env' aenv - -> PreOpenExp acc env' aenv e - -> Supplement acc env (env', e) aenv - -bindExps :: (Kit acc, Elt e) - => Supplement acc env env' aenv - -> PreOpenExp acc env' aenv e - -> PreOpenExp acc env aenv e -bindExps BaseSup = id -bindExps (PushSup g b) = bindExps g . Let b - -leftHandSideChangeEnv :: LeftHandSide arrs env1 env2 -> Exists (LeftHandSide arrs env3) -leftHandSideChangeEnv (LeftHandSideWildcard repr) = Exists $ LeftHandSideWildcard repr -leftHandSideChangeEnv LeftHandSideArray = Exists $ LeftHandSideArray -leftHandSideChangeEnv (LeftHandSidePair l1 l2) = case leftHandSideChangeEnv l1 of - Exists l1' -> case leftHandSideChangeEnv l2 of - Exists l2' -> Exists $ LeftHandSidePair l1' l2' - -aletBodyIsTrivial :: forall acc aenv aenv' a b. Kit acc => LeftHandSide a aenv aenv' -> acc aenv' b -> Maybe (a :~: b) -aletBodyIsTrivial lhs rhs = case extractArrayVars rhs of - Just vars -> case declareArrays @a @aenv (lhsToArraysR lhs) of - DeclareArrays lhs' _ value - | Just Refl <- matchLeftHandSide lhs lhs' - , Just Refl <- matchArrayVars vars $ value id - -> Just Refl - _ -> Nothing - Nothing -> Nothing diff --git a/src/Data/Array/Accelerate/Trafo/Config.hs b/src/Data/Array/Accelerate/Trafo/Config.hs index a3f984023..488d02001 100644 --- a/src/Data/Array/Accelerate/Trafo/Config.hs +++ b/src/Data/Array/Accelerate/Trafo/Config.hs @@ -17,7 +17,7 @@ module Data.Array.Accelerate.Trafo.Config ( defaultOptions, -- Other options not controlled by the command line flags - float_out_acc, + -- float_out_acc, ) where @@ -46,5 +46,5 @@ defaultOptions = unsafePerformIO $! -- Extra options not covered by command line flags -- -float_out_acc = Flag 31 +-- float_out_acc = Flag 31 diff --git a/src/Data/Array/Accelerate/Trafo/Fusion.hs b/src/Data/Array/Accelerate/Trafo/Fusion.hs index 38a0d7404..174064e2f 100644 --- a/src/Data/Array/Accelerate/Trafo/Fusion.hs +++ b/src/Data/Array/Accelerate/Trafo/Fusion.hs @@ -38,7 +38,6 @@ module Data.Array.Accelerate.Trafo.Fusion ( -- ** Types DelayedAcc, DelayedOpenAcc(..), DelayedAfun, DelayedOpenAfun, - DelayedExp, DelayedFun, DelayedOpenExp, DelayedOpenFun, -- ** Conversion convertAcc, convertAccWith, @@ -58,9 +57,8 @@ import Data.Array.Accelerate.Trafo.Config import Data.Array.Accelerate.Trafo.Shrink import Data.Array.Accelerate.Trafo.Simplify import Data.Array.Accelerate.Trafo.Substitution -import Data.Array.Accelerate.Array.Representation ( SliceIndex(..) ) -import Data.Array.Accelerate.Array.Sugar ( Array, ArraysR(..), arraysRtuple2 - , Elt, EltRepr, Shape, Tuple(..), eltType ) +import Data.Array.Accelerate.Array.Representation hiding (fromIndex, toIndex, shape) +import Data.Array.Accelerate.Analysis.Match import Data.Array.Accelerate.Type import Data.Array.Accelerate.Debug.Flags ( array_fusion ) @@ -129,26 +127,19 @@ convertOpenAcc config = manifest config . computeAcc . embedOpenAcc config -- representation. It is safe to match on BaseEnv because the first pass -- will put producers adjacent to the term consuming it. -- -delayed :: (Shape sh, Elt e) => Config -> OpenAcc aenv (Array sh e) -> DelayedOpenAcc aenv (Array sh e) +delayed :: Config -> OpenAcc aenv (Array sh e) -> DelayedOpenAcc aenv (Array sh e) delayed config (embedOpenAcc config -> Embed env cc) | BaseEnv <- env = case simplify cc of Done v -> avarsIn v - Yield (cvtE -> sh) (cvtF -> f) -> Delayed sh f (f `compose` fromIndex sh) - Step (cvtE -> sh) (cvtF -> p) (cvtF -> f) v + Yield repr sh f -> Delayed repr sh f (f `compose` fromIndex (arrayRshape repr) sh) + Step repr sh p f v | Just Refl <- match sh (arrayShape v) - , Just Refl <- isIdentity p -> Delayed sh (f `compose` indexArray v) (f `compose` linearIndex v) - | f' <- f `compose` indexArray v `compose` p -> Delayed sh f' (f' `compose` fromIndex sh) + , Just Refl <- isIdentity p -> Delayed repr sh (f `compose` indexArray v) (f `compose` linearIndex v) + | f' <- f `compose` indexArray v `compose` p -> Delayed repr sh f' (f' `compose` fromIndex (arrayRshape repr) sh) -- | otherwise = manifest config (computeAcc (Embed env cc)) - where - cvtE :: OpenExp env aenv t -> DelayedOpenExp env aenv t - cvtE = convertOpenExp config - - cvtF :: OpenFun env aenv f -> DelayedOpenFun env aenv f - cvtF (Lam f) = Lam (cvtF f) - cvtF (Body b) = Body (cvtE b) -- Convert array programs as manifest terms. @@ -161,15 +152,15 @@ manifest config (OpenAcc pacc) = -- Non-fusible terms -- ----------------- Avar ix -> Avar ix - Use arr -> Use arr - Unit e -> Unit (cvtE e) + Use repr arr -> Use repr arr + Unit tp e -> Unit tp e Alet lhs bnd body -> alet lhs (manifest config bnd) (manifest config body) - Acond p t e -> Acond (cvtE p) (manifest config t) (manifest config e) + Acond p t e -> Acond p (manifest config t) (manifest config e) Awhile p f a -> Awhile (cvtAF p) (cvtAF f) (manifest config a) Apair a1 a2 -> Apair (manifest config a1) (manifest config a2) Anil -> Anil - Apply f a -> apply (cvtAF f) (manifest config a) - Aforeign ff f a -> Aforeign ff (cvtAF f) (manifest config a) + Apply repr f a -> apply repr (cvtAF f) (manifest config a) + Aforeign repr ff f a -> Aforeign repr ff (cvtAF f) (manifest config a) -- Producers -- --------- @@ -179,11 +170,11 @@ manifest config (OpenAcc pacc) = -- of a let-binding to be used multiple times. The input array here -- should be a evaluated array term, else something went wrong. -- - Map f a -> Map (cvtF f) (delayed config a) - Generate sh f -> Generate (cvtE sh) (cvtF f) - Transform sh p f a -> Transform (cvtE sh) (cvtF p) (cvtF f) (delayed config a) - Backpermute sh p a -> Backpermute (cvtE sh) (cvtF p) (delayed config a) - Reshape sl a -> Reshape (cvtE sl) (manifest config a) + Map tp f a -> Map tp f (delayed config a) + Generate repr sh f -> Generate repr sh f + Transform repr sh p f a -> Transform repr sh p f (delayed config a) + Backpermute shr sh p a -> Backpermute shr sh p (delayed config a) + Reshape slr sl a -> Reshape slr sl (manifest config a) Replicate{} -> fusionError Slice{} -> fusionError @@ -196,31 +187,33 @@ manifest config (OpenAcc pacc) = -- with local bindings, these will have been floated up above the -- consumer already -- - Fold f z a -> Fold (cvtF f) (cvtE z) (delayed config a) - Fold1 f a -> Fold1 (cvtF f) (delayed config a) - FoldSeg f z a s -> FoldSeg (cvtF f) (cvtE z) (delayed config a) (delayed config s) - Fold1Seg f a s -> Fold1Seg (cvtF f) (delayed config a) (delayed config s) - Scanl f z a -> Scanl (cvtF f) (cvtE z) (delayed config a) - Scanl1 f a -> Scanl1 (cvtF f) (delayed config a) - Scanl' f z a -> Scanl' (cvtF f) (cvtE z) (delayed config a) - Scanr f z a -> Scanr (cvtF f) (cvtE z) (delayed config a) - Scanr1 f a -> Scanr1 (cvtF f) (delayed config a) - Scanr' f z a -> Scanr' (cvtF f) (cvtE z) (delayed config a) - Permute f d p a -> Permute (cvtF f) (manifest config d) (cvtF p) (delayed config a) - Stencil f x a -> Stencil (cvtF f) (cvtB x) (delayed config a) - Stencil2 f x a y b -> Stencil2 (cvtF f) (cvtB x) (delayed config a) (cvtB y) (delayed config b) + Fold f z a -> Fold f z (delayed config a) + Fold1 f a -> Fold1 f (delayed config a) + FoldSeg i f z a s -> FoldSeg i f z (delayed config a) (delayed config s) + Fold1Seg i f a s -> Fold1Seg i f (delayed config a) (delayed config s) + Scanl f z a -> Scanl f z (delayed config a) + Scanl1 f a -> Scanl1 f (delayed config a) + Scanl' f z a -> Scanl' f z (delayed config a) + Scanr f z a -> Scanr f z (delayed config a) + Scanr1 f a -> Scanr1 f (delayed config a) + Scanr' f z a -> Scanr' f z (delayed config a) + Permute f d p a -> Permute f (manifest config d) p (delayed config a) + Stencil s tp f x a -> Stencil s tp f x (delayed config a) + Stencil2 s1 s2 tp f x a y b + -> Stencil2 s1 s2 tp f x (delayed config a) y (delayed config b) -- Collect s -> Collect (cvtS s) where -- Flatten needless let-binds, which can be introduced by the -- conversion to the internal embeddable representation. -- - alet :: LeftHandSide a aenv aenv' + alet :: ALeftHandSide a aenv aenv' -> DelayedOpenAcc aenv a -> DelayedOpenAcc aenv' b -> PreOpenAcc DelayedOpenAcc aenv b alet lhs bnd body - | Just Refl <- aletBodyIsTrivial lhs body + | Just bodyVars <- extractArrayVars body + , Just Refl <- bindingIsTrivial lhs bodyVars , Manifest x <- bnd = x -- @@ -234,17 +227,19 @@ manifest config (OpenAcc pacc) = -- > compute :: Acc a -> Acc a -- > compute = id >-> id -- - apply :: PreOpenAfun DelayedOpenAcc aenv (a -> b) + apply :: ArraysR b + -> PreOpenAfun DelayedOpenAcc aenv (a -> b) -> DelayedOpenAcc aenv a -> PreOpenAcc DelayedOpenAcc aenv b - apply afun x + apply repr afun x | Alam lhs (Abody body) <- afun - , Just Refl <- aletBodyIsTrivial lhs body + , Just bodyVars <- extractArrayVars body + , Just Refl <- bindingIsTrivial lhs bodyVars , Manifest x' <- x = Stats.ruleFired "applyD/identity" x' -- | otherwise - = Apply afun x + = Apply repr afun x cvtAF :: OpenAfun aenv f -> PreOpenAfun DelayedOpenAcc aenv f cvtAF (Alam lhs f) = Alam lhs (cvtAF f) @@ -253,67 +248,6 @@ manifest config (OpenAcc pacc) = -- cvtS :: PreOpenSeq OpenAcc aenv senv s -> PreOpenSeq DelayedOpenAcc aenv senv s -- cvtS = convertOpenSeq config - -- Conversions for closed scalar functions and expressions - -- - cvtF :: OpenFun env aenv f -> DelayedOpenFun env aenv f - cvtF (Lam f) = Lam (cvtF f) - cvtF (Body b) = Body (cvtE b) - - cvtE :: OpenExp env aenv t -> DelayedOpenExp env aenv t - cvtE = convertOpenExp config - - cvtB :: Boundary aenv t -> PreBoundary DelayedOpenAcc aenv t - cvtB Clamp = Clamp - cvtB Mirror = Mirror - cvtB Wrap = Wrap - cvtB (Constant v) = Constant v - cvtB (Function f) = Function (cvtF f) - -convertOpenExp :: Config -> OpenExp env aenv t -> DelayedOpenExp env aenv t -convertOpenExp config exp = - case exp of - Let bnd body -> Let (cvtE bnd) (cvtE body) - Var ix -> Var ix - Const c -> Const c - Undef -> Undef - Tuple tup -> Tuple (cvtT tup) - Prj ix t -> Prj ix (cvtE t) - IndexNil -> IndexNil - IndexCons sh sz -> IndexCons (cvtE sh) (cvtE sz) - IndexHead sh -> IndexHead (cvtE sh) - IndexTail sh -> IndexTail (cvtE sh) - IndexAny -> IndexAny - IndexSlice x ix sh -> IndexSlice x (cvtE ix) (cvtE sh) - IndexFull x ix sl -> IndexFull x (cvtE ix) (cvtE sl) - ToIndex sh ix -> ToIndex (cvtE sh) (cvtE ix) - FromIndex sh ix -> FromIndex (cvtE sh) (cvtE ix) - Cond p t e -> Cond (cvtE p) (cvtE t) (cvtE e) - While p f x -> While (cvtF p) (cvtF f) (cvtE x) - PrimConst c -> PrimConst c - PrimApp f x -> PrimApp f (cvtE x) - Index a sh -> Index (manifest config a) (cvtE sh) - LinearIndex a i -> LinearIndex (manifest config a) (cvtE i) - Shape a -> Shape (manifest config a) - ShapeSize sh -> ShapeSize (cvtE sh) - Intersect s t -> Intersect (cvtE s) (cvtE t) - Union s t -> Union (cvtE s) (cvtE t) - Foreign ff f e -> Foreign ff (cvtF f) (cvtE e) - Coerce e -> Coerce (cvtE e) - where - cvtT :: Tuple (OpenExp env aenv) t -> Tuple (DelayedOpenExp env aenv) t - cvtT NilTup = NilTup - cvtT (SnocTup t e) = cvtT t `SnocTup` cvtE e - - -- Conversions for closed scalar functions and expressions - -- - cvtF :: OpenFun env aenv f -> DelayedOpenFun env aenv f - cvtF (Lam f) = Lam (cvtF f) - cvtF (Body b) = Body (cvtE b) - - cvtE :: OpenExp env aenv t -> DelayedOpenExp env aenv t - cvtE = convertOpenExp config - - convertOpenAfun :: Config -> OpenAfun aenv f -> DelayedOpenAfun aenv f convertOpenAfun c (Alam lhs f) = Alam lhs (convertOpenAfun c f) convertOpenAfun c (Abody b) = Abody (convertOpenAcc c b) @@ -411,16 +345,16 @@ embedPreAcc config embedAcc elimAcc pacc Alet lhs bnd body -> aletD embedAcc elimAcc lhs bnd body Anil -> done $ Anil Acond p at ae -> acondD embedAcc (cvtE p) at ae - Apply f a -> done $ Apply (cvtAF f) (cvtA a) + Apply repr f a -> done $ Apply repr (cvtAF f) (cvtA a) Awhile p f a -> done $ Awhile (cvtAF p) (cvtAF f) (cvtA a) Apair a1 a2 -> done $ Apair (cvtA a1) (cvtA a2) - Aforeign ff f a -> done $ Aforeign ff (cvtAF f) (cvtA a) + Aforeign repr ff f a -> done $ Aforeign repr ff (cvtAF f) (cvtA a) -- Collect s -> collectD s -- Array injection Avar v -> done $ Avar v - Use arrs -> done $ Use arrs - Unit e -> done $ Unit (cvtE e) + Use repr arr -> done $ Use repr arr + Unit tp e -> done $ Unit tp (cvtE e) -- Producers -- --------- @@ -435,16 +369,17 @@ embedPreAcc config embedAcc elimAcc pacc -- independently of all others, and so we can aggressively fuse arbitrary -- sequences of these operations. -- - Generate sh f -> generateD (cvtE sh) (cvtF f) + Generate repr sh f -> generateD repr (cvtE sh) (cvtF f) - Map f a -> mapD (cvtF f) (embedAcc a) - ZipWith f a b -> fuse2 (into zipWithD (cvtF f)) a b - Transform sh p f a -> transformD (cvtE sh) (cvtF p) (cvtF f) (embedAcc a) + Map tp f a -> mapD tp (cvtF f) (embedAcc a) + ZipWith tp f a b -> fuse2 (into (zipWithD tp) (cvtF f)) a b + Transform repr sh p f a -> transformD repr (cvtE sh) (cvtF p) (cvtF f) (embedAcc a) - Backpermute sl p a -> fuse (into2 backpermuteD (cvtE sl) (cvtF p)) a - Slice slix a sl -> fuse (into (sliceD slix) (cvtE sl)) a - Replicate slix sh a -> fuse (into (replicateD slix) (cvtE sh)) a - Reshape sl a -> reshapeD (embedAcc a) (cvtE sl) + Backpermute slr sl p a + -> fuse (into2 (backpermuteD slr) (cvtE sl) (cvtF p)) a + Slice slix a sl -> fuse (into (sliceD slix) (cvtE sl)) a + Replicate slix sh a -> fuse (into (replicateD slix) (cvtE sh)) a + Reshape slr sl a -> reshapeD slr (embedAcc a) (cvtE sl) -- Consumers -- --------- @@ -462,21 +397,23 @@ embedPreAcc config embedAcc elimAcc pacc -- node, so that the producer can be directly embedded into the consumer -- during the code generation phase. -- - Fold f z a -> embed ArraysRarray (into2 Fold (cvtF f) (cvtE z)) a - Fold1 f a -> embed ArraysRarray (into Fold1 (cvtF f)) a - FoldSeg f z a s -> embed2 ArraysRarray (into2 FoldSeg (cvtF f) (cvtE z)) a s - Fold1Seg f a s -> embed2 ArraysRarray (into Fold1Seg (cvtF f)) a s - Scanl f z a -> embed ArraysRarray (into2 Scanl (cvtF f) (cvtE z)) a - Scanl1 f a -> embed ArraysRarray (into Scanl1 (cvtF f)) a - Scanl' f z a -> embed arraysRtuple2 (into2 Scanl' (cvtF f) (cvtE z)) a - Scanr f z a -> embed ArraysRarray (into2 Scanr (cvtF f) (cvtE z)) a - Scanr1 f a -> embed ArraysRarray (into Scanr1 (cvtF f)) a - Scanr' f z a -> embed arraysRtuple2 (into2 Scanr' (cvtF f) (cvtE z)) a - Permute f d p a -> embed2 ArraysRarray (into2 permute (cvtF f) (cvtF p)) d a - Stencil f x a -> embed ArraysRarray (into2 stencil1 (cvtF f) (cvtB x)) a - Stencil2 f x a y b -> embed2 ArraysRarray (into3 stencil2 (cvtF f) (cvtB x) (cvtB y)) a b + Fold f z a -> embed repr (into2 Fold (cvtF f) (cvtE z)) a + Fold1 f a -> embed repr (into Fold1 (cvtF f)) a + FoldSeg i f z a s -> embed2 repr (into2 (FoldSeg i) (cvtF f) (cvtE z)) a s + Fold1Seg i f a s -> embed2 repr (into (Fold1Seg i) (cvtF f)) a s + Scanl f z a -> embed repr (into2 Scanl (cvtF f) (cvtE z)) a + Scanl1 f a -> embed repr (into Scanl1 (cvtF f)) a + Scanl' f z a -> embed repr (into2 Scanl' (cvtF f) (cvtE z)) a + Scanr f z a -> embed repr (into2 Scanr (cvtF f) (cvtE z)) a + Scanr1 f a -> embed repr (into Scanr1 (cvtF f)) a + Scanr' f z a -> embed repr (into2 Scanr' (cvtF f) (cvtE z)) a + Permute f d p a -> embed2 repr (into2 permute (cvtF f) (cvtF p)) d a + Stencil s t f x a -> embed repr (into2 (stencil1 s t) (cvtF f) (cvtB x)) a + Stencil2 s1 s2 t f x a y b + -> embed2 repr (into3 (stencil2 s1 s2 t) (cvtF f) (cvtB x) (cvtB y)) a b where + repr = arraysRepr pacc -- If fusion is not enabled, force terms to the manifest representation -- unembed :: Embed acc aenv arrs -> Embed acc aenv arrs @@ -487,8 +424,8 @@ embedPreAcc config embedAcc elimAcc pacc = case extractArrayVars $ inject pacc of Just vars -> Embed env $ Done vars _ - | DeclareArrays lhs _ value <- declareArrays (arraysRepr pacc) - -> Embed (PushEnv env lhs $ inject pacc) $ Done $ value id + | DeclareVars lhs _ value <- declareVars (arraysRepr pacc) + -> Embed (PushEnv env lhs $ inject pacc) $ Done $ value weakenId cvtA :: acc aenv' a -> acc aenv' a cvtA = computeAcc . embedAcc @@ -509,19 +446,19 @@ embedPreAcc config embedAcc elimAcc pacc -- when this duplication is beneficial (keeping in mind that the stencil -- implementations themselves may share neighbouring elements). -- - stencil1 f x a = Stencil f x a - stencil2 f x y a b = Stencil2 f x a y b + stencil1 s t f x a = Stencil s t f x a + stencil2 s1 s2 t f x y a b = Stencil2 s1 s2 t f x a y b -- Conversions for closed scalar functions and expressions. This just -- applies scalar simplifications. -- - cvtF :: PreFun acc aenv' t -> PreFun acc aenv' t + cvtF :: Fun aenv' t -> Fun aenv' t cvtF = simplify - cvtE :: Elt t => PreExp acc aenv' t -> PreExp acc aenv' t + cvtE :: Exp aenv' t -> Exp aenv' t cvtE = simplify - cvtB :: PreBoundary acc aenv' t -> PreBoundary acc aenv' t + cvtB :: Boundary aenv' t -> Boundary aenv' t cvtB Clamp = Clamp cvtB Mirror = Mirror cvtB Wrap = Wrap @@ -530,36 +467,36 @@ embedPreAcc config embedAcc elimAcc pacc -- Helpers to embed and fuse delayed terms -- - into :: Sink f => (f env' a -> b) -> f env a -> Extend acc env env' -> b - into op a env = op (sink env a) + into :: Sink f => (f env' a -> b) -> f env a -> Extend ArrayR acc env env' -> b + into op a env = op (sinkA env a) into2 :: (Sink f1, Sink f2) - => (f1 env' a -> f2 env' b -> c) -> f1 env a -> f2 env b -> Extend acc env env' -> c - into2 op a b env = op (sink env a) (sink env b) + => (f1 env' a -> f2 env' b -> c) -> f1 env a -> f2 env b -> Extend ArrayR acc env env' -> c + into2 op a b env = op (sinkA env a) (sinkA env b) into3 :: (Sink f1, Sink f2, Sink f3) - => (f1 env' a -> f2 env' b -> f3 env' c -> d) -> f1 env a -> f2 env b -> f3 env c -> Extend acc env env' -> d - into3 op a b c env = op (sink env a) (sink env b) (sink env c) + => (f1 env' a -> f2 env' b -> f3 env' c -> d) -> f1 env a -> f2 env b -> f3 env c -> Extend ArrayR acc env env' -> d + into3 op a b c env = op (sinkA env a) (sinkA env b) (sinkA env c) -- Operations which can be fused into consumers. Move all of the local -- bindings out of the way so that the fusible function operates -- directly on the delayed representation. See also: [Representing -- delayed arrays] -- - fuse :: (forall aenv'. Extend acc aenv aenv' -> Cunctation acc aenv' as -> Cunctation acc aenv' bs) + fuse :: (forall aenv'. Extend ArrayR acc aenv aenv' -> Cunctation aenv' as -> Cunctation aenv' bs) -> acc aenv as -> Embed acc aenv bs fuse op (embedAcc -> Embed env cc) = Embed env (op env cc) - fuse2 :: (forall aenv'. Extend acc aenv aenv' -> Cunctation acc aenv' as -> Cunctation acc aenv' bs -> Cunctation acc aenv' cs) + fuse2 :: (forall aenv'. Extend ArrayR acc aenv aenv' -> Cunctation aenv' as -> Cunctation aenv' bs -> Cunctation aenv' cs) -> acc aenv as -> acc aenv bs -> Embed acc aenv cs fuse2 op a1 a0 | Embed env1 cc1 <- embedAcc a1 - , Embed env0 cc0 <- embedAcc (sink env1 a0) + , Embed env0 cc0 <- embedAcc (sinkA env1 a0) , env <- env1 `append` env0 - = Embed env (op env (sink env0 cc1) cc0) + = Embed env (op env (sinkA env0 cc1) cc0) -- Consumer operations which will be evaluated. -- @@ -593,42 +530,42 @@ embedPreAcc config embedAcc elimAcc pacc -- update the array of default values. -- embed :: ArraysR bs - -> (forall aenv'. Extend acc aenv aenv' -> acc aenv' as -> PreOpenAcc acc aenv' bs) + -> (forall aenv'. Extend ArrayR acc aenv aenv' -> acc aenv' as -> PreOpenAcc acc aenv' bs) -> acc aenv as -> Embed acc aenv bs embed reprBs op (embedAcc -> Embed env cc) | Done{} <- cc - , DeclareArrays lhs _ value <- declareArrays reprBs - = Embed (PushEnv BaseEnv lhs $ inject (op BaseEnv (computeAcc (Embed env cc)))) $ Done $ value id + , DeclareVars lhs _ value <- declareVars reprBs + = Embed (PushEnv BaseEnv lhs $ inject (op BaseEnv (computeAcc (Embed env cc)))) $ Done $ value weakenId | otherwise -- Next line is duplicated for both branches, as the type variable for the environment is instantiated differently - , DeclareArrays lhs _ value <- declareArrays reprBs - = Embed (PushEnv env lhs $ inject (op env (inject (compute cc)))) $ Done $ value id + , DeclareVars lhs _ value <- declareVars reprBs + = Embed (PushEnv env lhs $ inject (op env (inject (compute cc)))) $ Done $ value weakenId embed2 :: ArraysR cs - -> (forall aenv'. Extend acc aenv aenv' -> acc aenv' as -> acc aenv' bs -> PreOpenAcc acc aenv' cs) + -> (forall aenv'. Extend ArrayR acc aenv aenv' -> acc aenv' as -> acc aenv' bs -> PreOpenAcc acc aenv' cs) -> acc aenv as -> acc aenv bs -> Embed acc aenv cs embed2 reprCs op (embedAcc -> Embed env1 cc1) a0 | Done{} <- cc1 , a1 <- computeAcc (Embed env1 cc1) - = embed reprCs (\env0 -> op env0 (sink env0 a1)) a0 + = embed reprCs (\env0 -> op env0 (sinkA env0 a1)) a0 -- - | Embed env0 cc0 <- embedAcc (sink env1 a0) + | Embed env0 cc0 <- embedAcc (sinkA env1 a0) , env <- env1 `append` env0 = case cc0 of Done{} - | DeclareArrays lhs _ value <- declareArrays reprCs - -> Embed (PushEnv env1 lhs $ inject (op env1 (inject (compute cc1)) (computeAcc (Embed env0 cc0)))) $ Done $ value id + | DeclareVars lhs _ value <- declareVars reprCs + -> Embed (PushEnv env1 lhs $ inject (op env1 (inject (compute cc1)) (computeAcc (Embed env0 cc0)))) $ Done $ value weakenId _ -- Next line is duplicated for both branches, as the type variable for the environment is instantiated differently - | DeclareArrays lhs _ value <- declareArrays reprCs - -> Embed (PushEnv env lhs $ inject (op env (inject (compute (sink env0 cc1))) (inject (compute cc0)))) $ Done $ value id + | DeclareVars lhs _ value <- declareVars reprCs + -> Embed (PushEnv env lhs $ inject (op env (inject (compute (sinkA env0 cc1))) (inject (compute cc0)))) $ Done $ value weakenId -- trav1 :: (Arrays as, Arrays bs) -- => (forall aenv'. Embed acc aenv' as -> Embed acc aenv' as) - -- -> (forall aenv'. Extend acc aenv aenv' -> acc aenv' as -> PreOpenAcc acc aenv' bs) + -- -> (forall aenv'. Extend ArrayR acc aenv aenv' -> acc aenv' as -> PreOpenAcc acc aenv' bs) -- -> acc aenv as -- -> Embed acc aenv bs -- trav1 f op (f . embedAcc -> Embed env cc) @@ -637,13 +574,13 @@ embedPreAcc config embedAcc elimAcc pacc -- trav2 :: (Arrays as, Arrays bs, Arrays cs) -- => (forall aenv'. Embed acc aenv' as -> Embed acc aenv' as) -- -> (forall aenv'. Embed acc aenv' bs -> Embed acc aenv' bs) - -- -> (forall aenv'. Extend acc aenv aenv' -> acc aenv' as -> acc aenv' bs -> PreOpenAcc acc aenv' cs) + -- -> (forall aenv'. Extend ArrayR acc aenv aenv' -> acc aenv' as -> acc aenv' bs -> PreOpenAcc acc aenv' cs) -- -> acc aenv as -- -> acc aenv bs -- -> Embed acc aenv cs - -- trav2 f1 f0 op (f1 . embedAcc -> Embed env1 cc1) (f0 . embedAcc . sink env1 -> Embed env0 cc0) + -- trav2 f1 f0 op (f1 . embedAcc -> Embed env1 cc1) (f0 . embedAcc . sinkA env1 -> Embed env0 cc0) -- | env <- env1 `append` env0 - -- , acc1 <- inject . compute $ sink env0 cc1 + -- , acc1 <- inject . compute $ sinkA env0 cc1 -- , acc0 <- inject . compute $ cc0 -- = Embed (env `pushArrayEnv` inject (op env acc1 acc0)) doneZeroIdx @@ -712,10 +649,10 @@ embedSeq embedAcc s cvtCT NilAtup = NilAtup cvtCT (SnocAtup t c) = SnocAtup (cvtCT t) (travC c env) - cvtE :: Elt t => PreExp acc aenv' t -> PreExp acc aenv' t + cvtE :: Elt t => Exp aenv' t -> Exp aenv' t cvtE = simplify - cvtF :: PreFun acc aenv' t -> PreFun acc aenv' t + cvtF :: Fun aenv' t -> Fun aenv' t cvtF = simplify cvtA :: Arrays a => acc aenv' a -> acc aenv' a @@ -759,7 +696,7 @@ data ExtendProducer acc aenv senv arrs where -- are defined with respect to this existentially quantified type, and there is -- no way to directly combine these two environments: -- --- append :: Extend env env1 -> Extend env env2 -> Extend env ??? +-- append :: Extend ArrayR env env1 -> Extend ArrayR env env2 -> Extend ArrayR env ??? -- -- And hence, no way to combine the terms of the delayed representation. -- @@ -773,10 +710,12 @@ data ExtendProducer acc aenv senv arrs where -- number of different rules we have for combining terms. -- data Embed acc aenv a where - Embed :: Extend acc aenv aenv' - -> Cunctation acc aenv' a - -> Embed acc aenv a + Embed :: Extend ArrayR acc aenv aenv' + -> Cunctation aenv' a + -> Embed acc aenv a +instance HasArraysRepr acc => HasArraysRepr (Embed acc) where + arraysRepr (Embed _ c) = arraysRepr c -- Cunctation (n): the action or an instance of delaying; a tardy action. -- @@ -786,23 +725,23 @@ data Embed acc aenv a where -- element at each index, and fusing successive producers by combining these -- scalar functions. -- -data Cunctation acc aenv a where +data Cunctation aenv a where -- The base case is just a real (manifest) array term. No fusion happens here. -- Note that the array is referenced by an index into the extended -- environment, ensuring that the array is manifest and making the term -- non-recursive in 'acc'. -- - Done :: ArrayVars aenv arrs - -> Cunctation acc aenv arrs + Done :: ArrayVars aenv arrs + -> Cunctation aenv arrs -- We can represent an array by its shape and a function to compute an element -- at each index. -- - Yield :: (Shape sh, Elt e) - => PreExp acc aenv sh - -> PreFun acc aenv (sh -> e) - -> Cunctation acc aenv (Array sh e) + Yield :: ArrayR (Array sh e) + -> Exp aenv sh + -> Fun aenv (sh -> e) + -> Cunctation aenv (Array sh e) -- A more restrictive form than 'Yield' may afford greater opportunities for -- optimisation by a backend. This more structured form applies an index and @@ -810,76 +749,80 @@ data Cunctation acc aenv a where -- array stored as an environment index, so that the term is non-recursive and -- it is always possible to embed into a collective operation. -- - Step :: (Shape sh, Shape sh', Elt a, Elt b) - => PreExp acc aenv sh' - -> PreFun acc aenv (sh' -> sh) - -> PreFun acc aenv (a -> b) - -> ArrayVar aenv (Array sh a) - -> Cunctation acc aenv (Array sh' b) - -instance Kit acc => Simplify (Cunctation acc aenv a) where + Step :: ArrayR (Array sh' b) + -> Exp aenv sh' + -> Fun aenv (sh' -> sh) + -> Fun aenv (a -> b) + -> ArrayVar aenv (Array sh a) + -> Cunctation aenv (Array sh' b) + +instance Simplify (Cunctation aenv a) where simplify = \case Done v -> Done v - Yield (simplify -> sh) (simplify -> f) -> Yield sh f - Step (simplify -> sh) (simplify -> p) (simplify -> f) v + Yield repr (simplify -> sh) (simplify -> f) -> Yield repr sh f + Step repr (simplify -> sh) (simplify -> p) (simplify -> f) v | Just Refl <- match sh (arrayShape v) , Just Refl <- isIdentity p - , Just Refl <- isIdentity f -> Done $ ArrayVarsArray v - | otherwise -> Step sh p f v + , Just Refl <- isIdentity f -> Done $ VarsSingle v + | otherwise -> Step repr sh p f v +instance HasArraysRepr Cunctation where + arraysRepr (Done v) = varsType v + arraysRepr (Yield repr _ _) = TupRsingle repr + arraysRepr (Step repr _ _ _ _) = TupRsingle repr -- Convert a real AST node into the internal representation -- done :: Kit acc => PreOpenAcc acc aenv a -> Embed acc aenv a done pacc | Just vars <- extractArrayVars $ inject pacc = Embed BaseEnv (Done vars) - | otherwise = case declareArrays (arraysRepr pacc) of - DeclareArrays lhs _ value -> Embed (PushEnv BaseEnv lhs $ inject pacc) $ Done $ value id + | otherwise = case declareVars (arraysRepr pacc) of + DeclareVars lhs _ value -> Embed (PushEnv BaseEnv lhs $ inject pacc) $ Done $ value weakenId -doneZeroIdx :: (Shape sh, Elt e) => Cunctation acc (aenv, Array sh e) (Array sh e) -doneZeroIdx = Done $ ArrayVarsArray $ ArrayVar ZeroIdx +doneZeroIdx :: ArrayR (Array sh e) -> Cunctation (aenv, Array sh e) (Array sh e) +doneZeroIdx repr = Done $ VarsSingle $ Var repr ZeroIdx -- Recast a cunctation into a mapping from indices to elements. -- -yield :: Kit acc - => Cunctation acc aenv (Array sh e) - -> Cunctation acc aenv (Array sh e) +yield :: Cunctation aenv (Array sh e) + -> Cunctation aenv (Array sh e) yield cc = case cc of - Yield{} -> cc - Step sh p f v -> Yield sh (f `compose` indexArray v `compose` p) - Done (ArrayVarsArray v@ArrayVar{}) -> Yield (arrayShape v) (indexArray v) + Yield{} -> cc + Step repr sh p f v -> Yield repr sh (f `compose` indexArray v `compose` p) + Done (VarsSingle v@(Var repr _)) + -> Yield repr (arrayShape v) (indexArray v) -- Recast a cunctation into transformation step form. Not possible if the source -- was in the Yield formulation. -- -step :: Kit acc - => Cunctation acc aenv (Array sh e) - -> Maybe (Cunctation acc aenv (Array sh e)) +step :: Cunctation aenv (Array sh e) + -> Maybe (Cunctation aenv (Array sh e)) step cc = case cc of - Yield{} -> Nothing - Step{} -> Just cc - Done (ArrayVarsArray v@ArrayVar{}) -> Just $ Step (arrayShape v) identity identity v + Yield{} -> Nothing + Step{} -> Just cc + Done (VarsSingle v@(Var repr@(ArrayR shr tp) _)) + -> Just $ Step repr (arrayShape v) (identity $ shapeType shr) (identity tp) v -- Get the shape of a delayed array -- -shape :: Kit acc => Cunctation acc aenv (Array sh e) -> PreExp acc aenv sh +shape :: Cunctation aenv (Array sh e) -> Exp aenv sh shape cc - | Just (Step sh _ _ _) <- step cc = sh - | Yield sh _ <- yield cc = sh + | Just (Step _ sh _ _ _) <- step cc = sh + | Yield _ sh _ <- yield cc = sh -- Environment manipulation -- ======================== -instance Kit acc => Sink (Cunctation acc) where +instance Sink Cunctation where weaken k = \case Done v -> Done (weaken k v) - Step sh p f v -> Step (weaken k sh) (weaken k p) (weaken k f) (weaken k v) - Yield sh f -> Yield (weaken k sh) (weaken k f) + Step repr sh p f v -> Step repr (weaken k sh) (weaken k p) (weaken k f) (weaken k v) + Yield repr sh f -> Yield repr (weaken k sh) (weaken k f) -- prjExtend :: Kit acc => Extend acc env env' -> Idx env' t -> PreOpenAcc acc env' t -- prjExtend (PushEnv _ v) ZeroIdx = weakenA rebuildAcc SuccIdx v @@ -944,47 +887,50 @@ computeAcc :: Kit acc => Embed acc aenv arrs -> acc aenv arrs computeAcc (Embed BaseEnv cc) = inject (compute cc) computeAcc (Embed env@(PushEnv bot lhs top) cc) = case simplify cc of - Done v -> bindA env (avarsIn v) - Yield sh f -> bindA env (inject (Generate sh f)) - Step sh p f v@(ArrayVar ix) + Done v -> bindA env (avarsIn v) + Yield repr sh f -> bindA env (inject (Generate repr sh f)) + Step repr sh p f v@(Var _ ix) | Just Refl <- match sh (arrayShape v) , Just Refl <- isIdentity p -> case ix of ZeroIdx - | LeftHandSideArray <- lhs - , Just g <- strengthen noTop f -> bindA bot (inject (Map g top)) - _ -> bindA env (inject (Map f (avarIn v))) + | LeftHandSideSingle ArrayR{} <- lhs + , Just (OpenAccFun g) <- strengthen noTop (OpenAccFun f) + -> bindA bot (inject (Map (arrayRtype repr) g top)) + _ -> bindA env (inject (Map (arrayRtype repr) f (avarIn v))) | Just Refl <- isIdentity f -> case ix of ZeroIdx - | LeftHandSideArray <- lhs - , Just q <- strengthen noTop p - , Just sz <- strengthen noTop sh -> bindA bot (inject (Backpermute sz q top)) - _ -> bindA env (inject (Backpermute sh p (avarIn v))) + | LeftHandSideSingle ArrayR{} <- lhs + , Just (OpenAccFun q) <- strengthen noTop (OpenAccFun p) + , Just (OpenAccExp sz) <- strengthen noTop (OpenAccExp sh) + -> bindA bot (inject (Backpermute (arrayRshape repr) sz q top)) + _ -> bindA env (inject (Backpermute (arrayRshape repr) sh p (avarIn v))) | otherwise -> case ix of ZeroIdx - | LeftHandSideArray <- lhs - , Just g <- strengthen noTop f - , Just q <- strengthen noTop p - , Just sz <- strengthen noTop sh -> bindA bot (inject (Transform sz q g top)) - _ -> bindA env (inject (Transform sh p f (avarIn v))) + | LeftHandSideSingle ArrayR{} <- lhs + , Just (OpenAccFun g) <- strengthen noTop (OpenAccFun f) + , Just (OpenAccFun q) <- strengthen noTop (OpenAccFun p) + , Just (OpenAccExp sz) <- strengthen noTop (OpenAccExp sh) + -> bindA bot (inject (Transform repr sz q g top)) + _ -> bindA env (inject (Transform repr sh p f (avarIn v))) where bindA :: Kit acc - => Extend acc aenv aenv' - -> acc aenv' a - -> acc aenv a + => Extend ArrayR acc aenv aenv' + -> acc aenv' a + -> acc aenv a bindA BaseEnv b = b - bindA (PushEnv env lhs a) b = + bindA (PushEnv env lhs a) b -- If the freshly bound value is directly, returned, we don't have to bind it in a -- let. We can do this if the left hand side does not contain wildcards (other than -- wildcards for unit / nil) and if the value contains the same variables. - case aletBodyIsTrivial lhs b of - Just Refl -> bindA env a - Nothing -> bindA env (inject (Alet lhs a b)) + | Just vars <- extractArrayVars b + , Just Refl <- bindingIsTrivial lhs vars = bindA env a + | otherwise = bindA env (inject (Alet lhs a b)) noTop :: (aenv, a) :?> aenv noTop ZeroIdx = Nothing @@ -994,126 +940,103 @@ computeAcc (Embed env@(PushEnv bot lhs top) cc) = -- Convert the internal representation of delayed arrays into a real AST -- node. Use the most specific version of a combinator whenever possible. -- -compute :: Kit acc => Cunctation acc aenv arrs -> PreOpenAcc acc aenv arrs +compute :: Kit acc => Cunctation aenv arrs -> PreOpenAcc acc aenv arrs compute cc = case simplify cc of - Done ArrayVarsNil -> Anil - Done (ArrayVarsArray v@ArrayVar{}) -> Avar v - Done (ArrayVarsPair v1 v2) -> avarsIn v1 `Apair` avarsIn v2 - Yield sh f -> Generate sh f - Step sh p f v + Done VarsNil -> Anil + Done (VarsSingle v@(Var ArrayR{} _)) -> Avar v + Done (VarsPair v1 v2) -> avarsIn v1 `Apair` avarsIn v2 + Yield repr sh f -> Generate repr sh f + Step (ArrayR shr tp) sh p f v | Just Refl <- match sh (arrayShape v) - , Just Refl <- isIdentity p -> Map f (avarIn v) - | Just Refl <- isIdentity f -> Backpermute sh p (avarIn v) - | otherwise -> Transform sh p f (avarIn v) + , Just Refl <- isIdentity p -> Map tp f (avarIn v) + | Just Refl <- isIdentity f -> Backpermute shr sh p (avarIn v) + | otherwise -> Transform (ArrayR shr tp) sh p f (avarIn v) -- Representation of a generator as a delayed array -- -generateD :: (Shape sh, Elt e) - => PreExp acc aenv sh - -> PreFun acc aenv (sh -> e) +generateD :: ArrayR (Array sh e) + -> Exp aenv sh + -> Fun aenv (sh -> e) -> Embed acc aenv (Array sh e) -generateD sh f +generateD repr sh f = Stats.ruleFired "generateD" - $ Embed BaseEnv (Yield sh f) + $ Embed BaseEnv (Yield repr sh f) -- Fuse a unary function into a delayed array. Also looks for unzips which can -- be executed in constant time; SEE [unzipD] -- -mapD :: (Kit acc, Shape sh, Elt a, Elt b) - => PreFun acc aenv (a -> b) +mapD :: Kit acc + => TupleType b + -> Fun aenv (a -> b) -> Embed acc aenv (Array sh a) -> Embed acc aenv (Array sh b) -mapD f (unzipD f -> Just a) = a -mapD f (Embed env cc) +mapD tp f (unzipD tp f -> Just a) = a +mapD tp f (Embed env cc) = Stats.ruleFired "mapD" $ Embed env (go cc) where - go (step -> Just (Step sh ix g v)) = Step sh ix (sink env f `compose` g) v - go (yield -> Yield sh g) = Yield sh (sink env f `compose` g) + go (step -> Just (Step (ArrayR shr _) sh ix g v)) = Step (ArrayR shr tp) sh ix (sinkA env f `compose` g) v + go (yield -> Yield (ArrayR shr _) sh g) = Yield (ArrayR shr tp) sh (sinkA env f `compose` g) -- If we are unzipping a manifest array then force the term to be computed; --- a backend will be able to execute this in constant time. This operations --- looks for the right terms recursively, splitting operations such as: --- --- map (\x -> fst . fst ... x) arr --- --- into multiple stages so that they can all be executed in constant time: --- --- map fst . map fst ... arr --- --- Note that this is a speculative operation, since we could dig under several --- levels of projection before discovering that the operation can not be --- unzipped. This should be fine though because digging through the terms is --- cheap; no environment changing operations are required. +-- a backend will be able to execute this in constant time. -- unzipD - :: forall acc aenv sh a b. (Kit acc, Shape sh, Elt a, Elt b) - => PreFun acc aenv (a -> b) + :: Kit acc + => TupleType b + -> Fun aenv (a -> b) -> Embed acc aenv (Array sh a) -> Maybe (Embed acc aenv (Array sh b)) -unzipD f (Embed env (Done v)) - | TypeRscalar VectorScalarType{} <- eltType @a - = Nothing +unzipD tp f (Embed env cc@(Done v)) + | Lam lhs (Body a) <- f + , Just vars <- extractExpVars a + , ArrayR shr _ <- arrayRepr cc + , f' <- Lam lhs $ Body $ evars vars = Just $ Embed (env `pushArrayEnv` inject (Map tp f' $ avarsIn v)) $ doneZeroIdx $ ArrayR shr tp - | Lam (Body (Prj tix (Var ZeroIdx))) <- f - = Stats.ruleFired "unzipD" - $ let f' = Lam (Body (Prj tix (Var ZeroIdx))) - a' = avarsIn v - in - Just $ Embed (env `pushArrayEnv` inject (Map f' a')) doneZeroIdx - - | Lam (Body (Prj tix p@Prj{})) <- f - , Just (Embed env' (Done v')) <- unzipD (Lam (Body p)) (Embed env (Done v)) - = Stats.ruleFired "unzipD" - $ let f' = Lam (Body (Prj tix (Var ZeroIdx))) - a' = avarsIn v' - in - Just $ Embed (env' `pushArrayEnv` inject (Map f' a')) doneZeroIdx - -unzipD _ _ +unzipD _ _ _ = Nothing - -- Fuse an index space transformation function that specifies where elements in -- the destination array read there data from in the source array. -- backpermuteD - :: (Kit acc, Shape sh') - => PreExp acc aenv sh' - -> PreFun acc aenv (sh' -> sh) - -> Cunctation acc aenv (Array sh e) - -> Cunctation acc aenv (Array sh' e) -backpermuteD sh' p = Stats.ruleFired "backpermuteD" . go + :: ShapeR sh' + -> Exp aenv sh' + -> Fun aenv (sh' -> sh) + -> Cunctation aenv (Array sh e) + -> Cunctation aenv (Array sh' e) +backpermuteD shr' sh' p = Stats.ruleFired "backpermuteD" . go where - go (step -> Just (Step _ q f v)) = Step sh' (q `compose` p) f v - go (yield -> Yield _ g) = Yield sh' (g `compose` p) + go (step -> Just (Step (ArrayR _ tp) _ q f v)) = Step (ArrayR shr' tp) sh' (q `compose` p) f v + go (yield -> Yield (ArrayR _ tp) _ g) = Yield (ArrayR shr' tp) sh' (g `compose` p) -- Transform as a combined map and backwards permutation -- transformD - :: (Kit acc, Shape sh, Shape sh', Elt a, Elt b) - => PreExp acc aenv sh' - -> PreFun acc aenv (sh' -> sh) - -> PreFun acc aenv (a -> b) + :: Kit acc + => ArrayR (Array sh' b) + -> Exp aenv sh' + -> Fun aenv (sh' -> sh) + -> Fun aenv (a -> b) -> Embed acc aenv (Array sh a) -> Embed acc aenv (Array sh' b) -transformD sh' p f +transformD (ArrayR shr' tp) sh' p f = Stats.ruleFired "transformD" - . fuse (into2 backpermuteD sh' p) - . mapD f + . fuse (into2 (backpermuteD shr') sh' p) + . mapD tp f where - fuse :: (forall aenv'. Extend acc aenv aenv' -> Cunctation acc aenv' as -> Cunctation acc aenv' bs) + fuse :: (forall aenv'. Extend ArrayR acc aenv aenv' -> Cunctation aenv' as -> Cunctation aenv' bs) -> Embed acc aenv as -> Embed acc aenv bs fuse op (Embed env cc) = Embed env (op env cc) into2 :: (Sink f1, Sink f2) - => (f1 env' a -> f2 env' b -> c) -> f1 env a -> f2 env b -> Extend acc env env' -> c - into2 op a b env = op (sink env a) (sink env b) + => (f1 env' a -> f2 env' b -> c) -> f1 env a -> f2 env b -> Extend ArrayR acc env env' -> c + into2 op a b env = op (sinkA env a) (sinkA env b) -- Replicate as a backwards permutation @@ -1123,27 +1046,25 @@ transformD sh' p f -- expensive and/or `sh` is large. -- replicateD - :: (Kit acc, Shape sh, Shape sl, Elt slix) - => SliceIndex (EltRepr slix) (EltRepr sl) co (EltRepr sh) - -> PreExp acc aenv slix - -> Cunctation acc aenv (Array sl e) - -> Cunctation acc aenv (Array sh e) + :: SliceIndex slix sl co sh + -> Exp aenv slix + -> Cunctation aenv (Array sl e) + -> Cunctation aenv (Array sh e) replicateD sliceIndex slix cc = Stats.ruleFired "replicateD" - $ backpermuteD (IndexFull sliceIndex slix (shape cc)) (extend sliceIndex slix) cc + $ backpermuteD (sliceDomainR sliceIndex) (IndexFull sliceIndex slix (shape cc)) (extend sliceIndex slix) cc -- Dimensional slice as a backwards permutation -- sliceD - :: (Kit acc, Shape sh, Shape sl, Elt slix) - => SliceIndex (EltRepr slix) (EltRepr sl) co (EltRepr sh) - -> PreExp acc aenv slix - -> Cunctation acc aenv (Array sh e) - -> Cunctation acc aenv (Array sl e) + :: SliceIndex slix sl co sh + -> Exp aenv slix + -> Cunctation aenv (Array sh e) + -> Cunctation aenv (Array sl e) sliceD sliceIndex slix cc = Stats.ruleFired "sliceD" - $ backpermuteD (IndexSlice sliceIndex slix (shape cc)) (restrict sliceIndex slix) cc + $ backpermuteD (sliceShapeR sliceIndex) (IndexSlice sliceIndex slix (shape cc)) (restrict sliceIndex slix) cc -- Reshape an array @@ -1157,58 +1078,90 @@ sliceD sliceIndex slix cc -- same number of elements: this has been lost for the delayed cases! -- reshapeD - :: (Kit acc, Shape sh, Shape sl, Elt e) - => Embed acc aenv (Array sh e) - -> PreExp acc aenv sl + :: Kit acc + => ShapeR sl + -> Embed acc aenv (Array sh e) + -> Exp aenv sl -> Embed acc aenv (Array sl e) -reshapeD (Embed env cc) (sink env -> sl) +reshapeD slr (Embed env cc) (sinkA env -> sl) | Done v <- cc - = Embed (env `pushArrayEnv` inject (Reshape sl (avarsIn v))) doneZeroIdx + = Embed (env `pushArrayEnv` inject (Reshape slr sl (avarsIn v))) $ doneZeroIdx repr | otherwise = Stats.ruleFired "reshapeD" - $ Embed env (backpermuteD sl (reindex (shape cc) sl) cc) + $ Embed env (backpermuteD slr sl (reindex (arrayRshape $ arrayRepr cc) (shape cc) slr sl) cc) + + where + ArrayR _ tp = arrayRepr cc + repr = ArrayR slr tp -- Combine two arrays element-wise with a binary function to produce a delayed -- array. -- -zipWithD :: (Kit acc, Shape sh, Elt a, Elt b, Elt c) - => PreFun acc aenv (a -> b -> c) - -> Cunctation acc aenv (Array sh a) - -> Cunctation acc aenv (Array sh b) - -> Cunctation acc aenv (Array sh c) -zipWithD f cc1 cc0 +zipWithD :: TupleType c + -> Fun aenv (a -> b -> c) + -> Cunctation aenv (Array sh a) + -> Cunctation aenv (Array sh b) + -> Cunctation aenv (Array sh c) +zipWithD tp f cc1 cc0 -- Two stepper functions identically accessing the same array can be kept in -- stepping form. This might yield a simpler final term. -- - | Just (Step sh1 p1 f1 v1) <- step cc1 - , Just (Step sh0 p0 f0 v0) <- step cc0 + | Just (Step (ArrayR shr _) sh1 p1 f1 v1) <- step cc1 + , Just (Step _ sh0 p0 f0 v0) <- step cc0 , Just Refl <- match v1 v0 , Just Refl <- match p1 p0 = Stats.ruleFired "zipWithD/step" - $ Step (sh1 `Intersect` sh0) p0 (combine f f1 f0) v0 + $ Step (ArrayR shr tp) (mkIntersect shr sh1 sh0) p0 (combine f f1 f0) v0 -- Otherwise transform both delayed terms into (index -> value) mappings and -- combine the two indexing functions that way. -- - | Yield sh1 f1 <- yield cc1 - , Yield sh0 f0 <- yield cc0 + | Yield (ArrayR shr _) sh1 f1 <- yield cc1 + , Yield _ sh0 f0 <- yield cc0 = Stats.ruleFired "zipWithD" - $ Yield (sh1 `Intersect` sh0) (combine f f1 f0) + $ Yield (ArrayR shr tp) (mkIntersect shr sh1 sh0) (combine f f1 f0) where - combine :: forall acc aenv a b c e. (Kit acc, Elt a, Elt b, Elt c) - => PreFun acc aenv (a -> b -> c) - -> PreFun acc aenv (e -> a) - -> PreFun acc aenv (e -> b) - -> PreFun acc aenv (e -> c) + combine :: forall aenv a b c e. + Fun aenv (a -> b -> c) + -> Fun aenv (e -> a) + -> Fun aenv (e -> b) + -> Fun aenv (e -> c) combine c ixa ixb - | Lam (Lam (Body c')) <- weakenE SuccIdx c :: PreOpenFun acc ((),e) aenv (a -> b -> c) - , Lam (Body ixa') <- ixa -- else the skolem 'e' will escape - , Lam (Body ixb') <- ixb - = Lam $ Body $ Let ixa' $ Let (weakenE SuccIdx ixb') c' - + | Lam lhs1 (Body ixa') <- ixa -- else the skolem 'e' will escape + , Lam lhs2 (Body ixb') <- ixb + -- The two LeftHandSides may differ in the use of wildcards. If they do not match, we must + -- combine them as done in `combineLhs`. As this will probably not occur often and requires + -- additional weakening, we do a quick check whether the left hand sides are equal. + -- + = case matchELeftHandSide lhs1 lhs2 of + Just Refl + | Lam lhsA (Lam lhsB (Body c')) <- weakenE (weakenWithLHS lhs1) c + -> Lam lhs1 $ Body $ Let lhsA ixa' $ Let lhsB (weakenE (weakenWithLHS lhsA) ixb') c' + Nothing + | CombinedLHS lhs k1 k2 <- combineLhs lhs1 lhs2 + , Lam lhsA (Lam lhsB (Body c')) <- weakenE (weakenWithLHS lhs) c + , ixa'' <- weakenE k1 ixa' + -> Lam lhs $ Body $ Let lhsA ixa'' $ Let lhsB (weakenE (weakenWithLHS lhsA .> k2) ixb') c' + +combineLhs :: LeftHandSide s t env env1' -> LeftHandSide s t env env2' -> CombinedLHS s t env1' env2' env +combineLhs = go weakenId weakenId + where + go :: env1 :> env -> env2 :> env -> LeftHandSide s t env1 env1' -> LeftHandSide s t env2 env2' -> CombinedLHS s t env1' env2' env + go k1 k2 (LeftHandSideWildcard tp) (LeftHandSideWildcard _) = CombinedLHS (LeftHandSideWildcard tp) k1 k2 + go k1 k2 (LeftHandSideSingle tp) (LeftHandSideSingle _) = CombinedLHS (LeftHandSideSingle tp) (sink k1) (sink k2) + go k1 k2 (LeftHandSidePair l1 h1) (LeftHandSidePair l2 h2) + | CombinedLHS l k1' k2' <- go k1 k2 l1 l2 + , CombinedLHS h k1'' k2'' <- go k1' k2' h1 h2 = CombinedLHS (LeftHandSidePair l h) k1'' k2'' + go k1 k2 (LeftHandSideWildcard _) lhs + | Exists lhs' <- rebuildLHS lhs = CombinedLHS lhs' (weakenWithLHS lhs' .> k1) (sinkWithLHS lhs lhs' k2) + go k1 k2 lhs (LeftHandSideWildcard _) + | Exists lhs' <- rebuildLHS lhs = CombinedLHS lhs' (sinkWithLHS lhs lhs' k1) (weakenWithLHS lhs' .> k2) + +data CombinedLHS s t env1' env2' env where + CombinedLHS :: LeftHandSide s t env env' -> env1' :> env' -> env2' :> env' -> CombinedLHS s t env1' env2' env -- NOTE: [Sharing vs. Fusion] -- @@ -1293,7 +1246,7 @@ zipWithD f cc1 cc0 aletD :: Kit acc => EmbedAcc acc -> ElimAcc acc - -> LeftHandSide arrs aenv aenv' + -> ALeftHandSide arrs aenv aenv' -> acc aenv arrs -> acc aenv' brrs -> Embed acc aenv brrs @@ -1306,8 +1259,8 @@ aletD embedAcc elimAcc lhs (embedAcc -> Embed env1 cc1) acc0 -- body, instead of adding to the environments and creating an indirection -- that must be later eliminated by shrinking. -- - | LeftHandSideArray <- lhs - , Done (ArrayVarsArray v1@ArrayVar{}) <- cc1 + | LeftHandSideSingle _ <- lhs + , Done (VarsSingle v1@(Var ArrayR{} _)) <- cc1 , Embed env0 cc0 <- embedAcc $ rebuildA (subAtop (Avar v1) . sink1 env1) acc0 = Stats.ruleFired "aletD/float" $ Embed (env1 `append` env0) cc0 @@ -1321,11 +1274,11 @@ aletD embedAcc elimAcc lhs (embedAcc -> Embed env1 cc1) acc0 aletD' :: forall acc aenv aenv' arrs brrs. Kit acc => EmbedAcc acc -> ElimAcc acc - -> LeftHandSide arrs aenv aenv' + -> ALeftHandSide arrs aenv aenv' -> Embed acc aenv arrs -> Embed acc aenv' brrs -> Embed acc aenv brrs -aletD' embedAcc elimAcc LeftHandSideArray (Embed env1 cc1) (Embed env0 cc0) +aletD' embedAcc elimAcc (LeftHandSideSingle ArrayR{}) (Embed env1 cc1) (Embed env0 cc0) -- let-binding -- ----------- @@ -1360,25 +1313,25 @@ aletD' embedAcc elimAcc LeftHandSideArray (Embed env1 cc1) (Embed env0 cc0) -- extra type variables, and ensures we don't do extra work manipulating the -- body when not necessary (which can lead to a complexity blowup). -- - eliminate :: forall aenv aenv' sh e brrs. (Shape sh, Elt e) - => Extend acc aenv aenv' - -> Cunctation acc aenv' (Array sh e) - -> acc (aenv', Array sh e) brrs - -> Embed acc aenv brrs + eliminate :: forall aenv aenv' sh e brrs. + Extend ArrayR acc aenv aenv' + -> Cunctation aenv' (Array sh e) + -> acc (aenv', Array sh e) brrs + -> Embed acc aenv brrs eliminate env1 cc1 body - | Done v1 <- cc1 - , ArrayVarsArray v1' <- v1 = elim (arrayShape v1') (indexArray v1') - | Step sh1 p1 f1 v1 <- cc1 = elim sh1 (f1 `compose` indexArray v1 `compose` p1) - | Yield sh1 f1 <- cc1 = elim sh1 f1 + | Done v1 <- cc1 + , VarsSingle v1'@(Var r _) <- v1 = elim r (arrayShape v1') (indexArray v1') + | Step r sh1 p1 f1 v1 <- cc1 = elim r sh1 (f1 `compose` indexArray v1 `compose` p1) + | Yield r sh1 f1 <- cc1 = elim r sh1 f1 where bnd :: PreOpenAcc acc aenv' (Array sh e) bnd = compute cc1 - elim :: PreExp acc aenv' sh -> PreFun acc aenv' (sh -> e) -> Embed acc aenv brrs - elim sh1 f1 - | sh1' <- weaken SuccIdx sh1 - , f1' <- weaken SuccIdx f1 - , Embed env0' cc0' <- embedAcc $ rebuildA (subAtop bnd) $ kmap (replaceA sh1' f1' $ ArrayVar ZeroIdx) body + elim :: ArrayR (Array sh e) -> Exp aenv' sh -> Fun aenv' (sh -> e) -> Embed acc aenv brrs + elim r sh1 f1 + | sh1' <- weaken (weakenSucc' weakenId) sh1 + , f1' <- weaken (weakenSucc' weakenId) f1 + , Embed env0' cc0' <- embedAcc $ rebuildA (subAtop bnd) $ kmap (replaceA sh1' f1' $ Var r ZeroIdx) body = Embed (env1 `append` env0') cc0' -- As part of let-elimination, we need to replace uses of array variables in @@ -1391,72 +1344,64 @@ aletD' embedAcc elimAcc LeftHandSideArray (Embed env1 cc1) (Embed env0 cc0) -- things, but that is limited in what it looks for. -- replaceE :: forall env aenv sh e t. - PreOpenExp acc env aenv sh -> PreOpenFun acc env aenv (sh -> e) -> ArrayVar aenv (Array sh e) - -> PreOpenExp acc env aenv t - -> PreOpenExp acc env aenv t - replaceE sh' f' avar exp = + OpenExp env aenv sh -> OpenFun env aenv (sh -> e) -> ArrayVar aenv (Array sh e) + -> OpenExp env aenv t + -> OpenExp env aenv t + replaceE sh' f' avar@(Var (ArrayR shr _) _) exp = case exp of - Let x y -> Let (cvtE x) (replaceE (weakenE SuccIdx sh') (weakenE SuccIdx f') avar y) - Var i -> Var i - Foreign ff f e -> Foreign ff f (cvtE e) - Const c -> Const c - Undef -> Undef - Tuple t -> Tuple (cvtT t) - Prj ix e -> Prj ix (cvtE e) - IndexNil -> IndexNil - IndexCons sl sz -> IndexCons (cvtE sl) (cvtE sz) - IndexHead sh -> IndexHead (cvtE sh) - IndexTail sz -> IndexTail (cvtE sz) - IndexAny -> IndexAny + Let lhs x y -> let k = weakenWithLHS lhs + in Let lhs (cvtE x) (replaceE (weakenE k sh') (weakenE k f') avar y) + Evar var -> Evar var + Foreign tp ff f e -> Foreign tp ff f (cvtE e) + Const tp c -> Const tp c + Undef tp -> Undef tp + Nil -> Nil + Pair e1 e2 -> Pair (cvtE e1) (cvtE e2) IndexSlice x ix sh -> IndexSlice x (cvtE ix) (cvtE sh) IndexFull x ix sl -> IndexFull x (cvtE ix) (cvtE sl) - ToIndex sh ix -> ToIndex (cvtE sh) (cvtE ix) - FromIndex sh i -> FromIndex (cvtE sh) (cvtE i) + ToIndex shr' sh ix -> ToIndex shr' (cvtE sh) (cvtE ix) + FromIndex shr' sh i -> FromIndex shr' (cvtE sh) (cvtE i) Cond p t e -> Cond (cvtE p) (cvtE t) (cvtE e) PrimConst c -> PrimConst c PrimApp g x -> PrimApp g (cvtE x) - ShapeSize sh -> ShapeSize (cvtE sh) - Intersect sh sl -> Intersect (cvtE sh) (cvtE sl) - Union s t -> Union (cvtE s) (cvtE t) + ShapeSize shr' sh -> ShapeSize shr' (cvtE sh) While p f x -> While (replaceF sh' f' avar p) (replaceF sh' f' avar f) (cvtE x) - Coerce e -> Coerce (cvtE e) + Coerce t1 t2 e -> Coerce t1 t2 (cvtE e) Shape a - | Just Refl <- match a a' -> Stats.substitution "replaceE/shape" sh' + | Just Refl <- match a avar -> Stats.substitution "replaceE/shape" sh' | otherwise -> exp Index a sh - | Just Refl <- match a a' - , Lam (Body b) <- f' -> Stats.substitution "replaceE/!" . cvtE $ Let sh b + | Just Refl <- match a avar + , Lam lhs (Body b) <- f' -> Stats.substitution "replaceE/!" . cvtE $ Let lhs sh b | otherwise -> Index a (cvtE sh) LinearIndex a i - | Just Refl <- match a a' - , Lam (Body b) <- f' -> Stats.substitution "replaceE/!!" . cvtE $ Let (Let i (FromIndex (weakenE SuccIdx sh') (Var ZeroIdx))) b + | Just Refl <- match a avar + , Lam lhs (Body b) <- f' + -> Stats.substitution "replaceE/!!" . cvtE + $ Let lhs + (Let (LeftHandSideSingle scalarTypeInt) i $ FromIndex shr (weakenE (weakenSucc' weakenId) sh') $ Evar $ Var scalarTypeInt ZeroIdx) + b | otherwise -> LinearIndex a (cvtE i) where - a' :: acc aenv (Array sh e) - a' = avarIn avar - - cvtE :: PreOpenExp acc env aenv s -> PreOpenExp acc env aenv s + cvtE :: OpenExp env aenv s -> OpenExp env aenv s cvtE = replaceE sh' f' avar - cvtT :: Tuple (PreOpenExp acc env aenv) s -> Tuple (PreOpenExp acc env aenv) s - cvtT NilTup = NilTup - cvtT (SnocTup t e) = cvtT t `SnocTup` cvtE e - replaceF :: forall env aenv sh e t. - PreOpenExp acc env aenv sh -> PreOpenFun acc env aenv (sh -> e) -> ArrayVar aenv (Array sh e) - -> PreOpenFun acc env aenv t - -> PreOpenFun acc env aenv t + OpenExp env aenv sh -> OpenFun env aenv (sh -> e) -> ArrayVar aenv (Array sh e) + -> OpenFun env aenv t + -> OpenFun env aenv t replaceF sh' f' avar fun = case fun of Body e -> Body (replaceE sh' f' avar e) - Lam f -> Lam (replaceF (weakenE SuccIdx sh') (weakenE SuccIdx f') avar f) + Lam lhs f -> let k = weakenWithLHS lhs + in Lam lhs (replaceF (weakenE k sh') (weakenE k f') avar f) replaceA :: forall aenv sh e a. - PreExp acc aenv sh -> PreFun acc aenv (sh -> e) -> ArrayVar aenv (Array sh e) + Exp aenv sh -> Fun aenv (sh -> e) -> ArrayVar aenv (Array sh e) -> PreOpenAcc acc aenv a -> PreOpenAcc acc aenv a replaceA sh' f' avar pacc = @@ -1473,26 +1418,26 @@ aletD' embedAcc elimAcc LeftHandSideArray (Embed env1 cc1) (Embed env0 cc0) in Alet lhs (cvtA bnd) (kmap (replaceA sh'' f'' (weaken w avar)) body) - Use arrs -> Use arrs - Unit e -> Unit (cvtE e) + Use repr arrs -> Use repr arrs + Unit tp e -> Unit tp (cvtE e) Acond p at ae -> Acond (cvtE p) (cvtA at) (cvtA ae) Anil -> Anil Apair a1 a2 -> Apair (cvtA a1) (cvtA a2) Awhile p f a -> Awhile (cvtAF p) (cvtAF f) (cvtA a) - Apply f a -> Apply (cvtAF f) (cvtA a) - Aforeign ff f a -> Aforeign ff f (cvtA a) -- no sharing between f and a - Generate sh f -> Generate (cvtE sh) (cvtF f) - Map f a -> Map (cvtF f) (cvtA a) - ZipWith f a b -> ZipWith (cvtF f) (cvtA a) (cvtA b) - Backpermute sh p a -> Backpermute (cvtE sh) (cvtF p) (cvtA a) - Transform sh p f a -> Transform (cvtE sh) (cvtF p) (cvtF f) (cvtA a) + Apply repr f a -> Apply repr (cvtAF f) (cvtA a) + Aforeign repr ff f a -> Aforeign repr ff f (cvtA a) -- no sharing between f and a + Generate repr sh f -> Generate repr (cvtE sh) (cvtF f) + Map tp f a -> Map tp (cvtF f) (cvtA a) + ZipWith tp f a b -> ZipWith tp (cvtF f) (cvtA a) (cvtA b) + Backpermute shr sh p a -> Backpermute shr (cvtE sh) (cvtF p) (cvtA a) + Transform repr sh p f a -> Transform repr (cvtE sh) (cvtF p) (cvtF f) (cvtA a) Slice slix a sl -> Slice slix (cvtA a) (cvtE sl) Replicate slix sh a -> Replicate slix (cvtE sh) (cvtA a) - Reshape sl a -> Reshape (cvtE sl) (cvtA a) + Reshape shr sl a -> Reshape shr (cvtE sl) (cvtA a) Fold f z a -> Fold (cvtF f) (cvtE z) (cvtA a) Fold1 f a -> Fold1 (cvtF f) (cvtA a) - FoldSeg f z a s -> FoldSeg (cvtF f) (cvtE z) (cvtA a) (cvtA s) - Fold1Seg f a s -> Fold1Seg (cvtF f) (cvtA a) (cvtA s) + FoldSeg i f z a s -> FoldSeg i (cvtF f) (cvtE z) (cvtA a) (cvtA s) + Fold1Seg i f a s -> Fold1Seg i (cvtF f) (cvtA a) (cvtA s) Scanl f z a -> Scanl (cvtF f) (cvtE z) (cvtA a) Scanl1 f a -> Scanl1 (cvtF f) (cvtA a) Scanl' f z a -> Scanl' (cvtF f) (cvtE z) (cvtA a) @@ -1500,21 +1445,22 @@ aletD' embedAcc elimAcc LeftHandSideArray (Embed env1 cc1) (Embed env0 cc0) Scanr1 f a -> Scanr1 (cvtF f) (cvtA a) Scanr' f z a -> Scanr' (cvtF f) (cvtE z) (cvtA a) Permute f d p a -> Permute (cvtF f) (cvtA d) (cvtF p) (cvtA a) - Stencil f x a -> Stencil (cvtF f) (cvtB x) (cvtA a) - Stencil2 f x a y b -> Stencil2 (cvtF f) (cvtB x) (cvtA a) (cvtB y) (cvtA b) + Stencil s t f x a -> Stencil s t (cvtF f) (cvtB x) (cvtA a) + Stencil2 s1 s2 t f x a y b + -> Stencil2 s1 s2 t (cvtF f) (cvtB x) (cvtA a) (cvtB y) (cvtA b) -- Collect seq -> Collect (cvtSeq seq) where cvtA :: acc aenv s -> acc aenv s cvtA = kmap (replaceA sh' f' avar) - cvtE :: PreExp acc aenv s -> PreExp acc aenv s + cvtE :: Exp aenv s -> Exp aenv s cvtE = replaceE sh' f' avar - cvtF :: PreFun acc aenv s -> PreFun acc aenv s + cvtF :: Fun aenv s -> Fun aenv s cvtF = replaceF sh' f' avar - cvtB :: PreBoundary acc aenv s -> PreBoundary acc aenv s + cvtB :: Boundary aenv s -> Boundary aenv s cvtB Clamp = Clamp cvtB Mirror = Mirror cvtB Wrap = Wrap @@ -1525,7 +1471,7 @@ aletD' embedAcc elimAcc LeftHandSideArray (Embed env1 cc1) (Embed env0 cc0) cvtAF = cvt sh' f' avar where cvt :: forall aenv a. - PreExp acc aenv sh -> PreFun acc aenv (sh -> e) -> ArrayVar aenv (Array sh e) + Exp aenv sh -> Fun aenv (sh -> e) -> ArrayVar aenv (Array sh e) -> PreOpenAfun acc aenv a -> PreOpenAfun acc aenv a cvt sh'' f'' avar' (Abody a) = Abody $ kmap (replaceA sh'' f'' avar') a @@ -1589,13 +1535,13 @@ aletD' _ _ lhs (Embed env1 cc1) (Embed env0 cc0) -- acondD :: Kit acc => EmbedAcc acc - -> PreExp acc aenv Bool + -> Exp aenv Bool -> acc aenv arrs -> acc aenv arrs -> Embed acc aenv arrs acondD embedAcc p t e - | Const True <- p = Stats.knownBranch "True" $ embedAcc t - | Const False <- p = Stats.knownBranch "False" $ embedAcc e + | Const _ True <- p = Stats.knownBranch "True" $ embedAcc t + | Const _ False <- p = Stats.knownBranch "False" $ embedAcc e | Just Refl <- match t e = Stats.knownBranch "redundant" $ embedAcc e | otherwise = done $ Acond p (computeAcc (embedAcc t)) (computeAcc (embedAcc e)) @@ -1604,46 +1550,50 @@ acondD embedAcc p t e -- Scalar expressions -- ------------------ -isIdentity :: PreFun acc aenv (a -> b) -> Maybe (a :~: b) -isIdentity f - | Lam (Body (Var ZeroIdx)) <- f = Just Refl - | otherwise = Nothing - -identity :: Elt a => PreOpenFun acc env aenv (a -> a) -identity = Lam (Body (Var ZeroIdx)) - -toIndex :: (Kit acc, Shape sh) => PreOpenExp acc env aenv sh -> PreOpenFun acc env aenv (sh -> Int) -toIndex sh = Lam (Body (ToIndex (weakenE SuccIdx sh) (Var ZeroIdx))) - -fromIndex :: (Kit acc, Shape sh) => PreOpenExp acc env aenv sh -> PreOpenFun acc env aenv (Int -> sh) -fromIndex sh = Lam (Body (FromIndex (weakenE SuccIdx sh) (Var ZeroIdx))) - -reindex :: (Kit acc, Shape sh, Shape sh') - => PreOpenExp acc env aenv sh' - -> PreOpenExp acc env aenv sh - -> PreOpenFun acc env aenv (sh -> sh') -reindex sh' sh - | Just Refl <- match sh sh' = identity - | otherwise = fromIndex sh' `compose` toIndex sh - -extend :: (Kit acc, Shape sh, Shape sl, Elt slix) - => SliceIndex (EltRepr slix) (EltRepr sl) co (EltRepr sh) - -> PreExp acc aenv slix - -> PreFun acc aenv (sh -> sl) -extend sliceIndex slix = Lam (Body (IndexSlice sliceIndex (weakenE SuccIdx slix) (Var ZeroIdx))) - -restrict :: (Kit acc, Shape sh, Shape sl, Elt slix) - => SliceIndex (EltRepr slix) (EltRepr sl) co (EltRepr sh) - -> PreExp acc aenv slix - -> PreFun acc aenv (sl -> sh) -restrict sliceIndex slix = Lam (Body (IndexFull sliceIndex (weakenE SuccIdx slix) (Var ZeroIdx))) - -arrayShape :: (Kit acc) => ArrayVar aenv (Array sh e) -> PreExp acc aenv sh -arrayShape v@ArrayVar{} = simplify $ Shape $ avarIn v - -indexArray :: (Kit acc) => ArrayVar aenv (Array sh e) -> PreFun acc aenv (sh -> e) -indexArray v@ArrayVar{} = Lam (Body (Index (avarIn v) (Var ZeroIdx))) - -linearIndex :: (Kit acc) => ArrayVar aenv (Array sh e) -> PreFun acc aenv (Int -> e) -linearIndex v@ArrayVar{} = Lam (Body (LinearIndex (avarIn v) (Var ZeroIdx))) +identity :: TupleType a -> OpenFun env aenv (a -> a) +identity tp + | DeclareVars lhs _ value <- declareVars tp + = Lam lhs $ Body $ evars $ value weakenId + +toIndex :: ShapeR sh -> OpenExp env aenv sh -> OpenFun env aenv (sh -> Int) +toIndex shr sh + | DeclareVars lhs k value <- declareVars $ shapeType shr + = Lam lhs $ Body $ ToIndex shr (weakenE k sh) $ evars $ value weakenId + +fromIndex :: ShapeR sh -> OpenExp env aenv sh -> OpenFun env aenv (Int -> sh) +fromIndex shr sh = Lam (LeftHandSideSingle scalarTypeInt) $ Body $ FromIndex shr (weakenE (weakenSucc' weakenId) sh) $ Evar $ Var scalarTypeInt ZeroIdx + +reindex :: ShapeR sh' + -> OpenExp env aenv sh' + -> ShapeR sh + -> OpenExp env aenv sh + -> OpenFun env aenv (sh -> sh') +reindex shr' sh' shr sh + | Just Refl <- match sh sh' = identity (shapeType shr') + | otherwise = fromIndex shr' sh' `compose` toIndex shr sh + +extend :: SliceIndex slix sl co sh + -> Exp aenv slix + -> Fun aenv (sh -> sl) +extend sliceIndex slix + | DeclareVars lhs k value <- declareVars $ shapeType $ sliceDomainR sliceIndex + = Lam lhs $ Body $ IndexSlice sliceIndex (weakenE k slix) $ evars $ value weakenId + +restrict :: SliceIndex slix sl co sh + -> Exp aenv slix + -> Fun aenv (sl -> sh) +restrict sliceIndex slix + | DeclareVars lhs k value <- declareVars $ shapeType $ sliceShapeR sliceIndex + = Lam lhs $ Body $ IndexFull sliceIndex (weakenE k slix) $ evars $ value weakenId + +arrayShape :: ArrayVar aenv (Array sh e) -> Exp aenv sh +arrayShape = simplify . Shape + +indexArray :: ArrayVar aenv (Array sh e) -> Fun aenv (sh -> e) +indexArray v@(Var (ArrayR shr _) _) + | DeclareVars lhs _ value <- declareVars $ shapeType shr + = Lam lhs $ Body $ Index v $ evars $ value weakenId + +linearIndex :: ArrayVar aenv (Array sh e) -> Fun aenv (Int -> e) +linearIndex v = Lam (LeftHandSideSingle scalarTypeInt) $ Body $ LinearIndex v $ Evar $ Var scalarTypeInt ZeroIdx diff --git a/src/Data/Array/Accelerate/Trafo/LetSplit.hs b/src/Data/Array/Accelerate/Trafo/LetSplit.hs new file mode 100644 index 000000000..cbb0e2e7b --- /dev/null +++ b/src/Data/Array/Accelerate/Trafo/LetSplit.hs @@ -0,0 +1,69 @@ +{-# LANGUAGE GADTs #-} +-- | +-- Module : Data.Array.Accelerate.Trafo.LetSplit +-- Copyright : [2012..2019] The Accelerate Team +-- License : BSD3 +-- +-- Maintainer : Trevor L. McDonell +-- Stability : experimental +-- Portability : non-portable (GHC extensions) +-- + +module Data.Array.Accelerate.Trafo.LetSplit ( + + convertAcc, convertAfun + +) where + +import Prelude hiding ( exp ) +import Data.Array.Accelerate.AST +import Data.Array.Accelerate.Trafo.Base + +convertAcc :: Kit acc => acc aenv a -> acc aenv a +convertAcc acc = case extract acc of + Just a -> travA a + Nothing -> acc + +travA :: Kit acc => PreOpenAcc acc aenv a -> acc aenv a +travA (Alet lhs bnd body) = travBinding lhs (convertAcc bnd) (convertAcc body) +travA (Avar var) = inject $ Avar var +travA (Apair a1 a2) = inject $ Apair (convertAcc a1) (convertAcc a2) +travA Anil = inject $ Anil +travA (Apply repr f a) = inject $ Apply repr (convertAfun f) (convertAcc a) +travA (Aforeign repr asm f a) = inject $ Aforeign repr asm (convertAfun f) (convertAcc a) +travA (Acond e a1 a2) = inject $ Acond e (convertAcc a1) (convertAcc a2) +travA (Awhile c f a) = inject $ Awhile (convertAfun c) (convertAfun f) (convertAcc a) +travA (Use repr arr) = inject $ Use repr arr +travA (Unit tp e) = inject $ Unit tp e +travA (Reshape shr e a) = inject $ Reshape shr e a +travA (Generate repr e f) = inject $ Generate repr e f +travA (Transform repr sh f g a) = inject $ Transform repr sh f g (convertAcc a) +travA (Replicate slix sl a) = inject $ Replicate slix sl (convertAcc a) +travA (Slice slix a sl) = inject $ Slice slix (convertAcc a) sl +travA (Map tp f a) = inject $ Map tp f (convertAcc a) +travA (ZipWith tp f a1 a2) = inject $ ZipWith tp f (convertAcc a1) (convertAcc a2) +travA (Fold f e a) = inject $ Fold f e (convertAcc a) +travA (Fold1 f a) = inject $ Fold1 f (convertAcc a) +travA (FoldSeg i f e a s) = inject $ FoldSeg i f e (convertAcc a) (convertAcc s) +travA (Fold1Seg i f a s) = inject $ Fold1Seg i f (convertAcc a) (convertAcc s) +travA (Scanl f e a) = inject $ Scanl f e (convertAcc a) +travA (Scanl' f e a) = inject $ Scanl' f e (convertAcc a) +travA (Scanl1 f a) = inject $ Scanl1 f (convertAcc a) +travA (Scanr f e a) = inject $ Scanr f e (convertAcc a) +travA (Scanr' f e a) = inject $ Scanr' f e (convertAcc a) +travA (Scanr1 f a) = inject $ Scanr1 f (convertAcc a) +travA (Permute f a1 g a2) = inject $ Permute f (convertAcc a1) g (convertAcc a2) +travA (Backpermute shr sh f a) = inject $ Backpermute shr sh f (convertAcc a) +travA (Stencil s tp f b a) = inject $ Stencil s tp f b (convertAcc a) +travA (Stencil2 s1 s2 tp f b1 a1 b2 a2) = inject $ Stencil2 s1 s2 tp f b1 (convertAcc a1) b2 (convertAcc a2) + +travBinding :: Kit acc => ALeftHandSide bnd aenv aenv' -> acc aenv bnd -> acc aenv' a -> acc aenv a +travBinding (LeftHandSideWildcard _) _ a = a +travBinding lhs@(LeftHandSideSingle _) bnd a = inject $ Alet lhs bnd a +travBinding lhs@(LeftHandSidePair l1 l2) bnd a = case extract bnd of + Just (Apair b1 b2) -> travBinding l1 b1 $ travBinding l2 (weaken (weakenWithLHS l1) b2) a + _ -> inject $ Alet lhs bnd a + +convertAfun :: Kit acc => PreOpenAfun acc aenv f -> PreOpenAfun acc aenv f +convertAfun (Alam lhs f) = Alam lhs $ convertAfun f +convertAfun (Abody a) = Abody $ convertAcc a diff --git a/src/Data/Array/Accelerate/Trafo/Normalise.hs b/src/Data/Array/Accelerate/Trafo/Normalise.hs index 69523917b..7124e9889 100644 --- a/src/Data/Array/Accelerate/Trafo/Normalise.hs +++ b/src/Data/Array/Accelerate/Trafo/Normalise.hs @@ -17,7 +17,6 @@ module Data.Array.Accelerate.Trafo.Normalise ( import Prelude hiding ( exp ) import Data.Array.Accelerate.AST -import Data.Array.Accelerate.Product import Data.Array.Accelerate.Trafo.Substitution diff --git a/src/Data/Array/Accelerate/Trafo/Sharing.hs b/src/Data/Array/Accelerate/Trafo/Sharing.hs index da4371271..5f584d1c3 100644 --- a/src/Data/Array/Accelerate/Trafo/Sharing.hs +++ b/src/Data/Array/Accelerate/Trafo/Sharing.hs @@ -1,6 +1,5 @@ {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE BangPatterns #-} -{-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} @@ -37,7 +36,7 @@ module Data.Array.Accelerate.Trafo.Sharing ( Afunction, AfunctionR, AreprFunctionR, AfunctionRepr(..), afunctionRepr, convertAfun, convertAfunWith, - Function, FunctionR, + Function, FunctionR, EltReprFunctionR, FunctionRepr(..), functionRepr, convertExp, convertExpWith, convertFun, convertFunWith, @@ -51,7 +50,6 @@ import Control.Monad.Fix import Data.Hashable import Data.List hiding ( (\\) ) import Data.Maybe -import Data.Typeable import System.IO.Unsafe ( unsafePerformIO ) import System.Mem.StableName import Text.Printf @@ -63,15 +61,18 @@ import Prelude -- friends import Data.BitSet ( (\\), member ) +import Data.Array.Accelerate.Type import Data.Array.Accelerate.Error import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Trafo.Base import Data.Array.Accelerate.Trafo.Config -import Data.Array.Accelerate.Array.Sugar as Sugar hiding ( (!!) ) +import Data.Array.Accelerate.Array.Representation hiding ((!!)) +import Data.Array.Accelerate.Array.Sugar ( Elt, EltRepr, Arrays, ArrRepr, eltType ) +import qualified Data.Array.Accelerate.Array.Sugar as Sugar import Data.Array.Accelerate.AST hiding ( PreOpenAcc(..), OpenAcc(..), Acc - , PreOpenExp(..), OpenExp, PreExp, Exp - , PreBoundary(..), Boundary, Stencil(..) - , showPreAccOp, showPreExpOp ) + , OpenExp(..), Exp + , Boundary(..) + , showPreAccOp, showPreExpOp, expType, HasArraysRepr(..), arraysRepr ) import qualified Data.Array.Accelerate.AST as AST import Data.Array.Accelerate.Debug.Trace as Debug import Data.Array.Accelerate.Debug.Flags as Debug @@ -84,84 +85,54 @@ import Data.Array.Accelerate.Debug.Flags as Debug -- Each entry in the layout holds the de Bruijn index that refers to the -- corresponding entry in the environment. -- -data Layout env env' where - EmptyLayout :: Layout env () - PushLayout :: Typeable t => Layout env env' -> Idx env t -> Layout env (env', t) +data Layout s env env' where + EmptyLayout :: Layout s env () + PushLayout :: Layout s env env1 + -> LeftHandSide s t env1 env2 + -> Vars s env t + -> Layout s env env2 + +type ELayout = Layout ScalarType +type ArrayLayout = Layout ArrayR -data ArrayLayout env env' where - ArrayEmptyLayout :: ArrayLayout env () - ArrayPushLayout - :: Typeable t - => ArrayLayout env env1 - -> LeftHandSide t env1 env2 - -> ArrayVars env t - -> ArrayLayout env env2 -- Project the nth index out of an environment layout. -- -- The first argument provides context information for error messages in the -- case of failure. -- -prjIdx :: Typeable t +prjIdx :: forall s t env env1. Match s => String + -> (forall t'. TupR s t' -> ShowS) + -> TupR s t -> Int - -> Layout env env' - -> Idx env t -prjIdx context = go + -> Layout s env env1 + -> Vars s env t +prjIdx context showTp tp = go where - go :: forall env env' t. Typeable t => Int -> Layout env env' -> Idx env t + go :: forall env'. Int -> Layout s env env' -> Vars s env t go _ EmptyLayout = no "environment does not contain index" - go 0 (PushLayout _ (ix :: Idx env0 s)) - | Just ix' <- gcast ix = ix' + go 0 (PushLayout _ lhs vars) + | Just Refl <- match tp tp' = vars | otherwise = no $ printf "couldn't match expected type `%s' with actual type `%s'" - (show (typeOf (undefined::t))) - (show (typeOf (undefined::s))) - go n (PushLayout l _) = go (n-1) l + (showTp tp "") + (showTp tp' "") + where + tp' = lhsToTupR lhs + go n (PushLayout l _ _) = go (n-1) l no :: String -> a no reason = $internalError "prjIdx" (printf "%s\nin the context: %s" reason context) -prjArrayIdx :: Typeable t - => String - -> Int - -> ArrayLayout env env' - -> AST.OpenAcc env t -prjArrayIdx context = go - where - go :: forall env env' t. Typeable t => Int -> ArrayLayout env env' -> AST.OpenAcc env t - go _ ArrayEmptyLayout = no "environment does not contain index" - go 0 (ArrayPushLayout _ _ (ix :: ArrayVars env0 s)) - | Just ix' <- gcast ix = avarsIn ix' - | otherwise = no $ printf "couldn't match expected type `%s' with actual type `%s'" - (show (typeOf (undefined::t))) - (show (typeOf (undefined::s))) - go n (ArrayPushLayout l _ _) = go (n-1) l - - no :: String -> a - no reason = $internalError "prjArrayIdx" (printf "%s\nin the context: %s" reason context) - -- Add an entry to a layout, incrementing all indices -- -incLayout :: Layout env env' -> Layout (env, t) env' -incLayout EmptyLayout = EmptyLayout -incLayout (PushLayout lyt ix) = PushLayout (incLayout lyt) (SuccIdx ix) - -incArrayLayoutWith :: env1 :> env2 -> ArrayLayout env1 env' -> ArrayLayout env2 env' -incArrayLayoutWith _ ArrayEmptyLayout = ArrayEmptyLayout -incArrayLayoutWith k (ArrayPushLayout lyt lhs t) = ArrayPushLayout (incArrayLayoutWith k lyt) lhs (incVarsWith k t) - -sizeLayout :: Layout env env' -> Int -sizeLayout EmptyLayout = 0 -sizeLayout (PushLayout lyt _) = 1 + sizeLayout lyt +incLayout :: env1 :> env2 -> Layout s env1 env' -> Layout s env2 env' +incLayout _ EmptyLayout = EmptyLayout +incLayout k (PushLayout lyt lhs v) = PushLayout (incLayout k lyt) lhs (weaken k v) -sizeArrayLayout :: ArrayLayout env env' -> Int -sizeArrayLayout ArrayEmptyLayout = 0 -sizeArrayLayout (ArrayPushLayout lyt _ _) = 1 + sizeArrayLayout lyt - -incVarsWith :: env1 :> env2 -> ArrayVars env1 t -> ArrayVars env2 t -incVarsWith _ ArrayVarsNil = ArrayVarsNil -incVarsWith k (ArrayVarsArray (ArrayVar idx)) = ArrayVarsArray $ ArrayVar $ k idx -incVarsWith k (ArrayVarsPair v1 v2) = incVarsWith k v1 `ArrayVarsPair` incVarsWith k v2 +sizeLayout :: Layout s env env' -> Int +sizeLayout EmptyLayout = 0 +sizeLayout (PushLayout lyt _ _) = 1 + sizeLayout lyt -- Conversion from HOAS to de Bruijn computation AST -- ================================================= @@ -172,11 +143,11 @@ incVarsWith k (ArrayVarsPair v1 v2) = incVarsWith k v1 `ArrayVarsPair` -- | Convert a closed array expression to de Bruijn form while also incorporating sharing -- information. -- -convertAcc :: Arrays arrs => Acc arrs -> AST.Acc (ArrRepr arrs) +convertAcc :: Acc arrs -> AST.Acc (ArrRepr arrs) convertAcc = convertAccWith defaultOptions -convertAccWith :: Arrays arrs => Config -> Acc arrs -> AST.Acc (ArrRepr arrs) -convertAccWith config (Acc acc) = convertOpenAcc config ArrayEmptyLayout acc +convertAccWith :: Config -> Acc arrs -> AST.Acc (ArrRepr arrs) +convertAccWith config (Acc acc) = convertOpenAcc config EmptyLayout acc -- | Convert a closed function over array computations, while incorporating @@ -186,7 +157,7 @@ convertAfun :: Afunction f => f -> AST.Afun (AreprFunctionR f) convertAfun = convertAfunWith defaultOptions convertAfunWith :: Afunction f => Config -> f -> AST.Afun (AreprFunctionR f) -convertAfunWith config = convertOpenAfun config ArrayEmptyLayout +convertAfunWith config = convertOpenAfun config EmptyLayout data AfunctionRepr f ar areprr where AfunctionReprBody @@ -215,12 +186,14 @@ instance (Arrays a, Afunction r) => Afunction (Acc a -> r) where type AreprFunctionR (Acc a -> r) = ArrRepr a -> AreprFunctionR r afunctionRepr = AfunctionReprLam $ afunctionRepr @r - convertOpenAfun config alyt f = case declareArrays $ arrays @a of - DeclareArrays lhs k value -> + convertOpenAfun config alyt f + | repr <- Sugar.arrays @a + , DeclareVars lhs k value <- declareVars repr = let - a = Acc $ SmartAcc $ Atag $ sizeArrayLayout alyt - alyt' = ArrayPushLayout (incArrayLayoutWith k alyt) lhs (value id) - in Alam lhs $ convertOpenAfun config alyt' $ f a + a = Acc $ SmartAcc $ Atag repr $ sizeLayout alyt + alyt' = PushLayout (incLayout k alyt) lhs (value weakenId) + in + Alam lhs $ convertOpenAfun config alyt' $ f a instance Arrays b => Afunction (Acc b) where type AfunctionR (Acc b) = b @@ -228,18 +201,25 @@ instance Arrays b => Afunction (Acc b) where afunctionRepr = AfunctionReprBody convertOpenAfun config alyt (Acc body) = Abody $ convertOpenAcc config alyt body +convertSmartAfun1 :: Config -> ArraysR a -> (SmartAcc a -> SmartAcc b) -> AST.Afun (a -> b) +convertSmartAfun1 config repr f + | DeclareVars lhs _ value <- declareVars repr = + let + a = SmartAcc $ Atag repr 0 + alyt' = PushLayout EmptyLayout lhs (value weakenId) + in + Alam lhs $ Abody $ convertOpenAcc config alyt' $ f a -- | Convert an open array expression to de Bruijn form while also incorporating sharing -- information. -- convertOpenAcc - :: Typeable arrs - => Config + :: Config -> ArrayLayout aenv aenv -> SmartAcc arrs -> AST.OpenAcc aenv arrs convertOpenAcc config alyt acc = - let lvl = sizeArrayLayout alyt + let lvl = sizeLayout alyt fvs = [lvl-1, lvl-2 .. 0] (sharingAcc, initialEnv) = recoverSharingAcc config lvl fvs acc in @@ -254,15 +234,15 @@ convertOpenAcc config alyt acc = -- in reverse chronological order (outermost variable is at the end of the list). -- convertSharingAcc - :: forall aenv arrs. Typeable arrs - => Config + :: forall aenv arrs. + Config -> ArrayLayout aenv aenv -> [StableSharingAcc] -> ScopedAcc arrs -> AST.OpenAcc aenv arrs -convertSharingAcc _ alyt aenv (ScopedAcc lams (AvarSharing sa)) +convertSharingAcc _ alyt aenv (ScopedAcc lams (AvarSharing sa repr)) | Just i <- findIndex (matchStableAcc sa) aenv' - = prjArrayIdx (ctxt ++ "; i = " ++ show i) i alyt + = avarsIn $ prjIdx (ctxt ++ "; i = " ++ show i) showArraysR repr i alyt | null aenv' = error $ "Cyclic definition of a value of type 'Acc' (sa = " ++ show (hashStableNameHeight sa) ++ ")" @@ -274,10 +254,10 @@ convertSharingAcc _ alyt aenv (ScopedAcc lams (AvarSharing sa)) err = "inconsistent valuation @ " ++ ctxt ++ ";\n aenv = " ++ show aenv' convertSharingAcc config alyt aenv (ScopedAcc lams (AletSharing sa@(StableSharingAcc (_ :: StableAccName as) boundAcc) bodyAcc)) - = case declareArrays $ arraysRepr bound of - DeclareArrays lhs k value -> + = case declareVars $ AST.arraysRepr bound of + DeclareVars lhs k value -> let - alyt' = ArrayPushLayout (incArrayLayoutWith k alyt) lhs (value id) + alyt' = PushLayout (incLayout k alyt) lhs (value weakenId) in AST.OpenAcc $ AST.Alet lhs @@ -291,53 +271,50 @@ convertSharingAcc config alyt aenv (ScopedAcc lams (AccSharing _ preAcc)) = AST.OpenAcc $ let aenv' = lams ++ aenv - cvtA :: Typeable a => ScopedAcc a -> AST.OpenAcc aenv a + cvtA :: ScopedAcc a -> AST.OpenAcc aenv a cvtA = convertSharingAcc config alyt aenv' - cvtE :: Elt t => ScopedExp t -> AST.Exp aenv t + cvtE :: ScopedExp t -> AST.Exp aenv t cvtE = convertSharingExp config EmptyLayout alyt [] aenv' - cvtF1 :: (Elt a, Elt b) => (Exp a -> ScopedExp b) -> AST.Fun aenv (a -> b) + cvtF1 :: TupleType a -> (SmartExp a -> ScopedExp b) -> AST.Fun aenv (a -> b) cvtF1 = convertSharingFun1 config alyt aenv' - cvtF2 :: (Elt a, Elt b, Elt c) => (Exp a -> Exp b -> ScopedExp c) -> AST.Fun aenv (a -> b -> c) + cvtF2 :: TupleType a -> TupleType b -> (SmartExp a -> SmartExp b -> ScopedExp c) -> AST.Fun aenv (a -> b -> c) cvtF2 = convertSharingFun2 config alyt aenv' - cvtAfun1 :: (Typeable a, Typeable b) => ArraysR a -> (SmartAcc a -> ScopedAcc b) -> AST.OpenAfun aenv (a -> b) + cvtAfun1 :: ArraysR a -> (SmartAcc a -> ScopedAcc b) -> AST.OpenAfun aenv (a -> b) cvtAfun1 = convertSharingAfun1 config alyt aenv' - cvtAprj :: forall a b c. (Typeable a, Typeable b) => PairIdx (a, b) c -> ScopedAcc (a, b) -> AST.OpenAcc aenv c + cvtAprj :: forall a b c. PairIdx (a, b) c -> ScopedAcc (a, b) -> AST.OpenAcc aenv c cvtAprj ix a = cvtAprj' ix $ cvtA a - cvtAprj' :: forall a b c aenv1. (Typeable a, Typeable b) => PairIdx (a, b) c -> AST.OpenAcc aenv1 (a, b) -> AST.OpenAcc aenv1 c + cvtAprj' :: forall a b c aenv1. PairIdx (a, b) c -> AST.OpenAcc aenv1 (a, b) -> AST.OpenAcc aenv1 c cvtAprj' PairIdxLeft (AST.OpenAcc (AST.Apair a _)) = a cvtAprj' PairIdxRight (AST.OpenAcc (AST.Apair _ b)) = b - cvtAprj' ix a = case declareArrays $ arraysRepr a of - DeclareArrays lhs _ value -> - AST.OpenAcc $ AST.Alet lhs a $ cvtAprj' ix $ avarsIn $ value id + cvtAprj' ix a = case declareVars $ AST.arraysRepr a of + DeclareVars lhs _ value -> + AST.OpenAcc $ AST.Alet lhs a $ cvtAprj' ix $ avarsIn $ value weakenId in case preAcc of - Atag i - -> let AST.OpenAcc a = prjArrayIdx ("de Bruijn conversion tag " ++ show i) i alyt + Atag repr i + -> let AST.OpenAcc a = avarsIn $ prjIdx ("de Bruijn conversion tag " ++ show i) showArraysR repr i alyt in a - Pipe reprA reprB (afun1 :: SmartAcc as -> ScopedAcc bs) (afun2 :: SmartAcc bs -> ScopedAcc cs) acc - -> + Pipe reprA reprB reprC (afun1 :: SmartAcc as -> ScopedAcc bs) (afun2 :: SmartAcc bs -> ScopedAcc cs) acc + | DeclareVars lhs k value <- declareVars reprB -> let noStableSharing = StableSharingAcc noStableAccName (undefined :: SharingAcc acc exp ()) - boundAcc = AST.Apply (cvtAfun1 reprA afun1) (cvtA acc) - in case declareArrays reprB of - DeclareArrays lhs k value -> - let - alyt' = ArrayPushLayout (incArrayLayoutWith k alyt) lhs (value id) - bodyAcc = AST.Apply - (convertSharingAfun1 config alyt' (noStableSharing : aenv') reprB afun2) - (avarsIn $ value id) - in AST.Alet lhs (AST.OpenAcc boundAcc) (AST.OpenAcc bodyAcc) - - Aforeign ff afun acc - -> AST.Aforeign ff (convertAfunWith config afun) (cvtA acc) + boundAcc = AST.Apply reprB (cvtAfun1 reprA afun1) (cvtA acc) + alyt' = PushLayout (incLayout k alyt) lhs (value weakenId) + bodyAcc = AST.Apply reprC + (convertSharingAfun1 config alyt' (noStableSharing : aenv') reprB afun2) + (avarsIn $ value weakenId) + in AST.Alet lhs (AST.OpenAcc boundAcc) (AST.OpenAcc bodyAcc) + + Aforeign repr ff afun acc + -> AST.Aforeign repr ff (convertSmartAfun1 config (arraysRepr acc) afun) (cvtA acc) Acond b acc1 acc2 -> AST.Acond (cvtE b) (cvtA acc1) (cvtA acc2) Awhile reprA pred iter init -> AST.Awhile (cvtAfun1 reprA pred) (cvtAfun1 reprA iter) (cvtA init) @@ -345,35 +322,45 @@ convertSharingAcc config alyt aenv (ScopedAcc lams (AccSharing _ preAcc)) Apair acc1 acc2 -> AST.Apair (cvtA acc1) (cvtA acc2) Aprj ix a -> let AST.OpenAcc a' = cvtAprj ix a in a' - Use array -> AST.Use array - Unit e -> AST.Unit (cvtE e) - Generate sh f -> AST.Generate (cvtE sh) (cvtF1 f) - Reshape e acc -> AST.Reshape (cvtE e) (cvtA acc) - Replicate ix acc -> mkReplicate (cvtE ix) (cvtA acc) - Slice acc ix -> mkIndex (cvtA acc) (cvtE ix) - Map f acc -> AST.Map (cvtF1 f) (cvtA acc) - ZipWith f acc1 acc2 -> AST.ZipWith (cvtF2 f) (cvtA acc1) (cvtA acc2) - Fold f e acc -> AST.Fold (cvtF2 f) (cvtE e) (cvtA acc) - Fold1 f acc -> AST.Fold1 (cvtF2 f) (cvtA acc) - FoldSeg f e acc1 acc2 -> AST.FoldSeg (cvtF2 f) (cvtE e) (cvtA acc1) (cvtA acc2) - Fold1Seg f acc1 acc2 -> AST.Fold1Seg (cvtF2 f) (cvtA acc1) (cvtA acc2) - Scanl f e acc -> AST.Scanl (cvtF2 f) (cvtE e) (cvtA acc) - Scanl' f e acc -> AST.Scanl' (cvtF2 f) (cvtE e) (cvtA acc) - Scanl1 f acc -> AST.Scanl1 (cvtF2 f) (cvtA acc) - Scanr f e acc -> AST.Scanr (cvtF2 f) (cvtE e) (cvtA acc) - Scanr' f e acc -> AST.Scanr' (cvtF2 f) (cvtE e) (cvtA acc) - Scanr1 f acc -> AST.Scanr1 (cvtF2 f) (cvtA acc) - Permute f dftAcc perm acc -> AST.Permute (cvtF2 f) (cvtA dftAcc) (cvtF1 perm) (cvtA acc) - Backpermute newDim perm acc -> AST.Backpermute (cvtE newDim) (cvtF1 perm) (cvtA acc) - Stencil stencil boundary acc - -> AST.Stencil (convertSharingStencilFun1 config acc alyt aenv' stencil) - (convertSharingBoundary config alyt aenv' boundary) + Use repr array -> AST.Use repr array + Unit tp e -> AST.Unit tp (cvtE e) + Generate repr@(ArrayR shr _) sh f + -> AST.Generate repr (cvtE sh) (cvtF1 (shapeType shr) f) + Reshape shr e acc -> AST.Reshape shr (cvtE e) (cvtA acc) + Replicate si ix acc -> AST.Replicate si (cvtE ix) (cvtA acc) + Slice si acc ix -> AST.Slice si (cvtA acc) (cvtE ix) + Map t1 t2 f acc -> AST.Map t2 (cvtF1 t1 f) (cvtA acc) + ZipWith t1 t2 t3 f acc1 acc2 + -> AST.ZipWith t3 (cvtF2 t1 t2 f) (cvtA acc1) (cvtA acc2) + Fold tp f e acc -> AST.Fold (cvtF2 tp tp f) (cvtE e) (cvtA acc) + Fold1 tp f acc -> AST.Fold1 (cvtF2 tp tp f) (cvtA acc) + FoldSeg i tp f e acc1 acc2 -> AST.FoldSeg i (cvtF2 tp tp f) (cvtE e) (cvtA acc1) (cvtA acc2) + Fold1Seg i tp f acc1 acc2 -> AST.Fold1Seg i (cvtF2 tp tp f) (cvtA acc1) (cvtA acc2) + Scanl tp f e acc -> AST.Scanl (cvtF2 tp tp f) (cvtE e) (cvtA acc) + Scanl' tp f e acc -> AST.Scanl' (cvtF2 tp tp f) (cvtE e) (cvtA acc) + Scanl1 tp f acc -> AST.Scanl1 (cvtF2 tp tp f) (cvtA acc) + Scanr tp f e acc -> AST.Scanr (cvtF2 tp tp f) (cvtE e) (cvtA acc) + Scanr' tp f e acc -> AST.Scanr' (cvtF2 tp tp f) (cvtE e) (cvtA acc) + Scanr1 tp f acc -> AST.Scanr1 (cvtF2 tp tp f) (cvtA acc) + Permute (ArrayR shr tp) f dftAcc perm acc + -> AST.Permute (cvtF2 tp tp f) (cvtA dftAcc) (cvtF1 (shapeType shr) perm) (cvtA acc) + Backpermute shr newDim perm acc + -> AST.Backpermute shr (cvtE newDim) (cvtF1 (shapeType shr) perm) (cvtA acc) + Stencil stencil tp f boundary acc + -> AST.Stencil stencil + tp + (convertSharingStencilFun1 config alyt aenv' stencil f) + (convertSharingBoundary config alyt aenv' (stencilShape stencil) boundary) (cvtA acc) - Stencil2 stencil bndy1 acc1 bndy2 acc2 - -> AST.Stencil2 (convertSharingStencilFun2 config acc1 acc2 alyt aenv' stencil) - (convertSharingBoundary config alyt aenv' bndy1) + Stencil2 stencil1 stencil2 tp f bndy1 acc1 bndy2 acc2 + | shr <- stencilShape stencil1 + -> AST.Stencil2 stencil1 + stencil2 + tp + (convertSharingStencilFun2 config alyt aenv' stencil1 stencil2 f) + (convertSharingBoundary config alyt aenv' shr bndy1) (cvtA acc1) - (convertSharingBoundary config alyt aenv' bndy2) + (convertSharingBoundary config alyt aenv' shr bndy2) (cvtA acc2) -- Collect seq -> AST.Collect (convertSharingSeq config alyt EmptyLayout aenv' [] seq) @@ -526,84 +513,43 @@ convertSharingSeq config alyt slyt aenv senv s --} convertSharingAfun1 - :: forall aenv a b. (Typeable a, Typeable b) - => Config + :: forall aenv a b. + Config -> ArrayLayout aenv aenv -> [StableSharingAcc] -> ArraysR a -> (SmartAcc a -> ScopedAcc b) -> OpenAfun aenv (a -> b) -convertSharingAfun1 config alyt aenv reprA f = case declareArrays reprA of - DeclareArrays lhs k value -> +convertSharingAfun1 config alyt aenv reprA f + | DeclareVars lhs k value <- declareVars reprA = let - alyt' = ArrayPushLayout (incArrayLayoutWith k alyt) lhs (value id) + alyt' = PushLayout (incLayout k alyt) lhs (value weakenId) body = f undefined in Alam lhs (Abody (convertSharingAcc config alyt' aenv body)) -{-- -convertSharingAfun2 - :: forall aenv a b c. (Arrays a, Arrays b, Arrays c) - => Config - -> Layout aenv aenv - -> [StableSharingAcc] - -> (Acc a -> Acc b -> ScopedAcc c) - -> OpenAfun aenv (a -> b -> c) -convertSharingAfun2 config alyt aenv f - = Alam (Alam (Abody (convertSharingAcc config alyt' aenv body))) - where - alyt' = incLayout (incLayout alyt `PushLayout` ZeroIdx) `PushLayout` ZeroIdx - body = f undefined undefined - -convertSharingAfun3 - :: forall aenv a b c d. (Arrays a, Arrays b, Arrays c, Arrays d) - => Config - -> Layout aenv aenv - -> [StableSharingAcc] - -> (Acc a -> Acc b -> Acc c -> ScopedAcc d) - -> OpenAfun aenv (a -> b -> c -> d) -convertSharingAfun3 config alyt aenv f - = Alam (Alam (Alam (Abody (convertSharingAcc config alyt' aenv body)))) - where - alyt' = incLayout (incLayout (incLayout alyt `PushLayout` ZeroIdx) `PushLayout` ZeroIdx) `PushLayout` ZeroIdx - body = f undefined undefined undefined ---} - -- | Convert a boundary condition -- convertSharingBoundary - :: forall aenv t. + :: forall aenv sh e. Config -> ArrayLayout aenv aenv -> [StableSharingAcc] - -> PreBoundary ScopedAcc ScopedExp t - -> AST.PreBoundary AST.OpenAcc aenv t -convertSharingBoundary config alyt aenv = cvt + -> ShapeR sh + -> PreBoundary ScopedAcc ScopedExp (Array sh e) + -> AST.Boundary aenv (Array sh e) +convertSharingBoundary config alyt aenv shr = cvt where - cvt :: PreBoundary ScopedAcc ScopedExp t -> AST.Boundary aenv t + cvt :: PreBoundary ScopedAcc ScopedExp (Array sh e) -> AST.Boundary aenv (Array sh e) cvt bndy = case bndy of Clamp -> AST.Clamp Mirror -> AST.Mirror Wrap -> AST.Wrap - Constant v -> AST.Constant $ fromElt v - Function f -> AST.Function $ convertSharingFun1 config alyt aenv f + Constant v -> AST.Constant v + Function f -> AST.Function $ convertSharingFun1 config alyt aenv (shapeType shr) f --- Smart constructors to represent AST forms --- -mkIndex :: forall slix e aenv. (Slice slix, Elt e) - => AST.OpenAcc aenv (Array (FullShape slix) e) - -> AST.Exp aenv slix - -> AST.PreOpenAcc AST.OpenAcc aenv (Array (SliceShape slix) e) -mkIndex = AST.Slice (sliceIndex @slix) - -mkReplicate :: forall slix e aenv. (Slice slix, Elt e) - => AST.Exp aenv slix - -> AST.OpenAcc aenv (Array (SliceShape slix) e) - -> AST.PreOpenAcc AST.OpenAcc aenv (Array (FullShape slix) e) -mkReplicate = AST.Replicate (sliceIndex @slix) - -- mkToSeq :: forall slsix slix e aenv senv. (Division slsix, DivisionSlice slsix ~ slix, Elt e, Elt slix, Slice slix) -- => slsix -- -> AST.OpenAcc aenv (Array (FullShape slix) e) @@ -624,29 +570,59 @@ mkReplicate = AST.Replicate (sliceIndex @slix) -- In higher-order abstract syntax, this represents an n-ary, polyvariadic -- function. -- -convertFun :: Function f => f -> AST.Fun () (FunctionR f) +convertFun :: Function f => f -> AST.Fun () (EltReprFunctionR f) convertFun = convertFunWith - $ defaultOptions { options = options defaultOptions \\ [seq_sharing, acc_sharing, float_out_acc] } + $ defaultOptions { options = options defaultOptions \\ [seq_sharing, acc_sharing] } -convertFunWith :: Function f => Config -> f -> AST.Fun () (FunctionR f) +convertFunWith :: Function f => Config -> f -> AST.Fun () (EltReprFunctionR f) convertFunWith config = convertOpenFun config EmptyLayout +data FunctionRepr f r reprr where + FunctionReprBody + :: Elt b => FunctionRepr (Exp b) b (EltRepr b) + + FunctionReprLam + :: Elt a + => FunctionRepr b br breprr + -> FunctionRepr (Exp a -> b) (a -> br) (EltRepr a -> breprr) + class Function f where type FunctionR f - convertOpenFun :: Config -> Layout env env -> f -> AST.OpenFun env () (FunctionR f) + type EltReprFunctionR f + + functionRepr :: FunctionRepr f (FunctionR f) (EltReprFunctionR f) + convertOpenFun :: Config -> ELayout env env -> f -> AST.OpenFun env () (EltReprFunctionR f) instance (Elt a, Function r) => Function (Exp a -> r) where type FunctionR (Exp a -> r) = a -> FunctionR r - convertOpenFun config lyt f = - let x = Exp $ Tag (sizeLayout lyt) - lyt' = incLayout lyt `PushLayout` ZeroIdx - in Lam $ convertOpenFun config lyt' (f x) + type EltReprFunctionR (Exp a -> r) = EltRepr a -> EltReprFunctionR r + + functionRepr = FunctionReprLam $ functionRepr @r + convertOpenFun config lyt f + | tp <- eltType @a + , DeclareVars lhs k value <- declareVars tp = + let + e = Exp $ SmartExp $ Tag tp $ sizeLayout lyt + lyt' = PushLayout (incLayout k lyt) lhs (value weakenId) + in + Lam lhs $ convertOpenFun config lyt' $ f e instance Elt b => Function (Exp b) where type FunctionR (Exp b) = b - convertOpenFun config lyt body = Body $ convertOpenExp config lyt body + type EltReprFunctionR (Exp b) = EltRepr b + + functionRepr = FunctionReprBody + convertOpenFun config lyt (Exp body) = Body $ convertOpenExp config lyt body +convertSmartFun :: Config -> TupleType a -> (SmartExp a -> SmartExp b) -> AST.Fun () (a -> b) +convertSmartFun config tp f + | DeclareVars lhs _ value <- declareVars tp = + let + e = SmartExp $ Tag tp 0 + lyt' = PushLayout EmptyLayout lhs (value weakenId) + in + Lam lhs $ Body $ convertOpenExp config lyt' $ f e -- Scalar expressions -- ------------------ @@ -654,26 +630,25 @@ instance Elt b => Function (Exp b) where -- | Convert a closed scalar expression to de Bruijn form while incorporating -- sharing information. -- -convertExp :: Elt e => Exp e -> AST.Exp () e +convertExp :: Exp e -> AST.Exp () (EltRepr e) convertExp = convertExpWith - $ defaultOptions { options = options defaultOptions \\ [seq_sharing, acc_sharing, float_out_acc] } + $ defaultOptions { options = options defaultOptions \\ [seq_sharing, acc_sharing] } -convertExpWith :: Elt e => Config -> Exp e -> AST.Exp () e -convertExpWith config = convertOpenExp config EmptyLayout +convertExpWith :: Config -> Exp e -> AST.Exp () (EltRepr e) +convertExpWith config (Exp e) = convertOpenExp config EmptyLayout e convertOpenExp - :: Elt e - => Config - -> Layout env env - -> Exp e + :: Config + -> ELayout env env + -> SmartExp e -> AST.OpenExp env () e convertOpenExp config lyt exp = let lvl = sizeLayout lyt fvs = [lvl-1, lvl-2 .. 0] (sharingExp, initialEnv) = recoverSharingExp config lvl fvs exp in - convertSharingExp config lyt ArrayEmptyLayout initialEnv [] sharingExp + convertSharingExp config lyt EmptyLayout initialEnv [] sharingExp -- | Convert an open expression with given environment layouts and sharing information into @@ -684,9 +659,9 @@ convertOpenExp config lyt exp = -- keeping them in reverse chronological order (outermost variable is at the end of the list). -- convertSharingExp - :: forall t env aenv. Elt t - => Config - -> Layout env env -- scalar environment + :: forall t env aenv. + Config + -> ELayout env env -- scalar environment -> ArrayLayout aenv aenv -- array environment -> [StableSharingExp] -- currently bound sharing variables of expressions -> [StableSharingAcc] -- currently bound sharing variables of array computations @@ -697,9 +672,9 @@ convertSharingExp config lyt alyt env aenv exp@(ScopedExp lams _) = cvt exp -- scalar environment with any lambda bound variables this expression is rooted in env' = lams ++ env - cvt :: Elt t' => ScopedExp t' -> AST.OpenExp env aenv t' - cvt (ScopedExp _ (VarSharing se)) - | Just i <- findIndex (matchStableExp se) env' = AST.Var (prjIdx (ctx i) i lyt) + cvt :: ScopedExp t' -> AST.OpenExp env aenv t' + cvt (ScopedExp _ (VarSharing se tp)) + | Just i <- findIndex (matchStableExp se) env' = evars (prjIdx (ctx i) showType tp i lyt) | otherwise = $internalError "convertSharingExp" msg where ctx i = printf "shared 'Exp' tree with stable name %d; i=%d" (hashStableNameHeight se) i @@ -747,161 +722,130 @@ convertSharingExp config lyt alyt env aenv exp@(ScopedExp lams _) = cvt exp ] cvt (ScopedExp _ (LetSharing se@(StableSharingExp _ boundExp) bodyExp)) - = let lyt' = incLayout lyt `PushLayout` ZeroIdx - in - AST.Let (cvt (ScopedExp [] boundExp)) (convertSharingExp config lyt' alyt (se:env') aenv bodyExp) + | DeclareVars lhs k value <- declareVars $ expType boundExp + = let + lyt' = PushLayout (incLayout k lyt) lhs (value weakenId) + in + AST.Let lhs (cvt (ScopedExp [] boundExp)) (convertSharingExp config lyt' alyt (se:env') aenv bodyExp) cvt (ScopedExp _ (ExpSharing _ pexp)) = case pexp of - Tag i -> AST.Var (prjIdx ("de Bruijn conversion tag " ++ show i) i lyt) - Const v -> AST.Const (fromElt v) - Undef -> AST.Undef - Tuple tup -> AST.Tuple (cvtT tup) - Prj idx e -> AST.Prj idx (cvt e) - IndexNil -> AST.IndexNil - IndexCons ix i -> AST.IndexCons (cvt ix) (cvt i) - IndexHead i -> AST.IndexHead (cvt i) - IndexTail ix -> AST.IndexTail (cvt ix) - IndexAny -> AST.IndexAny - ToIndex sh ix -> AST.ToIndex (cvt sh) (cvt ix) - FromIndex sh e -> AST.FromIndex (cvt sh) (cvt e) + Tag tp i -> evars $ prjIdx ("de Bruijn conversion tag " ++ show i) showType tp i lyt + Const tp v -> AST.Const tp v + Undef tp -> AST.Undef tp + Prj idx e -> cvtPrj idx (cvt e) + Nil -> AST.Nil + Pair e1 e2 -> AST.Pair (cvt e1) (cvt e2) + VecPack vec e -> AST.VecPack vec (cvt e) + VecUnpack vec e -> AST.VecUnpack vec (cvt e) + ToIndex shr sh ix -> AST.ToIndex shr (cvt sh) (cvt ix) + FromIndex shr sh e -> AST.FromIndex shr (cvt sh) (cvt e) Cond e1 e2 e3 -> AST.Cond (cvt e1) (cvt e2) (cvt e3) - While p it i -> AST.While (cvtFun1 p) (cvtFun1 it) (cvt i) + While tp p it i -> AST.While (cvtFun1 tp p) (cvtFun1 tp it) (cvt i) PrimConst c -> AST.PrimConst c PrimApp f e -> cvtPrimFun f (cvt e) - Index a e -> AST.Index (cvtA a) (cvt e) - LinearIndex a i -> AST.LinearIndex (cvtA a) (cvt i) - Shape a -> AST.Shape (cvtA a) - ShapeSize e -> AST.ShapeSize (cvt e) - Intersect sh1 sh2 -> AST.Intersect (cvt sh1) (cvt sh2) - Union sh1 sh2 -> AST.Union (cvt sh1) (cvt sh2) - Foreign ff f e -> AST.Foreign ff (convertFunWith config f) (cvt e) - Coerce e -> AST.Coerce (cvt e) - - cvtA :: Typeable a => ScopedAcc a -> AST.OpenAcc aenv a + Index _ a e -> AST.Index (cvtAvar a) (cvt e) + LinearIndex _ a i -> AST.LinearIndex (cvtAvar a) (cvt i) + Shape _ a -> AST.Shape (cvtAvar a) + ShapeSize shr e -> AST.ShapeSize shr (cvt e) + Foreign repr ff f e -> AST.Foreign repr ff (convertSmartFun config (expType e) f) (cvt e) + Coerce t1 t2 e -> AST.Coerce t1 t2 (cvt e) + + cvtPrj :: forall a b c env1 aenv1. PairIdx (a, b) c -> AST.OpenExp env1 aenv1 (a, b) -> AST.OpenExp env1 aenv1 c + cvtPrj PairIdxLeft (AST.Pair a _) = a + cvtPrj PairIdxRight (AST.Pair _ b) = b + cvtPrj ix a + | DeclareVars lhs _ value <- declareVars $ AST.expType a + = AST.Let lhs a $ cvtPrj ix $ evars $ value weakenId + + cvtA :: ScopedAcc a -> AST.OpenAcc aenv a cvtA = convertSharingAcc config alyt aenv - cvtT :: Tuple ScopedExp tup -> Tuple (AST.OpenExp env aenv) tup - cvtT = convertSharingTuple config lyt alyt env' aenv + cvtAvar :: ScopedAcc a -> AST.ArrayVar aenv a + cvtAvar a = case cvtA a of + AST.OpenAcc (AST.Avar var) -> var + _ -> $internalError "convertSharingExp" "Expected array computation in expression to be floated out" - cvtFun1 :: (Elt a, Elt b) => (Exp a -> ScopedExp b) -> AST.OpenFun env aenv (a -> b) - cvtFun1 f = Lam (Body (convertSharingExp config lyt' alyt env' aenv body)) - where - lyt' = incLayout lyt `PushLayout` ZeroIdx - body = f undefined + cvtFun1 :: TupleType a -> (SmartExp a -> ScopedExp b) -> AST.OpenFun env aenv (a -> b) + cvtFun1 tp f + | DeclareVars lhs k value <- declareVars tp = + let + lyt' = PushLayout (incLayout k lyt) lhs (value weakenId) + body = f undefined + in + Lam lhs $ Body $ convertSharingExp config lyt' alyt env' aenv body -- Push primitive function applications down through let bindings so that -- they are adjacent to their arguments. It looks a bit nicer this way. -- - cvtPrimFun :: (Elt a, Elt r) - => AST.PrimFun (a -> r) -> AST.OpenExp env' aenv' a -> AST.OpenExp env' aenv' r + cvtPrimFun :: AST.PrimFun (a -> r) -> AST.OpenExp env' aenv' a -> AST.OpenExp env' aenv' r cvtPrimFun f e = case e of - AST.Let bnd body -> AST.Let bnd (cvtPrimFun f body) - x -> AST.PrimApp f x - --- | Convert a tuple expression --- -convertSharingTuple - :: Config - -> Layout env env - -> ArrayLayout aenv aenv - -> [StableSharingExp] -- currently bound scalar sharing-variables - -> [StableSharingAcc] -- currently bound array sharing-variables - -> Tuple ScopedExp t - -> Tuple (AST.OpenExp env aenv) t -convertSharingTuple config lyt alyt env aenv tup = - case tup of - NilTup -> NilTup - SnocTup t e -> convertSharingTuple config lyt alyt env aenv t - `SnocTup` convertSharingExp config lyt alyt env aenv e + AST.Let lhs bnd body -> AST.Let lhs bnd (cvtPrimFun f body) + x -> AST.PrimApp f x -- | Convert a unary functions -- convertSharingFun1 - :: forall a b aenv. (Elt a, Elt b) - => Config + :: Config -> ArrayLayout aenv aenv -> [StableSharingAcc] -- currently bound array sharing-variables - -> (Exp a -> ScopedExp b) + -> TupleType a + -> (SmartExp a -> ScopedExp b) -> AST.Fun aenv (a -> b) -convertSharingFun1 config alyt aenv f = Lam (Body openF) - where - a = Exp undefined -- the 'tag' was already embedded in Phase 1 - lyt = EmptyLayout - `PushLayout` - (ZeroIdx :: Idx ((), a) a) - openF = convertSharingExp config lyt alyt [] aenv (f a) +convertSharingFun1 config alyt aenv tp f + | DeclareVars lhs _ value <- declareVars tp = + let + a = SmartExp undefined -- the 'tag' was already embedded in Phase 1 + lyt = PushLayout EmptyLayout lhs (value weakenId) + openF = convertSharingExp config lyt alyt [] aenv (f a) + in + Lam lhs (Body openF) -- | Convert a binary functions -- convertSharingFun2 - :: forall a b c aenv. (Elt a, Elt b, Elt c) - => Config + :: Config -> ArrayLayout aenv aenv -> [StableSharingAcc] -- currently bound array sharing-variables - -> (Exp a -> Exp b -> ScopedExp c) + -> TupleType a + -> TupleType b + -> (SmartExp a -> SmartExp b -> ScopedExp c) -> AST.Fun aenv (a -> b -> c) -convertSharingFun2 config alyt aenv f = Lam (Lam (Body openF)) - where - a = Exp undefined - b = Exp undefined - lyt = EmptyLayout - `PushLayout` - (SuccIdx ZeroIdx :: Idx (((), a), b) a) - `PushLayout` - (ZeroIdx :: Idx (((), a), b) b) - openF = convertSharingExp config lyt alyt [] aenv (f a b) +convertSharingFun2 config alyt aenv ta tb f + | DeclareVars lhs1 _ value1 <- declareVars ta + , DeclareVars lhs2 k2 value2 <- declareVars tb = + let + a = SmartExp undefined + b = SmartExp undefined + lyt1 = PushLayout EmptyLayout lhs1 (value1 k2) + lyt2 = PushLayout lyt1 lhs2 (value2 weakenId) + openF = convertSharingExp config lyt2 alyt [] aenv (f a b) + in + Lam lhs1 $ Lam lhs2 $ Body openF -- | Convert a unary stencil function -- convertSharingStencilFun1 - :: forall sh a stencil b aenv. (Elt a, Stencil sh a stencil, Elt b) - => Config - -> ScopedAcc (Array sh a) -- just passed to fix the type variables + :: Config -> ArrayLayout aenv aenv -> [StableSharingAcc] -- currently bound array sharing-variables - -> (stencil -> ScopedExp b) - -> AST.Fun aenv (StencilRepr sh stencil -> b) -convertSharingStencilFun1 config _ alyt aenv stencilFun = Lam (Body openStencilFun) - where - stencil = Exp undefined :: Exp (StencilRepr sh stencil) - lyt = EmptyLayout - `PushLayout` - (ZeroIdx :: Idx ((), StencilRepr sh stencil) - (StencilRepr sh stencil)) - - body = stencilFun (stencilPrj @sh @a stencil) - openStencilFun = convertSharingExp config lyt alyt [] aenv body + -> StencilR sh a stencil + -> (SmartExp stencil -> ScopedExp b) + -> AST.Fun aenv (stencil -> b) +convertSharingStencilFun1 config alyt aenv stencil stencilFun + = convertSharingFun1 config alyt aenv (stencilType stencil) stencilFun -- | Convert a binary stencil function -- convertSharingStencilFun2 - :: forall sh a b stencil1 stencil2 c aenv. - (Elt a, Stencil sh a stencil1, - Elt b, Stencil sh b stencil2, - Elt c) - => Config - -> ScopedAcc (Array sh a) -- just passed to fix the type variables - -> ScopedAcc (Array sh b) -- just passed to fix the type variables + :: Config -> ArrayLayout aenv aenv -> [StableSharingAcc] -- currently bound array sharing-variables - -> (stencil1 -> stencil2 -> ScopedExp c) - -> AST.Fun aenv (StencilRepr sh stencil1 -> StencilRepr sh stencil2 -> c) -convertSharingStencilFun2 config _ _ alyt aenv stencilFun = Lam (Lam (Body openStencilFun)) - where - stencil1 = Exp undefined :: Exp (StencilRepr sh stencil1) - stencil2 = Exp undefined :: Exp (StencilRepr sh stencil2) - lyt = EmptyLayout - `PushLayout` - (SuccIdx ZeroIdx :: Idx (((), StencilRepr sh stencil1), - StencilRepr sh stencil2) - (StencilRepr sh stencil1)) - `PushLayout` - (ZeroIdx :: Idx (((), StencilRepr sh stencil1), - StencilRepr sh stencil2) - (StencilRepr sh stencil2)) - - body = stencilFun (stencilPrj @sh @a stencil1) (stencilPrj @sh @b stencil2) - openStencilFun = convertSharingExp config lyt alyt [] aenv body + -> StencilR sh a stencil1 + -> StencilR sh b stencil2 + -> (SmartExp stencil1 -> SmartExp stencil2 -> ScopedExp c) + -> AST.Fun aenv (stencil1 -> stencil2 -> c) +convertSharingStencilFun2 config alyt aenv stencil1 stencil2 stencilFun + = convertSharingFun2 config alyt aenv (stencilType stencil1) (stencilType stencil2) stencilFun -- Sharing recovery @@ -968,15 +912,13 @@ convertSharingStencilFun2 config _ _ alyt aenv stencilFun = Lam (Lam (Body openS -- Opaque stable name for AST nodes — used to key the occurrence map. -- data StableASTName c where - StableASTName :: (Typeable c, Typeable t) => StableName (c t) -> StableASTName c + StableASTName :: StableName (c t) -> StableASTName c instance Show (StableASTName c) where show (StableASTName sn) = show $ hashStableName sn instance Eq (StableASTName c) where - StableASTName sn1 == StableASTName sn2 - | Just sn1' <- gcast sn1 = sn1' == sn2 - | otherwise = False + StableASTName sn1 == StableASTName sn2 = eqStableName sn1 sn2 instance Hashable (StableASTName c) where hashWithSalt s (StableASTName sn) = hashWithSalt s sn @@ -989,7 +931,7 @@ makeStableAST e = e `seq` makeStableName e data StableNameHeight t = StableNameHeight (StableName t) Int instance Eq (StableNameHeight t) where - (StableNameHeight sn1 _) == (StableNameHeight sn2 _) = sn1 == sn2 + (StableNameHeight sn1 _) == (StableNameHeight sn2 _) = eqStableName sn1 sn2 higherSNH :: StableNameHeight t1 -> StableNameHeight t2 -> Bool StableNameHeight _ h1 `higherSNH` StableNameHeight _ h2 = h1 > h2 @@ -1074,7 +1016,7 @@ lookupWithSharingAcc oc (StableSharingAcc (StableNameHeight sn _) _) -- Look up the occurrence map keyed by scalar expressions using a sharing expression. If an -- the key does not exist in the map, return an occurrence count of '1'. -- -lookupWithSharingExp :: OccMap Exp -> StableSharingExp -> Int +lookupWithSharingExp :: OccMap SmartExp -> StableSharingExp -> Int lookupWithSharingExp oc (StableSharingExp (StableNameHeight sn _) _) = lookupWithASTName oc (StableASTName sn) @@ -1084,33 +1026,44 @@ lookupWithSharingExp oc (StableSharingExp (StableNameHeight sn _) _) -- Stable name for 'SmartAcc' nodes including the height of the AST. -- -type StableAccName arrs = StableNameHeight (SmartAcc arrs) +type StableAccName t = StableNameHeight (SmartAcc t) -- Interleave sharing annotations into an array computation AST. Subtrees can be marked as being -- represented by variable (binding a shared subtree) using 'AvarSharing' and as being prefixed by -- a let binding (for a shared subtree) using 'AletSharing'. -- data SharingAcc acc exp arrs where - AvarSharing :: Typeable arrs - => StableAccName arrs -> SharingAcc acc exp arrs + AvarSharing :: StableAccName arrs -> ArraysR arrs -> SharingAcc acc exp arrs AletSharing :: StableSharingAcc -> acc arrs -> SharingAcc acc exp arrs - AccSharing :: Typeable arrs - => StableAccName arrs -> PreSmartAcc acc exp arrs -> SharingAcc acc exp arrs + AccSharing :: StableAccName arrs -> PreSmartAcc acc exp arrs -> SharingAcc acc exp arrs + +instance HasArraysRepr acc => HasArraysRepr (SharingAcc acc exp) where + arraysRepr (AvarSharing _ repr) = repr + arraysRepr (AletSharing _ acc) = arraysRepr acc + arraysRepr (AccSharing _ acc) = arraysRepr acc + -- Array expression with sharing but shared values have not been scoped; i.e. no let bindings. If -- the expression is rooted in a function, the list contains the tags of the variables bound by the -- immediate surrounding lambdas. data UnscopedAcc t = UnscopedAcc [Int] (SharingAcc UnscopedAcc RootExp t) +instance HasArraysRepr UnscopedAcc where + arraysRepr (UnscopedAcc _ acc) = arraysRepr acc + + -- Array expression with sharing. For expressions rooted in functions the list holds a sorted -- environment corresponding to the variables bound in the immediate surounding lambdas. data ScopedAcc t = ScopedAcc [StableSharingAcc] (SharingAcc ScopedAcc ScopedExp t) +instance HasArraysRepr ScopedAcc where + arraysRepr (ScopedAcc _ acc) = arraysRepr acc + + -- Stable name for an array computation associated with its sharing-annotated version. -- data StableSharingAcc where - StableSharingAcc :: Typeable arrs - => StableAccName arrs + StableSharingAcc :: StableAccName arrs -> SharingAcc ScopedAcc ScopedExp arrs -> StableSharingAcc @@ -1118,19 +1071,17 @@ instance Show StableSharingAcc where show (StableSharingAcc sn _) = show $ hashStableNameHeight sn instance Eq StableSharingAcc where - StableSharingAcc sn1 _ == StableSharingAcc sn2 _ - | Just sn1' <- gcast sn1 = sn1' == sn2 - | otherwise = False + StableSharingAcc (StableNameHeight sn1 _) _ == StableSharingAcc (StableNameHeight sn2 _) _ + = eqStableName sn1 sn2 higherSSA :: StableSharingAcc -> StableSharingAcc -> Bool StableSharingAcc sn1 _ `higherSSA` StableSharingAcc sn2 _ = sn1 `higherSNH` sn2 -- Test whether the given stable names matches an array computation with sharing. -- -matchStableAcc :: Typeable arrs => StableAccName arrs -> StableSharingAcc -> Bool -matchStableAcc sn1 (StableSharingAcc sn2 _) - | Just sn1' <- gcast sn1 = sn1' == sn2 - | otherwise = False +matchStableAcc :: StableAccName arrs -> StableSharingAcc -> Bool +matchStableAcc (StableNameHeight sn1 _) (StableSharingAcc (StableNameHeight sn2 _) _) + = eqStableName sn1 sn2 -- Dummy entry for environments to be used for unused variables. -- @@ -1143,58 +1094,64 @@ noStableAccName = unsafePerformIO $ StableNameHeight <$> makeStableName undefine -- Stable name for 'Exp' nodes including the height of the AST. -- -type StableExpName t = StableNameHeight (Exp t) +type StableExpName t = StableNameHeight (SmartExp t) -- Interleave sharing annotations into a scalar expressions AST in the same manner as 'SharingAcc' -- do for array computations. -- data SharingExp acc exp t where - VarSharing :: Elt t - => StableExpName t -> SharingExp acc exp t - LetSharing :: StableSharingExp -> exp t -> SharingExp acc exp t - ExpSharing :: Elt t - => StableExpName t -> PreExp acc exp t -> SharingExp acc exp t + VarSharing :: StableExpName t -> TupleType t -> SharingExp acc exp t + LetSharing :: StableSharingExp -> exp t -> SharingExp acc exp t + ExpSharing :: StableExpName t -> PreSmartExp acc exp t -> SharingExp acc exp t + +instance HasExpType exp => HasExpType (SharingExp acc exp) where + expType (VarSharing _ tp) = tp + expType (LetSharing _ exp) = expType exp + expType (ExpSharing _ exp) = expType exp -- Specifies a scalar expression AST with sharing annotations but no scoping; i.e. no LetSharing -- constructors. If the expression is rooted in a function, the list contains the tags of the -- variables bound by the immediate surrounding lambdas. data UnscopedExp t = UnscopedExp [Int] (SharingExp UnscopedAcc UnscopedExp t) +instance HasExpType UnscopedExp where + expType (UnscopedExp _ exp) = expType exp + -- Specifies a scalar expression AST with sharing. For expressions rooted in functions the list -- holds a sorted environment corresponding to the variables bound in the immediate surounding -- lambdas. data ScopedExp t = ScopedExp [StableSharingExp] (SharingExp ScopedAcc ScopedExp t) +instance HasExpType ScopedExp where + expType (ScopedExp _ exp) = expType exp + -- Expressions rooted in 'SmartAcc' computations. -- -- * When counting occurrences, the root of every expression embedded in an 'SmartAcc' is annotated by -- an occurrence map for that one expression (excluding any subterms that are rooted in embedded -- 'SmartAcc's.) -- -data RootExp t = RootExp (OccMap Exp) (UnscopedExp t) +data RootExp t = RootExp (OccMap SmartExp) (UnscopedExp t) -- Stable name for an expression associated with its sharing-annotated version. -- data StableSharingExp where - StableSharingExp :: Elt t => StableExpName t -> SharingExp ScopedAcc ScopedExp t -> StableSharingExp + StableSharingExp :: StableExpName t -> SharingExp ScopedAcc ScopedExp t -> StableSharingExp instance Show StableSharingExp where show (StableSharingExp sn _) = show $ hashStableNameHeight sn instance Eq StableSharingExp where - StableSharingExp sn1 _ == StableSharingExp sn2 _ - | Just sn1' <- gcast sn1 = sn1' == sn2 - | otherwise = False + StableSharingExp (StableNameHeight sn1 _) _ == StableSharingExp (StableNameHeight sn2 _) _ = + eqStableName sn1 sn2 higherSSE :: StableSharingExp -> StableSharingExp -> Bool StableSharingExp sn1 _ `higherSSE` StableSharingExp sn2 _ = sn1 `higherSNH` sn2 -- Test whether the given stable names matches an expression with sharing. -- -matchStableExp :: Typeable t => StableExpName t -> StableSharingExp -> Bool -matchStableExp sn1 (StableSharingExp sn2 _) - | Just sn1' <- gcast sn1 = sn1' == sn2 - | otherwise = False +matchStableExp :: StableExpName t -> StableSharingExp -> Bool +matchStableExp (StableNameHeight sn1 _) (StableSharingExp (StableNameHeight sn2 _) _) = eqStableName sn1 sn2 -- Dummy entry for environments to be used for unused variables. -- @@ -1295,8 +1252,7 @@ matchStableSeq sn1 (StableSharingSeq sn2 _) -- They are /not/ directly used to compute the de Brujin indices. -- makeOccMapAcc - :: Typeable arrs - => Config + :: Config -> Level -> SmartAcc arrs -> IO (UnscopedAcc arrs, OccMap SmartAcc) @@ -1310,41 +1266,42 @@ makeOccMapAcc config lvl acc = do makeOccMapSharingAcc - :: Typeable arrs - => Config + :: Config -> OccMapHash SmartAcc -> Level -> SmartAcc arrs -> IO (UnscopedAcc arrs, Int) makeOccMapSharingAcc config accOccMap = traverseAcc where - traverseFun1 :: (Elt a, Typeable b) => Level -> (Exp a -> Exp b) -> IO (Exp a -> RootExp b, Int) + traverseFun1 :: Level -> TupleType a -> (SmartExp a -> SmartExp b) -> IO (SmartExp a -> RootExp b, Int) traverseFun1 = makeOccMapFun1 config accOccMap - traverseFun2 :: (Elt a, Elt b, Typeable c) - => Level - -> (Exp a -> Exp b -> Exp c) - -> IO (Exp a -> Exp b -> RootExp c, Int) + traverseFun2 :: Level + -> TupleType a + -> TupleType b + -> (SmartExp a -> SmartExp b -> SmartExp c) + -> IO (SmartExp a -> SmartExp b -> RootExp c, Int) traverseFun2 = makeOccMapFun2 config accOccMap - traverseAfun1 :: (Typeable a, Typeable b) => Level -> (SmartAcc a -> SmartAcc b) -> IO (SmartAcc a -> UnscopedAcc b, Int) + traverseAfun1 :: Level -> ArraysR a -> (SmartAcc a -> SmartAcc b) -> IO (SmartAcc a -> UnscopedAcc b, Int) traverseAfun1 = makeOccMapAfun1 config accOccMap - traverseExp :: Typeable e => Level -> Exp e -> IO (RootExp e, Int) + traverseExp :: Level -> SmartExp e -> IO (RootExp e, Int) traverseExp = makeOccMapExp config accOccMap traverseBoundary :: Level - -> PreBoundary SmartAcc Exp t - -> IO (PreBoundary UnscopedAcc RootExp t, Int) - traverseBoundary lvl bndy = + -> ShapeR sh + -> PreBoundary SmartAcc SmartExp (Array sh e) + -> IO (PreBoundary UnscopedAcc RootExp (Array sh e), Int) + traverseBoundary lvl shr bndy = case bndy of Clamp -> return (Clamp, 0) Mirror -> return (Mirror, 0) Wrap -> return (Wrap, 0) Constant v -> return (Constant v, 0) Function f -> do - (f', h) <- traverseFun1 lvl f + (f', h) <- traverseFun1 lvl (shapeType shr) f return (Function f', h) -- traverseSeq :: forall arrs. Typeable arrs @@ -1352,7 +1309,7 @@ makeOccMapSharingAcc config accOccMap = traverseAcc -- -> IO (RootSeq arrs, Int) -- traverseSeq = makeOccMapRootSeq config accOccMap - traverseAcc :: forall arrs. Typeable arrs => Level -> SmartAcc arrs -> IO (UnscopedAcc arrs, Int) + traverseAcc :: forall arrs. Level -> SmartAcc arrs -> IO (UnscopedAcc arrs, Int) traverseAcc lvl acc@(SmartAcc pacc) = mfix $ \ ~(_, height) -> do -- Compute stable name and enter it into the occurrence map @@ -1371,169 +1328,168 @@ makeOccMapSharingAcc config accOccMap = traverseAcc -- In case of a repeated occurrence, the height comes from the occurrence map; otherwise -- it is computed by the traversal function passed in 'newAcc'. See also 'enterOcc'. -- - -- NB: This function can only be used in the case alternatives below; outside of the - -- case we cannot discharge the 'Arrays arrs' constraint. - -- let reconstruct :: IO (PreSmartAcc UnscopedAcc RootExp arrs, Int) -> IO (UnscopedAcc arrs, Int) reconstruct newAcc = case heightIfRepeatedOccurrence of Just height | acc_sharing `member` options config - -> return (UnscopedAcc [] (AvarSharing (StableNameHeight sn height)), height) + -> return (UnscopedAcc [] (AvarSharing (StableNameHeight sn height) (arraysRepr pacc)), height) _ -> do (acc, height) <- newAcc return (UnscopedAcc [] (AccSharing (StableNameHeight sn height) acc), height) - case pacc of - Atag i -> reconstruct $ return (Atag i, 0) -- height is 0! - Pipe repr1 repr2 afun1 afun2 acc - -> reconstruct $ do - (afun1', h1) <- traverseAfun1 lvl afun1 - (afun2', h2) <- traverseAfun1 lvl afun2 + reconstruct $ case pacc of + Atag repr i -> return (Atag repr i, 0) -- height is 0! + Pipe repr1 repr2 repr3 afun1 afun2 acc + -> do + (afun1', h1) <- traverseAfun1 lvl repr1 afun1 + (afun2', h2) <- traverseAfun1 lvl repr2 afun2 (acc', h3) <- traverseAcc lvl acc - return (Pipe repr1 repr2 afun1' afun2' acc' + return (Pipe repr1 repr2 repr3 afun1' afun2' acc' , h1 `max` h2 `max` h3 + 1) - Aforeign ff afun acc -> reconstruct $ travA (Aforeign ff afun) acc - Acond e acc1 acc2 -> reconstruct $ do + Aforeign repr ff afun acc -> travA (Aforeign repr ff afun) acc + Acond e acc1 acc2 -> do (e' , h1) <- traverseExp lvl e (acc1', h2) <- traverseAcc lvl acc1 (acc2', h3) <- traverseAcc lvl acc2 return (Acond e' acc1' acc2', h1 `max` h2 `max` h3 + 1) - Awhile repr pred iter init -> reconstruct $ do - (pred', h1) <- traverseAfun1 lvl pred - (iter', h2) <- traverseAfun1 lvl iter + Awhile repr pred iter init -> do + (pred', h1) <- traverseAfun1 lvl repr pred + (iter', h2) <- traverseAfun1 lvl repr iter (init', h3) <- traverseAcc lvl init return (Awhile repr pred' iter' init' , h1 `max` h2 `max` h3 + 1) - Anil -> reconstruct $ return (Anil, 0) - Apair acc1 acc2 -> reconstruct $ do + Anil -> return (Anil, 0) + Apair acc1 acc2 -> do (a', h1) <- traverseAcc lvl acc1 (b', h2) <- traverseAcc lvl acc2 return (Apair a' b', h1 `max` h2 + 1) - Aprj ix a -> reconstruct $ travA (Aprj ix) a + Aprj ix a -> travA (Aprj ix) a - Use arr -> reconstruct $ return (Use arr, 1) - Unit e -> reconstruct $ do + Use repr arr -> return (Use repr arr, 1) + Unit tp e -> do (e', h) <- traverseExp lvl e - return (Unit e', h + 1) - Generate e f -> reconstruct $ do + return (Unit tp e', h + 1) + Generate repr@(ArrayR shr _) e f + -> do (e', h1) <- traverseExp lvl e - (f', h2) <- traverseFun1 lvl f - return (Generate e' f', h1 `max` h2 + 1) - Reshape e acc -> reconstruct $ travEA Reshape e acc - Replicate e acc -> reconstruct $ travEA Replicate e acc - Slice acc e -> reconstruct $ travEA (flip Slice) e acc - Map f acc -> reconstruct $ do - (f' , h1) <- traverseFun1 lvl f + (f', h2) <- traverseFun1 lvl (shapeType shr) f + return (Generate repr e' f', h1 `max` h2 + 1) + Reshape shr e acc -> travEA (Reshape shr) e acc + Replicate si e acc -> travEA (Replicate si) e acc + Slice si acc e -> travEA (flip $ Slice si) e acc + Map t1 t2 f acc -> do + (f' , h1) <- traverseFun1 lvl t1 f (acc', h2) <- traverseAcc lvl acc - return (Map f' acc', h1 `max` h2 + 1) - ZipWith f acc1 acc2 -> reconstruct $ travF2A2 ZipWith f acc1 acc2 - Fold f e acc -> reconstruct $ travF2EA Fold f e acc - Fold1 f acc -> reconstruct $ travF2A Fold1 f acc - FoldSeg f e acc1 acc2 -> reconstruct $ do - (f' , h1) <- traverseFun2 lvl f + return (Map t1 t2 f' acc', h1 `max` h2 + 1) + ZipWith t1 t2 t3 f acc1 acc2 + -> travF2A2 (ZipWith t1 t2 t3) t1 t2 f acc1 acc2 + Fold tp f e acc -> travF2EA (Fold tp) tp tp f e acc + Fold1 tp f acc -> travF2A (Fold1 tp) tp tp f acc + FoldSeg i tp f e acc1 acc2 -> do + (f' , h1) <- traverseFun2 lvl tp tp f (e' , h2) <- traverseExp lvl e (acc1', h3) <- traverseAcc lvl acc1 (acc2', h4) <- traverseAcc lvl acc2 - return (FoldSeg f' e' acc1' acc2', + return (FoldSeg i tp f' e' acc1' acc2', h1 `max` h2 `max` h3 `max` h4 + 1) - Fold1Seg f acc1 acc2 -> reconstruct $ travF2A2 Fold1Seg f acc1 acc2 - Scanl f e acc -> reconstruct $ travF2EA Scanl f e acc - Scanl' f e acc -> reconstruct $ travF2EA Scanl' f e acc - Scanl1 f acc -> reconstruct $ travF2A Scanl1 f acc - Scanr f e acc -> reconstruct $ travF2EA Scanr f e acc - Scanr' f e acc -> reconstruct $ travF2EA Scanr' f e acc - Scanr1 f acc -> reconstruct $ travF2A Scanr1 f acc - Permute c acc1 p acc2 -> reconstruct $ do - (c' , h1) <- traverseFun2 lvl c - (p' , h2) <- traverseFun1 lvl p + Fold1Seg i tp f acc1 acc2 -> travF2A2 (Fold1Seg i tp) tp tp f acc1 acc2 + Scanl tp f e acc -> travF2EA (Scanl tp) tp tp f e acc + Scanl' tp f e acc -> travF2EA (Scanl' tp) tp tp f e acc + Scanl1 tp f acc -> travF2A (Scanl1 tp) tp tp f acc + Scanr tp f e acc -> travF2EA (Scanr tp) tp tp f e acc + Scanr' tp f e acc -> travF2EA (Scanr' tp) tp tp f e acc + Scanr1 tp f acc -> travF2A (Scanr1 tp) tp tp f acc + Permute repr@(ArrayR shr tp) c acc1 p acc2 + -> do + (c' , h1) <- traverseFun2 lvl tp tp c + (p' , h2) <- traverseFun1 lvl (shapeType shr) p (acc1', h3) <- traverseAcc lvl acc1 (acc2', h4) <- traverseAcc lvl acc2 - return (Permute c' acc1' p' acc2', + return (Permute repr c' acc1' p' acc2', h1 `max` h2 `max` h3 `max` h4 + 1) - Backpermute e p acc -> reconstruct $ do + Backpermute shr e p acc -> do (e' , h1) <- traverseExp lvl e - (p' , h2) <- traverseFun1 lvl p + (p' , h2) <- traverseFun1 lvl (shapeType shr) p (acc', h3) <- traverseAcc lvl acc - return (Backpermute e' p' acc', h1 `max` h2 `max` h3 + 1) - Stencil s bnd acc -> reconstruct $ do - (s' , h1) <- makeOccMapStencil1 config accOccMap acc lvl s - (bnd', h2) <- traverseBoundary lvl bnd + return (Backpermute shr e' p' acc', h1 `max` h2 `max` h3 + 1) + Stencil s tp f bnd acc -> do + (f' , h1) <- makeOccMapStencil1 config accOccMap s lvl f + (bnd', h2) <- traverseBoundary lvl (stencilShape s) bnd (acc', h3) <- traverseAcc lvl acc - return (Stencil s' bnd' acc', h1 `max` h2 `max` h3 + 1) - Stencil2 s bnd1 acc1 - bnd2 acc2 -> reconstruct $ do - (s' , h1) <- makeOccMapStencil2 config accOccMap acc1 acc2 lvl s - (bnd1', h2) <- traverseBoundary lvl bnd1 + return (Stencil s tp f' bnd' acc', h1 `max` h2 `max` h3 + 1) + Stencil2 s1 s2 tp f bnd1 acc1 + bnd2 acc2 -> do + let shr = stencilShape s1 + (f' , h1) <- makeOccMapStencil2 config accOccMap s1 s2 lvl f + (bnd1', h2) <- traverseBoundary lvl shr bnd1 (acc1', h3) <- traverseAcc lvl acc1 - (bnd2', h4) <- traverseBoundary lvl bnd2 + (bnd2', h4) <- traverseBoundary lvl shr bnd2 (acc2', h5) <- traverseAcc lvl acc2 - return (Stencil2 s' bnd1' acc1' bnd2' acc2', + return (Stencil2 s1 s2 tp f' bnd1' acc1' bnd2' acc2', h1 `max` h2 `max` h3 `max` h4 `max` h5 + 1) - -- Collect s -> reconstruct $ do + -- Collect s -> do -- (s', h) <- traverseSeq lvl s -- return (Collect s', h + 1) where - travA :: Typeable arrs' - => (UnscopedAcc arrs' -> PreSmartAcc UnscopedAcc RootExp arrs) + travA :: (UnscopedAcc arrs' -> PreSmartAcc UnscopedAcc RootExp arrs) -> SmartAcc arrs' -> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int) travA c acc = do (acc', h) <- traverseAcc lvl acc return (c acc', h + 1) - travEA :: (Typeable arrs', Typeable b) - => (RootExp b -> UnscopedAcc arrs' -> PreSmartAcc UnscopedAcc RootExp arrs) - -> Exp b -> SmartAcc arrs' -> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int) + travEA :: (RootExp b -> UnscopedAcc arrs' -> PreSmartAcc UnscopedAcc RootExp arrs) + -> SmartExp b -> SmartAcc arrs' -> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int) travEA c exp acc = do (exp', h1) <- traverseExp lvl exp (acc', h2) <- traverseAcc lvl acc return (c exp' acc', h1 `max` h2 + 1) - travF2A :: (Elt b, Elt c, Typeable d, Arrays arrs') - => ((Exp b -> Exp c -> RootExp d) -> UnscopedAcc arrs' + travF2A :: ((SmartExp b -> SmartExp c -> RootExp d) -> UnscopedAcc arrs' -> PreSmartAcc UnscopedAcc RootExp arrs) - -> (Exp b -> Exp c -> Exp d) -> SmartAcc arrs' + -> TupleType b -> TupleType c + -> (SmartExp b -> SmartExp c -> SmartExp d) -> SmartAcc arrs' -> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int) - travF2A c fun acc + travF2A c t1 t2 fun acc = do - (fun', h1) <- traverseFun2 lvl fun + (fun', h1) <- traverseFun2 lvl t1 t2 fun (acc', h2) <- traverseAcc lvl acc return (c fun' acc', h1 `max` h2 + 1) - travF2EA :: (Elt b, Elt c, Typeable d, Typeable e, Arrays arrs') - => ((Exp b -> Exp c -> RootExp d) -> RootExp e -> UnscopedAcc arrs' -> PreSmartAcc UnscopedAcc RootExp arrs) - -> (Exp b -> Exp c -> Exp d) -> Exp e -> SmartAcc arrs' + travF2EA :: ((SmartExp b -> SmartExp c -> RootExp d) -> RootExp e -> UnscopedAcc arrs' -> PreSmartAcc UnscopedAcc RootExp arrs) + -> TupleType b -> TupleType c + -> (SmartExp b -> SmartExp c -> SmartExp d) -> SmartExp e -> SmartAcc arrs' -> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int) - travF2EA c fun exp acc + travF2EA c t1 t2 fun exp acc = do - (fun', h1) <- traverseFun2 lvl fun + (fun', h1) <- traverseFun2 lvl t1 t2 fun (exp', h2) <- traverseExp lvl exp (acc', h3) <- traverseAcc lvl acc return (c fun' exp' acc', h1 `max` h2 `max` h3 + 1) - travF2A2 :: (Elt b, Elt c, Typeable d, Arrays arrs1, Arrays arrs2) - => ((Exp b -> Exp c -> RootExp d) -> UnscopedAcc arrs1 -> UnscopedAcc arrs2 -> PreSmartAcc UnscopedAcc RootExp arrs) - -> (Exp b -> Exp c -> Exp d) -> SmartAcc arrs1 -> SmartAcc arrs2 + travF2A2 :: ((SmartExp b -> SmartExp c -> RootExp d) -> UnscopedAcc arrs1 -> UnscopedAcc arrs2 -> PreSmartAcc UnscopedAcc RootExp arrs) + -> TupleType b -> TupleType c + -> (SmartExp b -> SmartExp c -> SmartExp d) -> SmartAcc arrs1 -> SmartAcc arrs2 -> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int) - travF2A2 c fun acc1 acc2 + travF2A2 c t1 t2 fun acc1 acc2 = do - (fun' , h1) <- traverseFun2 lvl fun + (fun' , h1) <- traverseFun2 lvl t1 t2 fun (acc1', h2) <- traverseAcc lvl acc1 (acc2', h3) <- traverseAcc lvl acc2 return (c fun' acc1' acc2', h1 `max` h2 `max` h3 + 1) -makeOccMapAfun1 :: (Typeable a, Typeable b) - => Config +makeOccMapAfun1 :: Config -> OccMapHash SmartAcc -> Level + -> ArraysR a -> (SmartAcc a -> SmartAcc b) -> IO (SmartAcc a -> UnscopedAcc b, Int) -makeOccMapAfun1 config accOccMap lvl f = do - let x = SmartAcc (Atag lvl) +makeOccMapAfun1 config accOccMap lvl repr f = do + let x = SmartAcc (Atag repr lvl) -- (UnscopedAcc [] body, height) <- makeOccMapSharingAcc config accOccMap (lvl+1) (f x) return (const (UnscopedAcc [lvl] body), height) @@ -1573,72 +1529,69 @@ makeOccMapAfun3 config accOccMap lvl f = do -- See Note [Traversing functions and side effects] -- makeOccMapExp - :: Typeable e - => Config + :: Config -> OccMapHash SmartAcc -> Level - -> Exp e + -> SmartExp e -> IO (RootExp e, Int) makeOccMapExp config accOccMap lvl = makeOccMapRootExp config accOccMap lvl [] makeOccMapFun1 - :: (Elt a, Typeable b) - => Config + :: Config -> OccMapHash SmartAcc -> Level - -> (Exp a -> Exp b) - -> IO (Exp a -> RootExp b, Int) -makeOccMapFun1 config accOccMap lvl f = do - let x = Exp (Tag lvl) + -> TupleType a + -> (SmartExp a -> SmartExp b) + -> IO (SmartExp a -> RootExp b, Int) +makeOccMapFun1 config accOccMap lvl tp f = do + let x = SmartExp (Tag tp lvl) -- (body, height) <- makeOccMapRootExp config accOccMap (lvl+1) [lvl] (f x) return (const body, height) makeOccMapFun2 - :: (Elt a, Elt b, Typeable c) - => Config + :: Config -> OccMapHash SmartAcc -> Level - -> (Exp a -> Exp b -> Exp c) - -> IO (Exp a -> Exp b -> RootExp c, Int) -makeOccMapFun2 config accOccMap lvl f = do - let x = Exp (Tag (lvl+1)) - y = Exp (Tag lvl) + -> TupleType a + -> TupleType b + -> (SmartExp a -> SmartExp b -> SmartExp c) + -> IO (SmartExp a -> SmartExp b -> RootExp c, Int) +makeOccMapFun2 config accOccMap lvl t1 t2 f = do + let x = SmartExp (Tag t1 (lvl+1)) + y = SmartExp (Tag t2 lvl) -- (body, height) <- makeOccMapRootExp config accOccMap (lvl+2) [lvl, lvl+1] (f x y) return (\_ _ -> body, height) makeOccMapStencil1 - :: forall sh a b stencil. (Stencil sh a stencil, Typeable b) - => Config + :: forall sh a b stencil. + Config -> OccMapHash SmartAcc - -> SmartAcc (Array sh a) {- dummy -} + -> StencilR sh a stencil -> Level - -> (stencil -> Exp b) - -> IO (stencil -> RootExp b, Int) -makeOccMapStencil1 config accOccMap _ lvl stencil = do - let x = Exp (Tag lvl) - f = stencil . stencilPrj @sh @a + -> (SmartExp stencil -> SmartExp b) + -> IO (SmartExp stencil -> RootExp b, Int) +makeOccMapStencil1 config accOccMap s lvl stencil = do + let x = SmartExp (Tag (stencilType s) lvl) -- - (body, height) <- makeOccMapRootExp config accOccMap (lvl+1) [lvl] (f x) + (body, height) <- makeOccMapRootExp config accOccMap (lvl+1) [lvl] (stencil x) return (const body, height) makeOccMapStencil2 - :: forall sh a b c stencil1 stencil2. (Stencil sh a stencil1, Stencil sh b stencil2, Typeable c) - => Config + :: forall sh a b c stencil1 stencil2. + Config -> OccMapHash SmartAcc - -> SmartAcc (Array sh a) {- dummy -} - -> SmartAcc (Array sh b) {- dummy -} + -> StencilR sh a stencil1 + -> StencilR sh b stencil2 -> Level - -> (stencil1 -> stencil2 -> Exp c) - -> IO (stencil1 -> stencil2 -> RootExp c, Int) -makeOccMapStencil2 config accOccMap _ _ lvl stencil = do - let x = Exp (Tag (lvl+1)) - y = Exp (Tag lvl) - f a b = stencil (stencilPrj @sh @a a) - (stencilPrj @sh @b b) + -> (SmartExp stencil1 -> SmartExp stencil2 -> SmartExp c) + -> IO (SmartExp stencil1 -> SmartExp stencil2 -> RootExp c, Int) +makeOccMapStencil2 config accOccMap s1 s2 lvl stencil = do + let x = SmartExp (Tag (stencilType s1) (lvl+1)) + y = SmartExp (Tag (stencilType s2) lvl) -- - (body, height) <- makeOccMapRootExp config accOccMap (lvl+2) [lvl, lvl+1] (f x y) + (body, height) <- makeOccMapRootExp config accOccMap (lvl+2) [lvl, lvl+1] (stencil x y) return (\_ _ -> body, height) @@ -1649,12 +1602,11 @@ makeOccMapStencil2 config accOccMap _ _ lvl stencil = do -- 2) a local occurrence map for that expression. -- makeOccMapRootExp - :: Typeable e - => Config + :: Config -> OccMapHash SmartAcc -> Level -- The level of currently bound scalar variables -> [Int] -- The tags of newly introduced free scalar variables in this expression - -> Exp e + -> SmartExp e -> IO (RootExp e, Int) makeOccMapRootExp config accOccMap lvl fvs exp = do traceLine "makeOccMapRootExp" "Enter" @@ -1668,17 +1620,16 @@ makeOccMapRootExp config accOccMap lvl fvs exp = do -- Generate sharing information for an open scalar expression. -- makeOccMapSharingExp - :: Typeable e - => Config + :: Config -> OccMapHash SmartAcc - -> OccMapHash Exp + -> OccMapHash SmartExp -> Level -- The level of currently bound variables - -> Exp e + -> SmartExp e -> IO (UnscopedExp e, Int) makeOccMapSharingExp config accOccMap expOccMap = travE where - travE :: forall a. Typeable a => Level -> Exp a -> IO (UnscopedExp a, Int) - travE lvl exp@(Exp pexp) + travE :: forall a. Level -> SmartExp a -> IO (UnscopedExp a, Int) + travE lvl exp@(SmartExp pexp) = mfix $ \ ~(_, height) -> do -- Compute stable name and enter it into the occurrence map -- @@ -1696,89 +1647,77 @@ makeOccMapSharingExp config accOccMap expOccMap = travE -- In case of a repeated occurrence, the height comes from the occurrence map; otherwise -- it is computed by the traversal function passed in 'newExp'. See also 'enterOcc'. -- - -- NB: This function can only be used in the case alternatives below; outside of the - -- case we cannot discharge the 'Elt a' constraint. - -- - let reconstruct :: Elt a - => IO (PreExp UnscopedAcc UnscopedExp a, Int) + let reconstruct :: IO (PreSmartExp UnscopedAcc UnscopedExp a, Int) -> IO (UnscopedExp a, Int) reconstruct newExp = case heightIfRepeatedOccurrence of Just height | exp_sharing `member` options config - -> return (UnscopedExp [] (VarSharing (StableNameHeight sn height)), height) + -> return (UnscopedExp [] (VarSharing (StableNameHeight sn height) (expType pexp)), height) _ -> do (exp, height) <- newExp return (UnscopedExp [] (ExpSharing (StableNameHeight sn height) exp), height) - case pexp of - Tag i -> reconstruct $ return (Tag i, 0) -- height is 0! - Const c -> reconstruct $ return (Const c, 1) - Undef -> reconstruct $ return (Undef, 1) - Tuple tup -> reconstruct $ do - (tup', h) <- travTup tup - return (Tuple tup', h) - Prj i e -> reconstruct $ travE1 (Prj i) e - IndexNil -> reconstruct $ return (IndexNil, 1) - IndexCons ix i -> reconstruct $ travE2 IndexCons ix i - IndexHead i -> reconstruct $ travE1 IndexHead i - IndexTail ix -> reconstruct $ travE1 IndexTail ix - IndexAny -> reconstruct $ return (IndexAny, 1) - ToIndex sh ix -> reconstruct $ travE2 ToIndex sh ix - FromIndex sh e -> reconstruct $ travE2 FromIndex sh e - Cond e1 e2 e3 -> reconstruct $ travE3 Cond e1 e2 e3 - While p iter init -> reconstruct $ do - (p' , h1) <- traverseFun1 lvl p - (iter', h2) <- traverseFun1 lvl iter + reconstruct $ case pexp of + Tag tp i -> return (Tag tp i, 0) -- height is 0! + Const tp c -> return (Const tp c, 1) + Undef tp -> return (Undef tp, 1) + Nil -> return (Nil, 1) + Pair e1 e2 -> travE2 Pair e1 e2 + Prj i e -> travE1 (Prj i) e + VecPack vec e -> travE1 (VecPack vec) e + VecUnpack vec e -> travE1 (VecUnpack vec) e + ToIndex shr sh ix -> travE2 (ToIndex shr) sh ix + FromIndex shr sh e -> travE2 (FromIndex shr) sh e + Cond e1 e2 e3 -> travE3 Cond e1 e2 e3 + While t p iter init -> do + (p' , h1) <- traverseFun1 lvl t p + (iter', h2) <- traverseFun1 lvl t iter (init', h3) <- travE lvl init - return (While p' iter' init', h1 `max` h2 `max` h3 + 1) - PrimConst c -> reconstruct $ return (PrimConst c, 1) - PrimApp p e -> reconstruct $ travE1 (PrimApp p) e - Index a e -> reconstruct $ travAE Index a e - LinearIndex a i -> reconstruct $ travAE LinearIndex a i - Shape a -> reconstruct $ travA Shape a - ShapeSize e -> reconstruct $ travE1 ShapeSize e - Intersect sh1 sh2 -> reconstruct $ travE2 Intersect sh1 sh2 - Union sh1 sh2 -> reconstruct $ travE2 Union sh1 sh2 - Foreign ff f e -> reconstruct $ do + return (While t p' iter' init', h1 `max` h2 `max` h3 + 1) + PrimConst c -> return (PrimConst c, 1) + PrimApp p e -> travE1 (PrimApp p) e + Index tp a e -> travAE (Index tp) a e + LinearIndex tp a i -> travAE (LinearIndex tp) a i + Shape shr a -> travA (Shape shr) a + ShapeSize shr e -> travE1 (ShapeSize shr) e + Foreign tp ff f e -> do (e', h) <- travE lvl e - return (Foreign ff f e', h+1) - Coerce e -> reconstruct $ travE1 Coerce e + return (Foreign tp ff f e', h+1) + Coerce t1 t2 e -> travE1 (Coerce t1 t2) e where - traverseAcc :: Typeable arrs => Level -> SmartAcc arrs -> IO (UnscopedAcc arrs, Int) + traverseAcc :: Level -> SmartAcc arrs -> IO (UnscopedAcc arrs, Int) traverseAcc = makeOccMapSharingAcc config accOccMap - traverseFun1 :: (Elt a, Typeable b) - => Level - -> (Exp a -> Exp b) - -> IO (Exp a -> UnscopedExp b, Int) - traverseFun1 lvl f + traverseFun1 :: Level + -> TupleType a + -> (SmartExp a -> SmartExp b) + -> IO (SmartExp a -> UnscopedExp b, Int) + traverseFun1 lvl tp f = do - let x = Exp (Tag lvl) + let x = SmartExp (Tag tp lvl) (UnscopedExp [] body, height) <- travE (lvl+1) (f x) return (const (UnscopedExp [lvl] body), height + 1) - travE1 :: Typeable b => (UnscopedExp b -> PreExp UnscopedAcc UnscopedExp a) -> Exp b - -> IO (PreExp UnscopedAcc UnscopedExp a, Int) + travE1 :: (UnscopedExp b -> PreSmartExp UnscopedAcc UnscopedExp a) -> SmartExp b + -> IO (PreSmartExp UnscopedAcc UnscopedExp a, Int) travE1 c e = do (e', h) <- travE lvl e return (c e', h + 1) - travE2 :: (Typeable b, Typeable c) - => (UnscopedExp b -> UnscopedExp c -> PreExp UnscopedAcc UnscopedExp a) - -> Exp b -> Exp c - -> IO (PreExp UnscopedAcc UnscopedExp a, Int) + travE2 :: (UnscopedExp b -> UnscopedExp c -> PreSmartExp UnscopedAcc UnscopedExp a) + -> SmartExp b -> SmartExp c + -> IO (PreSmartExp UnscopedAcc UnscopedExp a, Int) travE2 c e1 e2 = do (e1', h1) <- travE lvl e1 (e2', h2) <- travE lvl e2 return (c e1' e2', h1 `max` h2 + 1) - travE3 :: (Typeable b, Typeable c, Typeable d) - => (UnscopedExp b -> UnscopedExp c -> UnscopedExp d -> PreExp UnscopedAcc UnscopedExp a) - -> Exp b -> Exp c -> Exp d - -> IO (PreExp UnscopedAcc UnscopedExp a, Int) + travE3 :: (UnscopedExp b -> UnscopedExp c -> UnscopedExp d -> PreSmartExp UnscopedAcc UnscopedExp a) + -> SmartExp b -> SmartExp c -> SmartExp d + -> IO (PreSmartExp UnscopedAcc UnscopedExp a, Int) travE3 c e1 e2 e3 = do (e1', h1) <- travE lvl e1 @@ -1786,31 +1725,22 @@ makeOccMapSharingExp config accOccMap expOccMap = travE (e3', h3) <- travE lvl e3 return (c e1' e2' e3', h1 `max` h2 `max` h3 + 1) - travA :: Typeable b => (UnscopedAcc b -> PreExp UnscopedAcc UnscopedExp a) -> SmartAcc b - -> IO (PreExp UnscopedAcc UnscopedExp a, Int) + travA :: (UnscopedAcc b -> PreSmartExp UnscopedAcc UnscopedExp a) -> SmartAcc b + -> IO (PreSmartExp UnscopedAcc UnscopedExp a, Int) travA c acc = do (acc', h) <- traverseAcc lvl acc return (c acc', h + 1) - travAE :: (Typeable b, Typeable c) - => (UnscopedAcc b -> UnscopedExp c -> PreExp UnscopedAcc UnscopedExp a) - -> SmartAcc b -> Exp c - -> IO (PreExp UnscopedAcc UnscopedExp a, Int) + travAE :: (UnscopedAcc b -> UnscopedExp c -> PreSmartExp UnscopedAcc UnscopedExp a) + -> SmartAcc b -> SmartExp c + -> IO (PreSmartExp UnscopedAcc UnscopedExp a, Int) travAE c acc e = do (acc', h1) <- traverseAcc lvl acc (e' , h2) <- travE lvl e return (c acc' e', h1 `max` h2 + 1) - travTup :: Tuple Exp tup -> IO (Tuple UnscopedExp tup, Int) - travTup NilTup = return (NilTup, 1) - travTup (SnocTup tup e) = do - (tup', h1) <- travTup tup - (e' , h2) <- travE lvl e - return (SnocTup tup' e', h1 `max` h2 + 1) - - {-- makeOccMapRootSeq :: Typeable arrs @@ -1971,11 +1901,10 @@ makeOccMapSharingSeq config accOccMap seqOccMap = traverseSeq type NodeCounts = ([NodeCount], Map.HashMap NodeName (Set.HashSet NodeName)) data NodeName where - NodeName :: Typeable a => StableName a -> NodeName + NodeName :: StableName a -> NodeName instance Eq NodeName where - (NodeName sn1) == (NodeName sn2) | Just sn2' <- gcast sn2 = sn1 == sn2' - | otherwise = False + (NodeName sn1) == (NodeName sn2) = eqStableName sn1 sn2 instance Hashable NodeName where hashWithSalt hash (NodeName sn1) = hash + hashStableName sn1 @@ -2082,11 +2011,11 @@ nodeName (ExpNodeCount (StableSharingExp (StableNameHeight sn _) _) _) = NodeNam -- insert x@(SeqNodeCount _ _) (y@(AccNodeCount _ _) : ys') -- = x : insert y ys' - (StableSharingAcc _ (AvarSharing _)) `pickNoneAvar` sa2 = sa2 - sa1 `pickNoneAvar` _sa2 = sa1 + (StableSharingAcc _ (AvarSharing _ _)) `pickNoneAvar` sa2 = sa2 + sa1 `pickNoneAvar` _sa2 = sa1 - (StableSharingExp _ (VarSharing _)) `pickNoneVar` sa2 = sa2 - sa1 `pickNoneVar` _sa2 = sa1 + (StableSharingExp _ (VarSharing _ _)) `pickNoneVar` sa2 = sa2 + sa1 `pickNoneVar` _sa2 = sa1 -- pickNoneSvar :: StableSharingSeq -> StableSharingSeq -> StableSharingSeq -- (StableSharingSeq _ (SvarSharing _)) `pickNoneSvar` sa2 = sa2 @@ -2110,7 +2039,7 @@ buildInitialEnvAcc tags sas = map (lookupSA sas) tags sas2 -> $internalError "buildInitialEnvAcc" $ "Encountered duplicate 'ATag's\n " ++ intercalate ", " (map showSA sas2) where - hasTag (StableSharingAcc _ (AccSharing _ (Atag tag2))) = tag1 == tag2 + hasTag (StableSharingAcc _ (AccSharing _ (Atag _ tag2))) = tag1 == tag2 hasTag sa = $internalError "buildInitialEnvAcc" $ "Encountered a node that is not a plain 'Atag'\n " ++ showSA sa @@ -2120,8 +2049,8 @@ buildInitialEnvAcc tags sas = map (lookupSA sas) tags showSA (StableSharingAcc _ (AccSharing sn acc)) = show (hashStableNameHeight sn) ++ ": " ++ showPreAccOp acc - showSA (StableSharingAcc _ (AvarSharing sn)) = "AvarSharing " ++ show (hashStableNameHeight sn) - showSA (StableSharingAcc _ (AletSharing sa _ )) = "AletSharing " ++ show sa ++ "..." + showSA (StableSharingAcc _ (AvarSharing sn _)) = "AvarSharing " ++ show (hashStableNameHeight sn) + showSA (StableSharingAcc _ (AletSharing sa _)) = "AletSharing " ++ show sa ++ "..." -- Build an initial environment for the tag values given in the first argument for traversing a -- scalar expression. The 'StableSharingExp's for all tags /actually used/ in the expressions are @@ -2141,7 +2070,7 @@ buildInitialEnvExp tags ses = map (lookupSE ses) tags ses2 -> $internalError "buildInitialEnvExp" ("Encountered a duplicate 'Tag'\n " ++ intercalate ", " (map showSE ses2)) where - hasTag (StableSharingExp _ (ExpSharing _ (Tag tag2))) = tag1 == tag2 + hasTag (StableSharingExp _ (ExpSharing _ (Tag _ tag2))) = tag1 == tag2 hasTag se = $internalError "buildInitialEnvExp" ("Encountered a node that is not a plain 'Tag'\n " ++ showSE se) @@ -2151,15 +2080,15 @@ buildInitialEnvExp tags ses = map (lookupSE ses) tags showSE (StableSharingExp _ (ExpSharing sn exp)) = show (hashStableNameHeight sn) ++ ": " ++ showPreExpOp exp - showSE (StableSharingExp _ (VarSharing sn)) = "VarSharing " ++ show (hashStableNameHeight sn) + showSE (StableSharingExp _ (VarSharing sn _ )) = "VarSharing " ++ show (hashStableNameHeight sn) showSE (StableSharingExp _ (LetSharing se _ )) = "LetSharing " ++ show se ++ "..." -- Determine whether a 'NodeCount' is for an 'Atag' or 'Tag', which represent free variables. -- isFreeVar :: NodeCount -> Bool -isFreeVar (AccNodeCount (StableSharingAcc _ (AccSharing _ (Atag _))) _) = True -isFreeVar (ExpNodeCount (StableSharingExp _ (ExpSharing _ (Tag _))) _) = True -isFreeVar _ = False +isFreeVar (AccNodeCount (StableSharingAcc _ (AccSharing _ (Atag _ _))) _) = True +isFreeVar (ExpNodeCount (StableSharingExp _ (ExpSharing _ (Tag _ _))) _) = True +isFreeVar _ = False -- Determine scope of shared subterms @@ -2177,8 +2106,7 @@ isFreeVar _ = False -- Precondition: there are only 'AvarSharing' and 'AccSharing' nodes in the argument. -- determineScopesAcc - :: Typeable a - => Config + :: Config -> [Level] -> OccMap SmartAcc -> UnscopedAcc a @@ -2203,25 +2131,25 @@ determineScopesSharingAcc config accOccMap = scopesAcc scopesAcc (UnscopedAcc _ (AletSharing _ _)) = $internalError "determineScopesSharingAcc: scopesAcc" "unexpected 'AletSharing'" - scopesAcc (UnscopedAcc _ (AvarSharing sn)) - = (ScopedAcc [] (AvarSharing sn), StableSharingAcc sn (AvarSharing sn) `insertAccNode` noNodeCounts) + scopesAcc (UnscopedAcc _ (AvarSharing sn tp)) + = (ScopedAcc [] (AvarSharing sn tp), StableSharingAcc sn (AvarSharing sn tp) `insertAccNode` noNodeCounts) scopesAcc (UnscopedAcc _ (AccSharing sn pacc)) = case pacc of - Atag i -> reconstruct (Atag i) noNodeCounts - Pipe repr1 repr2 afun1 afun2 acc + Atag tp i -> reconstruct (Atag tp i) noNodeCounts + Pipe repr1 repr2 repr3 afun1 afun2 acc -> let (afun1', accCount1) = scopesAfun1 afun1 (afun2', accCount2) = scopesAfun1 afun2 (acc', accCount3) = scopesAcc acc in - reconstruct (Pipe repr1 repr2 afun1' afun2' acc') + reconstruct (Pipe repr1 repr2 repr3 afun1' afun2' acc') (accCount1 +++ accCount2 +++ accCount3) - Aforeign ff afun acc -> let + Aforeign r ff afun acc -> let (acc', accCount) = scopesAcc acc in - reconstruct (Aforeign ff afun acc') accCount + reconstruct (Aforeign r ff afun acc') accCount Acond e acc1 acc2 -> let (e' , accCount1) = scopesExp e (acc1', accCount2) = scopesAcc acc1 @@ -2247,64 +2175,67 @@ determineScopesSharingAcc config accOccMap = scopesAcc reconstruct (Apair a1' a2') (accCount1 +++ accCount2) Aprj ix a -> travA (Aprj ix) a - Use arr -> reconstruct (Use arr) noNodeCounts - Unit e -> let + Use repr arr -> reconstruct (Use repr arr) noNodeCounts + Unit tp e -> let (e', accCount) = scopesExp e in - reconstruct (Unit e') accCount - Generate sh f -> let + reconstruct (Unit tp e') accCount + Generate repr sh f -> let (sh', accCount1) = scopesExp sh (f' , accCount2) = scopesFun1 f in - reconstruct (Generate sh' f') (accCount1 +++ accCount2) - Reshape sh acc -> travEA Reshape sh acc - Replicate n acc -> travEA Replicate n acc - Slice acc i -> travEA (flip Slice) i acc - Map f acc -> let + reconstruct (Generate repr sh' f') (accCount1 +++ accCount2) + Reshape shr sh acc -> travEA (Reshape shr) sh acc + Replicate si n acc -> travEA (Replicate si) n acc + Slice si acc i -> travEA (flip $ Slice si) i acc + Map t1 t2 f acc -> let (f' , accCount1) = scopesFun1 f (acc', accCount2) = scopesAcc acc in - reconstruct (Map f' acc') (accCount1 +++ accCount2) - ZipWith f acc1 acc2 -> travF2A2 ZipWith f acc1 acc2 - Fold f z acc -> travF2EA Fold f z acc - Fold1 f acc -> travF2A Fold1 f acc - FoldSeg f z acc1 acc2 -> let + reconstruct (Map t1 t2 f' acc') (accCount1 +++ accCount2) + ZipWith t1 t2 t3 f acc1 acc2 + -> travF2A2 (ZipWith t1 t2 t3) f acc1 acc2 + Fold tp f z acc -> travF2EA (Fold tp) f z acc + Fold1 tp f acc -> travF2A (Fold1 tp) f acc + FoldSeg i tp f z acc1 acc2 -> let (f' , accCount1) = scopesFun2 f (z' , accCount2) = scopesExp z (acc1', accCount3) = scopesAcc acc1 (acc2', accCount4) = scopesAcc acc2 in - reconstruct (FoldSeg f' z' acc1' acc2') + reconstruct (FoldSeg i tp f' z' acc1' acc2') (accCount1 +++ accCount2 +++ accCount3 +++ accCount4) - Fold1Seg f acc1 acc2 -> travF2A2 Fold1Seg f acc1 acc2 - Scanl f z acc -> travF2EA Scanl f z acc - Scanl' f z acc -> travF2EA Scanl' f z acc - Scanl1 f acc -> travF2A Scanl1 f acc - Scanr f z acc -> travF2EA Scanr f z acc - Scanr' f z acc -> travF2EA Scanr' f z acc - Scanr1 f acc -> travF2A Scanr1 f acc - Permute fc acc1 fp acc2 -> let + Fold1Seg i tp f acc1 acc2 -> travF2A2 (Fold1Seg i tp) f acc1 acc2 + Scanl tp f z acc -> travF2EA (Scanl tp) f z acc + Scanl' tp f z acc -> travF2EA (Scanl' tp) f z acc + Scanl1 tp f acc -> travF2A (Scanl1 tp) f acc + Scanr tp f z acc -> travF2EA (Scanr tp) f z acc + Scanr' tp f z acc -> travF2EA (Scanr' tp) f z acc + Scanr1 tp f acc -> travF2A (Scanr1 tp) f acc + Permute repr fc acc1 fp acc2 + -> let (fc' , accCount1) = scopesFun2 fc (acc1', accCount2) = scopesAcc acc1 (fp' , accCount3) = scopesFun1 fp (acc2', accCount4) = scopesAcc acc2 in - reconstruct (Permute fc' acc1' fp' acc2') + reconstruct (Permute repr fc' acc1' fp' acc2') (accCount1 +++ accCount2 +++ accCount3 +++ accCount4) - Backpermute sh fp acc -> let + Backpermute shr sh fp acc + -> let (sh' , accCount1) = scopesExp sh (fp' , accCount2) = scopesFun1 fp (acc', accCount3) = scopesAcc acc in - reconstruct (Backpermute sh' fp' acc') + reconstruct (Backpermute shr sh' fp' acc') (accCount1 +++ accCount2 +++ accCount3) - Stencil st bnd acc -> let + Stencil sr tp st bnd acc -> let (st' , accCount1) = scopesStencil1 acc st (bnd', accCount2) = scopesBoundary bnd (acc', accCount3) = scopesAcc acc in - reconstruct (Stencil st' bnd' acc') (accCount1 +++ accCount2 +++ accCount3) - Stencil2 st bnd1 acc1 bnd2 acc2 + reconstruct (Stencil sr tp st' bnd' acc') (accCount1 +++ accCount2 +++ accCount3) + Stencil2 s1 s2 tp st bnd1 acc1 bnd2 acc2 -> let (st' , accCount1) = scopesStencil2 acc1 acc2 st (bnd1', accCount2) = scopesBoundary bnd1 @@ -2312,7 +2243,7 @@ determineScopesSharingAcc config accOccMap = scopesAcc (bnd2', accCount4) = scopesBoundary bnd2 (acc2', accCount5) = scopesAcc acc2 in - reconstruct (Stencil2 st' bnd1' acc1' bnd2' acc2') + reconstruct (Stencil2 s1 s2 tp st' bnd1' acc1' bnd2' acc2') (accCount1 +++ accCount2 +++ accCount3 +++ accCount4 +++ accCount5) -- Collect seq -> let -- (seq', accCount1) = scopesSeq seq @@ -2329,10 +2260,9 @@ determineScopesSharingAcc config accOccMap = scopesAcc (e' , accCount1) = scopesExp e (acc', accCount2) = scopesAcc acc - travF2A :: (Elt a, Elt b) - => ((Exp a -> Exp b -> ScopedExp c) -> ScopedAcc arrs' + travF2A :: ((SmartExp a -> SmartExp b -> ScopedExp c) -> ScopedAcc arrs' -> PreSmartAcc ScopedAcc ScopedExp arrs) - -> (Exp a -> Exp b -> RootExp c) + -> (SmartExp a -> SmartExp b -> RootExp c) -> UnscopedAcc arrs' -> (ScopedAcc arrs, NodeCounts) travF2A c f acc = reconstruct (c f' acc') (accCount1 +++ accCount2) @@ -2340,10 +2270,9 @@ determineScopesSharingAcc config accOccMap = scopesAcc (f' , accCount1) = scopesFun2 f (acc', accCount2) = scopesAcc acc - travF2EA :: (Elt a, Elt b) - => ((Exp a -> Exp b -> ScopedExp c) -> ScopedExp e + travF2EA :: ((SmartExp a -> SmartExp b -> ScopedExp c) -> ScopedExp e -> ScopedAcc arrs' -> PreSmartAcc ScopedAcc ScopedExp arrs) - -> (Exp a -> Exp b -> RootExp c) + -> (SmartExp a -> SmartExp b -> RootExp c) -> RootExp e -> UnscopedAcc arrs' -> (ScopedAcc arrs, NodeCounts) @@ -2353,10 +2282,9 @@ determineScopesSharingAcc config accOccMap = scopesAcc (e' , accCount2) = scopesExp e (acc', accCount3) = scopesAcc acc - travF2A2 :: (Elt a, Elt b) - => ((Exp a -> Exp b -> ScopedExp c) -> ScopedAcc arrs1 + travF2A2 :: ((SmartExp a -> SmartExp b -> ScopedExp c) -> ScopedAcc arrs1 -> ScopedAcc arrs2 -> PreSmartAcc ScopedAcc ScopedExp arrs) - -> (Exp a -> Exp b -> RootExp c) + -> (SmartExp a -> SmartExp b -> RootExp c) -> UnscopedAcc arrs1 -> UnscopedAcc arrs2 -> (ScopedAcc arrs, NodeCounts) @@ -2394,20 +2322,20 @@ determineScopesSharingAcc config accOccMap = scopesAcc reconstruct :: PreSmartAcc ScopedAcc ScopedExp arrs -> NodeCounts -> (ScopedAcc arrs, NodeCounts) - reconstruct newAcc@(Atag _) _subCount + reconstruct newAcc@(Atag tp _) _subCount -- free variable => replace by a sharing variable regardless of the number of -- occurrences = let thisCount = StableSharingAcc sn (AccSharing sn newAcc) `insertAccNode` noNodeCounts in tracePure "FREE" (show thisCount) - (ScopedAcc [] (AvarSharing sn), thisCount) + (ScopedAcc [] (AvarSharing sn tp), thisCount) reconstruct newAcc subCount -- shared subtree => replace by a sharing variable (if 'recoverAccSharing' enabled) | accOccCount > 1 && acc_sharing `member` options config = let allCount = (StableSharingAcc sn sharingAcc `insertAccNode` newCount) in tracePure ("SHARED" ++ completed) (show allCount) - (ScopedAcc [] (AvarSharing sn), allCount) + (ScopedAcc [] (AvarSharing sn $ arraysRepr newAcc), allCount) -- neither shared nor free variable => leave it as it is | otherwise = tracePure ("Normal" ++ completed) (show newCount) @@ -2471,13 +2399,13 @@ determineScopesSharingAcc config accOccMap = scopesAcc (freeCounts, counts') = partition isBoundHere counts ssa = buildInitialEnvAcc fvs [sa | AccNodeCount sa _ <- freeCounts] - isBoundHere (AccNodeCount (StableSharingAcc _ (AccSharing _ (Atag i))) _) = i `elem` fvs - isBoundHere _ = False + isBoundHere (AccNodeCount (StableSharingAcc _ (AccSharing _ (Atag _ i))) _) = i `elem` fvs + isBoundHere _ = False -- The lambda bound variable is at this point already irrelevant; for details, see -- Note [Traversing functions and side effects] -- - scopesFun1 :: Elt e1 => (Exp e1 -> RootExp e2) -> (Exp e1 -> ScopedExp e2, NodeCounts) + scopesFun1 :: (SmartExp e1 -> RootExp e2) -> (SmartExp e1 -> ScopedExp e2, NodeCounts) scopesFun1 f = (const body, counts) where (body, counts) = scopesExp (f undefined) @@ -2485,9 +2413,8 @@ determineScopesSharingAcc config accOccMap = scopesAcc -- The lambda bound variable is at this point already irrelevant; for details, see -- Note [Traversing functions and side effects] -- - scopesFun2 :: (Elt e1, Elt e2) - => (Exp e1 -> Exp e2 -> RootExp e3) - -> (Exp e1 -> Exp e2 -> ScopedExp e3, NodeCounts) + scopesFun2 :: (SmartExp e1 -> SmartExp e2 -> RootExp e3) + -> (SmartExp e1 -> SmartExp e2 -> ScopedExp e3, NodeCounts) scopesFun2 f = (\_ _ -> body, counts) where (body, counts) = scopesExp (f undefined undefined) @@ -2495,8 +2422,8 @@ determineScopesSharingAcc config accOccMap = scopesAcc -- The lambda bound variable is at this point already irrelevant; for details, see -- Note [Traversing functions and side effects] -- - scopesStencil1 :: forall sh e1 e2 stencil. Stencil sh e1 stencil - => UnscopedAcc (Array sh e1){-dummy-} + scopesStencil1 :: forall sh e1 e2 stencil. + UnscopedAcc (Array sh e1){-dummy-} -> (stencil -> RootExp e2) -> (stencil -> ScopedExp e2, NodeCounts) scopesStencil1 _ stencilFun = (const body, counts) @@ -2507,8 +2434,7 @@ determineScopesSharingAcc config accOccMap = scopesAcc -- Note [Traversing functions and side effects] -- scopesStencil2 :: forall sh e1 e2 e3 stencil1 stencil2. - (Stencil sh e1 stencil1, Stencil sh e2 stencil2) - => UnscopedAcc (Array sh e1){-dummy-} + UnscopedAcc (Array sh e1){-dummy-} -> UnscopedAcc (Array sh e2){-dummy-} -> (stencil1 -> stencil2 -> RootExp e3) -> (stencil1 -> stencil2 -> ScopedExp e3, NodeCounts) @@ -2547,7 +2473,7 @@ determineScopesExp config accOccMap (RootExp expOccMap exp@(UnscopedExp fvs _)) determineScopesSharingExp :: Config -> OccMap SmartAcc - -> OccMap Exp + -> OccMap SmartExp -> UnscopedExp t -> (ScopedExp t, NodeCounts) determineScopesSharingExp config accOccMap expOccMap = scopesExp @@ -2555,7 +2481,7 @@ determineScopesSharingExp config accOccMap expOccMap = scopesExp scopesAcc :: UnscopedAcc a -> (ScopedAcc a, NodeCounts) scopesAcc = determineScopesSharingAcc config accOccMap - scopesFun1 :: (Exp a -> UnscopedExp b) -> (Exp a -> ScopedExp b, NodeCounts) + scopesFun1 :: (SmartExp a -> UnscopedExp b) -> (SmartExp a -> ScopedExp b, NodeCounts) scopesFun1 f = tracePure ("LAMBDA " ++ show ssa) (show counts) (const (ScopedExp ssa body'), (counts',graph)) where body@(UnscopedExp fvs _) = f undefined @@ -2563,65 +2489,51 @@ determineScopesSharingExp config accOccMap expOccMap = scopesExp (freeCounts, counts') = partition isBoundHere counts ssa = buildInitialEnvExp fvs [se | ExpNodeCount se _ <- freeCounts] - isBoundHere (ExpNodeCount (StableSharingExp _ (ExpSharing _ (Tag i))) _) = i `elem` fvs - isBoundHere _ = False + isBoundHere (ExpNodeCount (StableSharingExp _ (ExpSharing _ (Tag _ i))) _) = i `elem` fvs + isBoundHere _ = False scopesExp :: forall t. UnscopedExp t -> (ScopedExp t, NodeCounts) scopesExp (UnscopedExp _ (LetSharing _ _)) = $internalError "determineScopesSharingExp: scopesExp" "unexpected 'LetSharing'" - scopesExp (UnscopedExp _ (VarSharing sn)) - = (ScopedExp [] (VarSharing sn), StableSharingExp sn (VarSharing sn) `insertExpNode` noNodeCounts) + scopesExp (UnscopedExp _ (VarSharing sn tp)) + = (ScopedExp [] (VarSharing sn tp), StableSharingExp sn (VarSharing sn tp) `insertExpNode` noNodeCounts) scopesExp (UnscopedExp _ (ExpSharing sn pexp)) = case pexp of - Tag i -> reconstruct (Tag i) noNodeCounts - Const c -> reconstruct (Const c) noNodeCounts - Undef -> reconstruct Undef noNodeCounts - Tuple tup -> let (tup', accCount) = travTup tup - in - reconstruct (Tuple tup') accCount + Tag tp i -> reconstruct (Tag tp i) noNodeCounts + Const tp c -> reconstruct (Const tp c) noNodeCounts + Undef tp -> reconstruct (Undef tp) noNodeCounts + Pair e1 e2 -> travE2 Pair e1 e2 + Nil -> reconstruct Nil noNodeCounts Prj i e -> travE1 (Prj i) e - IndexNil -> reconstruct IndexNil noNodeCounts - IndexCons ix i -> travE2 IndexCons ix i - IndexHead i -> travE1 IndexHead i - IndexTail ix -> travE1 IndexTail ix - IndexAny -> reconstruct IndexAny noNodeCounts - ToIndex sh ix -> travE2 ToIndex sh ix - FromIndex sh e -> travE2 FromIndex sh e + VecPack vec e -> travE1 (VecPack vec) e + VecUnpack vec e -> travE1 (VecUnpack vec) e + ToIndex shr sh ix -> travE2 (ToIndex shr) sh ix + FromIndex shr sh e -> travE2 (FromIndex shr) sh e Cond e1 e2 e3 -> travE3 Cond e1 e2 e3 - While p it i -> let + While tp p it i -> let (p' , accCount1) = scopesFun1 p (it', accCount2) = scopesFun1 it (i' , accCount3) = scopesExp i - in reconstruct (While p' it' i') (accCount1 +++ accCount2 +++ accCount3) + in reconstruct (While tp p' it' i') (accCount1 +++ accCount2 +++ accCount3) PrimConst c -> reconstruct (PrimConst c) noNodeCounts PrimApp p e -> travE1 (PrimApp p) e - Index a e -> travAE Index a e - LinearIndex a e -> travAE LinearIndex a e - Shape a -> travA Shape a - ShapeSize e -> travE1 ShapeSize e - Intersect sh1 sh2 -> travE2 Intersect sh1 sh2 - Union sh1 sh2 -> travE2 Union sh1 sh2 - Foreign ff f e -> travE1 (Foreign ff f) e - Coerce e -> travE1 Coerce e + Index tp a e -> travAE (Index tp) a e + LinearIndex tp a e -> travAE (LinearIndex tp) a e + Shape shr a -> travA (Shape shr) a + ShapeSize shr e -> travE1 (ShapeSize shr) e + Foreign tp ff f e -> travE1 (Foreign tp ff f) e + Coerce t1 t2 e -> travE1 (Coerce t1 t2) e where - travTup :: Tuple UnscopedExp tup -> (Tuple ScopedExp tup, NodeCounts) - travTup NilTup = (NilTup, noNodeCounts) - travTup (SnocTup tup e) = let - (tup', accCountT) = travTup tup - (e' , accCountE) = scopesExp e - in - (SnocTup tup' e', accCountT +++ accCountE) - - travE1 :: (ScopedExp a -> PreExp ScopedAcc ScopedExp t) -> UnscopedExp a + travE1 :: (ScopedExp a -> PreSmartExp ScopedAcc ScopedExp t) -> UnscopedExp a -> (ScopedExp t, NodeCounts) travE1 c e = reconstruct (c e') accCount where (e', accCount) = scopesExp e - travE2 :: (ScopedExp a -> ScopedExp b -> PreExp ScopedAcc ScopedExp t) + travE2 :: (ScopedExp a -> ScopedExp b -> PreSmartExp ScopedAcc ScopedExp t) -> UnscopedExp a -> UnscopedExp b -> (ScopedExp t, NodeCounts) @@ -2630,7 +2542,7 @@ determineScopesSharingExp config accOccMap expOccMap = scopesExp (e1', accCount1) = scopesExp e1 (e2', accCount2) = scopesExp e2 - travE3 :: (ScopedExp a -> ScopedExp b -> ScopedExp c -> PreExp ScopedAcc ScopedExp t) + travE3 :: (ScopedExp a -> ScopedExp b -> ScopedExp c -> PreSmartExp ScopedAcc ScopedExp t) -> UnscopedExp a -> UnscopedExp b -> UnscopedExp c @@ -2641,38 +2553,37 @@ determineScopesSharingExp config accOccMap expOccMap = scopesExp (e2', accCount2) = scopesExp e2 (e3', accCount3) = scopesExp e3 - travA :: (ScopedAcc a -> PreExp ScopedAcc ScopedExp t) -> UnscopedAcc a + travA :: (ScopedAcc a -> PreSmartExp ScopedAcc ScopedExp t) -> UnscopedAcc a -> (ScopedExp t, NodeCounts) - travA c acc = maybeFloatOutAcc c acc' accCount + travA c acc = floatOutAcc c acc' accCount where (acc', accCount) = scopesAcc acc - travAE :: (ScopedAcc a -> ScopedExp b -> PreExp ScopedAcc ScopedExp t) + travAE :: (ScopedAcc a -> ScopedExp b -> PreSmartExp ScopedAcc ScopedExp t) -> UnscopedAcc a -> UnscopedExp b -> (ScopedExp t, NodeCounts) - travAE c acc e = maybeFloatOutAcc (`c` e') acc' (accCountA +++ accCountE) + travAE c acc e = floatOutAcc (`c` e') acc' (accCountA +++ accCountE) where (acc', accCountA) = scopesAcc acc (e' , accCountE) = scopesExp e - maybeFloatOutAcc :: (ScopedAcc a -> PreExp ScopedAcc ScopedExp t) + floatOutAcc :: (ScopedAcc a -> PreSmartExp ScopedAcc ScopedExp t) -> ScopedAcc a -> NodeCounts -> (ScopedExp t, NodeCounts) - maybeFloatOutAcc c acc@(ScopedAcc _ (AvarSharing _)) accCount -- nothing to float out + floatOutAcc c acc@(ScopedAcc _ (AvarSharing _ _)) accCount -- nothing to float out = reconstruct (c acc) accCount - maybeFloatOutAcc c acc accCount - | float_out_acc `member` options config = reconstruct (c var) ((stableAcc `insertAccNode` noNodeCounts) +++ accCount) - | otherwise = reconstruct (c acc) accCount + floatOutAcc c acc accCount + = reconstruct (c var) ((stableAcc `insertAccNode` noNodeCounts) +++ accCount) where (var, stableAcc) = abstract acc (\(ScopedAcc _ s) -> s) abstract :: ScopedAcc a -> (ScopedAcc a -> SharingAcc ScopedAcc ScopedExp a) -> (ScopedAcc a, StableSharingAcc) - abstract (ScopedAcc _ (AvarSharing _)) _ = $internalError "sharingAccToVar" "AvarSharing" + abstract (ScopedAcc _ (AvarSharing _ _)) _ = $internalError "sharingAccToVar" "AvarSharing" abstract (ScopedAcc ssa (AletSharing sa acc)) lets = abstract acc (lets . ScopedAcc ssa . AletSharing sa) - abstract acc@(ScopedAcc ssa (AccSharing sn _)) lets = (ScopedAcc ssa (AvarSharing sn), StableSharingAcc sn (lets acc)) + abstract acc@(ScopedAcc ssa (AccSharing sn a)) lets = (ScopedAcc ssa (AvarSharing sn $ arraysRepr a), StableSharingAcc sn (lets acc)) -- Occurrence count of the currently processed node expOccCount = let StableNameHeight sn' _ = sn @@ -2690,22 +2601,22 @@ determineScopesSharingExp config accOccMap expOccMap = scopesExp -- In either case, any completed 'NodeCounts' are injected as bindings using 'LetSharing' -- node. -- - reconstruct :: PreExp ScopedAcc ScopedExp t -> NodeCounts + reconstruct :: PreSmartExp ScopedAcc ScopedExp t -> NodeCounts -> (ScopedExp t, NodeCounts) - reconstruct newExp@(Tag _) _subCount + reconstruct newExp@(Tag tp _) _subCount -- free variable => replace by a sharing variable regardless of the number of -- occurrences = let thisCount = StableSharingExp sn (ExpSharing sn newExp) `insertExpNode` noNodeCounts in tracePure "FREE" (show thisCount) - (ScopedExp [] (VarSharing sn), thisCount) + (ScopedExp [] (VarSharing sn tp), thisCount) reconstruct newExp subCount -- shared subtree => replace by a sharing variable (if 'recoverExpSharing' enabled) | expOccCount > 1 && exp_sharing `member` options config = let allCount = StableSharingExp sn sharingExp `insertExpNode` newCount in tracePure ("SHARED" ++ completed) (show allCount) - (ScopedExp [] (VarSharing sn), allCount) + (ScopedExp [] (VarSharing sn $ expType newExp), allCount) -- neither shared nor free variable => leave it as it is | otherwise = tracePure ("Normal" ++ completed) (show newCount) @@ -2918,8 +2829,7 @@ determineScopesSharingSeq config accOccMap _seqOccMap = scopesSeq -- {-# NOINLINE recoverSharingAcc #-} recoverSharingAcc - :: Typeable a - => Config + :: Config -> Level -- The level of currently bound array variables -> [Level] -- The tags of newly introduced free array variables -> SmartAcc a @@ -2934,11 +2844,10 @@ recoverSharingAcc config alvl avars acc {-# NOINLINE recoverSharingExp #-} recoverSharingExp - :: Typeable e - => Config + :: Config -> Level -- The level of currently bound scalar variables -> [Level] -- The tags of newly introduced free scalar variables - -> Exp e + -> SmartExp e -> (ScopedExp e, [StableSharingExp]) recoverSharingExp config lvl fvar exp = let @@ -2958,8 +2867,7 @@ recoverSharingExp config lvl fvar exp {-- {-# NOINLINE recoverSharingSeq #-} recoverSharingSeq - :: Typeable e - => Config + :: Config -> Seq e -> (ScopedSeq e, [StableSharingSeq]) recoverSharingSeq config seq diff --git a/src/Data/Array/Accelerate/Trafo/Shrink.hs b/src/Data/Array/Accelerate/Trafo/Shrink.hs index 0dfa875b9..057b0223d 100644 --- a/src/Data/Array/Accelerate/Trafo/Shrink.hs +++ b/src/Data/Array/Accelerate/Trafo/Shrink.hs @@ -1,8 +1,12 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternGuards #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} {-# LANGUAGE ViewPatterns #-} -- | -- Module : Data.Array.Accelerate.Trafo.Shrink @@ -38,15 +42,21 @@ module Data.Array.Accelerate.Trafo.Shrink ( ) where -- standard library -import Data.Monoid import Control.Applicative hiding ( Const ) import Prelude hiding ( exp, seq ) +import Data.Maybe ( isJust ) + +#if __GLASGOW_HASKELL__ < 804 +import Data.Semigroup +#else +import Data.Monoid +#endif -- friends import Data.Array.Accelerate.AST -import Data.Array.Accelerate.Array.Sugar hiding ( Any ) import Data.Array.Accelerate.Trafo.Base import Data.Array.Accelerate.Trafo.Substitution +import Data.Array.Accelerate.Error import qualified Data.Array.Accelerate.Debug.Stats as Stats @@ -57,12 +67,149 @@ class Shrink f where shrink = snd . shrink' -instance Kit acc => Shrink (PreOpenExp acc env aenv e) where +instance Shrink (OpenExp env aenv e) where shrink' = shrinkExp -instance Kit acc => Shrink (PreOpenFun acc env aenv f) where +instance Shrink (OpenFun env aenv f) where shrink' = shrinkFun +data VarsRange env = VarsRange !(Exists (Idx env)) !Int !(Maybe RangeTuple) -- rightmost variable, count, tuple + +data RangeTuple + = RTNil + | RTSingle + | RTPair !RangeTuple !RangeTuple + +lhsVarsRange :: LeftHandSide s v env env' -> Either (env :~: env') (VarsRange env') +lhsVarsRange lhs = case rightIx lhs of + Left eq -> Left eq + Right ix -> let (n, rt) = go lhs + in Right $ VarsRange ix n rt + where + rightIx :: LeftHandSide s v env env' -> Either (env :~: env') (Exists (Idx env')) + rightIx (LeftHandSideWildcard _) = Left Refl + rightIx (LeftHandSideSingle _) = Right $ Exists ZeroIdx + rightIx (LeftHandSidePair l1 l2) = case rightIx l2 of + Right ix -> Right ix + Left Refl -> rightIx l1 + + go :: LeftHandSide s v env env' -> (Int, Maybe (RangeTuple)) + go (LeftHandSideWildcard TupRunit) = (0, Just RTNil) + go (LeftHandSideWildcard _) = (0, Nothing) + go (LeftHandSideSingle _) = (1, Just RTSingle) + go (LeftHandSidePair l1 l2) = (n1 + n2, RTPair <$> t1 <*> t2) + where + (n1, t1) = go l1 + (n2, t2) = go l2 + +weakenVarsRange :: LeftHandSide s v env env' -> VarsRange env -> VarsRange env' +weakenVarsRange lhs (VarsRange ix n t) = VarsRange (go lhs ix) n t + where + go :: LeftHandSide s v env env' -> Exists (Idx env) -> Exists (Idx env') + go (LeftHandSideWildcard _) ix' = ix' + go (LeftHandSideSingle _) (Exists ix') = Exists (SuccIdx ix') + go (LeftHandSidePair l1 l2) ix' = go l2 $ go l1 ix' + +matchEVarsRange :: VarsRange env -> OpenExp env aenv t -> Bool +matchEVarsRange (VarsRange (Exists first) _ (Just rt)) expr = isJust $ go (idxToInt first) rt expr + where + go :: Int -> RangeTuple -> OpenExp env aenv t -> Maybe Int + go i RTNil Nil = Just i + go i RTSingle (Evar (Var _ ix)) + | checkIdx i ix = Just (i + 1) + go i (RTPair t1 t2) (Pair e1 e2) + | Just i' <- go i t2 e2 = go i' t1 e1 + go _ _ _ = Nothing + + checkIdx :: Int -> Idx env t -> Bool + checkIdx 0 ZeroIdx = True + checkIdx i (SuccIdx ix) = checkIdx (i - 1) ix + checkIdx _ _ = False +matchEVarsRange _ _ = False + +varInRange :: VarsRange env -> Var s env t -> Maybe Usages +varInRange (VarsRange (Exists rangeIx) n _) (Var _ varIx) = case go rangeIx varIx of + Nothing -> Nothing + Just j -> Just $ replicate j False ++ [True] ++ replicate (n - j - 1) False + where + -- `go ix ix'` checks whether ix <= ix' with recursion, and then checks + -- whether ix' < ix + n in go'. Returns a Just if both checks + -- are successful, containing an integer j such that ix + j = ix'. + go :: Idx env u -> Idx env t -> Maybe Int + go (SuccIdx ix) (SuccIdx ix') = go ix ix' + go ZeroIdx ix' = go' ix' 0 + go _ ZeroIdx = Nothing + + go' :: Idx env t -> Int -> Maybe Int + go' _ j | j >= n = Nothing + go' ZeroIdx j = Just j + go' (SuccIdx ix') j = go' ix' (j + 1) + +-- Describes how often the variables defined in a LHS are used together. +data Count + = Impossible !Usages -- Cannot inline this definition. This happens when the definition declares multiple variables (the right hand side returns a tuple) and the variables are used seperately. + | Infinity -- The variable is used in a loop. Inlining should only proceed if the computation is cheap. + | Finite {-# UNPACK #-} !Int + +type Usages = [Bool] -- Per variable a Bool denoting whether that variable is used. + +instance Semigroup Count where + Impossible u1 <> Impossible u2 = Impossible $ zipWith (||) u1 u2 + Impossible u <> Finite 0 = Impossible u + Finite 0 <> Impossible u = Impossible u + Impossible u <> _ = Impossible $ map (const True) u + _ <> Impossible u = Impossible $ map (const True) u + Infinity <> _ = Infinity + _ <> Infinity = Infinity + Finite a <> Finite b = Finite $ a + b + +loopCount :: Count -> Count +loopCount (Finite n) | n > 0 = Infinity +loopCount c = c + +shrinkLhs :: Count -> LeftHandSide s t env1 env2 -> Maybe (Exists (LeftHandSide s t env1)) +shrinkLhs _ (LeftHandSideWildcard _) = Nothing -- We cannot shrink this +shrinkLhs (Finite 0) lhs = Just $ Exists $ LeftHandSideWildcard $ lhsToTupR lhs -- LHS isn't used at all, replace with a wildcard +shrinkLhs (Impossible usages) lhs = case go usages lhs of + (True , [], lhs') -> Just lhs' + (False, [], _ ) -> Nothing -- No variables were dropped. Thus lhs == lhs'. + _ -> $internalError "shrinkLhs" "Mismatch in length of usages array and LHS" + where + go :: Usages -> LeftHandSide s t env1 env2 -> (Bool, Usages, Exists (LeftHandSide s t env1)) + go us (LeftHandSideWildcard tp) = (False, us, Exists $ LeftHandSideWildcard tp) + go (True : us) (LeftHandSideSingle tp) = (False, us, Exists $ LeftHandSideSingle tp) + go (False : us) (LeftHandSideSingle tp) = (True , us, Exists $ LeftHandSideWildcard $ TupRsingle tp) + go us (LeftHandSidePair l1 l2) + | (c2, us' , Exists l2') <- go us l2 + , (c1, us'', Exists l1') <- go us' l1 + , Exists l2'' <- rebuildLHS l2' + = let + lhs' + | LeftHandSideWildcard t1 <- l1' + , LeftHandSideWildcard t2 <- l2'' = LeftHandSideWildcard $ TupRpair t1 t2 + | otherwise = LeftHandSidePair l1' l2'' + in + (c1 || c2, us'', Exists lhs') + go _ _ = $internalError "shrinkLhs" "Empty array, mismatch in length of usages array and LHS" +shrinkLhs _ _ = Nothing + +-- The first LHS should be 'larger' than the second, eg the second may have a wildcard if the first LHS does bind variables there, +-- but not the other way around. +strengthenShrunkLHS :: LeftHandSide s t env1 env2 -> LeftHandSide s t env1' env2' -> env1 :?> env1' -> env2 :?> env2' +strengthenShrunkLHS (LeftHandSideWildcard _) (LeftHandSideWildcard _) k = k +strengthenShrunkLHS (LeftHandSideSingle _) (LeftHandSideSingle _) k = \ix -> case ix of + ZeroIdx -> Just ZeroIdx + SuccIdx ix' -> SuccIdx <$> k ix' +strengthenShrunkLHS (LeftHandSidePair lA hA) (LeftHandSidePair lB hB) k = strengthenShrunkLHS hA hB $ strengthenShrunkLHS lA lB k +strengthenShrunkLHS (LeftHandSideSingle _) (LeftHandSideWildcard _) k = \ix -> case ix of + ZeroIdx -> Nothing + SuccIdx ix' -> k ix' +strengthenShrunkLHS (LeftHandSidePair l h) (LeftHandSideWildcard t) k = strengthenShrunkLHS h (LeftHandSideWildcard t2) $ strengthenShrunkLHS l (LeftHandSideWildcard t1) k + where + TupRpair t1 t2 = t +strengthenShrunkLHS (LeftHandSideWildcard _) _ _ = $internalError "strengthenShrunkLHS" "Second LHS defines more variables" +strengthenShrunkLHS _ _ _ = $internalError "strengthenShrunkLHS" "Mismatch LHS single with LHS pair" + -- Shrinking -- ========= @@ -70,7 +217,7 @@ instance Kit acc => Shrink (PreOpenFun acc env aenv f) where -- instance of beta-reduction to cases where the bound variable is used zero -- (dead-code elimination) or one (linear inlining) times. -- -shrinkExp :: Kit acc => PreOpenExp acc env aenv t -> (Bool, PreOpenExp acc env aenv t) +shrinkExp :: OpenExp env aenv t -> (Bool, OpenExp env aenv t) shrinkExp = Stats.substitution "shrinkE" . first getAny . shrinkE where -- If the bound variable is used at most this many times, it will be inlined @@ -80,35 +227,62 @@ shrinkExp = Stats.substitution "shrinkE" . first getAny . shrinkE lIMIT :: Int lIMIT = 1 - shrinkE :: Kit acc => PreOpenExp acc env aenv t -> (Any, PreOpenExp acc env aenv t) + cheap :: OpenExp env aenv t -> Bool + cheap (Evar _) = True + cheap (Pair e1 e2) = cheap e1 && cheap e2 + cheap Nil = True + cheap Const{} = True + cheap PrimConst{} = True + cheap Undef{} = True + cheap (Coerce _ _ e) = cheap e + cheap _ = False + + shrinkE :: OpenExp env aenv t -> (Any, OpenExp env aenv t) shrinkE exp = case exp of - Let bnd body - | Var _ <- bnd -> Stats.inline "Var" . yes $ shrinkE (inline body bnd) - | uses <= lIMIT -> Stats.betaReduce msg . yes $ shrinkE (inline (snd body') (snd bnd')) - | otherwise -> Let <$> bnd' <*> body' + Let (LeftHandSideSingle _) bnd@Evar{} body -> Stats.inline "Var" . yes $ shrinkE (inline body bnd) + Let lhs bnd body + | shouldInline -> case inlineVars lhs (snd body') (snd bnd') of + Just inlined -> Stats.betaReduce msg . yes $ shrinkE inlined + _ -> error "shrinkExp: Unexpected failure while trying to inline some expression." + | Just (Exists lhs') <- shrinkLhs count lhs -> case strengthenE (strengthenShrunkLHS lhs lhs' Just) (snd body') of + Just body'' -> (Any True, Let lhs' (snd bnd') body'') + Nothing -> error "shrinkExp: Unexpected failure in strenthenE. Variable was analysed to be unused in usesOfExp, but appeared to be used in strenthenE." + | otherwise -> Let lhs <$> bnd' <*> body' where + shouldInline = case count of + Finite 0 -> False -- Handled by shrinkLhs + Finite n -> n <= lIMIT || cheap (snd bnd') + Infinity -> cheap (snd bnd') + Impossible _ -> False + bnd' = shrinkE bnd body' = shrinkE body - uses = usesOfExp ZeroIdx (snd body') - msg = case uses of - 0 -> "dead exp" - _ -> "inline exp" -- forced inlining when lIMIT > 1 + -- If the lhs includes non-trivial wildcards (the last field of range is Nothing), + -- then we cannot inline the binding. We can only check which variables are not used, + -- to detect unused variables. + -- If the lhs does not include non-trivial wildcards (the last field of range is a Just), + -- we can both analyse whether we can inline the binding, and check which variables are + -- not used, to detect unused variables. + count = case lhsVarsRange lhs of + Left _ -> Finite 0 + Right range -> usesOfExp range (snd body') + + msg = case count of + Finite 0 -> "dead exp" + _ -> "inline exp" -- forced inlining when lIMIT > 1 -- - Var idx -> pure (Var idx) - Const c -> pure (Const c) - Undef -> pure Undef - Tuple t -> Tuple <$> shrinkT t - Prj tup e -> Prj tup <$> shrinkE e - IndexNil -> pure IndexNil - IndexCons sl sz -> IndexCons <$> shrinkE sl <*> shrinkE sz - IndexHead sh -> IndexHead <$> shrinkE sh - IndexTail sh -> IndexTail <$> shrinkE sh + Evar v -> pure (Evar v) + Const t c -> pure (Const t c) + Undef t -> pure (Undef t) + Nil -> pure Nil + Pair x y -> Pair <$> shrinkE x <*> shrinkE y + VecPack vec e -> VecPack vec <$> shrinkE e + VecUnpack vec e -> VecUnpack vec <$> shrinkE e IndexSlice x ix sh -> IndexSlice x <$> shrinkE ix <*> shrinkE sh IndexFull x ix sl -> IndexFull x <$> shrinkE ix <*> shrinkE sl - IndexAny -> pure IndexAny - ToIndex sh ix -> ToIndex <$> shrinkE sh <*> shrinkE ix - FromIndex sh i -> FromIndex <$> shrinkE sh <*> shrinkE i + ToIndex shr sh ix -> ToIndex shr <$> shrinkE sh <*> shrinkE ix + FromIndex shr sh i -> FromIndex shr <$> shrinkE sh <*> shrinkE i Cond p t e -> Cond <$> shrinkE p <*> shrinkE t <*> shrinkE e While p f x -> While <$> shrinkF p <*> shrinkF f <*> shrinkE x PrimConst c -> pure (PrimConst c) @@ -116,17 +290,11 @@ shrinkExp = Stats.substitution "shrinkE" . first getAny . shrinkE Index a sh -> Index a <$> shrinkE sh LinearIndex a i -> LinearIndex a <$> shrinkE i Shape a -> pure (Shape a) - ShapeSize sh -> ShapeSize <$> shrinkE sh - Intersect sh sz -> Intersect <$> shrinkE sh <*> shrinkE sz - Union sh sz -> Union <$> shrinkE sh <*> shrinkE sz - Foreign ff f e -> Foreign ff <$> shrinkF f <*> shrinkE e - Coerce e -> Coerce <$> shrinkE e - - shrinkT :: Kit acc => Tuple (PreOpenExp acc env aenv) t -> (Any, Tuple (PreOpenExp acc env aenv) t) - shrinkT NilTup = pure NilTup - shrinkT (SnocTup t e) = SnocTup <$> shrinkT t <*> shrinkE e + ShapeSize shr sh -> ShapeSize shr <$> shrinkE sh + Foreign repr ff f e -> Foreign repr ff <$> shrinkF f <*> shrinkE e + Coerce t1 t2 e -> Coerce t1 t2 <$> shrinkE e - shrinkF :: Kit acc => PreOpenFun acc env aenv t -> (Any, PreOpenFun acc env aenv t) + shrinkF :: OpenFun env aenv t -> (Any, OpenFun env aenv t) shrinkF = first Any . shrinkFun first :: (a -> a') -> (a,b) -> (a',b) @@ -135,10 +303,25 @@ shrinkExp = Stats.substitution "shrinkE" . first getAny . shrinkE yes :: (Any, x) -> (Any, x) yes (_, x) = (Any True, x) -shrinkFun :: Kit acc => PreOpenFun acc env aenv f -> (Bool, PreOpenFun acc env aenv f) -shrinkFun (Lam f) = Lam <$> shrinkFun f -shrinkFun (Body b) = Body <$> shrinkExp b +shrinkFun :: OpenFun env aenv f -> (Bool, OpenFun env aenv f) +shrinkFun (Lam lhs f) = case lhsVarsRange lhs of + Left Refl -> + let b' = case lhs of + LeftHandSideWildcard TupRunit -> b + _ -> True + in (b', Lam (LeftHandSideWildcard $ lhsToTupR lhs) f') + Right range -> + let + count = usesOfFun range f + in case shrinkLhs count lhs of + Just (Exists lhs') -> case strengthenE (strengthenShrunkLHS lhs lhs' Just) f' of + Just f'' -> (True, Lam lhs' f'') + Nothing -> error "shrinkFun: Unexpected failure in strenthenE. Variable was analysed to be unused in usesOfExp, but appeared to be used in strenthenE." + Nothing -> (b, Lam lhs f') + where + (b, f') = shrinkFun f +shrinkFun (Body b) = Body <$> shrinkExp b -- The shrinking substitution for array computations. This is further limited to -- dead-code elimination only, primarily because linear inlining may inline @@ -226,7 +409,7 @@ shrinkPreAcc shrinkAcc reduceAcc = Stats.substitution "shrinkA" shrinkA shrinkCT (SnocAtup t c) = SnocAtup (shrinkCT t) (shrinkC c) --} - shrinkE :: PreOpenExp acc env aenv' t -> PreOpenExp acc env aenv' t + shrinkE :: OpenExp env aenv' t -> OpenExp env aenv' t shrinkE exp = case exp of Let bnd body -> Let (shrinkE bnd) (shrinkE body) Var idx -> Var idx @@ -256,11 +439,11 @@ shrinkPreAcc shrinkAcc reduceAcc = Stats.substitution "shrinkA" shrinkA Foreign ff f e -> Foreign ff (shrinkF f) (shrinkE e) Coerce e -> Coerce (shrinkE e) - shrinkF :: PreOpenFun acc env aenv' f -> PreOpenFun acc env aenv' f + shrinkF :: OpenFun env aenv' f -> OpenFun env aenv' f shrinkF (Lam f) = Lam (shrinkF f) shrinkF (Body b) = Body (shrinkE b) - shrinkT :: Tuple (PreOpenExp acc env aenv') t -> Tuple (PreOpenExp acc env aenv') t + shrinkT :: Tuple (OpenExp env aenv') t -> Tuple (OpenExp env aenv') t shrinkT NilTup = NilTup shrinkT (SnocTup t e) = shrinkT t `SnocTup` shrinkE e @@ -274,49 +457,41 @@ shrinkPreAcc shrinkAcc reduceAcc = Stats.substitution "shrinkA" shrinkA -- Count the number of occurrences an in-scope scalar expression bound at the -- given variable index recursively in a term. -- -usesOfExp :: forall acc env aenv s t. Idx env s -> PreOpenExp acc env aenv t -> Int -usesOfExp idx = countE +usesOfExp :: forall env aenv t. VarsRange env -> OpenExp env aenv t -> Count +usesOfExp range = countE where - countE :: PreOpenExp acc env aenv e -> Int + countE :: OpenExp env aenv e -> Count + countE exp | matchEVarsRange range exp = Finite 1 countE exp = case exp of - Var this - | Just Refl <- match this idx -> 1 - | otherwise -> 0 + Evar v -> case varInRange range v of + Just cs -> Impossible cs + Nothing -> Finite 0 -- - Let bnd body -> countE bnd + usesOfExp (SuccIdx idx) body - Const _ -> 0 - Undef -> 0 - Tuple t -> countT t - Prj _ e -> countE e - IndexNil -> 0 - IndexCons sl sz -> countE sl + countE sz - IndexHead sh -> countE sh - IndexTail sh -> countE sh - IndexSlice _ ix sh -> countE ix + countE sh - IndexFull _ ix sl -> countE ix + countE sl - IndexAny -> 0 - ToIndex sh ix -> countE sh + countE ix - FromIndex sh i -> countE sh + countE i - Cond p t e -> countE p + countE t + countE e - While p f x -> countE x + countF idx p + countF idx f - PrimConst _ -> 0 + Let lhs bnd body -> countE bnd <> usesOfExp (weakenVarsRange lhs range) body + Const _ _ -> Finite 0 + Undef _ -> Finite 0 + Nil -> Finite 0 + Pair e1 e2 -> countE e1 <> countE e2 + VecPack _ e -> countE e + VecUnpack _ e -> countE e + IndexSlice _ ix sh -> countE ix <> countE sh + IndexFull _ ix sl -> countE ix <> countE sl + FromIndex _ sh i -> countE sh <> countE i + ToIndex _ sh e -> countE sh <> countE e + Cond p t e -> countE p <> countE t <> countE e + While p f x -> countE x <> loopCount (usesOfFun range p) <> usesOfFun range f + PrimConst _ -> Finite 0 PrimApp _ x -> countE x Index _ sh -> countE sh LinearIndex _ i -> countE i - Shape _ -> 0 - ShapeSize sh -> countE sh - Intersect sh sz -> countE sh + countE sz - Union sh sz -> countE sh + countE sz - Foreign _ _ e -> countE e - Coerce e -> countE e + Shape _ -> Finite 0 + ShapeSize _ sh -> countE sh + Foreign _ _ _ e -> countE e + Coerce _ _ e -> countE e - countF :: Idx env' s -> PreOpenFun acc env' aenv f -> Int - countF idx' (Lam f) = countF (SuccIdx idx') f - countF idx' (Body b) = usesOfExp idx' b - - countT :: Tuple (PreOpenExp acc env aenv) e -> Int - countT NilTup = 0 - countT (SnocTup t e) = countT t + countE e +usesOfFun :: VarsRange env -> OpenFun env aenv f -> Count +usesOfFun range (Lam lhs f) = usesOfFun (weakenVarsRange lhs range) f +usesOfFun range (Body b) = usesOfExp range b -- Count the number of occurrences of the array term bound at the given -- environment index. If the first argument is 'True' then it includes in the @@ -340,88 +515,86 @@ usesOfPreAcc withShape countAcc idx = count count :: PreOpenAcc acc aenv a -> Int count pacc = case pacc of - Avar (ArrayVar this) -> countIdx this + Avar var -> countAvar var -- - Alet lhs bnd body -> countA bnd + countAcc withShape (weakenWithLHS lhs idx) body - Apair a1 a2 -> countA a1 + countA a2 - Anil -> 0 - Apply _ a -> countA a - Aforeign _ _ a -> countA a - Acond p t e -> countE p + countA t + countA e - Awhile _ _ a -> countA a - Use _ -> 0 - Unit e -> countE e - Reshape e a -> countE e + countA a - Generate e f -> countE e + countF f - Transform sh ix f a -> countE sh + countF ix + countF f + countA a - Replicate _ sh a -> countE sh + countA a - Slice _ a sl -> countE sl + countA a - Map f a -> countF f + countA a - ZipWith f a1 a2 -> countF f + countA a1 + countA a2 - Fold f z a -> countF f + countE z + countA a - Fold1 f a -> countF f + countA a - FoldSeg f z a s -> countF f + countE z + countA a + countA s - Fold1Seg f a s -> countF f + countA a + countA s - Scanl f z a -> countF f + countE z + countA a - Scanl' f z a -> countF f + countE z + countA a - Scanl1 f a -> countF f + countA a - Scanr f z a -> countF f + countE z + countA a - Scanr' f z a -> countF f + countE z + countA a - Scanr1 f a -> countF f + countA a - Permute f1 a1 f2 a2 -> countF f1 + countA a1 + countF f2 + countA a2 - Backpermute sh f a -> countE sh + countF f + countA a - Stencil f _ a -> countF f + countA a - Stencil2 f _ a1 _ a2 -> countF f + countA a1 + countA a2 + Alet lhs bnd body -> countA bnd + countAcc withShape (weakenWithLHS lhs >:> idx) body + Apair a1 a2 -> countA a1 + countA a2 + Anil -> 0 + Apply _ f a -> countAF f idx + countA a + Aforeign _ _ _ a -> countA a + Acond p t e -> countE p + countA t + countA e + -- Body and condition of the while loop may be evaluated multiple times. + -- We multiply the usage count, as a practical solution to this. As + -- we will check whether the count is at most 1, we will thus never + -- inline variables used in while loops. + Awhile c f a -> 2 * countAF c idx + 2 * countAF f idx + countA a + Use _ _ -> 0 + Unit _ e -> countE e + Reshape _ e a -> countE e + countA a + Generate _ e f -> countE e + countF f + Transform _ sh ix f a -> countE sh + countF ix + countF f + countA a + Replicate _ sh a -> countE sh + countA a + Slice _ a sl -> countE sl + countA a + Map _ f a -> countF f + countA a + ZipWith _ f a1 a2 -> countF f + countA a1 + countA a2 + Fold f z a -> countF f + countE z + countA a + Fold1 f a -> countF f + countA a + FoldSeg _ f z a s -> countF f + countE z + countA a + countA s + Fold1Seg _ f a s -> countF f + countA a + countA s + Scanl f z a -> countF f + countE z + countA a + Scanl' f z a -> countF f + countE z + countA a + Scanl1 f a -> countF f + countA a + Scanr f z a -> countF f + countE z + countA a + Scanr' f z a -> countF f + countE z + countA a + Scanr1 f a -> countF f + countA a + Permute f1 a1 f2 a2 -> countF f1 + countA a1 + countF f2 + countA a2 + Backpermute _ sh f a -> countE sh + countF f + countA a + Stencil _ _ f _ a -> countF f + countA a + Stencil2 _ _ _ f _ a1 _ a2 -> countF f + countA a1 + countA a2 -- Collect s -> countS s - countE :: PreOpenExp acc env aenv e -> Int + countE :: OpenExp env aenv e -> Int countE exp = case exp of - Let bnd body -> countE bnd + countE body - Var _ -> 0 - Const _ -> 0 - Undef -> 0 - Tuple t -> countT t - Prj _ e -> countE e - IndexNil -> 0 - IndexCons sl sz -> countE sl + countE sz - IndexHead sh -> countE sh - IndexTail sh -> countE sh - IndexSlice _ ix sh -> countE ix + countE sh - IndexFull _ ix sl -> countE ix + countE sl - IndexAny -> 0 - ToIndex sh ix -> countE sh + countE ix - FromIndex sh i -> countE sh + countE i - Cond p t e -> countE p + countE t + countE e - While p f x -> countF p + countF f + countE x - PrimConst _ -> 0 - PrimApp _ x -> countE x - Index a sh -> countA a + countE sh - LinearIndex a i -> countA a + countE i - ShapeSize sh -> countE sh - Intersect sh sz -> countE sh + countE sz - Union sh sz -> countE sh + countE sz + Let _ bnd body -> countE bnd + countE body + Evar _ -> 0 + Const _ _ -> 0 + Undef _ -> 0 + Nil -> 0 + Pair x y -> countE x + countE y + VecPack _ e -> countE e + VecUnpack _ e -> countE e + IndexSlice _ ix sh -> countE ix + countE sh + IndexFull _ ix sl -> countE ix + countE sl + ToIndex _ sh ix -> countE sh + countE ix + FromIndex _ sh i -> countE sh + countE i + Cond p t e -> countE p + countE t + countE e + While p f x -> countF p + countF f + countE x + PrimConst _ -> 0 + PrimApp _ x -> countE x + Index a sh -> countAvar a + countE sh + LinearIndex a i -> countAvar a + countE i + ShapeSize _ sh -> countE sh Shape a - | withShape -> countA a - | otherwise -> 0 - Foreign _ _ e -> countE e - Coerce e -> countE e + | withShape -> countAvar a + | otherwise -> 0 + Foreign _ _ _ e -> countE e + Coerce _ _ e -> countE e countA :: acc aenv a -> Int countA = countAcc withShape idx - -- countAF :: PreOpenAfun acc aenv' f - -- -> Idx aenv' s - -- -> Int - -- countAF (Alam f) v = countAF f (SuccIdx v) - -- countAF (Abody a) v = countAcc withShape v a + countAvar :: ArrayVar aenv a -> Int + countAvar (Var _ this) = countIdx this - countF :: PreOpenFun acc env aenv f -> Int - countF (Lam f) = countF f - countF (Body b) = countE b + countAF :: PreOpenAfun acc aenv' f + -> Idx aenv' s + -> Int + countAF (Alam lhs f) v = countAF f (weakenWithLHS lhs >:> v) + countAF (Abody a) v = countAcc withShape v a - countT :: Tuple (PreOpenExp acc env aenv) e -> Int - countT NilTup = 0 - countT (SnocTup t e) = countT t + countE e + countF :: OpenFun env aenv f -> Int + countF (Lam _ f) = countF f + countF (Body b) = countE b {-- countS :: PreOpenSeq acc aenv senv arrs -> Int diff --git a/src/Data/Array/Accelerate/Trafo/Simplify.hs b/src/Data/Array/Accelerate/Trafo/Simplify.hs index dee10f711..c92701e05 100644 --- a/src/Data/Array/Accelerate/Trafo/Simplify.hs +++ b/src/Data/Array/Accelerate/Trafo/Simplify.hs @@ -31,25 +31,19 @@ module Data.Array.Accelerate.Trafo.Simplify ( -- standard library import Control.Applicative hiding ( Const ) import Control.Lens hiding ( Const, ix ) -import Data.List ( nubBy ) import Data.Maybe import Data.Monoid -import Data.Typeable import Text.Printf import Prelude hiding ( exp, iterate ) -- friends import Data.Array.Accelerate.AST hiding ( prj ) -import Data.Array.Accelerate.Analysis.Match -import Data.Array.Accelerate.Analysis.Shape import Data.Array.Accelerate.Error -import Data.Array.Accelerate.Product import Data.Array.Accelerate.Trafo.Algebra import Data.Array.Accelerate.Trafo.Base import Data.Array.Accelerate.Trafo.Shrink import Data.Array.Accelerate.Type -import Data.Array.Accelerate.Array.Sugar ( Array, Shape, Elt(..), Z(..), (:.)(..) - , Tuple(..), IsTuple, fromTuple, TupleRepr, shapeToList ) +import Data.Array.Accelerate.Array.Representation ( Array, shapeToList ) import qualified Data.Array.Accelerate.Debug.Stats as Stats import qualified Data.Array.Accelerate.Debug.Flags as Debug import qualified Data.Array.Accelerate.Debug.Trace as Debug @@ -58,10 +52,10 @@ import qualified Data.Array.Accelerate.Debug.Trace as Debug class Simplify f where simplify :: f -> f -instance Kit acc => Simplify (PreFun acc aenv f) where +instance Simplify (Fun aenv f) where simplify = simplifyFun -instance (Kit acc, Elt e) => Simplify (PreExp acc aenv e) where +instance Simplify (Exp aenv e) where simplify = simplifyExp @@ -92,9 +86,9 @@ instance (Kit acc, Elt e) => Simplify (PreExp acc aenv e) where -- localCSE :: (Kit acc, Elt a) => Gamma acc env env aenv - -> PreOpenExp acc env aenv a - -> PreOpenExp acc (env,a) aenv b - -> Maybe (PreOpenExp acc env aenv b) + -> OpenExp env aenv a + -> OpenExp (env,a) aenv b + -> Maybe (OpenExp env aenv b) localCSE env bnd body | Just ix <- lookupExp env bnd = Stats.ruleFired "CSE" . Just $ inline body (Var ix) | otherwise = Nothing @@ -108,8 +102,8 @@ localCSE env bnd body -- globalCSE :: (Kit acc, Elt t) => Gamma acc env env aenv - -> PreOpenExp acc env aenv t - -> Maybe (PreOpenExp acc env aenv t) + -> OpenExp env aenv t + -> Maybe (OpenExp env aenv t) globalCSE env exp | Just ix <- lookupExp env exp = Stats.ruleFired "CSE" . Just $ Var ix | otherwise = Nothing @@ -146,9 +140,9 @@ globalCSE env exp recoverLoops :: (Kit acc, Elt b) => Gamma acc env env aenv - -> PreOpenExp acc env aenv a - -> PreOpenExp acc (env,a) aenv b - -> Maybe (PreOpenExp acc env aenv b) + -> OpenExp env aenv a + -> OpenExp (env,a) aenv b + -> Maybe (OpenExp env aenv b) recoverLoops _ bnd e3 -- To introduce scaler loops, we look for expressions of the form: -- @@ -183,15 +177,15 @@ recoverLoops _ bnd e3 = Nothing where - plus :: PreOpenExp acc env aenv Int -> PreOpenExp acc env aenv Int -> PreOpenExp acc env aenv Int + plus :: OpenExp env aenv Int -> OpenExp env aenv Int -> OpenExp env aenv Int plus x y = PrimApp (PrimAdd numType) $ Tuple $ NilTup `SnocTup` x `SnocTup` y - constant :: Int -> PreOpenExp acc env aenv Int + constant :: Int -> OpenExp env aenv Int constant i = Const ((),i) matchEnvTop :: (Elt s, Elt t) - => PreOpenExp acc (env,s) aenv f - -> PreOpenExp acc (env,t) aenv g + => OpenExp (env,s) aenv f + -> OpenExp (env,t) aenv g -> Maybe (s :=: t) matchEnvTop _ _ = gcast Refl --} @@ -204,38 +198,34 @@ recoverLoops _ bnd e3 -- introduced by the fusion transformation. This would benefit from a -- rewrite rule schema. -- +-- TODO: We currently pass around an environment Gamma, but we do not use it. +-- It might be helpful to do some inlining if this enables other optimizations. +-- Eg, for `let x = -y in -x`, the inlining would allow us to shorten it to `y`. +-- If we do not want to do inlining, we should remove the environment here. simplifyOpenExp - :: forall acc env aenv e. (Kit acc, Elt e) - => Gamma acc env env aenv - -> PreOpenExp acc env aenv e - -> (Bool, PreOpenExp acc env aenv e) + :: forall env aenv e. + Gamma env env aenv + -> OpenExp env aenv e + -> (Bool, OpenExp env aenv e) simplifyOpenExp env = first getAny . cvtE where - cvtE :: Elt t => PreOpenExp acc env aenv t -> (Any, PreOpenExp acc env aenv t) + cvtE :: OpenExp env aenv t -> (Any, OpenExp env aenv t) cvtE exp = case exp of - Let bnd body - -- Just reduct <- recoverLoops env (snd bnd') (snd body') -> yes . snd $ cvtE reduct - -- Just reduct <- localCSE env (snd bnd') (snd body') -> yes . snd $ cvtE reduct - | otherwise -> Let <$> bnd' <*> body' + Let lhs bnd body -> (u <> v, exp') where - bnd' = cvtE bnd - env' = env `pushExp` snd bnd' - body' = cvtE' (incExp env') body - - Var ix -> pure $ Var ix - Const c -> pure $ Const c - Undef -> pure Undef - Tuple tup -> Tuple <$> cvtT tup - Prj ix t -> prj env ix (cvtE t) - IndexNil -> pure IndexNil - IndexAny -> pure IndexAny - IndexCons sh sz -> indexCons (cvtE sh) (cvtE sz) - IndexHead sh -> indexHead (cvtE sh) - IndexTail sh -> indexTail (cvtE sh) + (u, bnd') = cvtE bnd + (v, exp') = cvtLet env lhs bnd' (\env' -> cvtE' env' body) + Evar var -> pure $ Evar var + Const tp c -> pure $ Const tp c + Undef tp -> pure $ Undef tp + Nil -> pure Nil + Pair e1 e2 -> Pair <$> cvtE e1 <*> cvtE e2 + VecPack vec e -> VecPack vec <$> cvtE e + VecUnpack vec e -> VecUnpack vec <$> cvtE e IndexSlice x ix sh -> IndexSlice x <$> cvtE ix <*> cvtE sh IndexFull x ix sl -> IndexFull x <$> cvtE ix <*> cvtE sl - ToIndex sh ix -> toIndex (cvtE sh) (cvtE ix) - FromIndex sh ix -> fromIndex (cvtE sh) (cvtE ix) + ToIndex shr sh ix -> toIndex shr (cvtE sh) (cvtE ix) + FromIndex shr sh ix -> fromIndex shr (cvtE sh) (cvtE ix) Cond p t e -> cond (cvtE p) (cvtE t) (cvtE e) PrimConst c -> pure $ PrimConst c PrimApp f x -> (u<>v, fx) @@ -245,180 +235,62 @@ simplifyOpenExp env = first getAny . cvtE Index a sh -> Index a <$> cvtE sh LinearIndex a i -> LinearIndex a <$> cvtE i Shape a -> shape a - ShapeSize sh -> shapeSize (cvtE sh) - Intersect s t -> cvtE s `intersect` cvtE t - Union s t -> cvtE s `union` cvtE t - Foreign ff f e -> Foreign ff <$> first Any (simplifyOpenFun EmptyExp f) <*> cvtE e + ShapeSize shr sh -> shapeSize shr (cvtE sh) + Foreign tp ff f e -> Foreign tp ff <$> first Any (simplifyOpenFun EmptyExp f) <*> cvtE e While p f x -> While <$> cvtF env p <*> cvtF env f <*> cvtE x - Coerce e -> Coerce <$> cvtE e + Coerce t1 t2 e -> Coerce t1 t2 <$> cvtE e - cvtT :: Tuple (PreOpenExp acc env aenv) t -> (Any, Tuple (PreOpenExp acc env aenv) t) - cvtT NilTup = pure NilTup - cvtT (SnocTup t e) = SnocTup <$> cvtT t <*> cvtE e - - cvtE' :: Elt e' => Gamma acc env' env' aenv -> PreOpenExp acc env' aenv e' -> (Any, PreOpenExp acc env' aenv e') + cvtE' :: Gamma env' env' aenv -> OpenExp env' aenv e' -> (Any, OpenExp env' aenv e') cvtE' env' = first Any . simplifyOpenExp env' - cvtF :: Gamma acc env' env' aenv -> PreOpenFun acc env' aenv f -> (Any, PreOpenFun acc env' aenv f) + cvtF :: Gamma env' env' aenv -> OpenFun env' aenv f -> (Any, OpenFun env' aenv f) cvtF env' = first Any . simplifyOpenFun env' - -- Return the minimal set of unique shapes to intersect. This is a bit - -- inefficient, but the number of shapes is expected to be small so should - -- be fine in practice. - -- - intersect :: Shape t - => (Any, PreOpenExp acc env aenv t) - -> (Any, PreOpenExp acc env aenv t) - -> (Any, PreOpenExp acc env aenv t) - intersect (c1, sh1) (c2, sh2) - | Nothing <- match sh sh' = Stats.ruleFired "intersect" (yes sh') - | otherwise = (c1 <> c2, sh') - where - sh = Intersect sh1 sh2 - sh' = foldl1 Intersect - $ nubBy (\x y -> isJust (match x y)) - $ leaves sh1 ++ leaves sh2 - - leaves :: Shape t => PreOpenExp acc env aenv t -> [PreOpenExp acc env aenv t] - leaves (Intersect x y) = leaves x ++ leaves y - leaves rest = [rest] - - -- Return the minimal set of unique shapes to take the union of. This is a bit - -- inefficient, but the number of shapes is expected to be small so should - -- be fine in practice. - -- - union :: Shape t - => (Any, PreOpenExp acc env aenv t) - -> (Any, PreOpenExp acc env aenv t) - -> (Any, PreOpenExp acc env aenv t) - union (c1, sh1) (c2, sh2) - | Nothing <- match sh sh' = Stats.ruleFired "union" (yes sh') - | otherwise = (c1 <> c2, sh') - where - sh = Union sh1 sh2 - sh' = foldl1 Union - $ nubBy (\x y -> isJust (match x y)) - $ leaves sh1 ++ leaves sh2 - - leaves :: Shape t => PreOpenExp acc env aenv t -> [PreOpenExp acc env aenv t] - leaves (Union x y) = leaves x ++ leaves y - leaves rest = [rest] - + cvtLet :: Gamma env' env' aenv -> ELeftHandSide bnd env' env'' -> OpenExp env' aenv bnd -> (Gamma env'' env'' aenv -> (Any, OpenExp env'' aenv t)) -> (Any, OpenExp env' aenv t) + cvtLet env' lhs@(LeftHandSideSingle _) bnd body = Let lhs bnd <$> body (incExp $ env' `pushExp` bnd) -- Single variable on the LHS, add binding to the environment + cvtLet env' (LeftHandSideWildcard _) _ body = body env' -- Binding not used, remove let binding + cvtLet env' (LeftHandSidePair l1 l2) (Pair e1 e2) body = -- Split binding to multiple bindings + first (const $ Any True) $ + cvtLet env' l1 e1 $ + \env'' -> cvtLet env'' l2 (weakenE (weakenWithLHS l1) e2) body + cvtLet env' lhs bnd body = Let lhs bnd <$> body (lhsExpr lhs env') -- Cannot split this binding. -- Simplify conditional expressions, in particular by eliminating branches -- when the predicate is a known constant. -- - cond :: forall t. Elt t - => (Any, PreOpenExp acc env aenv Bool) - -> (Any, PreOpenExp acc env aenv t) - -> (Any, PreOpenExp acc env aenv t) - -> (Any, PreOpenExp acc env aenv t) + cond :: forall t. + (Any, OpenExp env aenv Bool) + -> (Any, OpenExp env aenv t) + -> (Any, OpenExp env aenv t) + -> (Any, OpenExp env aenv t) cond p@(_,p') t@(_,t') e@(_,e') - | Const True <- p' = Stats.knownBranch "True" (yes t') - | Const False <- p' = Stats.knownBranch "False" (yes e') + | Const _ True <- p' = Stats.knownBranch "True" (yes t') + | Const _ False <- p' = Stats.knownBranch "False" (yes e') | Just Refl <- match t' e' = Stats.knownBranch "redundant" (yes e') | otherwise = Cond <$> p <*> t <*> e - -- If we are projecting elements from a tuple structure or tuple of constant - -- valued tuple, pick out the appropriate component directly. - -- - -- Follow variable bindings, but only if they result in a simplification. - -- - prj :: forall env' s t. (Elt s, Elt t, IsTuple t) - => Gamma acc env' env' aenv - -> TupleIdx (TupleRepr t) s - -> (Any, PreOpenExp acc env' aenv t) - -> (Any, PreOpenExp acc env' aenv s) - prj env' ix top@(_,e) = case e of - Tuple t -> Stats.inline "prj/Tuple" . yes $ prjT ix t - Const c -> Stats.inline "prj/Const" . yes $ prjC ix (fromTuple (toElt c :: t)) - Var v | Just x <- prjV v -> Stats.inline "prj/Var" . yes $ x - Let a b | Just x <- prjL a b -> Stats.inline "prj/Let" . yes $ x - _ -> Prj ix <$> top - where - prjT :: TupleIdx tup s -> Tuple (PreOpenExp acc env' aenv) tup -> PreOpenExp acc env' aenv s - prjT ZeroTupIdx (SnocTup _ v) = v - prjT (SuccTupIdx idx) (SnocTup t _) = prjT idx t -#if __GLASGOW_HASKELL__ < 800 - prjT _ _ = error "DO MORE OF WHAT MAKES YOU HAPPY" -#endif - - prjC :: TupleIdx tup s -> tup -> PreOpenExp acc env' aenv s - prjC ZeroTupIdx (_, v) = Const (fromElt v) - prjC (SuccTupIdx idx) (tup, _) = prjC idx tup - - prjV :: Idx env' t -> Maybe (PreOpenExp acc env' aenv s) - prjV var - | e' <- prjExp var env' - , Nothing <- match e e' - = case e' of - -- Don't push through nested let-bindings; this leads to code explosion - Let _ _ -> Nothing - _ | (Any True, x) <- prj env' ix (pure e') -> Just x - _ -> Nothing - | otherwise - = Nothing - - prjL :: Elt a - => PreOpenExp acc env' aenv a - -> PreOpenExp acc (env',a) aenv t - -> Maybe (PreOpenExp acc env' aenv s) - prjL a b - | (Any True, c) <- prj (incExp $ pushExp env' a) ix (pure b) = Just (Let a c) - prjL _ _ = Nothing - -- Shape manipulations -- - indexCons :: (Elt sl, Elt sz) - => (Any, PreOpenExp acc env aenv sl) - -> (Any, PreOpenExp acc env aenv sz) - -> (Any, PreOpenExp acc env aenv (sl :. sz)) - indexCons (_,IndexNil) (_,Const c) - | Just c' <- cast c -- EltRepr Z ~ EltRepr () - = Stats.ruleFired "Z:.const" $ yes (Const c') - indexCons (_,IndexNil) (_,IndexHead sz') - | 1 <- expDim sz' -- no type information that this is a 1D shape, hence gcast next - , Just sh' <- gcast sz' - = Stats.ruleFired "Z:.indexHead" $ yes sh' - indexCons (_,IndexTail sl') (_,IndexHead sz') - | Just Refl <- match sl' sz' - = Stats.ruleFired "indexTail:.indexHead" $ yes sl' - indexCons sl sz - = IndexCons <$> sl <*> sz - - indexHead :: forall sl sz. (Elt sl, Elt sz) => (Any, PreOpenExp acc env aenv (sl :. sz)) -> (Any, PreOpenExp acc env aenv sz) - indexHead (_, Const c) - | _ :. sz <- toElt c :: sl :. sz = Stats.ruleFired "indexHead/const" $ yes (Const (fromElt sz)) - indexHead (_, IndexCons _ sz) = Stats.ruleFired "indexHead/indexCons" $ yes sz - indexHead sh = IndexHead <$> sh - - indexTail :: forall sl sz. (Elt sl, Elt sz) => (Any, PreOpenExp acc env aenv (sl :. sz)) -> (Any, PreOpenExp acc env aenv sl) - indexTail (_, Const c) - | sl :. _ <- toElt c :: sl :. sz = Stats.ruleFired "indexTail/const" $ yes (Const (fromElt sl)) - indexTail (_, IndexCons sl _) = Stats.ruleFired "indexTail/indexCons" $ yes sl - indexTail sh = IndexTail <$> sh - - shape :: forall sh t. (Shape sh, Elt t) => acc aenv (Array sh t) -> (Any, PreOpenExp acc env aenv sh) - shape _ - | Just Refl <- matchTupleType (eltType @sh) (eltType @Z) - = Stats.ruleFired "shape/Z" $ yes (Const (fromElt Z)) + shape :: forall sh t. ArrayVar aenv (Array sh t) -> (Any, OpenExp env aenv sh) + shape (Var (ArrayR ShapeRz _) _) + = Stats.ruleFired "shape/Z" $ yes Nil shape a - = pure $ Shape a + = pure $ Shape a - shapeSize :: forall sh. Shape sh => (Any, PreOpenExp acc env aenv sh) -> (Any, PreOpenExp acc env aenv Int) - shapeSize (_, Const c) = Stats.ruleFired "shapeSize/const" $ yes (Const (product (shapeToList (toElt c :: sh)))) - shapeSize sh = ShapeSize <$> sh + shapeSize :: forall sh. ShapeR sh -> (Any, OpenExp env aenv sh) -> (Any, OpenExp env aenv Int) + shapeSize shr (_, extractConstTuple -> Just c) = Stats.ruleFired "shapeSize/const" $ yes (Const scalarTypeInt (product (shapeToList shr c))) + shapeSize shr sh = ShapeSize shr <$> sh - toIndex :: forall sh. Shape sh => (Any, PreOpenExp acc env aenv sh) -> (Any, PreOpenExp acc env aenv sh) -> (Any, PreOpenExp acc env aenv Int) - toIndex (_,sh) (_,FromIndex sh' ix) + toIndex :: forall sh. ShapeR sh -> (Any, OpenExp env aenv sh) -> (Any, OpenExp env aenv sh) -> (Any, OpenExp env aenv Int) + toIndex _ (_,sh) (_,FromIndex _ sh' ix) | Just Refl <- match sh sh' = Stats.ruleFired "toIndex/fromIndex" $ yes ix - toIndex sh ix = ToIndex <$> sh <*> ix + toIndex shr sh ix = ToIndex shr <$> sh <*> ix - fromIndex :: forall sh. Shape sh => (Any, PreOpenExp acc env aenv sh) -> (Any, PreOpenExp acc env aenv Int) -> (Any, PreOpenExp acc env aenv sh) - fromIndex (_,sh) (_,ToIndex sh' ix) + fromIndex :: forall sh. ShapeR sh -> (Any, OpenExp env aenv sh) -> (Any, OpenExp env aenv Int) -> (Any, OpenExp env aenv sh) + fromIndex _ (_,sh) (_,ToIndex _ sh' ix) | Just Refl <- match sh sh' = Stats.ruleFired "fromIndex/toIndex" $ yes ix - fromIndex sh ix = FromIndex <$> sh <*> ix + fromIndex shr sh ix = FromIndex shr <$> sh <*> ix first :: (a -> a') -> (a,b) -> (a',b) first f (x,y) = (f x, y) @@ -426,27 +298,35 @@ simplifyOpenExp env = first getAny . cvtE yes :: x -> (Any, x) yes x = (Any True, x) +extractConstTuple :: OpenExp env aenv t -> Maybe t +extractConstTuple Nil = Just () +extractConstTuple (Pair e1 e2) = (,) <$> extractConstTuple e1 <*> extractConstTuple e2 +extractConstTuple (Const _ c) = Just c +extractConstTuple _ = Nothing -- Simplification for open functions -- simplifyOpenFun - :: Kit acc - => Gamma acc env env aenv - -> PreOpenFun acc env aenv f - -> (Bool, PreOpenFun acc env aenv f) -simplifyOpenFun env (Body e) = Body <$> simplifyOpenExp env e -simplifyOpenFun env (Lam f) = Lam <$> simplifyOpenFun env' f + :: Gamma env env aenv + -> OpenFun env aenv f + -> (Bool, OpenFun env aenv f) +simplifyOpenFun env (Body e) = Body <$> simplifyOpenExp env e +simplifyOpenFun env (Lam lhs f) = Lam lhs <$> simplifyOpenFun env' f where - env' = incExp env `pushExp` Var ZeroIdx + env' = lhsExpr lhs env +lhsExpr :: ELeftHandSide t env env' -> Gamma env env aenv -> Gamma env' env' aenv +lhsExpr (LeftHandSideWildcard _) env = env +lhsExpr (LeftHandSideSingle tp) env = incExp env `pushExp` Evar (Var tp ZeroIdx) +lhsExpr (LeftHandSidePair l1 l2) env = lhsExpr l2 $ lhsExpr l1 env -- Simplify closed expressions and functions. The process is applied -- repeatedly until no more changes are made. -- -simplifyExp :: (Elt t, Kit acc) => PreExp acc aenv t -> PreExp acc aenv t +simplifyExp :: Exp aenv t -> Exp aenv t simplifyExp = iterate summariseOpenExp (simplifyOpenExp EmptyExp) -simplifyFun :: Kit acc => PreFun acc aenv f -> PreFun acc aenv f +simplifyFun :: Fun aenv f -> Fun aenv f simplifyFun = iterate summariseOpenFun (simplifyOpenFun EmptyExp) @@ -546,32 +426,24 @@ ops = lens _ops (\Stats{..} v -> Stats { _ops = v, ..}) {-# INLINE vars #-} {-# INLINE ops #-} -summariseOpenFun :: PreOpenFun acc env aenv f -> Stats -summariseOpenFun (Body e) = summariseOpenExp e & terms +~ 1 -summariseOpenFun (Lam f) = summariseOpenFun f & terms +~ 1 & binders +~ 1 +summariseOpenFun :: OpenFun env aenv f -> Stats +summariseOpenFun (Body e) = summariseOpenExp e & terms +~ 1 +summariseOpenFun (Lam _ f) = summariseOpenFun f & terms +~ 1 & binders +~ 1 -summariseOpenExp :: PreOpenExp acc env aenv t -> Stats +summariseOpenExp :: OpenExp env aenv t -> Stats summariseOpenExp = (terms +~ 1) . goE where zero = Stats 0 0 0 0 0 - travE :: PreOpenExp acc env aenv t -> Stats + travE :: OpenExp env aenv t -> Stats travE = summariseOpenExp - travF :: PreOpenFun acc env aenv t -> Stats + travF :: OpenFun env aenv t -> Stats travF = summariseOpenFun travA :: acc aenv a -> Stats travA _ = zero & vars +~ 1 -- assume an array index, else we should have failed elsewhere - travT :: Tuple (PreOpenExp acc env aenv) t -> Stats - travT NilTup = zero & terms +~ 1 - travT (SnocTup t e) = travT t +++ travE e & terms +~ 1 - - travTix :: TupleIdx t e -> Stats - travTix ZeroTupIdx = zero & terms +~ 1 - travTix (SuccTupIdx t) = travTix t & terms +~ 1 - travC :: PrimConst c -> Stats travC (PrimMinBound t) = travBoundedType t & terms +~ 1 travC (PrimMaxBound t) = travBoundedType t & terms +~ 1 @@ -610,36 +482,31 @@ summariseOpenExp = (terms +~ 1) . goE -- travVectorType (Vector16Type t) = travSingleType t & types +~ 1 -- The scrutinee has already been counted - goE :: PreOpenExp acc env aenv t -> Stats + goE :: OpenExp env aenv t -> Stats goE exp = case exp of - Let bnd body -> travE bnd +++ travE body & binders +~ 1 - Var{} -> zero & vars +~ 1 - Foreign _ _ x -> travE x & terms +~ 1 -- +1 for asm, ignore fallback impls. + Let _ bnd body -> travE bnd +++ travE body & binders +~ 1 + Evar{} -> zero & vars +~ 1 + Foreign _ _ _ x -> travE x & terms +~ 1 -- +1 for asm, ignore fallback impls. Const{} -> zero - Undef -> zero - Tuple tup -> travT tup & terms +~ 1 - Prj ix e -> travTix ix +++ travE e - IndexNil -> zero - IndexCons sh sz -> travE sh +++ travE sz - IndexHead sh -> travE sh - IndexTail sh -> travE sh - IndexAny -> zero + Undef _ -> zero + Nil -> zero & terms +~ 1 + Pair e1 e2 -> travE e1 +++ travE e2 & terms +~ 1 + VecPack _ e -> travE e + VecUnpack _ e -> travE e IndexSlice _ slix sh -> travE slix +++ travE sh & terms +~ 1 -- +1 for sliceIndex IndexFull _ slix sl -> travE slix +++ travE sl & terms +~ 1 -- +1 for sliceIndex - ToIndex sh ix -> travE sh +++ travE ix - FromIndex sh ix -> travE sh +++ travE ix + ToIndex _ sh ix -> travE sh +++ travE ix + FromIndex _ sh ix -> travE sh +++ travE ix Cond p t e -> travE p +++ travE t +++ travE e While p f x -> travF p +++ travF f +++ travE x PrimConst c -> travC c Index a ix -> travA a +++ travE ix LinearIndex a ix -> travA a +++ travE ix Shape a -> travA a - ShapeSize sh -> travE sh - Intersect sh1 sh2 -> travE sh1 +++ travE sh2 - Union sh1 sh2 -> travE sh1 +++ travE sh2 + ShapeSize _ sh -> travE sh PrimApp f x -> travPrimFun f +++ travE x - Coerce e -> travE e + Coerce _ _ e -> travE e travPrimFun :: PrimFun f -> Stats travPrimFun = (ops +~ 1) . goF diff --git a/src/Data/Array/Accelerate/Trafo/Substitution.hs b/src/Data/Array/Accelerate/Trafo/Substitution.hs index 6e6991127..e20ea4d90 100644 --- a/src/Data/Array/Accelerate/Trafo/Substitution.hs +++ b/src/Data/Array/Accelerate/Trafo/Substitution.hs @@ -6,9 +6,11 @@ {-# LANGUAGE PatternGuards #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ViewPatterns #-} {-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Array.Accelerate.Trafo.Substitution @@ -23,7 +25,7 @@ module Data.Array.Accelerate.Trafo.Substitution ( -- ** Renaming & Substitution - inline, substitute, compose, + inline, inlineVars, compose, subTop, subAtop, -- ** Weakening @@ -34,16 +36,24 @@ module Data.Array.Accelerate.Trafo.Substitution ( -- ** Rebuilding terms RebuildAcc, Rebuildable(..), RebuildableAcc, - RebuildableExp(..), RebuildTup(..), rebuildWeakenVar + RebuildableExp(..), rebuildWeakenVar, rebuildLHS, + OpenAccFun(..), OpenAccExp(..), + + -- ** Checks + isIdentity, isIdentityIndexing, extractExpVars, + bindingIsTrivial, ) where import Data.Kind import Control.Applicative hiding ( Const ) +import Control.Monad import Prelude hiding ( exp, seq ) import Data.Array.Accelerate.AST -import Data.Array.Accelerate.Array.Sugar ( Elt, Tuple(..), Array ) +import Data.Array.Accelerate.Array.Representation +import Data.Array.Accelerate.Analysis.Match +import Data.Array.Accelerate.Error import qualified Data.Array.Accelerate.Debug.Stats as Stats @@ -72,51 +82,151 @@ import qualified Data.Array.Accelerate.Debug.Stats as Stats -- a class of operations on variables that is closed under shifting. -- infixr `compose` -infixr `substitute` +-- infixr `substitute` + +lhsFullVars :: forall s a env1 env2. LeftHandSide s a env1 env2 -> Maybe (Vars s env2 a) +lhsFullVars = fmap snd . go weakenId + where + go :: forall env env' b. (env' :> env2) -> LeftHandSide s b env env' -> Maybe (env :> env2, Vars s env2 b) + go k (LeftHandSideWildcard TupRunit) = Just (k, VarsNil) + go k (LeftHandSideSingle s) = Just $ (weakenSucc $ k, VarsSingle $ Var s $ k >:> ZeroIdx) + go k (LeftHandSidePair l1 l2) + | Just (k', v2) <- go k l2 + , Just (k'', v1) <- go k' l1 = Just (k'', VarsPair v1 v2) + go _ _ = Nothing + +bindingIsTrivial :: LeftHandSide s a env1 env2 -> Vars s env2 b -> Maybe (a :~: b) +bindingIsTrivial lhs vars + | Just lhsVars <- lhsFullVars lhs + , Just Refl <- matchVars vars lhsVars = Just Refl +bindingIsTrivial _ _ = Nothing + +isIdentity :: OpenFun env aenv (a -> b) -> Maybe (a :~: b) +isIdentity (Lam lhs (Body (extractExpVars -> Just vars))) = bindingIsTrivial lhs vars +isIdentity _ = Nothing + +-- Detects whether the function is of the form \ix -> a ! ix +isIdentityIndexing :: OpenFun env aenv (a -> b) -> Maybe (ArrayVar aenv (Array a b)) +isIdentityIndexing (Lam lhs (Body body)) + | Index avar ix <- body + , Just vars <- extractExpVars ix + , Just Refl <- bindingIsTrivial lhs vars = Just avar +isIdentityIndexing _ = Nothing -- | Replace the first variable with the given expression. The environment -- shrinks. -- -inline :: RebuildableAcc acc - => PreOpenExp acc (env, s) aenv t - -> PreOpenExp acc env aenv s - -> PreOpenExp acc env aenv t +inline :: OpenExp (env, s) aenv t + -> OpenExp env aenv s + -> OpenExp env aenv t inline f g = Stats.substitution "inline" $ rebuildE (subTop g) f +inlineVars :: forall env env' aenv t1 t2. + ELeftHandSide t1 env env' + -> OpenExp env' aenv t2 + -> OpenExp env aenv t1 + -> Maybe (OpenExp env aenv t2) +inlineVars lhsBound expr bound + | Just vars <- lhsFullVars lhsBound = substitute (strengthenWithLHS lhsBound) weakenId vars expr + where + substitute :: forall env1 env2 t. + env1 :?> env2 + -> env :> env2 + -> ExpVars env1 t1 + -> OpenExp env1 aenv t + -> Maybe (OpenExp env2 aenv t) + substitute _ k2 vars (extractExpVars -> Just vars') + | Just Refl <- matchVars vars vars' = Just $ weakenE k2 bound + substitute k1 k2 vars e = case e of + Let lhs e1 e2 + | Exists lhs' <- rebuildLHS lhs + -> Let lhs' <$> travE e1 <*> substitute (strengthenAfter lhs lhs' k1) (weakenWithLHS lhs' .> k2) (weakenWithLHS lhs `weaken` vars) e2 + Evar (Var t ix) -> Evar . Var t <$> k1 ix + Foreign tp asm f e1 -> Foreign tp asm f <$> travE e1 + Pair e1 e2 -> Pair <$> travE e1 <*> travE e2 + Nil -> Just Nil + VecPack vec e1 -> VecPack vec <$> travE e1 + VecUnpack vec e1 -> VecUnpack vec <$> travE e1 + IndexSlice si e1 e2 -> IndexSlice si <$> travE e1 <*> travE e2 + IndexFull si e1 e2 -> IndexFull si <$> travE e1 <*> travE e2 + ToIndex shr e1 e2 -> ToIndex shr <$> travE e1 <*> travE e2 + FromIndex shr e1 e2 -> FromIndex shr <$> travE e1 <*> travE e2 + Cond e1 e2 e3 -> Cond <$> travE e1 <*> travE e2 <*> travE e3 + While f1 f2 e1 -> While <$> travF f1 <*> travF f2 <*> travE e1 + Const t c -> Just $ Const t c + PrimConst c -> Just $ PrimConst c + PrimApp p e1 -> PrimApp p <$> travE e1 + Index a e1 -> Index a <$> travE e1 + LinearIndex a e1 -> LinearIndex a <$> travE e1 + Shape a -> Just $ Shape a + ShapeSize shr e1 -> ShapeSize shr <$> travE e1 + Undef t -> Just $ Undef t + Coerce t1 t2 e1 -> Coerce t1 t2 <$> travE e1 + + where + travE :: OpenExp env1 aenv s -> Maybe (OpenExp env2 aenv s) + travE = substitute k1 k2 vars + + travF :: OpenFun env1 aenv s -> Maybe (OpenFun env2 aenv s) + travF = substituteF k1 k2 vars + + substituteF :: forall env1 env2 t. + env1 :?> env2 + -> env :> env2 + -> ExpVars env1 t1 + -> OpenFun env1 aenv t + -> Maybe (OpenFun env2 aenv t) + substituteF k1 k2 vars (Body e) = Body <$> substitute k1 k2 vars e + substituteF k1 k2 vars (Lam lhs f) + | Exists lhs' <- rebuildLHS lhs = Lam lhs' <$> substituteF (strengthenAfter lhs lhs' k1) (weakenWithLHS lhs' .> k2) (weakenWithLHS lhs `weaken` vars) f + +inlineVars _ _ _ = Nothing + + -- | Replace an expression that uses the top environment variable with another. -- The result of the first is let bound into the second. -- -substitute :: (RebuildableAcc acc, Elt b, Elt c) - => PreOpenExp acc (env, b) aenv c - -> PreOpenExp acc (env, a) aenv b - -> PreOpenExp acc (env, a) aenv c -substitute f g +{- substitute' :: OpenExp (env, b) aenv c + -> OpenExp (env, a) aenv b + -> OpenExp (env, a) aenv c +substitute' f g | Stats.substitution "substitute" False = undefined - - | Var ZeroIdx <- g = f -- don't rebind an identity function - | otherwise = Let g $ rebuildE split f + | isIdentity f = g -- don't rebind an identity function + | isIdentity g = f + | otherwise = Let g $ rebuildE split f where - split :: Elt c => Idx (env,b) c -> PreOpenExp acc ((env,a),b) aenv c + split :: Idx (env,b) c -> OpenExp ((env,a),b) aenv c split ZeroIdx = Var ZeroIdx split (SuccIdx ix) = Var (SuccIdx (SuccIdx ix)) +substitute :: LeftHandSide b env envb + -> OpenExp envb c + -> LeftHandSide a env enva + -> OpenExp enva b +-} -- | Composition of unary functions. -- -compose :: (RebuildableAcc acc, Elt c) - => PreOpenFun acc env aenv (b -> c) - -> PreOpenFun acc env aenv (a -> b) - -> PreOpenFun acc env aenv (a -> c) -compose (Lam (Body f)) (Lam (Body g)) = Stats.substitution "compose" . Lam . Body $ substitute f g -compose _ _ = error "compose: impossible evaluation" - -subTop :: Elt t => PreOpenExp acc env aenv s -> Idx (env, s) t -> PreOpenExp acc env aenv t -subTop s ZeroIdx = s -subTop _ (SuccIdx ix) = Var ix +compose :: OpenFun env aenv (b -> c) + -> OpenFun env aenv (a -> b) + -> OpenFun env aenv (a -> c) +compose f@(Lam lhsB (Body c)) g@(Lam lhsA (Body b)) + | Stats.substitution "compose" False = undefined + | Just Refl <- isIdentity f = g -- don't rebind an identity function + | Just Refl <- isIdentity g = f + + | Exists lhsB' <- rebuildLHS lhsB + = Lam lhsA $ Body $ Let lhsB' b (weakenE (sinkWithLHS lhsB lhsB' $ weakenWithLHS lhsA) c) + -- = Stats.substitution "compose" . Lam lhs2 . Body $ substitute' f g +compose _ _ = error "compose: impossible evaluation" + +subTop :: OpenExp env aenv s -> ExpVar (env, s) t -> OpenExp env aenv t +subTop s (Var _ ZeroIdx ) = s +subTop _ (Var tp (SuccIdx ix)) = Evar $ Var tp ix subAtop :: PreOpenAcc acc aenv t -> ArrayVar (aenv, t) (Array sh2 e2) -> PreOpenAcc acc aenv (Array sh2 e2) -subAtop t (ArrayVar ZeroIdx ) = t -subAtop _ (ArrayVar (SuccIdx idx)) = Avar $ ArrayVar idx +subAtop t (Var _ ZeroIdx ) = t +subAtop _ (Var repr (SuccIdx idx)) = Avar $ Var repr idx data Identity a = Identity { runIdentity :: a } @@ -153,13 +263,13 @@ class Rebuildable f where class RebuildableExp f where {-# MINIMAL rebuildPartialE #-} rebuildPartialE :: (Applicative f', SyntacticExp fe) - => (forall e'. Elt e' => Idx env e' -> f' (fe (AccClo (f env)) env' aenv e')) - -> f env aenv e + => (forall e'. ExpVar env e' -> f' (fe env' aenv e')) + -> f env aenv e -> f' (f env' aenv e) {-# INLINEABLE rebuildE #-} rebuildE :: SyntacticExp fe - => (forall e'. Elt e' => Idx env e' -> fe (AccClo (f env)) env' aenv e') + => (forall e'. ExpVar env e' -> fe env' aenv e') -> f env aenv e -> f env' aenv e rebuildE v = runIdentity . rebuildPartialE (Identity . v) @@ -168,17 +278,25 @@ class RebuildableExp f where -- type RebuildableAcc acc = (Rebuildable acc, AccClo acc ~ acc) +-- Wrappers which add the 'acc' type argument +-- +data OpenAccExp (acc :: Type -> Type -> Type) env aenv a where + OpenAccExp :: { unOpenAccExp :: OpenExp env aenv a } -> OpenAccExp acc env aenv a + +data OpenAccFun (acc :: Type -> Type -> Type) env aenv a where + OpenAccFun :: { unOpenAccFun :: OpenFun env aenv a } -> OpenAccFun acc env aenv a + -- We can use the same plumbing to rebuildPartial all the things we want to rebuild. -- -instance RebuildableAcc acc => Rebuildable (PreOpenExp acc env) where - type AccClo (PreOpenExp acc env) = acc +instance Rebuildable (OpenAccExp acc env) where + type AccClo (OpenAccExp acc env) = acc {-# INLINEABLE rebuildPartial #-} - rebuildPartial x = Stats.substitution "rebuild" $ rebuildPreOpenExp rebuildPartial (pure . IE) x + rebuildPartial v (OpenAccExp e) = OpenAccExp <$> Stats.substitution "rebuild" (rebuildOpenExp (pure . IE) (reindexAvar v) e) -instance RebuildableAcc acc => Rebuildable (PreOpenFun acc env) where - type AccClo (PreOpenFun acc env) = acc +instance Rebuildable (OpenAccFun acc env) where + type AccClo (OpenAccFun acc env) = acc {-# INLINEABLE rebuildPartial #-} - rebuildPartial x = Stats.substitution "rebuild" $ rebuildFun rebuildPartial (pure . IE) x + rebuildPartial v (OpenAccFun f) = OpenAccFun <$> Stats.substitution "rebuild" (rebuildFun (pure . IE) (reindexAvar v) f) instance RebuildableAcc acc => Rebuildable (PreOpenAcc acc) where type AccClo (PreOpenAcc acc) = acc @@ -190,26 +308,18 @@ instance RebuildableAcc acc => Rebuildable (PreOpenAfun acc) where {-# INLINEABLE rebuildPartial #-} rebuildPartial x = Stats.substitution "rebuild" $ rebuildAfun rebuildPartial x --- Tuples have to be handled specially. -newtype RebuildTup acc env aenv t = RebuildTup { unRTup :: Tuple (PreOpenExp acc env aenv) t } - -instance RebuildableAcc acc => Rebuildable (RebuildTup acc env) where - type AccClo (RebuildTup acc env) = acc - {-# INLINEABLE rebuildPartial #-} - rebuildPartial v t = Stats.substitution "rebuild" . RebuildTup <$> rebuildTup rebuildPartial (pure . IE) v (unRTup t) - instance Rebuildable OpenAcc where type AccClo OpenAcc = OpenAcc {-# INLINEABLE rebuildPartial #-} rebuildPartial x = Stats.substitution "rebuild" $ rebuildOpenAcc x -instance RebuildableAcc acc => RebuildableExp (PreOpenExp acc) where +instance RebuildableExp OpenExp where {-# INLINEABLE rebuildPartialE #-} - rebuildPartialE v x = Stats.substitution "rebuild" $ rebuildPreOpenExp rebuildPartial v (pure . IA) x + rebuildPartialE v x = Stats.substitution "rebuild" $ rebuildOpenExp v (ReindexAvar pure) x -instance RebuildableAcc acc => RebuildableExp (PreOpenFun acc) where +instance RebuildableExp OpenFun where {-# INLINEABLE rebuildPartialE #-} - rebuildPartialE v x = Stats.substitution "rebuild" $ rebuildFun rebuildPartial v (pure . IA) x + rebuildPartialE v x = Stats.substitution "rebuild" $ rebuildFun v (ReindexAvar pure) x -- NOTE: [Weakening] -- @@ -240,20 +350,23 @@ class Sink f where instance Sink Idx where {-# INLINEABLE weaken #-} - weaken k = k + weaken = (>:>) -instance Sink ArrayVar where +instance Sink (Var s) where {-# INLINEABLE weaken #-} - weaken k (ArrayVar ix) = ArrayVar (k ix) + weaken k (Var s ix) = Var s (k >:> ix) -instance Sink ArrayVars where +instance Sink (Vars s) where {-# INLINEABLE weaken #-} - weaken _ ArrayVarsNil = ArrayVarsNil - weaken k (ArrayVarsArray v) = ArrayVarsArray $ weaken k v - weaken k (ArrayVarsPair v w) = ArrayVarsPair (weaken k v) (weaken k w) + weaken _ VarsNil = VarsNil + weaken k (VarsSingle v) = VarsSingle $ weaken k v + weaken k (VarsPair v w) = VarsPair (weaken k v) (weaken k w) rebuildWeakenVar :: env :> env' -> ArrayVar env (Array sh e) -> PreOpenAcc acc env' (Array sh e) -rebuildWeakenVar k (ArrayVar idx) = Avar $ ArrayVar $ k idx +rebuildWeakenVar k (Var s idx) = Avar $ Var s $ k >:> idx + +rebuildWeakenEvar :: env :> env' -> ExpVar env t -> OpenExp env' aenv t +rebuildWeakenEvar k (Var s idx) = Evar $ Var s $ k >:> idx instance RebuildableAcc acc => Sink (PreOpenAcc acc) where {-# INLINEABLE weaken #-} @@ -263,19 +376,15 @@ instance RebuildableAcc acc => Sink (PreOpenAfun acc) where {-# INLINEABLE weaken #-} weaken k = Stats.substitution "weaken" . rebuildA (rebuildWeakenVar k) -instance RebuildableAcc acc => Sink (PreOpenExp acc env) where +instance Sink (OpenExp env) where {-# INLINEABLE weaken #-} - weaken k = Stats.substitution "weaken" . rebuildA (rebuildWeakenVar k) + weaken k = Stats.substitution "weaken" . runIdentity . rebuildOpenExp (Identity . Evar) (ReindexAvar (Identity . weaken k)) -instance RebuildableAcc acc => Sink (PreOpenFun acc env) where +instance Sink (OpenFun env) where {-# INLINEABLE weaken #-} - weaken k = Stats.substitution "weaken" . rebuildA (rebuildWeakenVar k) + weaken k = Stats.substitution "weaken" . runIdentity . rebuildFun (Identity . Evar) (ReindexAvar (Identity . weaken k)) -instance RebuildableAcc acc => Sink (RebuildTup acc env) where - {-# INLINEABLE weaken #-} - weaken k = Stats.substitution "weaken" . rebuildA (rebuildWeakenVar k) - -instance RebuildableAcc acc => Sink (PreBoundary acc) where +instance Sink Boundary where {-# INLINEABLE weaken #-} weaken k bndy = case bndy of @@ -305,13 +414,13 @@ class SinkExp f where -- default weakenE :: RebuildableExp f => env :> env' -> f env aenv t -> f env' aenv t -- weakenE v = Stats.substitution "weakenE" . rebuildE (IE . v) -instance RebuildableAcc acc => SinkExp (PreOpenExp acc) where +instance SinkExp OpenExp where {-# INLINEABLE weakenE #-} - weakenE v = Stats.substitution "weakenE" . rebuildE (IE . v) + weakenE v = Stats.substitution "weakenE" . rebuildE (rebuildWeakenEvar v) -instance RebuildableAcc acc => SinkExp (PreOpenFun acc) where +instance SinkExp OpenFun where {-# INLINEABLE weakenE #-} - weakenE v = Stats.substitution "weakenE" . rebuildE (IE . v) + weakenE v = Stats.substitution "weakenE" . rebuildE (rebuildWeakenEvar v) -- See above for why this is disabled. -- {-# RULES @@ -332,11 +441,27 @@ type env :?> env' = forall t'. Idx env t' -> Maybe (Idx env' t') {-# INLINEABLE strengthen #-} strengthen :: forall f env env' t. Rebuildable f => env :?> env' -> f env t -> Maybe (f env' t) -strengthen k x = Stats.substitution "strengthen" $ rebuildPartial @f @Maybe @IdxA (\(ArrayVar idx) -> fmap (IA . ArrayVar) $ k idx) x -- (\(ArrayVar idx) -> fmap (IA . ArrayVar) $ k idx) +strengthen k x = Stats.substitution "strengthen" $ rebuildPartial @f @Maybe @IdxA (\(Var s ix) -> fmap (IA . Var s) $ k ix) x {-# INLINEABLE strengthenE #-} -strengthenE :: RebuildableExp f => env :?> env' -> f env aenv t -> Maybe (f env' aenv t) -strengthenE k x = Stats.substitution "strengthenE" $ rebuildPartialE (fmap IE . k) x +strengthenE :: forall f env env' aenv t. RebuildableExp f => env :?> env' -> f env aenv t -> Maybe (f env' aenv t) +strengthenE k x = Stats.substitution "strengthenE" $ rebuildPartialE @f @Maybe @IdxE (\(Var tp ix) -> fmap (IE . Var tp) $ k ix) x + +strengthenWithLHS :: LeftHandSide s t env1 env2 -> env2 :?> env1 +strengthenWithLHS (LeftHandSideWildcard _) = Just +strengthenWithLHS (LeftHandSideSingle _) = \ix -> case ix of + ZeroIdx -> Nothing + SuccIdx i -> Just i +strengthenWithLHS (LeftHandSidePair l1 l2) = strengthenWithLHS l2 >=> strengthenWithLHS l1 + +strengthenAfter :: LeftHandSide s t env1 env2 -> LeftHandSide s t env1' env2' -> env1 :?> env1' -> env2 :?> env2' +strengthenAfter (LeftHandSideWildcard _) (LeftHandSideWildcard _) k = k +strengthenAfter (LeftHandSideSingle _) (LeftHandSideSingle _) k = \ix -> case ix of + ZeroIdx -> Just ZeroIdx + SuccIdx i -> SuccIdx <$> k i +strengthenAfter (LeftHandSidePair l1 l2) (LeftHandSidePair l1' l2') k + = strengthenAfter l2 l2' $ strengthenAfter l1 l1' k +strengthenAfter _ _ _ = error "Substitution.strengthenAfter: left hand sides do not match" -- Simultaneous Substitution =================================================== -- @@ -348,98 +473,90 @@ strengthenE k x = Stats.substitution "strengthenE" $ rebuildPartialE (fmap IE . -- SEE: [Weakening] -- class SyntacticExp f where - varIn :: Elt t => Idx env t -> f acc env aenv t - expOut :: Elt t => f acc env aenv t -> PreOpenExp acc env aenv t - weakenExp :: Elt t => RebuildAcc acc -> f acc env aenv t -> f acc (env, s) aenv t - -- weakenExpAcc :: Elt t => RebuildAcc acc -> f acc env aenv t -> f acc env (aenv, s) t + varIn :: ExpVar env t -> f env aenv t + expOut :: f env aenv t -> OpenExp env aenv t + weakenExp :: f env aenv t -> f (env, s) aenv t -newtype IdxE (acc :: Type -> Type -> Type) env aenv t = IE { unIE :: Idx env t } +newtype IdxE env aenv t = IE { unIE :: ExpVar env t } instance SyntacticExp IdxE where varIn = IE - expOut = Var . unIE - weakenExp _ = IE . SuccIdx . unIE - -- weakenExpAcc _ = IE . unIE + expOut = Evar . unIE + weakenExp (IE (Var tp ix)) = IE $ Var tp $ SuccIdx ix -instance SyntacticExp PreOpenExp where - varIn = Var +instance SyntacticExp OpenExp where + varIn = Evar expOut = id - weakenExp k = runIdentity . rebuildPreOpenExp k (Identity . weakenExp k . IE) (Identity . IA) - -- weakenExpAcc k = runIdentity . rebuildPreOpenExp k (Identity . IE) (Identity . weakenAcc k . IA) + weakenExp = runIdentity . rebuildOpenExp (Identity . weakenExp . IE) (ReindexAvar Identity) {-# INLINEABLE shiftE #-} shiftE - :: (Applicative f, SyntacticExp fe, Elt t) - => RebuildAcc acc - -> (forall t'. Elt t' => Idx env t' -> f (fe acc env' aenv t')) - -> Idx (env, s) t - -> f (fe acc (env', s) aenv t) -shiftE _ _ ZeroIdx = pure $ varIn ZeroIdx -shiftE k v (SuccIdx ix) = weakenExp k <$> (v ix) - -{-# INLINEABLE rebuildPreOpenExp #-} -rebuildPreOpenExp - :: (Applicative f, SyntacticExp fe, SyntacticAcc fa) - => RebuildAcc acc - -> (forall t'. Elt t' => Idx env t' -> f (fe acc env' aenv' t')) - -> RebuildAvar f fa acc aenv aenv' - -> PreOpenExp acc env aenv t - -> f (PreOpenExp acc env' aenv' t) -rebuildPreOpenExp k v av exp = + :: (Applicative f, SyntacticExp fe) + => RebuildEvar f fe env env' aenv + -> RebuildEvar f fe (env, s) (env', s) aenv +shiftE _ (Var tp ZeroIdx) = pure $ varIn (Var tp ZeroIdx) +shiftE v (Var tp (SuccIdx ix)) = weakenExp <$> v (Var tp ix) + +{-# INLINEABLE shiftE' #-} +shiftE' + :: (Applicative f, SyntacticExp fa) + => ELeftHandSide t env1 env1' + -> ELeftHandSide t env2 env2' + -> RebuildEvar f fa env1 env2 aenv + -> RebuildEvar f fa env1' env2' aenv +shiftE' (LeftHandSideWildcard _) (LeftHandSideWildcard _) v = v +shiftE' (LeftHandSideSingle _) (LeftHandSideSingle _) v = shiftE v +shiftE' (LeftHandSidePair a1 b1) (LeftHandSidePair a2 b2) v = shiftE' b1 b2 $ shiftE' a1 a2 v +shiftE' _ _ _ = error "Substitution: left hand sides do not match" + + +{-# INLINEABLE rebuildOpenExp #-} +rebuildOpenExp + :: (Applicative f, SyntacticExp fe) + => RebuildEvar f fe env env' aenv' + -> ReindexAvar f aenv aenv' + -> OpenExp env aenv t + -> f (OpenExp env' aenv' t) +rebuildOpenExp v av@(ReindexAvar reindex) exp = case exp of - Const c -> pure (Const c) - PrimConst c -> pure (PrimConst c) - Undef -> pure Undef - IndexNil -> pure IndexNil - IndexAny -> pure IndexAny - Var ix -> expOut <$> v ix - Let a b -> Let <$> rebuildPreOpenExp k v av a <*> rebuildPreOpenExp k (shiftE k v) av b - Tuple tup -> Tuple <$> rebuildTup k v av tup - Prj tup e -> Prj tup <$> rebuildPreOpenExp k v av e - IndexCons sh sz -> IndexCons <$> rebuildPreOpenExp k v av sh <*> rebuildPreOpenExp k v av sz - IndexHead sh -> IndexHead <$> rebuildPreOpenExp k v av sh - IndexTail sh -> IndexTail <$> rebuildPreOpenExp k v av sh - IndexSlice x ix sh -> IndexSlice x <$> rebuildPreOpenExp k v av ix <*> rebuildPreOpenExp k v av sh - IndexFull x ix sl -> IndexFull x <$> rebuildPreOpenExp k v av ix <*> rebuildPreOpenExp k v av sl - ToIndex sh ix -> ToIndex <$> rebuildPreOpenExp k v av sh <*> rebuildPreOpenExp k v av ix - FromIndex sh ix -> FromIndex <$> rebuildPreOpenExp k v av sh <*> rebuildPreOpenExp k v av ix - Cond p t e -> Cond <$> rebuildPreOpenExp k v av p <*> rebuildPreOpenExp k v av t <*> rebuildPreOpenExp k v av e - While p f x -> While <$> rebuildFun k v av p <*> rebuildFun k v av f <*> rebuildPreOpenExp k v av x - PrimApp f x -> PrimApp f <$> rebuildPreOpenExp k v av x - Index a sh -> Index <$> k av a <*> rebuildPreOpenExp k v av sh - LinearIndex a i -> LinearIndex <$> k av a <*> rebuildPreOpenExp k v av i - Shape a -> Shape <$> k av a - ShapeSize sh -> ShapeSize <$> rebuildPreOpenExp k v av sh - Intersect s t -> Intersect <$> rebuildPreOpenExp k v av s <*> rebuildPreOpenExp k v av t - Union s t -> Union <$> rebuildPreOpenExp k v av s <*> rebuildPreOpenExp k v av t - Foreign ff f e -> Foreign ff f <$> rebuildPreOpenExp k v av e - Coerce e -> Coerce <$> rebuildPreOpenExp k v av e - -{-# INLINEABLE rebuildTup #-} -rebuildTup - :: (Applicative f, SyntacticExp fe, SyntacticAcc fa) - => RebuildAcc acc - -> (forall t'. Elt t' => Idx env t' -> f (fe acc env' aenv' t')) - -> RebuildAvar f fa acc aenv aenv' - -> Tuple (PreOpenExp acc env aenv) t - -> f (Tuple (PreOpenExp acc env' aenv') t) -rebuildTup k v av tup = - case tup of - NilTup -> pure NilTup - SnocTup t e -> SnocTup <$> rebuildTup k v av t <*> rebuildPreOpenExp k v av e + Const t c -> pure $ Const t c + PrimConst c -> pure $ PrimConst c + Undef t -> pure $ Undef t + Evar var -> expOut <$> v var + Let lhs a b + | Exists lhs' <- rebuildLHS lhs + -> Let lhs' <$> rebuildOpenExp v av a <*> rebuildOpenExp (shiftE' lhs lhs' v) av b + Pair e1 e2 -> Pair <$> rebuildOpenExp v av e1 <*> rebuildOpenExp v av e2 + Nil -> pure $ Nil + VecPack vec e -> VecPack vec <$> rebuildOpenExp v av e + VecUnpack vec e -> VecUnpack vec <$> rebuildOpenExp v av e + IndexSlice x ix sh -> IndexSlice x <$> rebuildOpenExp v av ix <*> rebuildOpenExp v av sh + IndexFull x ix sl -> IndexFull x <$> rebuildOpenExp v av ix <*> rebuildOpenExp v av sl + ToIndex shr sh ix -> ToIndex shr <$> rebuildOpenExp v av sh <*> rebuildOpenExp v av ix + FromIndex shr sh ix -> FromIndex shr <$> rebuildOpenExp v av sh <*> rebuildOpenExp v av ix + Cond p t e -> Cond <$> rebuildOpenExp v av p <*> rebuildOpenExp v av t <*> rebuildOpenExp v av e + While p f x -> While <$> rebuildFun v av p <*> rebuildFun v av f <*> rebuildOpenExp v av x + PrimApp f x -> PrimApp f <$> rebuildOpenExp v av x + Index a sh -> Index <$> reindex a <*> rebuildOpenExp v av sh + LinearIndex a i -> LinearIndex <$> reindex a <*> rebuildOpenExp v av i + Shape a -> Shape <$> reindex a + ShapeSize shr sh -> ShapeSize shr <$> rebuildOpenExp v av sh + Foreign tp ff f e -> Foreign tp ff f <$> rebuildOpenExp v av e + Coerce t1 t2 e -> Coerce t1 t2 <$> rebuildOpenExp v av e {-# INLINEABLE rebuildFun #-} rebuildFun - :: (Applicative f, SyntacticExp fe, SyntacticAcc fa) - => RebuildAcc acc - -> (forall t'. Elt t' => Idx env t' -> f (fe acc env' aenv' t')) - -> RebuildAvar f fa acc aenv aenv' - -> PreOpenFun acc env aenv t - -> f (PreOpenFun acc env' aenv' t) -rebuildFun k v av fun = + :: (Applicative f, SyntacticExp fe) + => RebuildEvar f fe env env' aenv' + -> ReindexAvar f aenv aenv' + -> OpenFun env aenv t + -> f (OpenFun env' aenv' t) +rebuildFun v av fun = case fun of - Body e -> Body <$> rebuildPreOpenExp k v av e - Lam f -> Lam <$> rebuildFun k (shiftE k v) av f + Body e -> Body <$> rebuildOpenExp v av e + Lam lhs f + | Exists lhs' <- rebuildLHS lhs + -> Lam lhs' <$> rebuildFun (shiftE' lhs lhs' v) av f -- The array environment -- ----------------- @@ -458,9 +575,9 @@ class SyntacticAcc f where weakenAcc :: RebuildAcc acc -> f acc aenv (Array sh e) -> f acc (aenv, s) (Array sh e) instance SyntacticAcc IdxA where - avarIn = IA - accOut = Avar . unIA - weakenAcc _ (IA (ArrayVar idx)) = IA $ ArrayVar $ SuccIdx idx + avarIn = IA + accOut = Avar . unIA + weakenAcc _ (IA (Var s idx)) = IA $ Var s $ SuccIdx idx instance SyntacticAcc PreOpenAcc where avarIn = Avar @@ -470,25 +587,46 @@ instance SyntacticAcc PreOpenAcc where type RebuildAvar f (fa :: (Type -> Type -> Type) -> Type -> Type -> Type) acc aenv aenv' = forall sh e. ArrayVar aenv (Array sh e) -> f (fa acc aenv' (Array sh e)) +type RebuildEvar f fe env env' aenv' = + forall t'. ExpVar env t' -> f (fe env' aenv' t') + +newtype ReindexAvar f aenv aenv' = + ReindexAvar (forall sh e. ArrayVar aenv (Array sh e) -> f (ArrayVar aenv' (Array sh e))) + +reindexAvar + :: forall f fa acc aenv aenv'. + (Applicative f, SyntacticAcc fa) + => RebuildAvar f fa acc aenv aenv' + -> ReindexAvar f aenv aenv' +reindexAvar v = ReindexAvar f where + f :: forall sh e. ArrayVar aenv (Array sh e) -> f (ArrayVar aenv' (Array sh e)) + f var = g <$> v var + + g :: fa acc aenv' (Array sh e) -> ArrayVar aenv' (Array sh e) + g fa = case accOut fa of + Avar var' -> var' + _ -> $internalError "reindexAvar" "An Avar which was used in an Exp was mapped to an array term other than Avar. This mapping is invalid as an Exp can only contain array variables." + + {-# INLINEABLE shiftA #-} shiftA :: (Applicative f, SyntacticAcc fa) => RebuildAcc acc -> RebuildAvar f fa acc aenv aenv' - -> ArrayVar (aenv, s) (Array sh e) - -> f (fa acc (aenv', s) (Array sh e)) -shiftA _ _ (ArrayVar ZeroIdx) = pure $ avarIn $ ArrayVar ZeroIdx -shiftA k v (ArrayVar (SuccIdx ix)) = weakenAcc k <$> v (ArrayVar ix) + -> ArrayVar (aenv, s) (Array sh e) + -> f (fa acc (aenv', s) (Array sh e)) +shiftA _ _ (Var s ZeroIdx) = pure $ avarIn $ Var s ZeroIdx +shiftA k v (Var s (SuccIdx ix)) = weakenAcc k <$> v (Var s ix) shiftA' :: (Applicative f, SyntacticAcc fa) - => LeftHandSide t aenv1 aenv1' - -> LeftHandSide t aenv2 aenv2' + => ALeftHandSide t aenv1 aenv1' + -> ALeftHandSide t aenv2 aenv2' -> RebuildAcc acc -> RebuildAvar f fa acc aenv1 aenv2 -> RebuildAvar f fa acc aenv1' aenv2' shiftA' (LeftHandSideWildcard _) (LeftHandSideWildcard _) _ v = v -shiftA' LeftHandSideArray LeftHandSideArray k v = shiftA k v +shiftA' (LeftHandSideSingle _) (LeftHandSideSingle _) k v = shiftA k v shiftA' (LeftHandSidePair a1 b1) (LeftHandSidePair a2 b2) k v = shiftA' b1 b2 k $ shiftA' a1 a2 k v shiftA' _ _ _ _ = error "Substitution: left hand sides do not match" @@ -509,38 +647,40 @@ rebuildPreOpenAcc -> f (PreOpenAcc acc aenv' t) rebuildPreOpenAcc k av acc = case acc of - Use a -> pure (Use a) + Use repr a -> pure $ Use repr a Alet lhs a b -> rebuildAlet k av lhs a b Avar ix -> accOut <$> av ix Apair as bs -> Apair <$> k av as <*> k av bs Anil -> pure Anil - Apply f a -> Apply <$> rebuildAfun k av f <*> k av a - Acond p t e -> Acond <$> rebuildPreOpenExp k (pure . IE) av p <*> k av t <*> k av e + Apply repr f a -> Apply repr <$> rebuildAfun k av f <*> k av a + Acond p t e -> Acond <$> rebuildOpenExp (pure . IE) av' p <*> k av t <*> k av e Awhile p f a -> Awhile <$> rebuildAfun k av p <*> rebuildAfun k av f <*> k av a - Unit e -> Unit <$> rebuildPreOpenExp k (pure . IE) av e - Reshape e a -> Reshape <$> rebuildPreOpenExp k (pure . IE) av e <*> k av a - Generate e f -> Generate <$> rebuildPreOpenExp k (pure . IE) av e <*> rebuildFun k (pure . IE) av f - Transform sh ix f a -> Transform <$> rebuildPreOpenExp k (pure . IE) av sh <*> rebuildFun k (pure . IE) av ix <*> rebuildFun k (pure . IE) av f <*> k av a - Replicate sl slix a -> Replicate sl <$> rebuildPreOpenExp k (pure . IE) av slix <*> k av a - Slice sl a slix -> Slice sl <$> k av a <*> rebuildPreOpenExp k (pure . IE) av slix - Map f a -> Map <$> rebuildFun k (pure . IE) av f <*> k av a - ZipWith f a1 a2 -> ZipWith <$> rebuildFun k (pure . IE) av f <*> k av a1 <*> k av a2 - Fold f z a -> Fold <$> rebuildFun k (pure . IE) av f <*> rebuildPreOpenExp k (pure . IE) av z <*> k av a - Fold1 f a -> Fold1 <$> rebuildFun k (pure . IE) av f <*> k av a - FoldSeg f z a s -> FoldSeg <$> rebuildFun k (pure . IE) av f <*> rebuildPreOpenExp k (pure . IE) av z <*> k av a <*> k av s - Fold1Seg f a s -> Fold1Seg <$> rebuildFun k (pure . IE) av f <*> k av a <*> k av s - Scanl f z a -> Scanl <$> rebuildFun k (pure . IE) av f <*> rebuildPreOpenExp k (pure . IE) av z <*> k av a - Scanl' f z a -> Scanl' <$> rebuildFun k (pure . IE) av f <*> rebuildPreOpenExp k (pure . IE) av z <*> k av a - Scanl1 f a -> Scanl1 <$> rebuildFun k (pure . IE) av f <*> k av a - Scanr f z a -> Scanr <$> rebuildFun k (pure . IE) av f <*> rebuildPreOpenExp k (pure . IE) av z <*> k av a - Scanr' f z a -> Scanr' <$> rebuildFun k (pure . IE) av f <*> rebuildPreOpenExp k (pure . IE) av z <*> k av a - Scanr1 f a -> Scanr1 <$> rebuildFun k (pure . IE) av f <*> k av a - Permute f1 a1 f2 a2 -> Permute <$> rebuildFun k (pure . IE) av f1 <*> k av a1 <*> rebuildFun k (pure . IE) av f2 <*> k av a2 - Backpermute sh f a -> Backpermute <$> rebuildPreOpenExp k (pure . IE) av sh <*> rebuildFun k (pure . IE) av f <*> k av a - Stencil f b a -> Stencil <$> rebuildFun k (pure . IE) av f <*> rebuildBoundary k av b <*> k av a - Stencil2 f b1 a1 b2 a2 -> Stencil2 <$> rebuildFun k (pure . IE) av f <*> rebuildBoundary k av b1 <*> k av a1 <*> rebuildBoundary k av b2 <*> k av a2 + Unit tp e -> Unit tp <$> rebuildOpenExp (pure . IE) av' e + Reshape shr e a -> Reshape shr <$> rebuildOpenExp (pure . IE) av' e <*> k av a + Generate repr e f -> Generate repr <$> rebuildOpenExp (pure . IE) av' e <*> rebuildFun (pure . IE) av' f + Transform repr sh ix f a -> Transform repr <$> rebuildOpenExp (pure . IE) av' sh <*> rebuildFun (pure . IE) av' ix <*> rebuildFun (pure . IE) av' f <*> k av a + Replicate sl slix a -> Replicate sl <$> rebuildOpenExp (pure . IE) av' slix <*> k av a + Slice sl a slix -> Slice sl <$> k av a <*> rebuildOpenExp (pure . IE) av' slix + Map tp f a -> Map tp <$> rebuildFun (pure . IE) av' f <*> k av a + ZipWith tp f a1 a2 -> ZipWith tp <$> rebuildFun (pure . IE) av' f <*> k av a1 <*> k av a2 + Fold f z a -> Fold <$> rebuildFun (pure . IE) av' f <*> rebuildOpenExp (pure . IE) av' z <*> k av a + Fold1 f a -> Fold1 <$> rebuildFun (pure . IE) av' f <*> k av a + FoldSeg itp f z a s -> FoldSeg itp <$> rebuildFun (pure . IE) av' f <*> rebuildOpenExp (pure . IE) av' z <*> k av a <*> k av s + Fold1Seg itp f a s -> Fold1Seg itp <$> rebuildFun (pure . IE) av' f <*> k av a <*> k av s + Scanl f z a -> Scanl <$> rebuildFun (pure . IE) av' f <*> rebuildOpenExp (pure . IE) av' z <*> k av a + Scanl' f z a -> Scanl' <$> rebuildFun (pure . IE) av' f <*> rebuildOpenExp (pure . IE) av' z <*> k av a + Scanl1 f a -> Scanl1 <$> rebuildFun (pure . IE) av' f <*> k av a + Scanr f z a -> Scanr <$> rebuildFun (pure . IE) av' f <*> rebuildOpenExp (pure . IE) av' z <*> k av a + Scanr' f z a -> Scanr' <$> rebuildFun (pure . IE) av' f <*> rebuildOpenExp (pure . IE) av' z <*> k av a + Scanr1 f a -> Scanr1 <$> rebuildFun (pure . IE) av' f <*> k av a + Permute f1 a1 f2 a2 -> Permute <$> rebuildFun (pure . IE) av' f1 <*> k av a1 <*> rebuildFun (pure . IE) av' f2 <*> k av a2 + Backpermute shr sh f a -> Backpermute shr <$> rebuildOpenExp (pure . IE) av' sh <*> rebuildFun (pure . IE) av' f <*> k av a + Stencil sr tp f b a -> Stencil sr tp <$> rebuildFun (pure . IE) av' f <*> rebuildBoundary av' b <*> k av a + Stencil2 s1 s2 tp f b1 a1 b2 a2 -> Stencil2 s1 s2 tp <$> rebuildFun (pure . IE) av' f <*> rebuildBoundary av' b1 <*> k av a1 <*> rebuildBoundary av' b2 <*> k av a2 -- Collect seq -> Collect <$> rebuildSeq k av seq - Aforeign ff afun as -> Aforeign ff afun <$> k av as + Aforeign repr ff afun as -> Aforeign repr ff afun <$> k av as + where + av' = reindexAvar av {-# INLINEABLE rebuildAfun #-} rebuildAfun @@ -559,7 +699,7 @@ rebuildAlet :: forall f fa acc aenv1 aenv1' aenv2 bndArrs arrs. (Applicative f, SyntacticAcc fa) => RebuildAcc acc -> RebuildAvar f fa acc aenv1 aenv2 - -> LeftHandSide bndArrs aenv1 aenv1' + -> ALeftHandSide bndArrs aenv1 aenv1' -> acc aenv1 bndArrs -> acc aenv1' arrs -> f (PreOpenAcc acc aenv2 arrs) @@ -567,27 +707,26 @@ rebuildAlet k av lhs1 bind1 body1 = case rebuildLHS lhs1 of Exists lhs2 -> Alet lhs2 <$> k av bind1 <*> k (shiftA' lhs1 lhs2 k av) body1 {-# INLINEABLE rebuildLHS #-} -rebuildLHS :: LeftHandSide arr aenv1 aenv1' -> Exists (LeftHandSide arr aenv2) +rebuildLHS :: LeftHandSide s t aenv1 aenv1' -> Exists (LeftHandSide s t aenv2) rebuildLHS (LeftHandSideWildcard r) = Exists $ LeftHandSideWildcard r -rebuildLHS LeftHandSideArray = Exists $ LeftHandSideArray +rebuildLHS (LeftHandSideSingle s) = Exists $ LeftHandSideSingle s rebuildLHS (LeftHandSidePair as bs) = case rebuildLHS as of Exists as' -> case rebuildLHS bs of Exists bs' -> Exists $ LeftHandSidePair as' bs' {-# INLINEABLE rebuildBoundary #-} rebuildBoundary - :: (Applicative f, SyntacticAcc fa) - => RebuildAcc acc - -> RebuildAvar f fa acc aenv aenv' - -> PreBoundary acc aenv t - -> f (PreBoundary acc aenv' t) -rebuildBoundary k av bndy = + :: Applicative f + => ReindexAvar f aenv aenv' + -> Boundary aenv t + -> f (Boundary aenv' t) +rebuildBoundary av bndy = case bndy of Clamp -> pure Clamp Mirror -> pure Mirror Wrap -> pure Wrap Constant v -> pure (Constant v) - Function f -> Function <$> rebuildFun k (pure . IE) av f + Function f -> Function <$> rebuildFun (pure . IE) av f {-- {-# INLINEABLE rebuildSeq #-} @@ -616,7 +755,7 @@ rebuildP k v p = MapSeq f x -> MapSeq <$> rebuildAfun k v f <*> pure x ChunkedMapSeq f x -> ChunkedMapSeq <$> rebuildAfun k v f <*> pure x ZipWithSeq f x y -> ZipWithSeq <$> rebuildAfun k v f <*> pure x <*> pure y - ScanSeq f e x -> ScanSeq <$> rebuildFun k (pure . IE) v f <*> rebuildPreOpenExp k (pure . IE) v e <*> pure x + ScanSeq f e x -> ScanSeq <$> rebuildFun (pure . IE) v f <*> rebuildOpenExp (pure . IE) v e <*> pure x {-# INLINEABLE rebuildC #-} rebuildC :: forall acc fa f aenv aenv' senv a. (SyntacticAcc fa, Applicative f) @@ -626,7 +765,7 @@ rebuildC :: forall acc fa f aenv aenv' senv a. (SyntacticAcc fa, Applicative f) -> f (Consumer acc aenv' senv a) rebuildC k v c = case c of - FoldSeq f e x -> FoldSeq <$> rebuildFun k (pure . IE) v f <*> rebuildPreOpenExp k (pure . IE) v e <*> pure x + FoldSeq f e x -> FoldSeq <$> rebuildFun (pure . IE) v f <*> rebuildOpenExp (pure . IE) v e <*> pure x FoldSeqFlatten f acc x -> FoldSeqFlatten <$> rebuildAfun k v f <*> k v acc <*> pure x Stuple t -> Stuple <$> rebuildT t where @@ -635,3 +774,8 @@ rebuildC k v c = rebuildT (SnocAtup t s) = SnocAtup <$> (rebuildT t) <*> (rebuildC k v s) --} +extractExpVars :: OpenExp env aenv a -> Maybe (ExpVars env a) +extractExpVars Nil = Just VarsNil +extractExpVars (Pair e1 e2) = VarsPair <$> extractExpVars e1 <*> extractExpVars e2 +extractExpVars (Evar v) = Just $ VarsSingle v +extractExpVars _ = Nothing diff --git a/src/Data/Array/Accelerate/Trafo/Vectorise.hs b/src/Data/Array/Accelerate/Trafo/Vectorise.hs index 4333c1735..b61d5b7d2 100644 --- a/src/Data/Array/Accelerate/Trafo/Vectorise.hs +++ b/src/Data/Array/Accelerate/Trafo/Vectorise.hs @@ -1,7 +1,6 @@ {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE CPP #-} {-# LANGUAGE ConstraintKinds #-} -{-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} @@ -45,7 +44,6 @@ module Data.Array.Accelerate.Trafo.Vectorise ( import Prelude hiding ( exp, replicate, concat ) import qualified Prelude as P -import Data.Typeable import Control.Applicative hiding ( Const ) import Data.Maybe @@ -58,7 +56,6 @@ import Data.Array.Accelerate.Array.Sugar import Data.Array.Accelerate.Trafo.Base import Data.Array.Accelerate.Pretty () import Data.Array.Accelerate.Trafo.Substitution -import Data.Array.Accelerate.Product import Data.Array.Accelerate.Type import qualified Data.Array.Accelerate.Classes.Eq as S import qualified Data.Array.Accelerate.Language as S @@ -109,7 +106,7 @@ type VectoriseAcc acc = forall aenv aenv' t. -> LiftedAcc acc aenv' t data None sh = None sh - deriving (Typeable, Show, Eq) + deriving (Show, Eq) type instance EltRepr (None sh) = EltRepr sh @@ -124,12 +121,6 @@ instance Shape sh => Slice (None sh) where type FullShape (None sh) = sh sliceIndex _ = sliceNoneIndex (undefined :: sh) -instance Shape sh => IsProduct Elt (None sh) where - type ProdRepr (None sh) = ((),sh) - fromProd _ (None sh) = ((),sh) - toProd _ ((),sh) = None sh - prod _ _ = ProdRsnoc ProdRunit - -- Lifting terms -- ------------- diff --git a/src/Data/Array/Accelerate/Type.hs b/src/Data/Array/Accelerate/Type.hs index 869263a25..f7a1c0c49 100644 --- a/src/Data/Array/Accelerate/Type.hs +++ b/src/Data/Array/Accelerate/Type.hs @@ -2,7 +2,6 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE MagicHash #-} @@ -74,6 +73,7 @@ module Data.Array.Accelerate.Type ( ) where import Data.Orphans () -- orphan instances for 8-tuples and beyond +import Data.Array.Accelerate.Orphans () -- Prim Half import Control.Monad.ST import Data.Bits @@ -82,7 +82,6 @@ import Data.Primitive.ByteArray import Data.Primitive.Types import Data.Text.Prettyprint.Doc import Data.Type.Equality -import Data.Typeable import Data.Word import Foreign.C.Types import Foreign.Storable ( Storable ) @@ -124,29 +123,29 @@ data NonNumDict a where -- | Integral types supported in array computations. -- data IntegralType a where - TypeInt :: IntegralDict Int -> IntegralType Int - TypeInt8 :: IntegralDict Int8 -> IntegralType Int8 - TypeInt16 :: IntegralDict Int16 -> IntegralType Int16 - TypeInt32 :: IntegralDict Int32 -> IntegralType Int32 - TypeInt64 :: IntegralDict Int64 -> IntegralType Int64 - TypeWord :: IntegralDict Word -> IntegralType Word - TypeWord8 :: IntegralDict Word8 -> IntegralType Word8 - TypeWord16 :: IntegralDict Word16 -> IntegralType Word16 - TypeWord32 :: IntegralDict Word32 -> IntegralType Word32 - TypeWord64 :: IntegralDict Word64 -> IntegralType Word64 + TypeInt :: IntegralType Int + TypeInt8 :: IntegralType Int8 + TypeInt16 :: IntegralType Int16 + TypeInt32 :: IntegralType Int32 + TypeInt64 :: IntegralType Int64 + TypeWord :: IntegralType Word + TypeWord8 :: IntegralType Word8 + TypeWord16 :: IntegralType Word16 + TypeWord32 :: IntegralType Word32 + TypeWord64 :: IntegralType Word64 -- | Floating-point types supported in array computations. -- data FloatingType a where - TypeHalf :: FloatingDict Half -> FloatingType Half - TypeFloat :: FloatingDict Float -> FloatingType Float - TypeDouble :: FloatingDict Double -> FloatingType Double + TypeHalf :: FloatingType Half + TypeFloat :: FloatingType Float + TypeDouble :: FloatingType Double -- | Non-numeric types supported in array computations. -- data NonNumType a where - TypeBool :: NonNumDict Bool -> NonNumType Bool -- marshalled to Word8 - TypeChar :: NonNumDict Char -> NonNumType Char + TypeBool :: NonNumType Bool -- marshalled to Word8 + TypeChar :: NonNumType Char -- | Numeric element types implement Num & Real -- @@ -171,31 +170,31 @@ data SingleType a where NonNumSingleType :: NonNumType a -> SingleType a data VectorType a where - VectorType :: {-# UNPACK #-} !Int -> SingleType a -> VectorType (Vec n a) + VectorType :: KnownNat n => {-# UNPACK #-} !Int -> SingleType a -> VectorType (Vec n a) -- Showing type names -- instance Show (IntegralType a) where - show TypeInt{} = "Int" - show TypeInt8{} = "Int8" - show TypeInt16{} = "Int16" - show TypeInt32{} = "Int32" - show TypeInt64{} = "Int64" - show TypeWord{} = "Word" - show TypeWord8{} = "Word8" - show TypeWord16{} = "Word16" - show TypeWord32{} = "Word32" - show TypeWord64{} = "Word64" + show TypeInt = "Int" + show TypeInt8 = "Int8" + show TypeInt16 = "Int16" + show TypeInt32 = "Int32" + show TypeInt64 = "Int64" + show TypeWord = "Word" + show TypeWord8 = "Word8" + show TypeWord16 = "Word16" + show TypeWord32 = "Word32" + show TypeWord64 = "Word64" instance Show (FloatingType a) where - show TypeHalf{} = "Half" - show TypeFloat{} = "Float" - show TypeDouble{} = "Double" + show TypeHalf = "Half" + show TypeFloat = "Float" + show TypeDouble = "Double" instance Show (NonNumType a) where - show TypeBool{} = "Bool" - show TypeChar{} = "Char" + show TypeBool = "Bool" + show TypeChar = "Char" instance Show (NumType a) where show (IntegralNumType ty) = show ty @@ -216,7 +215,6 @@ instance Show (ScalarType a) where show (SingleScalarType ty) = show ty show (VectorScalarType ty) = show ty - -- Querying scalar type representations -- @@ -252,7 +250,7 @@ class IsScalar a => IsSingle a where -- | All scalar types -- -class Typeable a => IsScalar a where +class IsScalar a where scalarType :: ScalarType a @@ -260,51 +258,123 @@ class Typeable a => IsScalar a where -- integralDict :: IntegralType a -> IntegralDict a -integralDict (TypeInt dict) = dict -integralDict (TypeInt8 dict) = dict -integralDict (TypeInt16 dict) = dict -integralDict (TypeInt32 dict) = dict -integralDict (TypeInt64 dict) = dict -integralDict (TypeWord dict) = dict -integralDict (TypeWord8 dict) = dict -integralDict (TypeWord16 dict) = dict -integralDict (TypeWord32 dict) = dict -integralDict (TypeWord64 dict) = dict +integralDict TypeInt = IntegralDict +integralDict TypeInt8 = IntegralDict +integralDict TypeInt16 = IntegralDict +integralDict TypeInt32 = IntegralDict +integralDict TypeInt64 = IntegralDict +integralDict TypeWord = IntegralDict +integralDict TypeWord8 = IntegralDict +integralDict TypeWord16 = IntegralDict +integralDict TypeWord32 = IntegralDict +integralDict TypeWord64 = IntegralDict floatingDict :: FloatingType a -> FloatingDict a -floatingDict (TypeHalf dict) = dict -floatingDict (TypeFloat dict) = dict -floatingDict (TypeDouble dict) = dict +floatingDict TypeHalf = FloatingDict +floatingDict TypeFloat = FloatingDict +floatingDict TypeDouble = FloatingDict nonNumDict :: NonNumType a -> NonNumDict a -nonNumDict (TypeBool dict) = dict -nonNumDict (TypeChar dict) = dict - - --- Type representation +nonNumDict TypeBool = NonNumDict +nonNumDict TypeChar = NonNumDict + +showType :: TupleType tp -> ShowS +showType TupRunit = showString "()" +showType (TupRsingle tp) = showString $ showScalarType tp +showType (TupRpair t1 t2) = showString "(" . showType t1 . showString ", " . showType t2 . showString ")" + +showScalarType :: ScalarType tp -> String +showScalarType (SingleScalarType tp) = showSingleType tp +showScalarType (VectorScalarType (VectorType n tp)) = "Vec " ++ show n ++ " " ++ showSingleType tp + +showSingleType :: SingleType tp -> String +showSingleType (NumSingleType (IntegralNumType tp)) = case tp of + TypeInt -> "Int" + TypeInt8 -> "Int8" + TypeInt16 -> "Int16" + TypeInt32 -> "Int32" + TypeInt64 -> "Int64" + TypeWord -> "Word" + TypeWord8 -> "Word8" + TypeWord16 -> "Word16" + TypeWord32 -> "Word32" + TypeWord64 -> "Word64" +showSingleType (NumSingleType (FloatingNumType tp)) = case tp of + TypeHalf -> "Half" + TypeFloat -> "Float" + TypeDouble -> "Double" +showSingleType (NonNumSingleType TypeChar) = "Char" +showSingleType (NonNumSingleType TypeBool) = "Bool" + +-- Common used types in the compiler. +scalarTypeBool :: ScalarType Bool +scalarTypeBool = SingleScalarType $ NonNumSingleType TypeBool + +scalarTypeInt :: ScalarType Int +scalarTypeInt = SingleScalarType $ NumSingleType $ IntegralNumType TypeInt + +scalarTypeInt32 :: ScalarType Int32 +scalarTypeInt32 = SingleScalarType $ NumSingleType $ IntegralNumType TypeInt32 + +scalarTypeWord8 :: ScalarType Word8 +scalarTypeWord8 = SingleScalarType $ NumSingleType $ IntegralNumType TypeWord8 + +scalarTypeWord32 :: ScalarType Word32 +scalarTypeWord32 = SingleScalarType $ NumSingleType $ IntegralNumType TypeWord32 + +-- Tuple representation -- ------------------- -- --- Representation of product types, consisting of: +-- Both arrays (Acc) and expressions (Exp) may form tuples. These are represented +-- using as product types, consisting of: -- -- * unit (void) -- --- * scalar types: values which go in registers. These may be single value +-- * single array / scalar types +-- in case of expressions: values which go in registers. These may be single value -- types such as int and float, or SIMD vectors of single value types such -- as <4 * float>. We do not allow vectors-of-vectors. -- -- * pairs: representing compound values (i.e. tuples) where each component -- will be stored in a separate array. -- -data TupleType a where - TypeRunit :: TupleType () - TypeRscalar :: ScalarType a -> TupleType a - TypeRpair :: TupleType a -> TupleType b -> TupleType (a, b) +data TupR s a where + TupRunit :: TupR s () + TupRsingle :: s a -> TupR s a + TupRpair :: TupR s a -> TupR s b -> TupR s (a, b) + +type TupleType = TupR ScalarType -- Rename to EltR? + +instance Show (TupR ScalarType a) where + show TupRunit = "()" + show (TupRsingle t) = show t + show (TupRpair a b) = "(" ++ show a ++ "," ++ show b ++")" + +type Tup2 a b = (((), a), b) +type Tup3 a b c = ((((), a), b), c) +type Tup4 a b c d = (((((), a), b), c), d) +type Tup5 a b c d e = ((((((), a), b), c), d), e) +type Tup6 a b c d e f = (((((((), a), b), c), d), e), f) +type Tup7 a b c d e f g = ((((((((), a), b), c), d), e), f), g) +type Tup8 a b c d e f g h = (((((((((), a), b), c), d), e), f), g), h) +type Tup9 a b c d e f g h i = ((((((((((), a), b), c), d), e), f), g), h), i) +type Tup16 a b c d e f g h + i j k l m n o p = (((((((((((((((((), a), b), c), d), e), f), g), h), i), j), k), l), m), n), o), p) + +tupR2 :: TupR s t1 -> TupR s t2 -> TupR s (Tup2 t1 t2) +tupR2 t1 t2 = TupRunit `TupRpair` t1 `TupRpair` t2 + +tupR3 :: TupR s t1 -> TupR s t2 -> TupR s t3 -> TupR s (Tup3 t1 t2 t3) +tupR3 t1 t2 t3 = TupRunit `TupRpair` t1 `TupRpair` t2 `TupRpair` t3 -instance Show (TupleType a) where - show TypeRunit = "()" - show (TypeRscalar t) = show t - show (TypeRpair a b) = printf "(%s,%s)" (show a) (show b) +tupR5 :: TupR s t1 -> TupR s t2 -> TupR s t3 -> TupR s t4 -> TupR s t5 -> TupR s (Tup5 t1 t2 t3 t4 t5) +tupR5 t1 t2 t3 t4 t5 = TupRunit `TupRpair` t1 `TupRpair` t2 `TupRpair` t3 `TupRpair` t4 `TupRpair` t5 +tupR7 :: TupR s t1 -> TupR s t2 -> TupR s t3 -> TupR s t4 -> TupR s t5 -> TupR s t6 -> TupR s t7 -> TupR s (Tup7 t1 t2 t3 t4 t5 t6 t7) +tupR7 t1 t2 t3 t4 t5 t6 t7 = TupRunit `TupRpair` t1 `TupRpair` t2 `TupRpair` t3 `TupRpair` t4 `TupRpair` t5 `TupRpair` t6 `TupRpair` t7 + +tupR9 :: TupR s t1 -> TupR s t2 -> TupR s t3 -> TupR s t4 -> TupR s t5 -> TupR s t6 -> TupR s t7 -> TupR s t8 -> TupR s t9 -> TupR s (Tup9 t1 t2 t3 t4 t5 t6 t7 t8 t9) +tupR9 t1 t2 t3 t4 t5 t6 t7 t8 t9 = TupRunit `TupRpair` t1 `TupRpair` t2 `TupRpair` t3 `TupRpair` t4 `TupRpair` t5 `TupRpair` t6 `TupRpair` t7 `TupRpair` t8 `TupRpair` t9 -- Type-level bit sizes -- -------------------- @@ -352,27 +422,52 @@ type family BitSize a :: Nat -- which redundant for our use case (derivable from type level information). -- data Vec (n::Nat) a = Vec ByteArray# - deriving Typeable type role Vec nominal representational instance (Show a, Prim a, KnownNat n) => Show (Vec n a) where - show (Vec ba#) = vec (go 0#) + show = vec . vecToArray where vec :: [a] -> String vec = show . group . encloseSep (flatAlt "< " "<") (flatAlt " >" ">") ", " . map viaShow - -- - go :: Int# -> [a] - go i# | isTrue# (i# <# n#) = indexByteArray# ba# i# : go (i# +# 1#) - | otherwise = [] - -- - !(I# n#) = fromIntegral (natVal' (proxy# :: Proxy# n)) + +vecToArray :: forall a n. (Prim a, KnownNat n) => Vec n a -> [a] +vecToArray (Vec ba#) = go 0# + where + go :: Int# -> [a] + go i# | isTrue# (i# <# n#) = indexByteArray# ba# i# : go (i# +# 1#) + | otherwise = [] + + !(I# n#) = fromIntegral (natVal' (proxy# :: Proxy# n)) instance Eq (Vec n a) where Vec ba1# == Vec ba2# = ByteArray ba1# == ByteArray ba2# +data PrimDict a where + PrimDict :: Prim a => PrimDict a + +getPrim :: SingleType a -> PrimDict a +getPrim (NumSingleType (IntegralNumType tp)) = case tp of + TypeInt -> PrimDict + TypeInt8 -> PrimDict + TypeInt16 -> PrimDict + TypeInt32 -> PrimDict + TypeInt64 -> PrimDict + TypeWord -> PrimDict + TypeWord8 -> PrimDict + TypeWord16 -> PrimDict + TypeWord32 -> PrimDict + TypeWord64 -> PrimDict +getPrim (NumSingleType (FloatingNumType tp)) = case tp of + TypeHalf -> PrimDict + TypeFloat -> PrimDict + TypeDouble -> PrimDict +getPrim (NonNumSingleType TypeChar) = PrimDict +getPrim (NonNumSingleType TypeBool) = error "prim: We don't support vector of bools yet" + + -- Type synonyms for common SIMD vector types -- @@ -571,7 +666,7 @@ $(runQ $ do mkIntegral :: Name -> Integer -> Q [Dec] mkIntegral t n = [d| instance IsIntegral $(conT t) where - integralType = $(conE (mkName ("Type" ++ nameBase t))) IntegralDict + integralType = $(conE (mkName ("Type" ++ nameBase t))) instance IsNum $(conT t) where numType = IntegralNumType integralType @@ -591,7 +686,7 @@ $(runQ $ do mkFloating :: Name -> Integer -> Q [Dec] mkFloating t n = [d| instance IsFloating $(conT t) where - floatingType = $(conE (mkName ("Type" ++ nameBase t))) FloatingDict + floatingType = $(conE (mkName ("Type" ++ nameBase t))) instance IsNum $(conT t) where numType = FloatingNumType floatingType @@ -608,7 +703,7 @@ $(runQ $ do mkNonNum :: Name -> Integer -> Q [Dec] mkNonNum t n = [d| instance IsNonNum $(conT t) where - nonNumType = $(conE (mkName ("Type" ++ nameBase t))) NonNumDict + nonNumType = $(conE (mkName ("Type" ++ nameBase t))) instance IsBounded $(conT t) where boundedType = NonNumBoundedType nonNumType diff --git a/src/Data/Array/Accelerate/Unsafe.hs b/src/Data/Array/Accelerate/Unsafe.hs index 82824e28f..4701618e9 100644 --- a/src/Data/Array/Accelerate/Unsafe.hs +++ b/src/Data/Array/Accelerate/Unsafe.hs @@ -1,3 +1,5 @@ +{-# LANGUAGE MonoLocalBinds #-} +{-# LANGUAGE FlexibleContexts #-} -- | -- Module : Data.Array.Accelerate.Unsafe -- Copyright : [2009..2019] The Accelerate Team @@ -15,7 +17,7 @@ module Data.Array.Accelerate.Unsafe ( -- ** Unsafe operations - undef, coerce, + undef, coerce, Coerce ) where @@ -39,11 +41,10 @@ import Data.Array.Accelerate.Smart -- abstract type to the concrete type by dropping the extra @()@ from the -- representation, and vice-versa. -- --- You will get a runtime error if it fails to find a coercion between the two --- representations. +-- The type class 'Coerce' assures that there is a coercion between the two +-- types. -- -- @since 1.2.0.0 -- -coerce :: (Elt a, Elt b) => Exp a -> Exp b -coerce = mkUnsafeCoerce - +coerce :: Coerce (EltRepr a) (EltRepr b) => Exp a -> Exp b +coerce = mkCoerce