diff --git a/src/Data/Array/Accelerate/AST.hs b/src/Data/Array/Accelerate/AST.hs index 704083907..b201c0784 100644 --- a/src/Data/Array/Accelerate/AST.hs +++ b/src/Data/Array/Accelerate/AST.hs @@ -92,7 +92,7 @@ module Data.Array.Accelerate.AST ( -- * Accelerated array expressions PreOpenAfun(..), OpenAfun, PreAfun, Afun, PreOpenAcc(..), OpenAcc(..), Acc, - PreBoundary(..), Boundary, StencilR(..), + Boundary(..), StencilR(..), HasArraysRepr(..), arrayRepr, lhsToTupR, PairIdx(..), ArrayR(..), ArraysR, ShapeR(..), SliceIndex(..), VecR(..), vecRvector, vecRtuple, @@ -101,19 +101,19 @@ module Data.Array.Accelerate.AST ( -- Producer(..), Consumer(..), -- * Scalar expressions - PreOpenFun(..), OpenFun, PreFun, Fun, PreOpenExp(..), OpenExp, PreExp, Exp, PrimConst(..), + OpenFun(..), Fun, OpenExp(..), Exp, PrimConst(..), PrimFun(..), expType, primConstType, primFunType, -- NFData NFDataAcc, - rnfPreOpenAfun, rnfPreOpenAcc, rnfPreOpenFun, rnfPreOpenExp, + rnfPreOpenAfun, rnfPreOpenAcc, rnfOpenFun, rnfOpenExp, rnfArrays, rnfArrayR, -- TemplateHaskell LiftAcc, liftIdx, liftConst, liftSliceIndex, liftPrimConst, liftPrimFun, - liftPreOpenAfun, liftPreOpenAcc, liftPreOpenFun, liftPreOpenExp, + liftPreOpenAfun, liftPreOpenAcc, liftOpenFun, liftOpenExp, liftALhs, liftELhs, liftArray, liftArraysR, liftTupleType, liftArrayR, liftScalarType, liftShapeR, liftVecR, liftIntegralType, @@ -319,7 +319,7 @@ data Vars s env t where VarsNil :: Vars s aenv () VarsPair :: Vars s aenv a -> Vars s aenv b -> Vars s aenv (a, b) -evars :: ExpVars env tp -> PreOpenExp acc env aenv tp +evars :: ExpVars env tp -> OpenExp env aenv tp evars VarsNil = Nil evars (VarsSingle var) = Evar var evars (VarsPair v1 v2) = evars v1 `Pair` evars v2 @@ -348,7 +348,7 @@ varsType (VarsPair v1 v2) = varsType v1 `TupRpair` varsType v2 -- We use a non-recursive variant parametrised over the recursive closure, -- to facilitate attribute calculation in the backend. -- -data PreOpenAcc acc aenv a where +data PreOpenAcc (acc :: Type -> Type -> Type) aenv a where -- Local non-recursive binding to represent sharing and demand -- explicitly. Note this is an eager binding! @@ -394,7 +394,7 @@ data PreOpenAcc acc aenv a where -- If-then-else for array-level computations -- - Acond :: PreExp acc aenv Bool + Acond :: Exp aenv Bool -> acc aenv arrs -> acc aenv arrs -> PreOpenAcc acc aenv arrs @@ -417,7 +417,7 @@ data PreOpenAcc acc aenv a where -- Capture a scalar (or a tuple of scalars) in a singleton array -- Unit :: TupleType e - -> PreExp acc aenv e + -> Exp aenv e -> PreOpenAcc acc aenv (Scalar e) -- Change the shape of an array without altering its contents. @@ -426,24 +426,24 @@ data PreOpenAcc acc aenv a where -- > dim == size dim' -- Reshape :: ShapeR sh - -> PreExp acc aenv sh -- new shape + -> Exp aenv sh -- new shape -> acc aenv (Array sh' e) -- array to be reshaped -> PreOpenAcc acc aenv (Array sh e) -- Construct a new array by applying a function to each index. -- Generate :: ArrayR (Array sh e) - -> PreExp acc aenv sh -- output shape - -> PreFun acc aenv (sh -> e) -- representation function + -> Exp aenv sh -- output shape + -> Fun aenv (sh -> e) -- representation function -> PreOpenAcc acc aenv (Array sh e) -- Hybrid map/backpermute, where we separate the index and value -- transformations. -- Transform :: ArrayR (Array sh' b) - -> PreExp acc aenv sh' -- dimension of the result - -> PreFun acc aenv (sh' -> sh) -- index permutation function - -> PreFun acc aenv (a -> b) -- function to apply at each element + -> Exp aenv sh' -- dimension of the result + -> Fun aenv (sh' -> sh) -- index permutation function + -> Fun aenv (a -> b) -- function to apply at each element -> acc aenv (Array sh a) -- source array -> PreOpenAcc acc aenv (Array sh' b) @@ -454,7 +454,7 @@ data PreOpenAcc acc aenv a where sl co sh - -> PreExp acc aenv slix -- slice value specification + -> Exp aenv slix -- slice value specification -> acc aenv (Array sl e) -- data to be replicated -> PreOpenAcc acc aenv (Array sh e) @@ -466,13 +466,13 @@ data PreOpenAcc acc aenv a where co sh -> acc aenv (Array sh e) -- array to be indexed - -> PreExp acc aenv slix -- slice value specification + -> Exp aenv slix -- slice value specification -> PreOpenAcc acc aenv (Array sl e) -- Apply the given unary function to all elements of the given array -- Map :: TupleType e' - -> PreFun acc aenv (e -> e') + -> Fun aenv (e -> e') -> acc aenv (Array sh e) -> PreOpenAcc acc aenv (Array sh e') @@ -481,7 +481,7 @@ data PreOpenAcc acc aenv a where -- two argument arrays. -- ZipWith :: TupleType e3 - -> PreFun acc aenv (e1 -> e2 -> e3) + -> Fun aenv (e1 -> e2 -> e3) -> acc aenv (Array sh e1) -> acc aenv (Array sh e2) -> PreOpenAcc acc aenv (Array sh e3) @@ -489,14 +489,14 @@ data PreOpenAcc acc aenv a where -- Fold along the innermost dimension of an array with a given -- /associative/ function. -- - Fold :: PreFun acc aenv (e -> e -> e) -- combination function - -> PreExp acc aenv e -- default value + Fold :: Fun aenv (e -> e -> e) -- combination function + -> Exp aenv e -- default value -> acc aenv (Array (sh, Int) e) -- folded array -> PreOpenAcc acc aenv (Array sh e) -- As 'Fold' without a default value -- - Fold1 :: PreFun acc aenv (e -> e -> e) -- combination function + Fold1 :: Fun aenv (e -> e -> e) -- combination function -> acc aenv (Array (sh, Int) e) -- folded array -> PreOpenAcc acc aenv (Array sh e) @@ -504,8 +504,8 @@ data PreOpenAcc acc aenv a where -- /associative/ function -- FoldSeg :: IntegralType i - -> PreFun acc aenv (e -> e -> e) -- combination function - -> PreExp acc aenv e -- default value + -> Fun aenv (e -> e -> e) -- combination function + -> Exp aenv e -- default value -> acc aenv (Array (sh, Int) e) -- folded array -> acc aenv (Segments i) -- segment descriptor -> PreOpenAcc acc aenv (Array (sh, Int) e) @@ -513,7 +513,7 @@ data PreOpenAcc acc aenv a where -- As 'FoldSeg' without a default value -- Fold1Seg :: IntegralType i - -> PreFun acc aenv (e -> e -> e) -- combination function + -> Fun aenv (e -> e -> e) -- combination function -> acc aenv (Array (sh, Int) e) -- folded array -> acc aenv (Segments i) -- segment descriptor -> PreOpenAcc acc aenv (Array (sh, Int) e) @@ -522,8 +522,8 @@ data PreOpenAcc acc aenv a where -- /associative/ function and an initial element (which does not need to -- be the neutral of the associative operations) -- - Scanl :: PreFun acc aenv (e -> e -> e) -- combination function - -> PreExp acc aenv e -- initial value + Scanl :: Fun aenv (e -> e -> e) -- combination function + -> Exp aenv e -- initial value -> acc aenv (Array (sh, Int) e) -> PreOpenAcc acc aenv (Array (sh, Int) e) @@ -531,34 +531,34 @@ data PreOpenAcc acc aenv a where -- same length as the input array (the fold value would be the rightmost -- element in a Haskell-style scan) -- - Scanl' :: PreFun acc aenv (e -> e -> e) -- combination function - -> PreExp acc aenv e -- initial value + Scanl' :: Fun aenv (e -> e -> e) -- combination function + -> Exp aenv e -- initial value -> acc aenv (Array (sh, Int) e) -> PreOpenAcc acc aenv (Array (sh, Int) e, Array sh e) -- Haskell-style scan without an initial value -- - Scanl1 :: PreFun acc aenv (e -> e -> e) -- combination function + Scanl1 :: Fun aenv (e -> e -> e) -- combination function -> acc aenv (Array (sh, Int) e) -> PreOpenAcc acc aenv (Array (sh, Int) e) -- Right-to-left version of 'Scanl' -- - Scanr :: PreFun acc aenv (e -> e -> e) -- combination function - -> PreExp acc aenv e -- initial value + Scanr :: Fun aenv (e -> e -> e) -- combination function + -> Exp aenv e -- initial value -> acc aenv (Array (sh, Int) e) -> PreOpenAcc acc aenv (Array (sh, Int) e) -- Right-to-left version of 'Scanl\'' -- - Scanr' :: PreFun acc aenv (e -> e -> e) -- combination function - -> PreExp acc aenv e -- initial value + Scanr' :: Fun aenv (e -> e -> e) -- combination function + -> Exp aenv e -- initial value -> acc aenv (Array (sh, Int) e) -> PreOpenAcc acc aenv (Array (sh, Int) e, Array sh e) -- Right-to-left version of 'Scanl1' -- - Scanr1 :: PreFun acc aenv (e -> e -> e) -- combination function + Scanr1 :: Fun aenv (e -> e -> e) -- combination function -> acc aenv (Array (sh, Int) e) -> PreOpenAcc acc aenv (Array (sh, Int) e) @@ -582,9 +582,9 @@ data PreOpenAcc acc aenv a where -- function is used to combine elements, which needs to be /associative/ -- and /commutative/. -- - Permute :: PreFun acc aenv (e -> e -> e) -- combination function + Permute :: Fun aenv (e -> e -> e) -- combination function -> acc aenv (Array sh' e) -- default values - -> PreFun acc aenv (sh -> sh') -- permutation function + -> Fun aenv (sh -> sh') -- permutation function -> acc aenv (Array sh e) -- source array -> PreOpenAcc acc aenv (Array sh' e) @@ -592,8 +592,8 @@ data PreOpenAcc acc aenv a where -- be between arrays of varying shape; the permutation function must be total -- Backpermute :: ShapeR sh' - -> PreExp acc aenv sh' -- dimensions of the result - -> PreFun acc aenv (sh' -> sh) -- permutation function + -> Exp aenv sh' -- dimensions of the result + -> Fun aenv (sh' -> sh) -- permutation function -> acc aenv (Array sh e) -- source array -> PreOpenAcc acc aenv (Array sh' e) @@ -602,8 +602,8 @@ data PreOpenAcc acc aenv a where -- Stencil :: StencilR sh e stencil -> TupleType e' - -> PreFun acc aenv (stencil -> e') -- stencil function - -> PreBoundary acc aenv (Array sh e) -- boundary condition + -> Fun aenv (stencil -> e') -- stencil function + -> Boundary aenv (Array sh e) -- boundary condition -> acc aenv (Array sh e) -- source array -> PreOpenAcc acc aenv (Array sh e') @@ -612,10 +612,10 @@ data PreOpenAcc acc aenv a where Stencil2 :: StencilR sh a stencil1 -> StencilR sh b stencil2 -> TupleType c - -> PreFun acc aenv (stencil1 -> stencil2 -> c) -- stencil function - -> PreBoundary acc aenv (Array sh a) -- boundary condition #1 + -> Fun aenv (stencil1 -> stencil2 -> c) -- stencil function + -> Boundary aenv (Array sh a) -- boundary condition #1 -> acc aenv (Array sh a) -- source array #1 - -> PreBoundary acc aenv (Array sh b) -- boundary condition #2 + -> Boundary aenv (Array sh b) -- boundary condition #2 -> acc aenv (Array sh b) -- source array #2 -> PreOpenAcc acc aenv (Array sh c) @@ -742,29 +742,25 @@ type Seq = PreOpenSeq OpenAcc () () --} --- | Vanilla stencil boundary condition +-- | Vanilla boundary condition specification for stencil operations -- -type Boundary = PreBoundary OpenAcc - --- | Boundary condition specification for stencil operations --- -data PreBoundary acc aenv t where +data Boundary aenv t where -- Clamp coordinates to the extent of the array - Clamp :: PreBoundary acc aenv t + Clamp :: Boundary aenv t -- Mirror coordinates beyond the array extent - Mirror :: PreBoundary acc aenv t + Mirror :: Boundary aenv t -- Wrap coordinates around on each dimension - Wrap :: PreBoundary acc aenv t + Wrap :: Boundary aenv t -- Use a constant value for outlying coordinates Constant :: e - -> PreBoundary acc aenv (Array sh e) + -> Boundary aenv (Array sh e) -- Apply the given function to outlying coordinates - Function :: PreFun acc aenv (sh -> e) - -> PreBoundary acc aenv (Array sh e) + Function :: Fun aenv (sh -> e) + -> Boundary aenv (Array sh e) data PairIdx p a where PairIdxLeft :: PairIdx (a, b) a @@ -828,161 +824,145 @@ instance HasArraysRepr OpenAcc where -- Embedded expressions -- -------------------- --- |Parametrised open function abstraction --- -data PreOpenFun acc env aenv t where - Body :: PreOpenExp acc env aenv t -> PreOpenFun acc env aenv t - Lam :: ELeftHandSide a env env' -> PreOpenFun acc env' aenv t -> PreOpenFun acc env aenv (a -> t) - -- |Vanilla open function abstraction -- -type OpenFun = PreOpenFun OpenAcc - --- |Parametrised function without free scalar variables --- -type PreFun acc = PreOpenFun acc () +data OpenFun env aenv t where + Body :: OpenExp env aenv t -> OpenFun env aenv t + Lam :: ELeftHandSide a env env' -> OpenFun env' aenv t -> OpenFun env aenv (a -> t) -- |Vanilla function without free scalar variables -- type Fun = OpenFun () --- |Vanilla open expression --- -type OpenExp = PreOpenExp OpenAcc - --- |Parametrised expression without free scalar variables --- -type PreExp acc = PreOpenExp acc () - -- |Vanilla expression without free scalar variables -- type Exp = OpenExp () --- |Parametrised open expressions using de Bruijn indices for variables ranging over tuples +-- |Vanilla open expressions using de Bruijn indices for variables ranging over tuples -- of scalars and arrays of tuples. All code, except Cond, is evaluated eagerly. N-tuples are -- represented as nested pairs. -- -- The data type is parametrised over the representation type (not the surface types). -- -data PreOpenExp acc env aenv t where +data OpenExp env aenv t where -- Local binding of a scalar expression Let :: ELeftHandSide bnd_t env env' - -> PreOpenExp acc env aenv bnd_t - -> PreOpenExp acc env' aenv body_t - -> PreOpenExp acc env aenv body_t + -> OpenExp env aenv bnd_t + -> OpenExp env' aenv body_t + -> OpenExp env aenv body_t -- Variable index, ranging only over tuples or scalars Evar :: ExpVar env t - -> PreOpenExp acc env aenv t + -> OpenExp env aenv t -- Apply a backend-specific foreign function Foreign :: Sugar.Foreign asm => TupleType y - -> asm (x -> y) -- foreign function - -> PreFun acc () (x -> y) -- alternate implementation (for other backends) - -> PreOpenExp acc env aenv x - -> PreOpenExp acc env aenv y + -> asm (x -> y) -- foreign function + -> Fun () (x -> y) -- alternate implementation (for other backends) + -> OpenExp env aenv x + -> OpenExp env aenv y -- Tuples - Pair :: PreOpenExp acc env aenv t1 - -> PreOpenExp acc env aenv t2 - -> PreOpenExp acc env aenv (t1, t2) + Pair :: OpenExp env aenv t1 + -> OpenExp env aenv t2 + -> OpenExp env aenv (t1, t2) - Nil :: PreOpenExp acc env aenv () + Nil :: OpenExp env aenv () -- SIMD vectors VecPack :: KnownNat n => VecR n s tup - -> PreOpenExp acc env aenv tup - -> PreOpenExp acc env aenv (Vec n s) + -> OpenExp env aenv tup + -> OpenExp env aenv (Vec n s) VecUnpack :: KnownNat n => VecR n s tup - -> PreOpenExp acc env aenv (Vec n s) - -> PreOpenExp acc env aenv tup + -> OpenExp env aenv (Vec n s) + -> OpenExp env aenv tup -- Array indices & shapes IndexSlice :: SliceIndex slix sl co sh - -> PreOpenExp acc env aenv slix - -> PreOpenExp acc env aenv sh - -> PreOpenExp acc env aenv sl + -> OpenExp env aenv slix + -> OpenExp env aenv sh + -> OpenExp env aenv sl IndexFull :: SliceIndex slix sl co sh - -> PreOpenExp acc env aenv slix - -> PreOpenExp acc env aenv sl - -> PreOpenExp acc env aenv sh + -> OpenExp env aenv slix + -> OpenExp env aenv sl + -> OpenExp env aenv sh -- Shape and index conversion ToIndex :: ShapeR sh - -> PreOpenExp acc env aenv sh -- shape of the array - -> PreOpenExp acc env aenv sh -- index into the array - -> PreOpenExp acc env aenv Int + -> OpenExp env aenv sh -- shape of the array + -> OpenExp env aenv sh -- index into the array + -> OpenExp env aenv Int FromIndex :: ShapeR sh - -> PreOpenExp acc env aenv sh -- shape of the array - -> PreOpenExp acc env aenv Int -- index into linear representation - -> PreOpenExp acc env aenv sh + -> OpenExp env aenv sh -- shape of the array + -> OpenExp env aenv Int -- index into linear representation + -> OpenExp env aenv sh -- Conditional expression (non-strict in 2nd and 3rd argument) - Cond :: PreOpenExp acc env aenv Bool - -> PreOpenExp acc env aenv t - -> PreOpenExp acc env aenv t - -> PreOpenExp acc env aenv t + Cond :: OpenExp env aenv Bool + -> OpenExp env aenv t + -> OpenExp env aenv t + -> OpenExp env aenv t -- Value recursion - While :: PreOpenFun acc env aenv (a -> Bool) -- continue while true - -> PreOpenFun acc env aenv (a -> a) -- function to iterate - -> PreOpenExp acc env aenv a -- initial value - -> PreOpenExp acc env aenv a + While :: OpenFun env aenv (a -> Bool) -- continue while true + -> OpenFun env aenv (a -> a) -- function to iterate + -> OpenExp env aenv a -- initial value + -> OpenExp env aenv a -- Constant values Const :: ScalarType t -> t - -> PreOpenExp acc env aenv t + -> OpenExp env aenv t PrimConst :: PrimConst t - -> PreOpenExp acc env aenv t + -> OpenExp env aenv t -- Primitive scalar operations PrimApp :: PrimFun (a -> r) - -> PreOpenExp acc env aenv a - -> PreOpenExp acc env aenv r + -> OpenExp env aenv a + -> OpenExp env aenv r -- Project a single scalar from an array. -- The array expression can not contain any free scalar variables. - Index :: acc aenv (Array dim t) - -> PreOpenExp acc env aenv dim - -> PreOpenExp acc env aenv t + Index :: ArrayVar aenv (Array dim t) + -> OpenExp env aenv dim + -> OpenExp env aenv t - LinearIndex :: acc aenv (Array dim t) - -> PreOpenExp acc env aenv Int - -> PreOpenExp acc env aenv t + LinearIndex :: ArrayVar aenv (Array dim t) + -> OpenExp env aenv Int + -> OpenExp env aenv t -- Array shape. -- The array expression can not contain any free scalar variables. - Shape :: acc aenv (Array dim e) - -> PreOpenExp acc env aenv dim + Shape :: ArrayVar aenv (Array dim e) + -> OpenExp env aenv dim -- Number of elements of an array given its shape ShapeSize :: ShapeR dim - -> PreOpenExp acc env aenv dim - -> PreOpenExp acc env aenv Int + -> OpenExp env aenv dim + -> OpenExp env aenv Int -- Unsafe operations (may fail or result in undefined behaviour) -- An unspecified bit pattern Undef :: ScalarType t - -> PreOpenExp acc env aenv t + -> OpenExp env aenv t -- Reinterpret the bits of a value as a different type Coerce :: BitSizeEq a b => ScalarType a -> ScalarType b - -> PreOpenExp acc env aenv a - -> PreOpenExp acc env aenv b + -> OpenExp env aenv a + -> OpenExp env aenv b -expType :: HasArraysRepr acc => PreOpenExp acc aenv env t -> TupleType t +expType :: OpenExp aenv env t -> TupleType t expType expr = case expr of Let _ _ body -> expType body Evar (Var tp _) -> TupRsingle tp @@ -1001,9 +981,9 @@ expType expr = case expr of Const tp _ -> TupRsingle tp PrimConst c -> TupRsingle $ SingleScalarType $ primConstType c PrimApp f _ -> snd $ primFunType f - Index a _ -> arrayRtype $ arrayRepr a - LinearIndex a _ -> arrayRtype $ arrayRepr a - Shape a -> shapeType $ arrayRshape $ arrayRepr a + Index (Var repr _) _ -> arrayRtype repr + LinearIndex (Var repr _) _ -> arrayRtype repr + Shape (Var repr _) -> shapeType $ arrayRshape repr ShapeSize _ _ -> TupRsingle $ SingleScalarType $ NumSingleType $ IntegralNumType $ TypeInt Undef tp -> TupRsingle tp Coerce _ tp _ -> TupRsingle tp @@ -1274,10 +1254,10 @@ instance NFData (OpenAcc aenv t) where -- rnf = rnfPreOpenSeq rnfOpenAcc instance NFData (OpenExp env aenv t) where - rnf = rnfPreOpenExp rnfOpenAcc + rnf = rnfOpenExp instance NFData (OpenFun env aenv t) where - rnf = rnfPreOpenFun rnfOpenAcc + rnf = rnfOpenFun -- Array expressions @@ -1305,21 +1285,21 @@ rnfPreOpenAcc rnfA pacc = rnfAF :: PreOpenAfun acc aenv' t' -> () rnfAF = rnfPreOpenAfun rnfA - rnfE :: PreOpenExp acc env' aenv' t' -> () - rnfE = rnfPreOpenExp rnfA + rnfE :: OpenExp env' aenv' t' -> () + rnfE = rnfOpenExp - rnfF :: PreOpenFun acc env' aenv' t' -> () - rnfF = rnfPreOpenFun rnfA + rnfF :: OpenFun env' aenv' t' -> () + rnfF = rnfOpenFun -- rnfS :: PreOpenSeq acc aenv' senv' t' -> () -- rnfS = rnfPreOpenSeq rnfA - rnfB :: ArrayR (Array sh e) -> PreBoundary acc aenv' (Array sh e) -> () - rnfB = rnfBoundary rnfA + rnfB :: ArrayR (Array sh e) -> Boundary aenv' (Array sh e) -> () + rnfB = rnfBoundary in case pacc of Alet lhs bnd body -> rnfALhs lhs `seq` rnfA bnd `seq` rnfA body - Avar (Var repr ix) -> rnfArrayR repr `seq` rnfIdx ix + Avar var -> rnfArrayVar var Apair as bs -> rnfA as `seq` rnfA bs Anil -> () Apply repr afun acc -> rnfTupR rnfArrayR repr `seq` rnfAF afun `seq` rnfA acc @@ -1360,6 +1340,9 @@ rnfPreOpenAcc rnfA pacc = in rnfStencilR sr1 `seq` rnfStencilR sr2 `seq` rnfTupR rnfScalarType tp `seq` rnfF f `seq` rnfB repr1 b1 `seq` rnfB repr2 b2 `seq` rnfA a1 `seq` rnfA a2 -- Collect s -> rnfS s +rnfArrayVar :: ArrayVar aenv a -> () +rnfArrayVar (Var repr ix) = rnfArrayR repr `seq` rnfIdx ix + rnfLhs :: (forall b. s b -> ()) -> LeftHandSide s arrs env env' -> () rnfLhs rnfS (LeftHandSideWildcard r) = rnfTupR rnfS r rnfLhs rnfS (LeftHandSideSingle s) = rnfS s @@ -1404,12 +1387,12 @@ rnfStencilR (StencilRtup9 s1 s2 s3 s4 s5 s6 s7 s8 s9) = rnfStencilR s1 `seq` rnfStencilR s2 `seq` rnfStencilR s3 `seq` rnfStencilR s4 `seq` rnfStencilR s5 `seq` rnfStencilR s6 `seq` rnfStencilR s7 `seq` rnfStencilR s8 `seq` rnfStencilR s9 -rnfBoundary :: forall acc aenv sh e. NFDataAcc acc -> ArrayR (Array sh e) -> PreBoundary acc aenv (Array sh e) -> () -rnfBoundary _ _ Clamp = () -rnfBoundary _ _ Mirror = () -rnfBoundary _ _ Wrap = () -rnfBoundary _ (ArrayR _ tp) (Constant c) = rnfConst tp c -rnfBoundary rnfA _ (Function f) = rnfPreOpenFun rnfA f +rnfBoundary :: forall aenv sh e. ArrayR (Array sh e) -> Boundary aenv (Array sh e) -> () +rnfBoundary _ Clamp = () +rnfBoundary _ Mirror = () +rnfBoundary _ Wrap = () +rnfBoundary (ArrayR _ tp) (Constant c) = rnfConst tp c +rnfBoundary _ (Function f) = rnfOpenFun f @@ -1444,11 +1427,11 @@ rnfSeqProducer rnfA topSeq = rnfAF :: PreOpenAfun acc aenv' t' -> () rnfAF = rnfPreOpenAfun rnfA - rnfF :: PreOpenFun acc env' aenv' t' -> () - rnfF = rnfPreOpenFun rnfA + rnfF :: OpenFun env' aenv' t' -> () + rnfF = rnfOpenFun rnfA - rnfE :: PreOpenExp acc env' aenv' t' -> () - rnfE = rnfPreOpenExp rnfA + rnfE :: OpenExp env' aenv' t' -> () + rnfE = rnfOpenExp rnfA in case topSeq of StreamIn as -> rnfArrs as @@ -1464,11 +1447,11 @@ rnfSeqConsumer rnfA topSeq = rnfAF :: PreOpenAfun acc aenv' t' -> () rnfAF = rnfPreOpenAfun rnfA - rnfF :: PreOpenFun acc env' aenv' t' -> () - rnfF = rnfPreOpenFun rnfA + rnfF :: OpenFun env' aenv' t' -> () + rnfF = rnfOpenFun rnfA - rnfE :: PreOpenExp acc env' aenv' t' -> () - rnfE = rnfPreOpenExp rnfA + rnfE :: OpenExp env' aenv' t' -> () + rnfE = rnfOpenExp rnfA in case topSeq of FoldSeq f z ix -> rnfF f `seq` rnfE z `seq` rnfIdx ix @@ -1483,18 +1466,18 @@ rnfStuple rnfA (SnocAtup tup c) = rnfStuple rnfA tup `seq` rnfSeqConsumer rnfA c -- Scalar expressions -- ------------------ -rnfPreOpenFun :: NFDataAcc acc -> PreOpenFun acc env aenv t -> () -rnfPreOpenFun rnfA (Body b) = rnfPreOpenExp rnfA b -rnfPreOpenFun rnfA (Lam lhs f) = rnfELhs lhs `seq` rnfPreOpenFun rnfA f +rnfOpenFun :: OpenFun env aenv t -> () +rnfOpenFun (Body b) = rnfOpenExp b +rnfOpenFun (Lam lhs f) = rnfELhs lhs `seq` rnfOpenFun f -rnfPreOpenExp :: forall acc env aenv t. NFDataAcc acc -> PreOpenExp acc env aenv t -> () -rnfPreOpenExp rnfA topExp = +rnfOpenExp :: forall env aenv t. OpenExp env aenv t -> () +rnfOpenExp topExp = let - rnfF :: PreOpenFun acc env' aenv' t' -> () - rnfF = rnfPreOpenFun rnfA + rnfF :: OpenFun env' aenv' t' -> () + rnfF = rnfOpenFun - rnfE :: PreOpenExp acc env' aenv' t' -> () - rnfE = rnfPreOpenExp rnfA + rnfE :: OpenExp env' aenv' t' -> () + rnfE = rnfOpenExp in case topExp of Let lhs bnd body -> rnfELhs lhs `seq` rnfE bnd `seq` rnfE body @@ -1514,9 +1497,9 @@ rnfPreOpenExp rnfA topExp = While p f x -> rnfF p `seq` rnfF f `seq` rnfE x PrimConst c -> rnfPrimConst c PrimApp f x -> rnfPrimFun f `seq` rnfE x - Index a ix -> rnfA a `seq` rnfE ix - LinearIndex a ix -> rnfA a `seq` rnfE ix - Shape a -> rnfA a + Index a ix -> rnfArrayVar a `seq` rnfE ix + LinearIndex a ix -> rnfArrayVar a `seq` rnfE ix + Shape a -> rnfArrayVar a ShapeSize shr sh -> rnfShapeR shr `seq` rnfE sh Coerce t1 t2 e -> rnfScalarType t1 `seq` rnfScalarType t2 `seq` rnfE e @@ -1671,22 +1654,22 @@ liftPreOpenAcc -> Q (TExp (PreOpenAcc acc aenv a)) liftPreOpenAcc liftA pacc = let - liftE :: PreOpenExp acc env aenv t -> Q (TExp (PreOpenExp acc env aenv t)) - liftE = liftPreOpenExp liftA + liftE :: OpenExp env aenv t -> Q (TExp (OpenExp env aenv t)) + liftE = liftOpenExp - liftF :: PreOpenFun acc env aenv t -> Q (TExp (PreOpenFun acc env aenv t)) - liftF = liftPreOpenFun liftA + liftF :: OpenFun env aenv t -> Q (TExp (OpenFun env aenv t)) + liftF = liftOpenFun liftAF :: PreOpenAfun acc aenv f -> Q (TExp (PreOpenAfun acc aenv f)) liftAF = liftPreOpenAfun liftA - liftB :: ArrayR (Array sh e) -> PreBoundary acc aenv (Array sh e) -> Q (TExp (PreBoundary acc aenv (Array sh e))) - liftB = liftBoundary liftA + liftB :: ArrayR (Array sh e) -> Boundary aenv (Array sh e) -> Q (TExp (Boundary aenv (Array sh e))) + liftB = liftBoundary in case pacc of Alet lhs bnd body -> [|| Alet $$(liftALhs lhs) $$(liftA bnd) $$(liftA body) ||] - Avar (Var tp ix) -> [|| Avar (Var $$(liftArrayR tp) $$(liftIdx ix)) ||] + Avar var -> [|| Avar $$(liftArrayVar var) ||] Apair as bs -> [|| Apair $$(liftA as) $$(liftA bs) ||] Anil -> [|| Anil ||] Apply repr f a -> [|| Apply $$(liftArraysR repr) $$(liftAF f) $$(liftA a) ||] @@ -1764,30 +1747,28 @@ liftStencilR (StencilRtup9 s1 s2 s3 s4 s5 s6 s7 s8 s9) = [|| StencilRtup9 $$(liftStencilR s1) $$(liftStencilR s2) $$(liftStencilR s3) $$(liftStencilR s4) $$(liftStencilR s5) $$(liftStencilR s6) $$(liftStencilR s7) $$(liftStencilR s8) $$(liftStencilR s9) ||] -liftPreOpenFun - :: LiftAcc acc - -> PreOpenFun acc env aenv t - -> Q (TExp (PreOpenFun acc env aenv t)) -liftPreOpenFun liftA (Lam lhs f) = [|| Lam $$(liftELhs lhs) $$(liftPreOpenFun liftA f) ||] -liftPreOpenFun liftA (Body b) = [|| Body $$(liftPreOpenExp liftA b) ||] - -liftPreOpenExp - :: forall acc env aenv t. - LiftAcc acc - -> PreOpenExp acc env aenv t - -> Q (TExp (PreOpenExp acc env aenv t)) -liftPreOpenExp liftA pexp = +liftOpenFun + :: OpenFun env aenv t + -> Q (TExp (OpenFun env aenv t)) +liftOpenFun (Lam lhs f) = [|| Lam $$(liftELhs lhs) $$(liftOpenFun f) ||] +liftOpenFun (Body b) = [|| Body $$(liftOpenExp b) ||] + +liftOpenExp + :: forall env aenv t. + OpenExp env aenv t + -> Q (TExp (OpenExp env aenv t)) +liftOpenExp pexp = let - liftE :: PreOpenExp acc env aenv e -> Q (TExp (PreOpenExp acc env aenv e)) - liftE = liftPreOpenExp liftA + liftE :: OpenExp env aenv e -> Q (TExp (OpenExp env aenv e)) + liftE = liftOpenExp - liftF :: PreOpenFun acc env aenv f -> Q (TExp (PreOpenFun acc env aenv f)) - liftF = liftPreOpenFun liftA + liftF :: OpenFun env aenv f -> Q (TExp (OpenFun env aenv f)) + liftF = liftOpenFun in case pexp of - Let lhs bnd body -> [|| Let $$(liftELhs lhs) $$(liftPreOpenExp liftA bnd) $$(liftPreOpenExp liftA body) ||] + Let lhs bnd body -> [|| Let $$(liftELhs lhs) $$(liftOpenExp bnd) $$(liftOpenExp body) ||] Evar var -> [|| Evar $$(liftExpVar var) ||] - Foreign repr asm f x -> [|| Foreign $$(liftTupleType repr) $$(Sugar.liftForeign asm) $$(liftPreOpenFun liftA f) $$(liftE x) ||] + Foreign repr asm f x -> [|| Foreign $$(liftTupleType repr) $$(Sugar.liftForeign asm) $$(liftOpenFun f) $$(liftE x) ||] Const tp c -> [|| Const $$(liftScalarType tp) $$(liftConst (TupRsingle tp) c) ||] Undef tp -> [|| Undef $$(liftScalarType tp) ||] Pair a b -> [|| Pair $$(liftE a) $$(liftE b) ||] @@ -1802,15 +1783,18 @@ liftPreOpenExp liftA pexp = While p f x -> [|| While $$(liftF p) $$(liftF f) $$(liftE x) ||] PrimConst t -> [|| PrimConst $$(liftPrimConst t) ||] PrimApp f x -> [|| PrimApp $$(liftPrimFun f) $$(liftE x) ||] - Index a ix -> [|| Index $$(liftA a) $$(liftE ix) ||] - LinearIndex a ix -> [|| LinearIndex $$(liftA a) $$(liftE ix) ||] - Shape a -> [|| Shape $$(liftA a) ||] + Index a ix -> [|| Index $$(liftArrayVar a) $$(liftE ix) ||] + LinearIndex a ix -> [|| LinearIndex $$(liftArrayVar a) $$(liftE ix) ||] + Shape a -> [|| Shape $$(liftArrayVar a) ||] ShapeSize shr ix -> [|| ShapeSize $$(liftShapeR shr) $$(liftE ix) ||] Coerce t1 t2 e -> [|| Coerce $$(liftScalarType t1) $$(liftScalarType t2) $$(liftE e) ||] liftExpVar :: ExpVar env t -> Q (TExp (ExpVar env t)) liftExpVar (Var tp ix) = [|| Var $$(liftScalarType tp) $$(liftIdx ix) ||] +liftArrayVar :: ArrayVar aenv a -> Q (TExp (ArrayVar aenv a)) +liftArrayVar (Var repr ix) = [|| Var $$(liftArrayR repr) $$(liftIdx ix) ||] + liftArray :: forall sh e. ArrayR (Array sh e) -> Array sh e -> Q (TExp (Array sh e)) liftArray (ArrayR shr tp) (Array sh adata) = [|| Array $$(liftConst (shapeType shr) sh) $$(go tp adata) ||] `sigE` [t| Array $(typeToQType $ shapeType shr) $(typeToQType tp) |] @@ -1914,16 +1898,15 @@ liftArray (ArrayR shr tp) (Array sh adata) = goVector (NonNumSingleType TypeBool) = arr liftBoundary - :: forall acc aenv sh e. - LiftAcc acc - -> ArrayR (Array sh e) - -> PreBoundary acc aenv (Array sh e) - -> Q (TExp (PreBoundary acc aenv (Array sh e))) -liftBoundary _ _ Clamp = [|| Clamp ||] -liftBoundary _ _ Mirror = [|| Mirror ||] -liftBoundary _ _ Wrap = [|| Wrap ||] -liftBoundary _ (ArrayR _ tp) (Constant v) = [|| Constant $$(liftConst tp v) ||] -liftBoundary liftA _ (Function f) = [|| Function $$(liftPreOpenFun liftA f) ||] + :: forall aenv sh e. + ArrayR (Array sh e) + -> Boundary aenv (Array sh e) + -> Q (TExp (Boundary aenv (Array sh e))) +liftBoundary _ Clamp = [|| Clamp ||] +liftBoundary _ Mirror = [|| Mirror ||] +liftBoundary _ Wrap = [|| Wrap ||] +liftBoundary (ArrayR _ tp) (Constant v) = [|| Constant $$(liftConst tp v) ||] +liftBoundary _ (Function f) = [|| Function $$(liftOpenFun f) ||] liftSliceIndex :: SliceIndex ix slice coSlice sliceDim -> Q (TExp (SliceIndex ix slice coSlice sliceDim)) liftSliceIndex SliceNil = [|| SliceNil ||] @@ -2174,7 +2157,7 @@ showShortendArr repr@(ArrayR _ tp) arr elements = intercalate ", " $ map (showElement tp) $ take cutoff l -showPreExpOp :: forall acc aenv env t. PreOpenExp acc aenv env t -> String +showPreExpOp :: forall aenv env t. OpenExp aenv env t -> String showPreExpOp Let{} = "Let" showPreExpOp (Evar (Var _ ix)) = "Var x" ++ show (idxToInt ix) showPreExpOp (Const tp c) = "Const " ++ showElement (TupRsingle tp) c diff --git a/src/Data/Array/Accelerate/Analysis/Hash.hs b/src/Data/Array/Accelerate/Analysis/Hash.hs index 862cbee09..f1eb6b5e1 100644 --- a/src/Data/Array/Accelerate/Analysis/Hash.hs +++ b/src/Data/Array/Accelerate/Analysis/Hash.hs @@ -23,14 +23,13 @@ module Data.Array.Accelerate.Analysis.Hash ( Hash, HashOptions(..), defaultHashOptions, hashPreOpenAcc, hashPreOpenAccWith, - hashPreOpenFun, hashPreOpenFunWith, - hashPreOpenExp, hashPreOpenExpWith, + hashOpenFun, hashOpenExp, -- auxiliary EncodeAcc, encodePreOpenAcc, - encodePreOpenExp, - encodePreOpenFun, + encodeOpenExp, + encodeOpenFun, encodeArraysType, hashQ, @@ -95,14 +94,6 @@ defaultHashOptions = HashOptions True hashPreOpenAcc :: HasArraysRepr acc => EncodeAcc acc -> PreOpenAcc acc aenv a -> Hash hashPreOpenAcc = hashPreOpenAccWith defaultHashOptions -{-# INLINEABLE hashPreOpenFun #-} -hashPreOpenFun :: HasArraysRepr acc => EncodeAcc acc -> PreOpenFun acc env aenv f -> Hash -hashPreOpenFun = hashPreOpenFunWith defaultHashOptions - -{-# INLINEABLE hashPreOpenExp #-} -hashPreOpenExp :: HasArraysRepr acc => EncodeAcc acc -> PreOpenExp acc env aenv t -> Hash -hashPreOpenExp = hashPreOpenExpWith defaultHashOptions - {-# INLINEABLE hashPreOpenAccWith #-} hashPreOpenAccWith :: HasArraysRepr acc => HashOptions -> EncodeAcc acc -> PreOpenAcc acc aenv a -> Hash hashPreOpenAccWith options encodeAcc @@ -110,19 +101,19 @@ hashPreOpenAccWith options encodeAcc . toLazyByteString . encodePreOpenAcc options encodeAcc -{-# INLINEABLE hashPreOpenFunWith #-} -hashPreOpenFunWith :: HasArraysRepr acc => HashOptions -> EncodeAcc acc -> PreOpenFun acc env aenv f -> Hash -hashPreOpenFunWith options encodeAcc +{-# INLINEABLE hashOpenFun #-} +hashOpenFun :: OpenFun env aenv f -> Hash +hashOpenFun = hashlazy . toLazyByteString - . encodePreOpenFun options encodeAcc + . encodeOpenFun -{-# INLINEABLE hashPreOpenExpWith #-} -hashPreOpenExpWith :: HasArraysRepr acc => HashOptions -> EncodeAcc acc -> PreOpenExp acc env aenv t -> Hash -hashPreOpenExpWith options encodeAcc +{-# INLINEABLE hashOpenExp #-} +hashOpenExp :: OpenExp env aenv t -> Hash +hashOpenExp = hashlazy . toLazyByteString - . encodePreOpenExp options encodeAcc + . encodeOpenExp -- Array computations @@ -145,20 +136,17 @@ encodePreOpenAcc options encodeAcc pacc = travAF :: PreOpenAfun acc aenv' f -> Builder travAF = encodePreOpenAfun options encodeAcc - travE :: PreOpenExp acc env' aenv' e -> Builder - travE = encodePreOpenExp options encodeAcc + travE :: OpenExp env' aenv' e -> Builder + travE = encodeOpenExp - travF :: PreOpenFun acc env' aenv' f -> Builder - travF = encodePreOpenFun options encodeAcc - - travB :: TupleType e -> PreBoundary acc aenv' (Array sh e) -> Builder - travB = encodePreBoundary options encodeAcc + travF :: OpenFun env' aenv' f -> Builder + travF = encodeOpenFun deep :: Builder -> Builder deep | perfect options = id | otherwise = const mempty - deepE :: forall env' aenv' e. PreOpenExp acc env' aenv' e -> Builder + deepE :: forall env' aenv' e. OpenExp env' aenv' e -> Builder deepE e | perfect options = travE e | otherwise = encodeTupleType $ expType e @@ -195,8 +183,8 @@ encodePreOpenAcc options encodeAcc pacc = Scanr' f e a -> intHost $(hashQ "Scanr'") <> travF f <> travE e <> travA a Scanr1 f a -> intHost $(hashQ "Scanr1") <> travF f <> travA a Permute f1 a1 f2 a2 -> intHost $(hashQ "Permute") <> travF f1 <> travA a1 <> travF f2 <> travA a2 - Stencil s _ f b a -> intHost $(hashQ "Stencil") <> travF f <> travB (stencilElt s) b <> travA a - Stencil2 s1 s2 _ f b1 a1 b2 a2 -> intHost $(hashQ "Stencil2") <> travF f <> travB (stencilElt s1) b1 <> travA a1 <> travB (stencilElt s2) b2 <> travA a2 + Stencil s _ f b a -> intHost $(hashQ "Stencil") <> travF f <> encodeBoundary (stencilElt s) b <> travA a + Stencil2 s1 s2 _ f b1 a1 b2 a2 -> intHost $(hashQ "Stencil2") <> travF f <> encodeBoundary (stencilElt s1) b1 <> travA a1 <> encodeBoundary (stencilElt s2) b2 <> travA a2 {-- {-# INLINEABLE encodePreOpenSeq #-} @@ -206,14 +194,14 @@ encodePreOpenSeq encodeAcc s = travA :: acc aenv' a -> Builder travA = encodeAcc -- XXX: plus type information? - travE :: PreOpenExp acc env' aenv' e -> Builder - travE = encodePreOpenExp encodeAcc + travE :: OpenExp env' aenv' e -> Builder + travE = encodeOpenExp encodeAcc travAF :: PreOpenAfun acc aenv' f -> Builder travAF = encodePreOpenAfun encodeAcc - travF :: PreOpenFun acc env' aenv' f -> Builder - travF = encodePreOpenFun encodeAcc + travF :: OpenFun env' aenv' f -> Builder + travF = encodeOpenFun encodeAcc travS :: PreOpenSeq acc aenv senv' arrs' -> Builder travS = encodePreOpenSeq encodeAcc @@ -285,18 +273,15 @@ encodePreOpenAfun options travA afun = Alam lhs l -> intHost $(hashQ "Alam") <> travL lhs l -encodePreBoundary - :: forall acc aenv sh e. - HashOptions - -> EncodeAcc acc - -> TupleType e - -> PreBoundary acc aenv (Array sh e) +encodeBoundary + :: TupleType e + -> Boundary aenv (Array sh e) -> Builder -encodePreBoundary _ _ _ Wrap = intHost $(hashQ "Wrap") -encodePreBoundary _ _ _ Clamp = intHost $(hashQ "Clamp") -encodePreBoundary _ _ _ Mirror = intHost $(hashQ "Mirror") -encodePreBoundary _ _ tp (Constant c) = intHost $(hashQ "Constant") <> encodeConst tp c -encodePreBoundary o h _ (Function f) = intHost $(hashQ "Function") <> encodePreOpenFun o h f +encodeBoundary _ Wrap = intHost $(hashQ "Wrap") +encodeBoundary _ Clamp = intHost $(hashQ "Clamp") +encodeBoundary _ Mirror = intHost $(hashQ "Mirror") +encodeBoundary tp (Constant c) = intHost $(hashQ "Constant") <> encodeConst tp c +encodeBoundary _ (Function f) = intHost $(hashQ "Function") <> encodeOpenFun f encodeSliceIndex :: SliceIndex slix sl co sh -> Builder encodeSliceIndex SliceNil = intHost $(hashQ "SliceNil") @@ -307,31 +292,18 @@ encodeSliceIndex (SliceFixed r) = intHost $(hashQ "sliceFixed") <> encodeSlice -- Scalar expressions -- ------------------ -{-# INLINEABLE encodePreOpenExp #-} -encodePreOpenExp - :: forall acc env aenv exp. - HashOptions - -> EncodeAcc acc - -> PreOpenExp acc env aenv exp +{-# INLINEABLE encodeOpenExp #-} +encodeOpenExp + :: forall env aenv exp. + OpenExp env aenv exp -> Builder -encodePreOpenExp options encodeAcc exp = +encodeOpenExp exp = let - -- XXX: Temporary fix for hashing expressions which only depend on - -- free array variables. For the code generating backends it will - -- never pick up expressions which differ only at free array - -- variables. We know that this will always be an Avar (we depend on - -- array expressions being floated out already) so we should change - -- this in the AST. This problem occurred in the Quickhull program. - -- -- TLM 2020-01-08 - -- - travA :: forall aenv' a. acc aenv' a -> Builder - travA a = encodeAcc (options {perfect=True}) a - - travE :: forall env' aenv' e. PreOpenExp acc env' aenv' e -> Builder - travE e = encodePreOpenExp options encodeAcc e + travE :: forall env' aenv' e. OpenExp env' aenv' e -> Builder + travE e = encodeOpenExp e - travF :: PreOpenFun acc env' aenv' f -> Builder - travF = encodePreOpenFun options encodeAcc + travF :: OpenFun env' aenv' f -> Builder + travF = encodeOpenFun in case exp of Let lhs bnd body -> intHost $(hashQ "Let") <> encodeLeftHandSide encodeScalarType lhs <> travE bnd <> travE body @@ -350,32 +322,22 @@ encodePreOpenExp options encodeAcc exp = While p f x -> intHost $(hashQ "While") <> travF p <> travF f <> travE x PrimApp f x -> intHost $(hashQ "PrimApp") <> encodePrimFun f <> travE x PrimConst c -> intHost $(hashQ "PrimConst") <> encodePrimConst c - Index a ix -> intHost $(hashQ "Index") <> travA a <> travE ix - LinearIndex a ix -> intHost $(hashQ "LinearIndex") <> travA a <> travE ix - Shape a -> intHost $(hashQ "Shape") <> travA a + Index a ix -> intHost $(hashQ "Index") <> encodeArrayVar a <> travE ix + LinearIndex a ix -> intHost $(hashQ "LinearIndex") <> encodeArrayVar a <> travE ix + Shape a -> intHost $(hashQ "Shape") <> encodeArrayVar a ShapeSize _ sh -> intHost $(hashQ "ShapeSize") <> travE sh Foreign _ _ f e -> intHost $(hashQ "Foreign") <> travF f <> travE e Coerce _ tp e -> intHost $(hashQ "Coerce") <> encodeScalarType tp <> travE e +encodeArrayVar :: ArrayVar aenv a -> Builder +encodeArrayVar (Var repr v) = encodeArrayType repr <> encodeIdx v -{-# INLINEABLE encodePreOpenFun #-} -encodePreOpenFun - :: forall acc env aenv f. - HashOptions - -> EncodeAcc acc - -> PreOpenFun acc env aenv f +{-# INLINEABLE encodeOpenFun #-} +encodeOpenFun + :: OpenFun env aenv f -> Builder -encodePreOpenFun options travA fun = - let - travB :: forall env' aenv' e. PreOpenExp acc env' aenv' e -> Builder - travB b = encodePreOpenExp options travA b - - travL :: forall env' aenv' b. PreOpenFun acc env' aenv' b -> Builder - travL l = encodePreOpenFun options travA l - in - case fun of - Body b -> intHost $(hashQ "Body") <> travB b - Lam lhs l -> intHost $(hashQ "Lam") <> encodeLeftHandSide encodeScalarType lhs <> travL l +encodeOpenFun (Body b) = intHost $(hashQ "Body") <> encodeOpenExp b +encodeOpenFun (Lam lhs l) = intHost $(hashQ "Lam") <> encodeLeftHandSide encodeScalarType lhs <> encodeOpenFun l encodeConst :: TupleType t -> t -> Builder diff --git a/src/Data/Array/Accelerate/Analysis/Match.hs b/src/Data/Array/Accelerate/Analysis/Match.hs index 6c0e8db77..d88c27b2f 100644 --- a/src/Data/Array/Accelerate/Analysis/Match.hs +++ b/src/Data/Array/Accelerate/Analysis/Match.hs @@ -24,8 +24,8 @@ module Data.Array.Accelerate.Analysis.Match ( (:~:)(..), matchPreOpenAcc, matchPreOpenAfun, - matchPreOpenExp, - matchPreOpenFun, + matchOpenExp, + matchOpenFun, matchPrimFun, matchPrimFun', -- auxiliary @@ -63,17 +63,16 @@ type MatchAcc acc = forall aenv s t. acc aenv s -> acc aenv t -> Maybe (s :~: t) matchPreOpenAcc :: forall acc aenv s t. HasArraysRepr acc => MatchAcc acc - -> EncodeAcc acc -> PreOpenAcc acc aenv s -> PreOpenAcc acc aenv t -> Maybe (s :~: t) -matchPreOpenAcc matchAcc encodeAcc = match +matchPreOpenAcc matchAcc = match where - matchFun :: PreOpenFun acc env' aenv' u -> PreOpenFun acc env' aenv' v -> Maybe (u :~: v) - matchFun = matchPreOpenFun matchAcc encodeAcc + matchFun :: OpenFun env' aenv' u -> OpenFun env' aenv' v -> Maybe (u :~: v) + matchFun = matchOpenFun - matchExp :: PreOpenExp acc env' aenv' u -> PreOpenExp acc env' aenv' v -> Maybe (u :~: v) - matchExp = matchPreOpenExp matchAcc encodeAcc + matchExp :: OpenExp env' aenv' u -> OpenExp env' aenv' v -> Maybe (u :~: v) + matchExp = matchOpenExp match :: PreOpenAcc acc aenv s -> PreOpenAcc acc aenv t -> Maybe (s :~: t) match (Alet lhs1 x1 a1) (Alet lhs2 x2 a2) @@ -242,15 +241,15 @@ matchPreOpenAcc matchAcc encodeAcc = match match (Stencil s1 _ f1 b1 a1) (Stencil _ _ f2 b2 a2) | Just Refl <- matchFun f1 f2 , Just Refl <- matchAcc a1 a2 - , matchBoundary matchAcc encodeAcc (stencilElt s1) b1 b2 + , matchBoundary (stencilElt s1) b1 b2 = Just Refl match (Stencil2 s1 s2 _ f1 b1 a1 b2 a2) (Stencil2 _ _ _ f2 b1' a1' b2' a2') | Just Refl <- matchFun f1 f2 , Just Refl <- matchAcc a1 a1' , Just Refl <- matchAcc a2 a2' - , matchBoundary matchAcc encodeAcc (stencilElt s1) b1 b1' - , matchBoundary matchAcc encodeAcc (stencilElt s2) b2 b2' + , matchBoundary (stencilElt s1) b1 b1' + , matchBoundary (stencilElt s2) b2 b2' = Just Refl -- match (Collect s1) (Collect s2) @@ -310,21 +309,18 @@ matchLeftHandSide _ _ _ = Nothing -- Match stencil boundaries -- matchBoundary - :: HasArraysRepr acc - => MatchAcc acc - -> EncodeAcc acc - -> TupleType t - -> PreBoundary acc aenv (Array sh t) - -> PreBoundary acc aenv (Array sh t) + :: TupleType t + -> Boundary aenv (Array sh t) + -> Boundary aenv (Array sh t) -> Bool -matchBoundary _ _ _ Clamp Clamp = True -matchBoundary _ _ _ Mirror Mirror = True -matchBoundary _ _ _ Wrap Wrap = True -matchBoundary _ _ tp (Constant s) (Constant t) = matchConst tp s t -matchBoundary m h _ (Function f) (Function g) - | Just Refl <- matchPreOpenFun m h f g +matchBoundary _ Clamp Clamp = True +matchBoundary _ Mirror Mirror = True +matchBoundary _ Wrap Wrap = True +matchBoundary tp (Constant s) (Constant t) = matchConst tp s t +matchBoundary _ (Function f) (Function g) + | Just Refl <- matchOpenFun f g = True -matchBoundary _ _ _ _ _ +matchBoundary _ _ _ = False @@ -340,11 +336,11 @@ matchSeq -> Maybe (s :~: t) matchSeq m h = match where - matchFun :: PreOpenFun acc env' aenv' u -> PreOpenFun acc env' aenv' v -> Maybe (u :~: v) - matchFun = matchPreOpenFun m h + matchFun :: OpenFun env' aenv' u -> OpenFun env' aenv' v -> Maybe (u :~: v) + matchFun = matchOpenFun m h - matchExp :: PreOpenExp acc env' aenv' u -> PreOpenExp acc env' aenv' v -> Maybe (u :~: v) - matchExp = matchPreOpenExp m h + matchExp :: OpenExp env' aenv' u -> OpenExp env' aenv' v -> Maybe (u :~: v) + matchExp = matchOpenExp m h match :: PreOpenSeq acc aenv senv' u -> PreOpenSeq acc aenv senv' v -> Maybe (u :~: v) match (Producer p1 s1) (Producer p2 s2) @@ -454,145 +450,135 @@ matchArrayR _ _ = Nothing -- The below attempts to use real typed equality, but occasionally still needs -- to use a cast, particularly when we can only match the representation types. -- -{-# INLINEABLE matchPreOpenExp #-} -matchPreOpenExp - :: forall acc env aenv s t. HasArraysRepr acc - => MatchAcc acc - -> EncodeAcc acc - -> PreOpenExp acc env aenv s - -> PreOpenExp acc env aenv t +{-# INLINEABLE matchOpenExp #-} +matchOpenExp + :: forall env aenv s t. + OpenExp env aenv s + -> OpenExp env aenv t -> Maybe (s :~: t) -matchPreOpenExp matchAcc encodeAcc = match - where - match :: forall env' aenv' s' t'. - PreOpenExp acc env' aenv' s' - -> PreOpenExp acc env' aenv' t' - -> Maybe (s' :~: t') - match (Let lhs1 x1 e1) (Let lhs2 x2 e2) - | Just Refl <- matchELeftHandSide lhs1 lhs2 - , Just Refl <- match x1 x2 - , Just Refl <- match e1 e2 - = Just Refl - match (Evar v1) (Evar v2) - = matchVar v1 v2 +matchOpenExp (Let lhs1 x1 e1) (Let lhs2 x2 e2) + | Just Refl <- matchELeftHandSide lhs1 lhs2 + , Just Refl <- matchOpenExp x1 x2 + , Just Refl <- matchOpenExp e1 e2 + = Just Refl - match (Foreign _ ff1 f1 e1) (Foreign _ ff2 f2 e2) - | Just Refl <- match e1 e2 - , unsafePerformIO $ do - sn1 <- makeStableName ff1 - sn2 <- makeStableName ff2 - return $! hashStableName sn1 == hashStableName sn2 - , Just Refl <- matchPreOpenFun matchAcc encodeAcc f1 f2 - = Just Refl +matchOpenExp (Evar v1) (Evar v2) + = matchVar v1 v2 - match (Const t1 c1) (Const t2 c2) - | Just Refl <- matchScalarType t1 t2 - , matchConst (TupRsingle t1) c1 c2 - = Just Refl +matchOpenExp (Foreign _ ff1 f1 e1) (Foreign _ ff2 f2 e2) + | Just Refl <- matchOpenExp e1 e2 + , unsafePerformIO $ do + sn1 <- makeStableName ff1 + sn2 <- makeStableName ff2 + return $! hashStableName sn1 == hashStableName sn2 + , Just Refl <- matchOpenFun f1 f2 + = Just Refl - match (Undef t1) (Undef t2) = matchScalarType t1 t2 +matchOpenExp (Const t1 c1) (Const t2 c2) + | Just Refl <- matchScalarType t1 t2 + , matchConst (TupRsingle t1) c1 c2 + = Just Refl - match (Coerce _ t1 e1) (Coerce _ t2 e2) - | Just Refl <- matchScalarType t1 t2 - , Just Refl <- match e1 e2 - = Just Refl +matchOpenExp (Undef t1) (Undef t2) = matchScalarType t1 t2 - match (Pair a1 b1) (Pair a2 b2) - | Just Refl <- match a1 a2 - , Just Refl <- match b1 b2 - = Just Refl +matchOpenExp (Coerce _ t1 e1) (Coerce _ t2 e2) + | Just Refl <- matchScalarType t1 t2 + , Just Refl <- matchOpenExp e1 e2 + = Just Refl - match Nil Nil - = Just Refl +matchOpenExp (Pair a1 b1) (Pair a2 b2) + | Just Refl <- matchOpenExp a1 a2 + , Just Refl <- matchOpenExp b1 b2 + = Just Refl - match (IndexSlice sliceIndex1 ix1 sh1) (IndexSlice sliceIndex2 ix2 sh2) - | Just Refl <- match ix1 ix2 - , Just Refl <- match sh1 sh2 - , Just Refl <- matchSliceIndex sliceIndex1 sliceIndex2 - = Just Refl +matchOpenExp Nil Nil + = Just Refl - match (IndexFull sliceIndex1 ix1 sl1) (IndexFull sliceIndex2 ix2 sl2) - | Just Refl <- match ix1 ix2 - , Just Refl <- match sl1 sl2 - , Just Refl <- matchSliceIndex sliceIndex1 sliceIndex2 - = Just Refl +matchOpenExp (IndexSlice sliceIndex1 ix1 sh1) (IndexSlice sliceIndex2 ix2 sh2) + | Just Refl <- matchOpenExp ix1 ix2 + , Just Refl <- matchOpenExp sh1 sh2 + , Just Refl <- matchSliceIndex sliceIndex1 sliceIndex2 + = Just Refl - match (ToIndex _ sh1 i1) (ToIndex _ sh2 i2) - | Just Refl <- match sh1 sh2 - , Just Refl <- match i1 i2 - = Just Refl +matchOpenExp (IndexFull sliceIndex1 ix1 sl1) (IndexFull sliceIndex2 ix2 sl2) + | Just Refl <- matchOpenExp ix1 ix2 + , Just Refl <- matchOpenExp sl1 sl2 + , Just Refl <- matchSliceIndex sliceIndex1 sliceIndex2 + = Just Refl - match (FromIndex _ sh1 i1) (FromIndex _ sh2 i2) - | Just Refl <- match i1 i2 - , Just Refl <- match sh1 sh2 - = Just Refl +matchOpenExp (ToIndex _ sh1 i1) (ToIndex _ sh2 i2) + | Just Refl <- matchOpenExp sh1 sh2 + , Just Refl <- matchOpenExp i1 i2 + = Just Refl - match (Cond p1 t1 e1) (Cond p2 t2 e2) - | Just Refl <- match p1 p2 - , Just Refl <- match t1 t2 - , Just Refl <- match e1 e2 - = Just Refl +matchOpenExp (FromIndex _ sh1 i1) (FromIndex _ sh2 i2) + | Just Refl <- matchOpenExp i1 i2 + , Just Refl <- matchOpenExp sh1 sh2 + = Just Refl - match (While p1 f1 x1) (While p2 f2 x2) - | Just Refl <- match x1 x2 - , Just Refl <- matchPreOpenFun matchAcc encodeAcc p1 p2 - , Just Refl <- matchPreOpenFun matchAcc encodeAcc f1 f2 - = Just Refl +matchOpenExp (Cond p1 t1 e1) (Cond p2 t2 e2) + | Just Refl <- matchOpenExp p1 p2 + , Just Refl <- matchOpenExp t1 t2 + , Just Refl <- matchOpenExp e1 e2 + = Just Refl - match (PrimConst c1) (PrimConst c2) - = matchPrimConst c1 c2 +matchOpenExp (While p1 f1 x1) (While p2 f2 x2) + | Just Refl <- matchOpenExp x1 x2 + , Just Refl <- matchOpenFun p1 p2 + , Just Refl <- matchOpenFun f1 f2 + = Just Refl - match (PrimApp f1 x1) (PrimApp f2 x2) - | Just x1' <- commutes encodeAcc f1 x1 - , Just x2' <- commutes encodeAcc f2 x2 - , Just Refl <- match x1' x2' - , Just Refl <- matchPrimFun f1 f2 - = Just Refl +matchOpenExp (PrimConst c1) (PrimConst c2) + = matchPrimConst c1 c2 - | Just Refl <- match x1 x2 - , Just Refl <- matchPrimFun f1 f2 - = Just Refl +matchOpenExp (PrimApp f1 x1) (PrimApp f2 x2) + | Just x1' <- commutes f1 x1 + , Just x2' <- commutes f2 x2 + , Just Refl <- matchOpenExp x1' x2' + , Just Refl <- matchPrimFun f1 f2 + = Just Refl - match (Index a1 x1) (Index a2 x2) - | Just Refl <- matchAcc a1 a2 -- should only be array indices - , Just Refl <- match x1 x2 - = Just Refl + | Just Refl <- matchOpenExp x1 x2 + , Just Refl <- matchPrimFun f1 f2 + = Just Refl - match (LinearIndex a1 x1) (LinearIndex a2 x2) - | Just Refl <- matchAcc a1 a2 - , Just Refl <- match x1 x2 - = Just Refl +matchOpenExp (Index a1 x1) (Index a2 x2) + | Just Refl <- matchVar a1 a2 -- should only be array indices + , Just Refl <- matchOpenExp x1 x2 + = Just Refl - match (Shape a1) (Shape a2) - | Just Refl <- matchAcc a1 a2 -- should only be array indices - = Just Refl +matchOpenExp (LinearIndex a1 x1) (LinearIndex a2 x2) + | Just Refl <- matchVar a1 a2 + , Just Refl <- matchOpenExp x1 x2 + = Just Refl - match (ShapeSize _ sh1) (ShapeSize _ sh2) - | Just Refl <- match sh1 sh2 - = Just Refl +matchOpenExp (Shape a1) (Shape a2) + | Just Refl <- matchVar a1 a2 -- should only be array indices + = Just Refl - match _ _ - = Nothing +matchOpenExp (ShapeSize _ sh1) (ShapeSize _ sh2) + | Just Refl <- matchOpenExp sh1 sh2 + = Just Refl + +matchOpenExp _ _ + = Nothing -- Match scalar functions -- -{-# INLINEABLE matchPreOpenFun #-} -matchPreOpenFun - :: HasArraysRepr acc - => MatchAcc acc - -> EncodeAcc acc - -> PreOpenFun acc env aenv s - -> PreOpenFun acc env aenv t +{-# INLINEABLE matchOpenFun #-} +matchOpenFun + :: OpenFun env aenv s + -> OpenFun env aenv t -> Maybe (s :~: t) -matchPreOpenFun m h (Lam lhs1 s) (Lam lhs2 t) +matchOpenFun (Lam lhs1 s) (Lam lhs2 t) | Just Refl <- matchELeftHandSide lhs1 lhs2 - , Just Refl <- matchPreOpenFun m h s t + , Just Refl <- matchOpenFun s t = Just Refl -matchPreOpenFun m h (Body s) (Body t) = matchPreOpenExp m h s t -matchPreOpenFun _ _ _ _ = Nothing +matchOpenFun (Body s) (Body t) = matchOpenExp s t +matchOpenFun _ _ = Nothing -- Matching constants -- @@ -940,12 +926,11 @@ matchNonNumType _ _ = Nothing -- commutativity. -- commutes - :: forall acc env aenv a r. HasArraysRepr acc - => EncodeAcc acc - -> PrimFun (a -> r) - -> PreOpenExp acc env aenv a - -> Maybe (PreOpenExp acc env aenv a) -commutes h f x = case f of + :: forall env aenv a r. + PrimFun (a -> r) + -> OpenExp env aenv a + -> Maybe (OpenExp env aenv a) +commutes f x = case f of PrimAdd{} -> Just (swizzle x) PrimMul{} -> Just (swizzle x) PrimBAnd{} -> Just (swizzle x) @@ -959,10 +944,10 @@ commutes h f x = case f of PrimLOr -> Just (swizzle x) _ -> Nothing where - swizzle :: PreOpenExp acc env aenv (a',a') -> PreOpenExp acc env aenv (a',a') + swizzle :: OpenExp env aenv (a',a') -> OpenExp env aenv (a',a') swizzle exp | (a `Pair` b) <- exp - , hashPreOpenExp h a > hashPreOpenExp h b = b `Pair` a + , hashOpenExp a > hashOpenExp b = b `Pair` a -- | otherwise = exp diff --git a/src/Data/Array/Accelerate/Analysis/Shape.hs b/src/Data/Array/Accelerate/Analysis/Shape.hs index 89884389a..416e77389 100644 --- a/src/Data/Array/Accelerate/Analysis/Shape.hs +++ b/src/Data/Array/Accelerate/Analysis/Shape.hs @@ -32,7 +32,7 @@ accDim = rank . arrayRshape . arrayRepr -- |Reify dimensionality of a scalar expression yielding a shape -- -expDim :: forall acc env aenv sh. HasArraysRepr acc => PreOpenExp acc env aenv sh -> Int +expDim :: forall env aenv sh. OpenExp env aenv sh -> Int expDim = ndim . expType -- Count the number of components to a tuple type diff --git a/src/Data/Array/Accelerate/Interpreter.hs b/src/Data/Array/Accelerate/Interpreter.hs index 1a4344e90..f833b4659 100644 --- a/src/Data/Array/Accelerate/Interpreter.hs +++ b/src/Data/Array/Accelerate/Interpreter.hs @@ -68,7 +68,7 @@ import Unsafe.Coerce import Prelude hiding ( (!!), sum ) -- friends -import Data.Array.Accelerate.AST hiding ( Boundary, PreBoundary(..) ) +import Data.Array.Accelerate.AST hiding ( Boundary(..) ) import Data.Array.Accelerate.Analysis.Type ( sizeOfSingleType ) import Data.Array.Accelerate.Array.Data import Data.Array.Accelerate.Array.Representation @@ -152,8 +152,6 @@ data Delayed a where -- Array expression evaluation -- --------------------------- -type EvalAcc acc = forall aenv a. acc aenv a -> Val aenv -> WithReprs a - type WithReprs acc = (ArraysR acc, acc) fromFunction' :: ArrayR (Array sh e) -> sh -> (sh -> e) -> WithReprs (Array sh e) @@ -187,14 +185,14 @@ evalOpenAcc (AST.Manifest pacc) aenv = where (TupRsingle repr, a) = manifest a' - evalE :: DelayedExp aenv t -> t - evalE exp = evalPreExp evalOpenAcc exp aenv + evalE :: Exp aenv t -> t + evalE exp = evalExp exp aenv - evalF :: DelayedFun aenv f -> f - evalF fun = evalPreFun evalOpenAcc fun aenv + evalF :: Fun aenv f -> f + evalF fun = evalFun fun aenv - evalB :: AST.PreBoundary DelayedOpenAcc aenv t -> Boundary t - evalB bnd = evalPreBoundary evalOpenAcc bnd aenv + evalB :: AST.Boundary aenv t -> Boundary t + evalB bnd = evalBoundary bnd aenv in case pacc of Avar (Var repr ix) -> (TupRsingle repr, prj ix aenv) @@ -848,14 +846,14 @@ data Boundary t where Function :: (sh -> e) -> Boundary (Array sh e) -evalPreBoundary :: HasArraysRepr acc => EvalAcc acc -> AST.PreBoundary acc aenv t -> Val aenv -> Boundary t -evalPreBoundary evalAcc bnd aenv = +evalBoundary :: AST.Boundary aenv t -> Val aenv -> Boundary t +evalBoundary bnd aenv = case bnd of AST.Clamp -> Clamp AST.Mirror -> Mirror AST.Wrap -> Wrap AST.Constant v -> Constant v - AST.Function f -> Function (evalPreFun evalAcc f aenv) + AST.Function f -> Function (evalFun f aenv) -- Scalar expression evaluation @@ -863,20 +861,20 @@ evalPreBoundary evalAcc bnd aenv = -- Evaluate a closed scalar expression -- -evalPreExp :: HasArraysRepr acc => EvalAcc acc -> PreExp acc aenv t -> Val aenv -> t -evalPreExp evalAcc e aenv = evalPreOpenExp evalAcc e Empty aenv +evalExp :: Exp aenv t -> Val aenv -> t +evalExp e aenv = evalOpenExp e Empty aenv -- Evaluate a closed scalar function -- -evalPreFun :: HasArraysRepr acc => EvalAcc acc -> PreFun acc aenv t -> Val aenv -> t -evalPreFun evalAcc f aenv = evalPreOpenFun evalAcc f Empty aenv +evalFun :: Fun aenv t -> Val aenv -> t +evalFun f aenv = evalOpenFun f Empty aenv -- Evaluate an open scalar function -- -evalPreOpenFun :: HasArraysRepr acc => EvalAcc acc -> PreOpenFun acc env aenv t -> Val env -> Val aenv -> t -evalPreOpenFun evalAcc (Body e) env aenv = evalPreOpenExp evalAcc e env aenv -evalPreOpenFun evalAcc (Lam lhs f) env aenv = - \x -> evalPreOpenFun evalAcc f (env `push` (lhs, x)) aenv +evalOpenFun :: OpenFun env aenv t -> Val env -> Val aenv -> t +evalOpenFun (Body e) env aenv = evalOpenExp e env aenv +evalOpenFun (Lam lhs f) env aenv = + \x -> evalOpenFun f (env `push` (lhs, x)) aenv -- Evaluate an open scalar expression @@ -887,33 +885,27 @@ evalPreOpenFun evalAcc (Lam lhs f) env aenv = -- mapped over an array, the array argument would be evaluated many times -- leading to a large amount of wasteful recomputation. -- --- TODO: If we change the argument of Shape, Index and LinearIndex to be an array --- variable (instead of an arbitrary array computation), we could remove the --- HasArraysRepr constraint and just pattern match on the Var. --- -evalPreOpenExp - :: forall acc env aenv t. - HasArraysRepr acc - => EvalAcc acc - -> PreOpenExp acc env aenv t +evalOpenExp + :: forall env aenv t. + OpenExp env aenv t -> Val env -> Val aenv -> t -evalPreOpenExp evalAcc pexp env aenv = +evalOpenExp pexp env aenv = let - evalE :: PreOpenExp acc env aenv t' -> t' - evalE e = evalPreOpenExp evalAcc e env aenv + evalE :: OpenExp env aenv t' -> t' + evalE e = evalOpenExp e env aenv - evalF :: PreOpenFun acc env aenv f' -> f' - evalF f = evalPreOpenFun evalAcc f env aenv + evalF :: OpenFun env aenv f' -> f' + evalF f = evalOpenFun f env aenv - evalA :: acc aenv a -> WithReprs a - evalA a = evalAcc a aenv + evalA :: ArrayVar aenv a -> WithReprs a + evalA (Var repr ix) = (TupRsingle repr, prj ix aenv) in case pexp of Let lhs exp1 exp2 -> let !v1 = evalE exp1 env' = env `push` (lhs, v1) - in evalPreOpenExp evalAcc exp2 env' aenv + in evalOpenExp exp2 env' aenv Evar (Var _ ix) -> prj ix env Const _ c -> c Undef tp -> evalUndefScalar tp @@ -969,7 +961,7 @@ evalPreOpenExp evalAcc pexp env aenv = in (repr, a) ! ix Shape acc -> shape $ snd $ evalA acc ShapeSize shr sh -> size shr (evalE sh) - Foreign _ _ f e -> evalPreOpenFun evalAcc f Empty Empty $ evalE e + Foreign _ _ f e -> evalOpenFun f Empty Empty $ evalE e Coerce t1 t2 e -> evalCoerceScalar t1 t2 (evalE e) @@ -1809,10 +1801,10 @@ evalSeq conf s aenv = evalSeq' s evalAF f = evalOpenAfun f aenv evalE :: DelayedExp aenv t -> t - evalE exp = evalPreExp evalOpenAcc exp aenv + evalE exp = evalExp exp aenv evalF :: DelayedFun aenv f -> f - evalF fun = evalPreFun evalOpenAcc fun aenv + evalF fun = evalFun fun aenv initProducer :: forall a senv. Producer DelayedOpenAcc aenv senv a @@ -1858,9 +1850,9 @@ evalSeq conf s aenv = evalSeq' s delayed :: DelayedOpenAcc aenv (Array sh e) -> Delayed (Array sh e) delayed AST.Manifest{} = $internalError "evalOpenAcc" "expected delayed array" - delayed AST.Delayed{..} = Delayed (evalPreExp evalOpenAcc extentD aenv) - (evalPreFun evalOpenAcc indexD aenv) - (evalPreFun evalOpenAcc linearIndexD aenv) + delayed AST.Delayed{..} = Delayed (evalExp extentD aenv) + (evalFun indexD aenv) + (evalFun linearIndexD aenv) produce :: Arrays a => ExecP senv a -> Val' senv -> (Chunk a, Maybe (ExecP senv a)) produce p senv = diff --git a/src/Data/Array/Accelerate/Pretty.hs b/src/Data/Array/Accelerate/Pretty.hs index 82ffffbb5..b5b32a97b 100644 --- a/src/Data/Array/Accelerate/Pretty.hs +++ b/src/Data/Array/Accelerate/Pretty.hs @@ -24,8 +24,8 @@ module Data.Array.Accelerate.Pretty ( PrettyAcc, ExtractAcc, prettyPreOpenAcc, prettyPreOpenAfun, - prettyPreOpenExp, - prettyPreOpenFun, + prettyOpenExp, + prettyOpenFun, -- ** Graphviz Graph, @@ -101,17 +101,11 @@ instance PrettyEnv aenv => Show (DelayedOpenAcc aenv a) where instance PrettyEnv aenv => Show (DelayedOpenAfun aenv f) where show = renderForTerminal . prettyPreOpenAfun prettyDelayedOpenAcc (prettyEnv (pretty 'a')) -instance (PrettyEnv env, PrettyEnv aenv) => Show (PreOpenExp OpenAcc env aenv e) where - show = renderForTerminal . prettyPreOpenExp context0 prettyOpenAcc extractOpenAcc (prettyEnv (pretty 'x')) (prettyEnv (pretty 'a')) +instance (PrettyEnv env, PrettyEnv aenv) => Show (OpenExp env aenv e) where + show = renderForTerminal . prettyOpenExp context0 (prettyEnv (pretty 'x')) (prettyEnv (pretty 'a')) -instance (PrettyEnv env, PrettyEnv aenv) => Show (PreOpenExp DelayedOpenAcc env aenv e) where - show = renderForTerminal . prettyPreOpenExp context0 prettyDelayedOpenAcc extractDelayedOpenAcc (prettyEnv (pretty 'x')) (prettyEnv (pretty 'a')) - -instance (PrettyEnv env, PrettyEnv aenv) => Show (PreOpenFun OpenAcc env aenv e) where - show = renderForTerminal . prettyPreOpenFun prettyOpenAcc extractOpenAcc (prettyEnv (pretty 'x')) (prettyEnv (pretty 'a')) - -instance (PrettyEnv env, PrettyEnv aenv) => Show (PreOpenFun DelayedOpenAcc env aenv e) where - show = renderForTerminal . prettyPreOpenFun prettyDelayedOpenAcc extractDelayedOpenAcc (prettyEnv (pretty 'x')) (prettyEnv (pretty 'a')) +instance (PrettyEnv env, PrettyEnv aenv) => Show (OpenFun env aenv e) where + show = renderForTerminal . prettyOpenFun (prettyEnv (pretty 'x')) (prettyEnv (pretty 'a')) -- Internals @@ -162,8 +156,8 @@ prettyDelayedOpenAcc _ aenv (Delayed _ sh f _) = parens $ nest shiftwidth $ sep [ delayed "delayed" - , prettyPreOpenExp app prettyDelayedOpenAcc extractDelayedOpenAcc Empty aenv sh - , parens $ prettyPreOpenFun prettyDelayedOpenAcc extractDelayedOpenAcc Empty aenv f + , prettyOpenExp app Empty aenv sh + , parens $ prettyOpenFun Empty aenv f ] extractDelayedOpenAcc :: DelayedOpenAcc aenv a -> PreOpenAcc DelayedOpenAcc aenv a diff --git a/src/Data/Array/Accelerate/Pretty/Graphviz.hs b/src/Data/Array/Accelerate/Pretty/Graphviz.hs index 4a82b9418..3533a6472 100644 --- a/src/Data/Array/Accelerate/Pretty/Graphviz.hs +++ b/src/Data/Array/Accelerate/Pretty/Graphviz.hs @@ -280,21 +280,17 @@ prettyDelayedOpenAcc detail ctx aenv atop@(Manifest pacc) = -- Free variables -- - fvA :: FVAcc DelayedOpenAcc - fvA env (Manifest (Avar (Var _ ix))) = [ Vertex (fst $ aprj ix env) Nothing ] - fvA _ _ = $internalError "graphviz" "expected array variable" + fvF :: Fun aenv t -> [Vertex] + fvF = fvOpenFun Empty aenv - fvF :: DelayedFun aenv t -> [Vertex] - fvF = fvPreOpenFun fvA Empty aenv - - fvE :: DelayedExp aenv t -> [Vertex] - fvE = fvPreOpenExp fvA Empty aenv + fvE :: Exp aenv t -> [Vertex] + fvE = fvOpenExp Empty aenv -- Pretty-printing -- avar :: ArrayVar aenv t -> PDoc avar (Var _ ix) = let (ident, v) = aprj ix aenv - in PDoc (pretty v) [Vertex ident Nothing] + in PDoc (pretty v) [Vertex ident Nothing] aenv' :: Val aenv aenv' = avalToVal aenv @@ -312,14 +308,14 @@ prettyDelayedOpenAcc detail ctx aenv atop@(Manifest pacc) = | Shape a <- sh -- identical shape , Just b <- isIdentityIndexing f -- function is `\ix -> b ! ix` , Just Refl <- match a b -- function thus is `\ix -> a ! ix` - = ppA a + = ppA $ Manifest $ Avar a ppA (Delayed _ sh f _) = do PDoc d v <- "Delayed" `fmt` [ ppE sh, ppF f ] return $ PDoc (parens d) v ppB :: forall sh e. TupleType e - -> PreBoundary DelayedOpenAcc aenv (Array sh e) + -> Boundary aenv (Array sh e) -> Dot PDoc ppB _ Clamp = return (PDoc "clamp" []) ppB _ Mirror = return (PDoc "mirror" []) @@ -327,11 +323,11 @@ prettyDelayedOpenAcc detail ctx aenv atop@(Manifest pacc) = ppB tp (Constant e) = return (PDoc (prettyConst tp e) []) ppB _ (Function f) = ppF f - ppF :: DelayedFun aenv t -> Dot PDoc - ppF = return . uncurry PDoc . (parens . prettyDelayedFun aenv' &&& fvF) + ppF :: Fun aenv t -> Dot PDoc + ppF = return . uncurry PDoc . (parens . prettyFun aenv' &&& fvF) - ppE :: DelayedExp aenv t -> Dot PDoc - ppE = return . uncurry PDoc . (prettyDelayedExp aenv' &&& fvE) + ppE :: Exp aenv t -> Dot PDoc + ppE = return . uncurry PDoc . (prettyExp aenv' &&& fvE) lift :: DelayedOpenAcc aenv a -> Dot Vertex lift Delayed{} = $internalError "prettyDelayedOpenAcc" "expected manifest array" @@ -483,49 +479,6 @@ replant pnode@(PNode ident tree _) = -- nodes. -- -prettyDelayedFun :: Val aenv -> DelayedFun aenv f -> Adoc -prettyDelayedFun = prettyDelayedOpenFun Empty - -prettyDelayedExp :: Val aenv -> DelayedExp aenv t -> Adoc -prettyDelayedExp = prettyDelayedOpenExp context0 Empty - - -prettyDelayedOpenFun - :: forall env aenv f. - Val env - -> Val aenv - -> DelayedOpenFun env aenv f - -> Adoc -prettyDelayedOpenFun env0 aenv = next "\\\\" env0 - where - -- graphviz will silently not print a label containing the string "->", - -- so instead we use the special token "&rarr" for a short right arrow. - -- - next :: Adoc -> Val env' -> PreOpenFun DelayedOpenAcc env' aenv f' -> Adoc - next vs env (Body body) = - nest shiftwidth (sep [ vs <> "→" - , prettyDelayedOpenExp context0 env aenv body ]) - next vs env (Lam lhs lam) = - let (env', arg) = prettyELhs True env lhs - in next (vs <> arg <> space) env' lam - -prettyDelayedOpenExp - :: Context - -> Val env - -> Val aenv - -> DelayedOpenExp env aenv t - -> Adoc -prettyDelayedOpenExp context = prettyPreOpenExp context pp ex - where - pp :: PrettyAcc DelayedOpenAcc - pp _ aenv (Manifest (Avar (Var _ ix))) = prj ix aenv - pp _ _ _ = $internalError "prettyDelayedOpenExp" "expected array variable" - - ex :: ExtractAcc DelayedOpenAcc - ex (Manifest pacc) = pacc - ex Delayed{} = $internalError "prettyDelayedOpenExp" "expected manifest array" - - -- Data dependencies -- ----------------- -- @@ -534,38 +487,37 @@ prettyDelayedOpenExp context = prettyPreOpenExp context pp ex -- nodes (vertices) into the current term. -- -type FVAcc acc = forall aenv a. Aval aenv -> acc aenv a -> [Vertex] +fvAvar :: Aval aenv -> ArrayVar aenv a -> [Vertex] +fvAvar env (Var _ ix) = [ Vertex (fst $ aprj ix env) Nothing ] -fvPreOpenFun - :: forall acc env aenv fun. - FVAcc acc - -> Val env +fvOpenFun + :: forall env aenv fun. + Val env -> Aval aenv - -> PreOpenFun acc env aenv fun + -> OpenFun env aenv fun -> [Vertex] -fvPreOpenFun fvA env aenv (Body b) = fvPreOpenExp fvA env aenv b -fvPreOpenFun fvA env aenv (Lam lhs f) = fvPreOpenFun fvA env' aenv f +fvOpenFun env aenv (Body b) = fvOpenExp env aenv b +fvOpenFun env aenv (Lam lhs f) = fvOpenFun env' aenv f where (env', _) = prettyELhs True env lhs -fvPreOpenExp - :: forall acc env aenv exp. - FVAcc acc - -> Val env +fvOpenExp + :: forall env aenv exp. + Val env -> Aval aenv - -> PreOpenExp acc env aenv exp + -> OpenExp env aenv exp -> [Vertex] -fvPreOpenExp fvA env aenv = fv +fvOpenExp env aenv = fv where - fvF :: PreOpenFun acc env aenv f -> [Vertex] - fvF = fvPreOpenFun fvA env aenv + fvF :: OpenFun env aenv f -> [Vertex] + fvF = fvOpenFun env aenv - fv :: PreOpenExp acc env aenv e -> [Vertex] - fv (Shape acc) = if cfgIncludeShape then fvA aenv acc else [] - fv (Index acc i) = concat [ fvA aenv acc, fv i ] - fv (LinearIndex acc i) = concat [ fvA aenv acc, fv i ] + fv :: OpenExp env aenv e -> [Vertex] + fv (Shape acc) = if cfgIncludeShape then fvAvar aenv acc else [] + fv (Index acc i) = concat [ fvAvar aenv acc, fv i ] + fv (LinearIndex acc i) = concat [ fvAvar aenv acc, fv i ] -- - fv (Let lhs e1 e2) = concat [ fv e1, fvPreOpenExp fvA env' aenv e2 ] + fv (Let lhs e1 e2) = concat [ fv e1, fvOpenExp env' aenv e2 ] where (env', _) = prettyELhs False env lhs fv Evar{} = [] diff --git a/src/Data/Array/Accelerate/Pretty/Print.hs b/src/Data/Array/Accelerate/Pretty/Print.hs index b242d1027..f64bb347b 100644 --- a/src/Data/Array/Accelerate/Pretty/Print.hs +++ b/src/Data/Array/Accelerate/Pretty/Print.hs @@ -25,8 +25,8 @@ module Data.Array.Accelerate.Pretty.Print ( PrettyAcc, ExtractAcc, prettyPreOpenAcc, prettyPreOpenAfun, - prettyPreOpenExp, - prettyPreOpenFun, + prettyOpenExp, prettyExp, + prettyOpenFun, prettyFun, prettyArray, prettyConst, prettyELhs, @@ -187,15 +187,15 @@ prettyPreOpenAcc ctx prettyAcc extractAcc aenv pacc = ppAF :: PreOpenAfun acc aenv f -> Adoc ppAF = parens . prettyPreOpenAfun prettyAcc aenv - ppE :: PreExp acc aenv t -> Adoc - ppE = prettyPreOpenExp app prettyAcc extractAcc Empty aenv + ppE :: Exp aenv t -> Adoc + ppE = prettyOpenExp app Empty aenv - ppF :: PreFun acc aenv t -> Adoc - ppF = parens . prettyPreOpenFun prettyAcc extractAcc Empty aenv + ppF :: Fun aenv t -> Adoc + ppF = parens . prettyOpenFun Empty aenv ppB :: forall sh e. TupleType e - -> PreBoundary acc aenv (Array sh e) + -> Boundary aenv (Array sh e) -> Adoc ppB _ Clamp = "clamp" ppB _ Mirror = "mirror" @@ -305,18 +305,21 @@ prettyArray repr = parens . fromString . showArray repr -- Scalar expressions -- ------------------ +prettyFun :: Val aenv -> Fun aenv f -> Adoc +prettyFun = prettyOpenFun Empty -prettyPreOpenFun - :: forall acc env aenv f. - PrettyAcc acc - -> ExtractAcc acc - -> Val env +prettyExp :: Val aenv -> Exp aenv t -> Adoc +prettyExp = prettyOpenExp context0 Empty + +prettyOpenFun + :: forall env aenv f. + Val env -> Val aenv - -> PreOpenFun acc env aenv f + -> OpenFun env aenv f -> Adoc -prettyPreOpenFun prettyAcc extractAcc env0 aenv = next (pretty '\\') env0 +prettyOpenFun env0 aenv = next (pretty '\\') env0 where - next :: Adoc -> Val env' -> PreOpenFun acc env' aenv f' -> Adoc + next :: Adoc -> Val env' -> OpenFun env' aenv f' -> Adoc next vs env (Body body) -- PrimApp f x <- body -- , op <- primOperator f @@ -327,24 +330,22 @@ prettyPreOpenFun prettyAcc extractAcc env0 aenv = next (pretty '\\') env0 -- = opName op -- surrounding context will add parens -- = hang shiftwidth (sep [ vs <> "->" - , prettyPreOpenExp context0 prettyAcc extractAcc env aenv body]) + , prettyOpenExp context0 env aenv body]) next vs env (Lam lhs lam) = let (env', lhs') = prettyELhs True env lhs in next (vs <> lhs' <> space) env' lam -prettyPreOpenExp - :: forall acc env aenv t. +prettyOpenExp + :: forall env aenv t. Context - -> PrettyAcc acc - -> ExtractAcc acc -> Val env -> Val aenv - -> PreOpenExp acc env aenv t + -> OpenExp env aenv t -> Adoc -prettyPreOpenExp ctx prettyAcc extractAcc env aenv exp = +prettyOpenExp ctx env aenv exp = case exp of Evar (Var _ idx) -> prj idx env - Let{} -> prettyLet ctx prettyAcc extractAcc env aenv exp + Let{} -> prettyLet ctx env aenv exp PrimApp f x | a `Pair` b <- x -> ppF2 op (ppE a) (ppE b) | otherwise -> ppF1 op' (ppE x) @@ -354,7 +355,7 @@ prettyPreOpenExp ctx prettyAcc extractAcc env aenv exp = -- PrimConst c -> prettyPrimConst c Const tp c -> prettyConst (TupRsingle tp) c - Pair{} -> prettyTuple ctx prettyAcc extractAcc env aenv exp + Pair{} -> prettyTuple ctx env aenv exp Nil -> "()" VecPack _ e -> ppF1 "vecPack" (ppE e) VecUnpack _ e -> ppF1 "vecUnpack" (ppE e) @@ -385,14 +386,14 @@ prettyPreOpenExp ctx prettyAcc extractAcc env aenv exp = Undef tp -> withTypeRep tp "undef" where - ppE :: PreOpenExp acc env aenv e -> Context -> Adoc - ppE e c = prettyPreOpenExp c prettyAcc extractAcc env aenv e + ppE :: OpenExp env aenv e -> Context -> Adoc + ppE e c = prettyOpenExp c env aenv e - ppA :: acc aenv a -> Context -> Adoc - ppA acc _ = prettyAcc app aenv acc + ppA :: ArrayVar aenv a -> Context -> Adoc + ppA acc _ = prettyArrayVar aenv acc - ppF :: PreOpenFun acc env aenv f -> Context -> Adoc - ppF f _ = parens $ prettyPreOpenFun prettyAcc extractAcc env aenv f + ppF :: OpenFun env aenv f -> Context -> Adoc + ppF f _ = parens $ prettyOpenFun env aenv f ppF1 :: Operator -> (Context -> Adoc) -> Adoc ppF1 op x @@ -418,21 +419,25 @@ prettyPreOpenExp ctx prettyAcc extractAcc env aenv exp = withTypeRep :: ScalarType t -> Adoc -> Adoc withTypeRep tp op = op <> enclose langle rangle (pretty (showScalarType tp)) +prettyArrayVar + :: forall aenv a. + Val aenv + -> ArrayVar aenv a + -> Adoc +prettyArrayVar aenv (Var _ idx) = prj idx aenv prettyLet - :: forall acc env aenv t. + :: forall env aenv t. Context - -> PrettyAcc acc - -> ExtractAcc acc -> Val env -> Val aenv - -> PreOpenExp acc env aenv t + -> OpenExp env aenv t -> Adoc -prettyLet ctx prettyAcc extractAcc env0 aenv +prettyLet ctx env0 aenv = parensIf (needsParens ctx "let") . align . wrap . collect env0 where - collect :: Val env' -> PreOpenExp acc env' aenv e -> ([Adoc], Adoc) + collect :: Val env' -> OpenExp env' aenv e -> ([Adoc], Adoc) collect env = \case Let lhs e1 e2 -> @@ -446,12 +451,12 @@ prettyLet ctx prettyAcc extractAcc env0 aenv -- next -> ([], ppE env next) - isLet :: PreOpenExp acc env' aenv t' -> Bool + isLet :: OpenExp env' aenv t' -> Bool isLet Let{} = True isLet _ = False - ppE :: Val env' -> PreOpenExp acc env' aenv t' -> Adoc - ppE env = prettyPreOpenExp context0 prettyAcc extractAcc env aenv + ppE :: Val env' -> OpenExp env' aenv t' -> Adoc + ppE env = prettyOpenExp context0 env aenv wrap :: ([Adoc], Adoc) -> Adoc wrap ([], body) = body -- shouldn't happen! @@ -464,27 +469,25 @@ prettyLet ctx prettyAcc extractAcc env0 aenv ] prettyTuple - :: forall acc env aenv t. + :: forall env aenv t. Context - -> PrettyAcc acc - -> ExtractAcc acc -> Val env -> Val aenv - -> PreOpenExp acc env aenv t + -> OpenExp env aenv t -> Adoc -prettyTuple ctx prettyAcc extractAcc env aenv exp = case collect exp of +prettyTuple ctx env aenv exp = case collect exp of Just tup -> align $ parensIf (ctxPrecedence ctx > 0) ("T" <> pretty (length tup) <+> sep tup) Nothing -> align $ ppPair exp where - ppPair :: PreOpenExp acc env aenv t' -> Adoc - ppPair (Pair e1 e2) = "(" <> ppPair e1 <> "," <+> prettyPreOpenExp context0 prettyAcc extractAcc env aenv e2 <> ")" - ppPair e = prettyPreOpenExp context0 prettyAcc extractAcc env aenv e + ppPair :: OpenExp env aenv t' -> Adoc + ppPair (Pair e1 e2) = "(" <> ppPair e1 <> "," <+> prettyOpenExp context0 env aenv e2 <> ")" + ppPair e = prettyOpenExp context0 env aenv e - collect :: PreOpenExp acc env aenv t' -> Maybe [Adoc] + collect :: OpenExp env aenv t' -> Maybe [Adoc] collect Nil = Just [] collect (Pair e1 e2) | Just tup <- collect e1 - = Just $ tup ++ [prettyPreOpenExp app prettyAcc extractAcc env aenv e2] + = Just $ tup ++ [prettyOpenExp app env aenv e2] collect _ = Nothing {- diff --git a/src/Data/Array/Accelerate/Smart.hs b/src/Data/Array/Accelerate/Smart.hs index 2bee6dbf8..bd381b206 100644 --- a/src/Data/Array/Accelerate/Smart.hs +++ b/src/Data/Array/Accelerate/Smart.hs @@ -80,8 +80,8 @@ import Data.Array.Accelerate.Array.Sugar (Elt, Arrays, EltRepr, ArrRep import qualified Data.Array.Accelerate.Array.Sugar as Sugar import Data.Array.Accelerate.Array.Representation hiding (DIM1) import Data.Array.Accelerate.AST hiding ( PreOpenAcc(..), OpenAcc(..), Acc - , PreOpenExp(..), OpenExp, PreExp, Exp - , PreBoundary(..), Boundary, HasArraysRepr(..), arrayRepr, expType + , OpenExp(..), Exp + , Boundary(..), HasArraysRepr(..), arrayRepr, expType , showPreAccOp, showPreExpOp ) import GHC.TypeNats diff --git a/src/Data/Array/Accelerate/Trafo.hs b/src/Data/Array/Accelerate/Trafo.hs index 7266a3a1b..9cadbe8c4 100644 --- a/src/Data/Array/Accelerate/Trafo.hs +++ b/src/Data/Array/Accelerate/Trafo.hs @@ -38,8 +38,6 @@ module Data.Array.Accelerate.Trafo ( -- * Fusion DelayedAcc, DelayedOpenAcc(..), DelayedAfun, DelayedOpenAfun, - DelayedExp, DelayedOpenExp, - DelayedFun, DelayedOpenFun, -- * Substitution module Data.Array.Accelerate.Trafo.Substitution, @@ -60,7 +58,7 @@ import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Array.Sugar ( ArrRepr, EltRepr ) import Data.Array.Accelerate.Trafo.Base ( Match(..), matchDelayedOpenAcc, encodeDelayedOpenAcc ) import Data.Array.Accelerate.Trafo.Config -import Data.Array.Accelerate.Trafo.Fusion ( DelayedAcc, DelayedOpenAcc(..), DelayedAfun, DelayedOpenAfun, DelayedExp, DelayedFun, DelayedOpenExp, DelayedOpenFun ) +import Data.Array.Accelerate.Trafo.Fusion ( DelayedAcc, DelayedOpenAcc(..), DelayedAfun, DelayedOpenAfun ) import Data.Array.Accelerate.Trafo.Sharing ( Function, FunctionR, Afunction, AfunctionR, AreprFunctionR, AfunctionRepr(..), afunctionRepr, EltReprFunctionR ) import Data.Array.Accelerate.Trafo.Substitution import qualified Data.Array.Accelerate.AST as AST diff --git a/src/Data/Array/Accelerate/Trafo/Algebra.hs b/src/Data/Array/Accelerate/Trafo/Algebra.hs index bc0edf04e..b24e11b42 100644 --- a/src/Data/Array/Accelerate/Trafo/Algebra.hs +++ b/src/Data/Array/Accelerate/Trafo/Algebra.hs @@ -52,13 +52,13 @@ import qualified Data.Array.Accelerate.Debug.Stats as Stats -- or constant let bindings. Be careful not to follow self-cycles. -- propagate - :: forall acc env aenv exp. Kit acc - => Gamma acc env env aenv - -> PreOpenExp acc env aenv exp + :: forall env aenv exp. + Gamma env env aenv + -> OpenExp env aenv exp -> Maybe exp propagate env = cvtE where - cvtE :: PreOpenExp acc env aenv e -> Maybe e + cvtE :: OpenExp env aenv e -> Maybe e cvtE exp = case exp of Const _ c -> Just c PrimConst c -> Just (evalPrimConst c) @@ -73,11 +73,11 @@ propagate env = cvtE -- Attempt to evaluate primitive function applications -- evalPrimApp - :: forall acc env aenv a r. (Kit acc) - => Gamma acc env env aenv + :: forall env aenv a r. + Gamma env env aenv -> PrimFun (a -> r) - -> PreOpenExp acc env aenv a - -> (Any, PreOpenExp acc env aenv r) + -> OpenExp env aenv a + -> (Any, OpenExp env aenv r) evalPrimApp env f x -- First attempt to move constant values towards the left | Just r <- commutes f x env = evalPrimApp env f r @@ -159,11 +159,11 @@ evalPrimApp env f x -- to the left of the operator. Returning Nothing indicates no change is made. -- commutes - :: forall acc env aenv a r. Kit acc - => PrimFun (a -> r) - -> PreOpenExp acc env aenv a - -> Gamma acc env env aenv - -> Maybe (PreOpenExp acc env aenv a) + :: forall env aenv a r. + PrimFun (a -> r) + -> OpenExp env aenv a + -> Gamma env env aenv + -> Maybe (OpenExp env aenv a) commutes f x env = case f of PrimAdd _ -> swizzle x PrimMul _ -> swizzle x @@ -176,7 +176,7 @@ commutes f x env = case f of PrimMin _ -> swizzle x _ -> Nothing where - swizzle :: PreOpenExp acc env aenv (b,b) -> Maybe (PreOpenExp acc env aenv (b,b)) + swizzle :: OpenExp env aenv (b,b) -> Maybe (OpenExp env aenv (b,b)) swizzle (Pair a b) | Nothing <- propagate env a , Just _ <- propagate env b @@ -213,8 +213,8 @@ commutes f x env = case f of associates :: (Elt a, Elt r) => PrimFun (a -> r) - -> PreOpenExp acc env aenv a - -> Maybe (PreOpenExp acc env aenv r) + -> OpenExp env aenv a + -> Maybe (OpenExp env aenv r) associates fun exp = case fun of PrimAdd _ -> swizzle fun exp [PrimAdd ty, PrimSub ty] PrimSub _ -> swizzle fun exp [PrimAdd ty, PrimSub ty] @@ -226,7 +226,7 @@ associates fun exp = case fun of ty = undefined ops = [ PrimMul ty, PrimFDiv ty, PrimAdd ty, PrimSub ty, PrimBAnd ty, PrimBOr ty, PrimBXor ty ] - swizzle :: (Elt a, Elt r) => PrimFun (a -> r) -> PreOpenExp acc env aenv a -> [PrimFun (a -> r)] -> Maybe (PreOpenExp acc env aenv r) + swizzle :: (Elt a, Elt r) => PrimFun (a -> r) -> OpenExp env aenv a -> [PrimFun (a -> r)] -> Maybe (OpenExp env aenv r) swizzle f x lvl | Just Refl <- matches f ops , Just (a,bc) <- untup2 x @@ -253,7 +253,7 @@ associates fun exp = case fun of -- Helper functions -- ---------------- -type a :-> b = forall acc env aenv. Kit acc => PreOpenExp acc env aenv a -> Gamma acc env env aenv -> Maybe (PreOpenExp acc env aenv b) +type a :-> b = forall env aenv. OpenExp env aenv a -> Gamma env env aenv -> Maybe (OpenExp env aenv b) eval1 :: SingleType b -> (a -> b) -> a :-> b eval1 tp f x env @@ -270,10 +270,10 @@ eval2 tp f (untup2 -> Just (x,y)) env eval2 _ _ _ _ = Nothing -tup2 :: (PreOpenExp acc env aenv a, PreOpenExp acc env aenv b) -> PreOpenExp acc env aenv (a, b) +tup2 :: (OpenExp env aenv a, OpenExp env aenv b) -> OpenExp env aenv (a, b) tup2 (a,b) = Pair a b -untup2 :: PreOpenExp acc env aenv (a, b) -> Maybe (PreOpenExp acc env aenv a, PreOpenExp acc env aenv b) +untup2 :: OpenExp env aenv (a, b) -> Maybe (OpenExp env aenv a, OpenExp env aenv b) untup2 exp | Pair a b <- exp = Just (a, b) | otherwise = Nothing diff --git a/src/Data/Array/Accelerate/Trafo/Base.hs b/src/Data/Array/Accelerate/Trafo/Base.hs index 060d77527..91ffdc181 100644 --- a/src/Data/Array/Accelerate/Trafo/Base.hs +++ b/src/Data/Array/Accelerate/Trafo/Base.hs @@ -37,8 +37,6 @@ module Data.Array.Accelerate.Trafo.Base ( -- Delayed Arrays DelayedAcc, DelayedOpenAcc(..), DelayedAfun, DelayedOpenAfun, - DelayedExp, DelayedOpenExp, - DelayedFun, DelayedOpenFun, matchDelayedOpenAcc, encodeDelayedOpenAcc, @@ -46,7 +44,7 @@ module Data.Array.Accelerate.Trafo.Base ( Gamma(..), incExp, prjExp, pushExp, Extend(..), pushArrayEnv, append, bind, Sink(..), SinkExp(..), sinkA, sink1, - PreOpenExp', bindExps, + OpenExp', bindExps, -- Adding new variables to the environment declareVars, DeclareVars(..), @@ -105,7 +103,7 @@ encodeOpenAcc :: EncodeAcc OpenAcc encodeOpenAcc options (OpenAcc pacc) = encodePreOpenAcc options encodeAcc pacc matchOpenAcc :: MatchAcc OpenAcc -matchOpenAcc (OpenAcc pacc1) (OpenAcc pacc2) = matchPreOpenAcc matchAcc encodeAcc pacc1 pacc2 +matchOpenAcc (OpenAcc pacc1) (OpenAcc pacc2) = matchPreOpenAcc matchAcc pacc1 pacc2 avarIn :: forall acc aenv a. Kit acc => ArrayVar aenv a -> acc aenv a avarIn v@(Var ArrayR{} _) = inject $ Avar v @@ -176,17 +174,17 @@ instance Match ArrayR where instance Match a => Match (TupR a) where match = matchTupR match -instance Kit acc => Match (PreOpenExp acc env aenv) where +instance Match (OpenExp env aenv) where {-# INLINEABLE match #-} - match = matchPreOpenExp matchAcc encodeAcc + match = matchOpenExp -instance Kit acc => Match (PreOpenFun acc env aenv) where +instance Match (OpenFun env aenv) where {-# INLINEABLE match #-} - match = matchPreOpenFun matchAcc encodeAcc + match = matchOpenFun instance Kit acc => Match (PreOpenAcc acc aenv) where {-# INLINEABLE match #-} - match = matchPreOpenAcc matchAcc encodeAcc + match = matchPreOpenAcc matchAcc instance {-# INCOHERENT #-} Kit acc => Match (acc aenv) where {-# INLINEABLE match #-} @@ -203,17 +201,12 @@ instance {-# INCOHERENT #-} Kit acc => Match (acc aenv) where type DelayedAcc = DelayedOpenAcc () type DelayedAfun = PreOpenAfun DelayedOpenAcc () -type DelayedExp = DelayedOpenExp () -type DelayedFun = DelayedOpenFun () - -- data DelayedSeq t where -- DelayedSeq :: Extend DelayedOpenAcc () aenv -- -> DelayedOpenSeq aenv () t -- -> DelayedSeq t type DelayedOpenAfun = PreOpenAfun DelayedOpenAcc -type DelayedOpenExp = PreOpenExp DelayedOpenAcc -type DelayedOpenFun = PreOpenFun DelayedOpenAcc -- type DelayedOpenSeq = PreOpenSeq DelayedOpenAcc data DelayedOpenAcc aenv a where @@ -221,9 +214,9 @@ data DelayedOpenAcc aenv a where Delayed :: { reprD :: ArrayR (Array sh e) - , extentD :: PreExp DelayedOpenAcc aenv sh - , indexD :: PreFun DelayedOpenAcc aenv (sh -> e) - , linearIndexD :: PreFun DelayedOpenAcc aenv (Int -> e) + , extentD :: Exp aenv sh + , indexD :: Fun aenv (sh -> e) + , linearIndexD :: Fun aenv (Int -> e) } -> DelayedOpenAcc aenv (Array sh e) instance HasArraysRepr DelayedOpenAcc where @@ -235,10 +228,10 @@ instance Rebuildable DelayedOpenAcc where {-# INLINEABLE rebuildPartial #-} rebuildPartial v acc = case acc of Manifest pacc -> Manifest <$> rebuildPartial v pacc - Delayed{..} -> Delayed reprD - <$> rebuildPartial v extentD - <*> rebuildPartial v indexD - <*> rebuildPartial v linearIndexD + Delayed{..} -> (\e i l -> Delayed reprD (unOpenAccExp e) (unOpenAccFun i) (unOpenAccFun l)) + <$> rebuildPartial v (OpenAccExp extentD) + <*> rebuildPartial v (OpenAccFun indexD) + <*> rebuildPartial v (OpenAccFun linearIndexD) instance Sink DelayedOpenAcc where weaken k = Stats.substitution "weaken" . rebuildA (rebuildWeakenVar k) @@ -265,11 +258,11 @@ instance NFData (DelayedOpenAcc aenv t) where encodeDelayedOpenAcc :: EncodeAcc DelayedOpenAcc encodeDelayedOpenAcc options acc = let - travE :: DelayedExp aenv sh -> Builder - travE = encodePreOpenExp options encodeDelayedOpenAcc + travE :: Exp aenv sh -> Builder + travE = encodeOpenExp - travF :: DelayedFun aenv f -> Builder - travF = encodePreOpenFun options encodeDelayedOpenAcc + travF :: Fun aenv f -> Builder + travF = encodeOpenFun travA :: PreOpenAcc DelayedOpenAcc aenv a -> Builder travA = encodePreOpenAcc options encodeDelayedOpenAcc @@ -285,12 +278,12 @@ encodeDelayedOpenAcc options acc = {-# INLINEABLE matchDelayedOpenAcc #-} matchDelayedOpenAcc :: MatchAcc DelayedOpenAcc matchDelayedOpenAcc (Manifest pacc1) (Manifest pacc2) - = matchPreOpenAcc matchDelayedOpenAcc encodeDelayedOpenAcc pacc1 pacc2 + = matchPreOpenAcc matchDelayedOpenAcc pacc1 pacc2 matchDelayedOpenAcc (Delayed _ sh1 ix1 lx1) (Delayed _ sh2 ix2 lx2) - | Just Refl <- matchPreOpenExp matchDelayedOpenAcc encodeDelayedOpenAcc sh1 sh2 - , Just Refl <- matchPreOpenFun matchDelayedOpenAcc encodeDelayedOpenAcc ix1 ix2 - , Just Refl <- matchPreOpenFun matchDelayedOpenAcc encodeDelayedOpenAcc lx1 lx2 + | Just Refl <- matchOpenExp sh1 sh2 + , Just Refl <- matchOpenFun ix1 ix2 + , Just Refl <- matchOpenFun lx1 lx2 = Just Refl matchDelayedOpenAcc _ _ @@ -299,9 +292,9 @@ matchDelayedOpenAcc _ _ rnfDelayedOpenAcc :: DelayedOpenAcc aenv t -> () rnfDelayedOpenAcc (Manifest pacc) = rnfPreOpenAcc rnfDelayedOpenAcc pacc rnfDelayedOpenAcc (Delayed repr sh ix lx) = rnfArrayR repr - `seq` rnfPreOpenExp rnfDelayedOpenAcc sh - `seq` rnfPreOpenFun rnfDelayedOpenAcc ix - `seq` rnfPreOpenFun rnfDelayedOpenAcc lx + `seq` rnfOpenExp sh + `seq` rnfOpenFun ix + `seq` rnfOpenFun lx {-- rnfDelayedSeq :: DelayedSeq t -> () @@ -321,18 +314,18 @@ rnfExtend rnfA (PushEnv env a) = rnfExtend rnfA env `seq` rnfA a -- environment variable env' is used to project out the corresponding -- index when looking up in the environment congruent expressions. -- -data Gamma acc env env' aenv where - EmptyExp :: Gamma acc env env' aenv +data Gamma env env' aenv where + EmptyExp :: Gamma env env' aenv - PushExp :: Gamma acc env env' aenv - -> WeakPreOpenExp acc env aenv t - -> Gamma acc env (env', t) aenv + PushExp :: Gamma env env' aenv + -> WeakOpenExp env aenv t + -> Gamma env (env', t) aenv -data WeakPreOpenExp acc env aenv t where +data WeakOpenExp env aenv t where Subst :: env :> env' - -> PreOpenExp acc env aenv t - -> PreOpenExp acc env' aenv t {- LAZY -} - -> WeakPreOpenExp acc env' aenv t + -> OpenExp env aenv t + -> OpenExp env' aenv t {- LAZY -} + -> WeakOpenExp env' aenv t -- XXX: The simplifier calls this function every time it moves under a let -- binding. This means we have a number of calls to 'weakenE' exponential in the @@ -346,28 +339,26 @@ data WeakPreOpenExp acc env aenv t where -- -- incExp - :: Kit acc - => Gamma acc env env' aenv - -> Gamma acc (env,s) env' aenv + :: Gamma env env' aenv + -> Gamma (env,s) env' aenv incExp EmptyExp = EmptyExp incExp (PushExp env w) = incExp env `PushExp` subs w where - subs :: forall acc env aenv s t. Kit acc => WeakPreOpenExp acc env aenv t -> WeakPreOpenExp acc (env,s) aenv t - subs (Subst k (e :: PreOpenExp acc env_ aenv t) _) = Subst (weakenSucc' k) e (weakenE (weakenSucc' k) e) + subs :: forall env aenv s t. WeakOpenExp env aenv t -> WeakOpenExp (env,s) aenv t + subs (Subst k (e :: OpenExp env_ aenv t) _) = Subst (weakenSucc' k) e (weakenE (weakenSucc' k) e) -prjExp :: Idx env' t -> Gamma acc env env' aenv -> PreOpenExp acc env aenv t +prjExp :: Idx env' t -> Gamma env env' aenv -> OpenExp env aenv t prjExp ZeroIdx (PushExp _ (Subst _ _ e)) = e prjExp (SuccIdx ix) (PushExp env _) = prjExp ix env prjExp _ _ = $internalError "prjExp" "inconsistent valuation" -pushExp :: Gamma acc env env' aenv -> PreOpenExp acc env aenv t -> Gamma acc env (env',t) aenv +pushExp :: Gamma env env' aenv -> OpenExp env aenv t -> Gamma env (env',t) aenv pushExp env e = env `PushExp` Subst weakenId e e {-- lookupExp - :: Kit acc - => Gamma acc env env' aenv - -> PreOpenExp acc env aenv t + :: Gamma env env' aenv + -> OpenExp env aenv t -> Maybe (Idx env' t) lookupExp EmptyExp _ = Nothing lookupExp (PushExp env e) x @@ -375,17 +366,16 @@ lookupExp (PushExp env e) x | otherwise = SuccIdx `fmap` lookupExp env x weakenGamma1 - :: Kit acc - => Gamma acc env env' aenv - -> Gamma acc env env' (aenv,t) + :: Gamma env env' aenv + -> Gamma env env' (aenv,t) weakenGamma1 EmptyExp = EmptyExp weakenGamma1 (PushExp env e) = PushExp (weakenGamma1 env) (weaken SuccIdx e) sinkGamma :: Kit acc => Extend acc aenv aenv' - -> Gamma acc env env' aenv - -> Gamma acc env env' aenv' + -> Gamma env env' aenv + -> Gamma env env' aenv' sinkGamma _ EmptyExp = EmptyExp sinkGamma ext (PushExp env e) = PushExp (sinkGamma ext env) (sinkA ext e) --} @@ -440,24 +430,22 @@ sinkWeaken (PushEnv e (LeftHandSidePair l1 l2) _) = sinkWeaken (PushEnv (PushEnv sink1 :: Sink f => Extend s acc env env' -> f (env,t') t -> f (env',t') t sink1 env = weaken $ sink $ sinkWeaken env --- Wrapper around PreOpenExp, with the order of type arguments env and aenv flipped -newtype PreOpenExp' acc aenv env e = PreOpenExp' (PreOpenExp acc env aenv e) +-- Wrapper around OpenExp, with the order of type arguments env and aenv flipped +newtype OpenExp' aenv env e = OpenExp' (OpenExp env aenv e) -bindExps :: Kit acc - => Extend ScalarType (PreOpenExp' acc aenv) env env' - -> PreOpenExp acc env' aenv e - -> PreOpenExp acc env aenv e +bindExps :: Extend ScalarType (OpenExp' aenv) env env' + -> OpenExp env' aenv e + -> OpenExp env aenv e bindExps BaseEnv = id -bindExps (PushEnv g lhs (PreOpenExp' b)) = bindExps g . Let lhs b +bindExps (PushEnv g lhs (OpenExp' b)) = bindExps g . Let lhs b -- Utilities for working with shapes -mkShapeBinary :: (HasArraysRepr acc, RebuildableAcc acc) - => (forall env'. PreOpenExp acc env' aenv Int -> PreOpenExp acc env' aenv Int -> PreOpenExp acc env' aenv Int) +mkShapeBinary :: (forall env'. OpenExp env' aenv Int -> OpenExp env' aenv Int -> OpenExp env' aenv Int) -> ShapeR sh - -> PreOpenExp acc env aenv sh - -> PreOpenExp acc env aenv sh - -> PreOpenExp acc env aenv sh + -> OpenExp env aenv sh + -> OpenExp env aenv sh + -> OpenExp env aenv sh mkShapeBinary _ ShapeRz _ _ = Nil mkShapeBinary f (ShapeRsnoc shr) (Pair as a) (Pair bs b) = mkShapeBinary f shr as bs `Pair` f a b mkShapeBinary f shr (Let lhs bnd a) b = Let lhs bnd $ mkShapeBinary f shr a (weakenE (weakenWithLHS lhs) b) @@ -469,20 +457,18 @@ mkShapeBinary f shr a b -- `b` is not a Pair | DeclareVars lhs k value <- declareVars $ shapeType shr = Let lhs b $ mkShapeBinary f shr (weakenE k a) (evars $ value weakenId) -mkIntersect :: (HasArraysRepr acc, RebuildableAcc acc) - => ShapeR sh - -> PreOpenExp acc env aenv sh - -> PreOpenExp acc env aenv sh - -> PreOpenExp acc env aenv sh +mkIntersect :: ShapeR sh + -> OpenExp env aenv sh + -> OpenExp env aenv sh + -> OpenExp env aenv sh mkIntersect = mkShapeBinary f where f a b = PrimApp (PrimMin singleType) $ Pair a b -mkUnion :: (HasArraysRepr acc, RebuildableAcc acc) - => ShapeR sh - -> PreOpenExp acc env aenv sh - -> PreOpenExp acc env aenv sh - -> PreOpenExp acc env aenv sh +mkUnion :: ShapeR sh + -> OpenExp env aenv sh + -> OpenExp env aenv sh + -> OpenExp env aenv sh mkUnion = mkShapeBinary f where f a b = PrimApp (PrimMax singleType) $ Pair a b diff --git a/src/Data/Array/Accelerate/Trafo/Config.hs b/src/Data/Array/Accelerate/Trafo/Config.hs index a3f984023..488d02001 100644 --- a/src/Data/Array/Accelerate/Trafo/Config.hs +++ b/src/Data/Array/Accelerate/Trafo/Config.hs @@ -17,7 +17,7 @@ module Data.Array.Accelerate.Trafo.Config ( defaultOptions, -- Other options not controlled by the command line flags - float_out_acc, + -- float_out_acc, ) where @@ -46,5 +46,5 @@ defaultOptions = unsafePerformIO $! -- Extra options not covered by command line flags -- -float_out_acc = Flag 31 +-- float_out_acc = Flag 31 diff --git a/src/Data/Array/Accelerate/Trafo/Fusion.hs b/src/Data/Array/Accelerate/Trafo/Fusion.hs index c6f513117..b2f4258a2 100644 --- a/src/Data/Array/Accelerate/Trafo/Fusion.hs +++ b/src/Data/Array/Accelerate/Trafo/Fusion.hs @@ -38,7 +38,6 @@ module Data.Array.Accelerate.Trafo.Fusion ( -- ** Types DelayedAcc, DelayedOpenAcc(..), DelayedAfun, DelayedOpenAfun, - DelayedExp, DelayedFun, DelayedOpenExp, DelayedOpenFun, -- ** Conversion convertAcc, convertAccWith, @@ -133,21 +132,14 @@ delayed config (embedOpenAcc config -> Embed env cc) | BaseEnv <- env = case simplify cc of Done v -> avarsIn v - Yield repr (cvtE -> sh) (cvtF -> f) -> Delayed repr sh f (f `compose` fromIndex (arrayRshape repr) sh) - Step repr (cvtE -> sh) (cvtF -> p) (cvtF -> f) v + Yield repr sh f -> Delayed repr sh f (f `compose` fromIndex (arrayRshape repr) sh) + Step repr sh p f v | Just Refl <- match sh (arrayShape v) , Just Refl <- isIdentity p -> Delayed repr sh (f `compose` indexArray v) (f `compose` linearIndex v) | f' <- f `compose` indexArray v `compose` p -> Delayed repr sh f' (f' `compose` fromIndex (arrayRshape repr) sh) -- | otherwise = manifest config (computeAcc (Embed env cc)) - where - cvtE :: OpenExp env aenv t -> DelayedOpenExp env aenv t - cvtE = convertOpenExp config - - cvtF :: OpenFun env aenv f -> DelayedOpenFun env aenv f - cvtF (Lam lhs f) = Lam lhs (cvtF f) - cvtF (Body b) = Body (cvtE b) -- Convert array programs as manifest terms. @@ -161,9 +153,9 @@ manifest config (OpenAcc pacc) = -- ----------------- Avar ix -> Avar ix Use repr arr -> Use repr arr - Unit tp e -> Unit tp (cvtE e) + Unit tp e -> Unit tp e Alet lhs bnd body -> alet lhs (manifest config bnd) (manifest config body) - Acond p t e -> Acond (cvtE p) (manifest config t) (manifest config e) + Acond p t e -> Acond p (manifest config t) (manifest config e) Awhile p f a -> Awhile (cvtAF p) (cvtAF f) (manifest config a) Apair a1 a2 -> Apair (manifest config a1) (manifest config a2) Anil -> Anil @@ -178,11 +170,11 @@ manifest config (OpenAcc pacc) = -- of a let-binding to be used multiple times. The input array here -- should be a evaluated array term, else something went wrong. -- - Map tp f a -> Map tp (cvtF f) (delayed config a) - Generate repr sh f -> Generate repr (cvtE sh) (cvtF f) - Transform repr sh p f a -> Transform repr (cvtE sh) (cvtF p) (cvtF f) (delayed config a) - Backpermute shr sh p a -> Backpermute shr (cvtE sh) (cvtF p) (delayed config a) - Reshape slr sl a -> Reshape slr (cvtE sl) (manifest config a) + Map tp f a -> Map tp f (delayed config a) + Generate repr sh f -> Generate repr sh f + Transform repr sh p f a -> Transform repr sh p f (delayed config a) + Backpermute shr sh p a -> Backpermute shr sh p (delayed config a) + Reshape slr sl a -> Reshape slr sl (manifest config a) Replicate{} -> fusionError Slice{} -> fusionError @@ -195,20 +187,20 @@ manifest config (OpenAcc pacc) = -- with local bindings, these will have been floated up above the -- consumer already -- - Fold f z a -> Fold (cvtF f) (cvtE z) (delayed config a) - Fold1 f a -> Fold1 (cvtF f) (delayed config a) - FoldSeg i f z a s -> FoldSeg i (cvtF f) (cvtE z) (delayed config a) (delayed config s) - Fold1Seg i f a s -> Fold1Seg i (cvtF f) (delayed config a) (delayed config s) - Scanl f z a -> Scanl (cvtF f) (cvtE z) (delayed config a) - Scanl1 f a -> Scanl1 (cvtF f) (delayed config a) - Scanl' f z a -> Scanl' (cvtF f) (cvtE z) (delayed config a) - Scanr f z a -> Scanr (cvtF f) (cvtE z) (delayed config a) - Scanr1 f a -> Scanr1 (cvtF f) (delayed config a) - Scanr' f z a -> Scanr' (cvtF f) (cvtE z) (delayed config a) - Permute f d p a -> Permute (cvtF f) (manifest config d) (cvtF p) (delayed config a) - Stencil s tp f x a -> Stencil s tp (cvtF f) (cvtB x) (delayed config a) + Fold f z a -> Fold f z (delayed config a) + Fold1 f a -> Fold1 f (delayed config a) + FoldSeg i f z a s -> FoldSeg i f z (delayed config a) (delayed config s) + Fold1Seg i f a s -> Fold1Seg i f (delayed config a) (delayed config s) + Scanl f z a -> Scanl f z (delayed config a) + Scanl1 f a -> Scanl1 f (delayed config a) + Scanl' f z a -> Scanl' f z (delayed config a) + Scanr f z a -> Scanr f z (delayed config a) + Scanr1 f a -> Scanr1 f (delayed config a) + Scanr' f z a -> Scanr' f z (delayed config a) + Permute f d p a -> Permute f (manifest config d) p (delayed config a) + Stencil s tp f x a -> Stencil s tp f x (delayed config a) Stencil2 s1 s2 tp f x a y b - -> Stencil2 s1 s2 tp (cvtF f) (cvtB x) (delayed config a) (cvtB y) (delayed config b) + -> Stencil2 s1 s2 tp f x (delayed config a) y (delayed config b) -- Collect s -> Collect (cvtS s) where @@ -256,58 +248,6 @@ manifest config (OpenAcc pacc) = -- cvtS :: PreOpenSeq OpenAcc aenv senv s -> PreOpenSeq DelayedOpenAcc aenv senv s -- cvtS = convertOpenSeq config - -- Conversions for closed scalar functions and expressions - -- - cvtF :: OpenFun env aenv f -> DelayedOpenFun env aenv f - cvtF (Lam lhs f) = Lam lhs (cvtF f) - cvtF (Body b) = Body (cvtE b) - - cvtE :: OpenExp env aenv t -> DelayedOpenExp env aenv t - cvtE = convertOpenExp config - - cvtB :: Boundary aenv t -> PreBoundary DelayedOpenAcc aenv t - cvtB Clamp = Clamp - cvtB Mirror = Mirror - cvtB Wrap = Wrap - cvtB (Constant v) = Constant v - cvtB (Function f) = Function (cvtF f) - -convertOpenExp :: Config -> OpenExp env aenv t -> DelayedOpenExp env aenv t -convertOpenExp config exp = - case exp of - Let lhs bnd body -> Let lhs (cvtE bnd) (cvtE body) - Evar var -> Evar var - Const tp c -> Const tp c - Undef tp -> Undef tp - Nil -> Nil - Pair e1 e2 -> Pair (cvtE e1) (cvtE e2) - VecPack vec e -> VecPack vec (cvtE e) - VecUnpack vec e -> VecUnpack vec (cvtE e) - IndexSlice x ix sh -> IndexSlice x (cvtE ix) (cvtE sh) - IndexFull x ix sl -> IndexFull x (cvtE ix) (cvtE sl) - ToIndex shr sh ix -> ToIndex shr (cvtE sh) (cvtE ix) - FromIndex shr sh ix -> FromIndex shr (cvtE sh) (cvtE ix) - Cond p t e -> Cond (cvtE p) (cvtE t) (cvtE e) - While p f x -> While (cvtF p) (cvtF f) (cvtE x) - PrimConst c -> PrimConst c - PrimApp f x -> PrimApp f (cvtE x) - Index a sh -> Index (manifest config a) (cvtE sh) - LinearIndex a i -> LinearIndex (manifest config a) (cvtE i) - Shape a -> Shape (manifest config a) - ShapeSize shr sh -> ShapeSize shr (cvtE sh) - Foreign tp ff f e -> Foreign tp ff (cvtF f) (cvtE e) - Coerce t1 t2 e -> Coerce t1 t2 (cvtE e) - where - -- Conversions for closed scalar functions and expressions - -- - cvtF :: OpenFun env aenv f -> DelayedOpenFun env aenv f - cvtF (Lam lhs f) = Lam lhs (cvtF f) - cvtF (Body b) = Body (cvtE b) - - cvtE :: OpenExp env aenv t -> DelayedOpenExp env aenv t - cvtE = convertOpenExp config - - convertOpenAfun :: Config -> OpenAfun aenv f -> DelayedOpenAfun aenv f convertOpenAfun c (Alam lhs f) = Alam lhs (convertOpenAfun c f) convertOpenAfun c (Abody b) = Abody (convertOpenAcc c b) @@ -512,13 +452,13 @@ embedPreAcc config embedAcc elimAcc pacc -- Conversions for closed scalar functions and expressions. This just -- applies scalar simplifications. -- - cvtF :: PreFun acc aenv' t -> PreFun acc aenv' t + cvtF :: Fun aenv' t -> Fun aenv' t cvtF = simplify - cvtE :: PreExp acc aenv' t -> PreExp acc aenv' t + cvtE :: Exp aenv' t -> Exp aenv' t cvtE = simplify - cvtB :: PreBoundary acc aenv' t -> PreBoundary acc aenv' t + cvtB :: Boundary aenv' t -> Boundary aenv' t cvtB Clamp = Clamp cvtB Mirror = Mirror cvtB Wrap = Wrap @@ -543,12 +483,12 @@ embedPreAcc config embedAcc elimAcc pacc -- directly on the delayed representation. See also: [Representing -- delayed arrays] -- - fuse :: (forall aenv'. Extend ArrayR acc aenv aenv' -> Cunctation acc aenv' as -> Cunctation acc aenv' bs) + fuse :: (forall aenv'. Extend ArrayR acc aenv aenv' -> Cunctation aenv' as -> Cunctation aenv' bs) -> acc aenv as -> Embed acc aenv bs fuse op (embedAcc -> Embed env cc) = Embed env (op env cc) - fuse2 :: (forall aenv'. Extend ArrayR acc aenv aenv' -> Cunctation acc aenv' as -> Cunctation acc aenv' bs -> Cunctation acc aenv' cs) + fuse2 :: (forall aenv'. Extend ArrayR acc aenv aenv' -> Cunctation aenv' as -> Cunctation aenv' bs -> Cunctation aenv' cs) -> acc aenv as -> acc aenv bs -> Embed acc aenv cs @@ -709,10 +649,10 @@ embedSeq embedAcc s cvtCT NilAtup = NilAtup cvtCT (SnocAtup t c) = SnocAtup (cvtCT t) (travC c env) - cvtE :: Elt t => PreExp acc aenv' t -> PreExp acc aenv' t + cvtE :: Elt t => Exp aenv' t -> Exp aenv' t cvtE = simplify - cvtF :: PreFun acc aenv' t -> PreFun acc aenv' t + cvtF :: Fun aenv' t -> Fun aenv' t cvtF = simplify cvtA :: Arrays a => acc aenv' a -> acc aenv' a @@ -771,7 +711,7 @@ data ExtendProducer acc aenv senv arrs where -- data Embed acc aenv a where Embed :: Extend ArrayR acc aenv aenv' - -> Cunctation acc aenv' a + -> Cunctation aenv' a -> Embed acc aenv a instance HasArraysRepr acc => HasArraysRepr (Embed acc) where @@ -785,23 +725,23 @@ instance HasArraysRepr acc => HasArraysRepr (Embed acc) where -- element at each index, and fusing successive producers by combining these -- scalar functions. -- -data Cunctation acc aenv a where +data Cunctation aenv a where -- The base case is just a real (manifest) array term. No fusion happens here. -- Note that the array is referenced by an index into the extended -- environment, ensuring that the array is manifest and making the term -- non-recursive in 'acc'. -- - Done :: ArrayVars aenv arrs - -> Cunctation acc aenv arrs + Done :: ArrayVars aenv arrs + -> Cunctation aenv arrs -- We can represent an array by its shape and a function to compute an element -- at each index. -- Yield :: ArrayR (Array sh e) - -> PreExp acc aenv sh - -> PreFun acc aenv (sh -> e) - -> Cunctation acc aenv (Array sh e) + -> Exp aenv sh + -> Fun aenv (sh -> e) + -> Cunctation aenv (Array sh e) -- A more restrictive form than 'Yield' may afford greater opportunities for -- optimisation by a backend. This more structured form applies an index and @@ -810,13 +750,13 @@ data Cunctation acc aenv a where -- it is always possible to embed into a collective operation. -- Step :: ArrayR (Array sh' b) - -> PreExp acc aenv sh' - -> PreFun acc aenv (sh' -> sh) - -> PreFun acc aenv (a -> b) - -> ArrayVar aenv (Array sh a) - -> Cunctation acc aenv (Array sh' b) + -> Exp aenv sh' + -> Fun aenv (sh' -> sh) + -> Fun aenv (a -> b) + -> ArrayVar aenv (Array sh a) + -> Cunctation aenv (Array sh' b) -instance Kit acc => Simplify (Cunctation acc aenv a) where +instance Simplify (Cunctation aenv a) where simplify = \case Done v -> Done v Yield repr (simplify -> sh) (simplify -> f) -> Yield repr sh f @@ -826,7 +766,7 @@ instance Kit acc => Simplify (Cunctation acc aenv a) where , Just Refl <- isIdentity f -> Done $ VarsSingle v | otherwise -> Step repr sh p f v -instance HasArraysRepr (Cunctation acc) where +instance HasArraysRepr Cunctation where arraysRepr (Done v) = varsType v arraysRepr (Yield repr _ _) = TupRsingle repr arraysRepr (Step repr _ _ _ _) = TupRsingle repr @@ -840,14 +780,13 @@ done pacc | DeclareVars lhs _ value <- declareVars (arraysRepr pacc) = Embed (PushEnv BaseEnv lhs $ inject pacc) $ Done $ value weakenId -doneZeroIdx :: ArrayR (Array sh e) -> Cunctation acc (aenv, Array sh e) (Array sh e) +doneZeroIdx :: ArrayR (Array sh e) -> Cunctation (aenv, Array sh e) (Array sh e) doneZeroIdx repr = Done $ VarsSingle $ Var repr ZeroIdx -- Recast a cunctation into a mapping from indices to elements. -- -yield :: Kit acc - => Cunctation acc aenv (Array sh e) - -> Cunctation acc aenv (Array sh e) +yield :: Cunctation aenv (Array sh e) + -> Cunctation aenv (Array sh e) yield cc = case cc of Yield{} -> cc @@ -859,9 +798,8 @@ yield cc = -- Recast a cunctation into transformation step form. Not possible if the source -- was in the Yield formulation. -- -step :: Kit acc - => Cunctation acc aenv (Array sh e) - -> Maybe (Cunctation acc aenv (Array sh e)) +step :: Cunctation aenv (Array sh e) + -> Maybe (Cunctation aenv (Array sh e)) step cc = case cc of Yield{} -> Nothing @@ -872,7 +810,7 @@ step cc = -- Get the shape of a delayed array -- -shape :: Kit acc => Cunctation acc aenv (Array sh e) -> PreExp acc aenv sh +shape :: Cunctation aenv (Array sh e) -> Exp aenv sh shape cc | Just (Step _ sh _ _ _) <- step cc = sh | Yield _ sh _ <- yield cc = sh @@ -881,7 +819,7 @@ shape cc -- Environment manipulation -- ======================== -instance Kit acc => Sink (Cunctation acc) where +instance Sink Cunctation where weaken k = \case Done v -> Done (weaken k v) Step repr sh p f v -> Step repr (weaken k sh) (weaken k p) (weaken k f) (weaken k v) @@ -958,25 +896,28 @@ computeAcc (Embed env@(PushEnv bot lhs top) cc) = -> case ix of ZeroIdx | LeftHandSideSingle ArrayR{} <- lhs - , Just g <- strengthen noTop f -> bindA bot (inject (Map (arrayRtype repr) g top)) - _ -> bindA env (inject (Map (arrayRtype repr) f (avarIn v))) + , Just (OpenAccFun g) <- strengthen noTop (OpenAccFun f) + -> bindA bot (inject (Map (arrayRtype repr) g top)) + _ -> bindA env (inject (Map (arrayRtype repr) f (avarIn v))) | Just Refl <- isIdentity f -> case ix of ZeroIdx | LeftHandSideSingle ArrayR{} <- lhs - , Just q <- strengthen noTop p - , Just sz <- strengthen noTop sh -> bindA bot (inject (Backpermute (arrayRshape repr) sz q top)) - _ -> bindA env (inject (Backpermute (arrayRshape repr) sh p (avarIn v))) + , Just (OpenAccFun q) <- strengthen noTop (OpenAccFun p) + , Just (OpenAccExp sz) <- strengthen noTop (OpenAccExp sh) + -> bindA bot (inject (Backpermute (arrayRshape repr) sz q top)) + _ -> bindA env (inject (Backpermute (arrayRshape repr) sh p (avarIn v))) | otherwise -> case ix of ZeroIdx | LeftHandSideSingle ArrayR{} <- lhs - , Just g <- strengthen noTop f - , Just q <- strengthen noTop p - , Just sz <- strengthen noTop sh -> bindA bot (inject (Transform repr sz q g top)) - _ -> bindA env (inject (Transform repr sh p f (avarIn v))) + , Just (OpenAccFun g) <- strengthen noTop (OpenAccFun f) + , Just (OpenAccFun q) <- strengthen noTop (OpenAccFun p) + , Just (OpenAccExp sz) <- strengthen noTop (OpenAccExp sh) + -> bindA bot (inject (Transform repr sz q g top)) + _ -> bindA env (inject (Transform repr sh p f (avarIn v))) where bindA :: Kit acc @@ -1000,7 +941,7 @@ computeAcc (Embed env@(PushEnv bot lhs top) cc) = -- Convert the internal representation of delayed arrays into a real AST -- node. Use the most specific version of a combinator whenever possible. -- -compute :: Kit acc => Cunctation acc aenv arrs -> PreOpenAcc acc aenv arrs +compute :: Kit acc => Cunctation aenv arrs -> PreOpenAcc acc aenv arrs compute cc = case simplify cc of Done VarsNil -> Anil Done (VarsSingle v@(Var ArrayR{} _)) -> Avar v @@ -1016,8 +957,8 @@ compute cc = case simplify cc of -- Representation of a generator as a delayed array -- generateD :: ArrayR (Array sh e) - -> PreExp acc aenv sh - -> PreFun acc aenv (sh -> e) + -> Exp aenv sh + -> Fun aenv (sh -> e) -> Embed acc aenv (Array sh e) generateD repr sh f = Stats.ruleFired "generateD" @@ -1029,7 +970,7 @@ generateD repr sh f -- mapD :: Kit acc => TupleType b - -> PreFun acc aenv (a -> b) + -> Fun aenv (a -> b) -> Embed acc aenv (Array sh a) -> Embed acc aenv (Array sh b) mapD tp f (unzipD tp f -> Just a) = a @@ -1047,7 +988,7 @@ mapD tp f (Embed env cc) unzipD :: Kit acc => TupleType b - -> PreFun acc aenv (a -> b) + -> Fun aenv (a -> b) -> Embed acc aenv (Array sh a) -> Maybe (Embed acc aenv (Array sh b)) unzipD tp f (Embed env cc@(Done v)) @@ -1063,12 +1004,11 @@ unzipD _ _ _ -- the destination array read there data from in the source array. -- backpermuteD - :: Kit acc - => ShapeR sh' - -> PreExp acc aenv sh' - -> PreFun acc aenv (sh' -> sh) - -> Cunctation acc aenv (Array sh e) - -> Cunctation acc aenv (Array sh' e) + :: ShapeR sh' + -> Exp aenv sh' + -> Fun aenv (sh' -> sh) + -> Cunctation aenv (Array sh e) + -> Cunctation aenv (Array sh' e) backpermuteD shr' sh' p = Stats.ruleFired "backpermuteD" . go where go (step -> Just (Step (ArrayR _ tp) _ q f v)) = Step (ArrayR shr' tp) sh' (q `compose` p) f v @@ -1080,9 +1020,9 @@ backpermuteD shr' sh' p = Stats.ruleFired "backpermuteD" . go transformD :: Kit acc => ArrayR (Array sh' b) - -> PreExp acc aenv sh' - -> PreFun acc aenv (sh' -> sh) - -> PreFun acc aenv (a -> b) + -> Exp aenv sh' + -> Fun aenv (sh' -> sh) + -> Fun aenv (a -> b) -> Embed acc aenv (Array sh a) -> Embed acc aenv (Array sh' b) transformD (ArrayR shr' tp) sh' p f @@ -1090,7 +1030,7 @@ transformD (ArrayR shr' tp) sh' p f . fuse (into2 (backpermuteD shr') sh' p) . mapD tp f where - fuse :: (forall aenv'. Extend ArrayR acc aenv aenv' -> Cunctation acc aenv' as -> Cunctation acc aenv' bs) + fuse :: (forall aenv'. Extend ArrayR acc aenv aenv' -> Cunctation aenv' as -> Cunctation aenv' bs) -> Embed acc aenv as -> Embed acc aenv bs fuse op (Embed env cc) = Embed env (op env cc) @@ -1107,11 +1047,10 @@ transformD (ArrayR shr' tp) sh' p f -- expensive and/or `sh` is large. -- replicateD - :: Kit acc - => SliceIndex slix sl co sh - -> PreExp acc aenv slix - -> Cunctation acc aenv (Array sl e) - -> Cunctation acc aenv (Array sh e) + :: SliceIndex slix sl co sh + -> Exp aenv slix + -> Cunctation aenv (Array sl e) + -> Cunctation aenv (Array sh e) replicateD sliceIndex slix cc = Stats.ruleFired "replicateD" $ backpermuteD (sliceDomainR sliceIndex) (IndexFull sliceIndex slix (shape cc)) (extend sliceIndex slix) cc @@ -1120,11 +1059,10 @@ replicateD sliceIndex slix cc -- Dimensional slice as a backwards permutation -- sliceD - :: Kit acc - => SliceIndex slix sl co sh - -> PreExp acc aenv slix - -> Cunctation acc aenv (Array sh e) - -> Cunctation acc aenv (Array sl e) + :: SliceIndex slix sl co sh + -> Exp aenv slix + -> Cunctation aenv (Array sh e) + -> Cunctation aenv (Array sl e) sliceD sliceIndex slix cc = Stats.ruleFired "sliceD" $ backpermuteD (sliceShapeR sliceIndex) (IndexSlice sliceIndex slix (shape cc)) (restrict sliceIndex slix) cc @@ -1144,7 +1082,7 @@ reshapeD :: Kit acc => ShapeR sl -> Embed acc aenv (Array sh e) - -> PreExp acc aenv sl + -> Exp aenv sl -> Embed acc aenv (Array sl e) reshapeD slr (Embed env cc) (sinkA env -> sl) | Done v <- cc @@ -1162,12 +1100,11 @@ reshapeD slr (Embed env cc) (sinkA env -> sl) -- Combine two arrays element-wise with a binary function to produce a delayed -- array. -- -zipWithD :: Kit acc - => TupleType c - -> PreFun acc aenv (a -> b -> c) - -> Cunctation acc aenv (Array sh a) - -> Cunctation acc aenv (Array sh b) - -> Cunctation acc aenv (Array sh c) +zipWithD :: TupleType c + -> Fun aenv (a -> b -> c) + -> Cunctation aenv (Array sh a) + -> Cunctation aenv (Array sh b) + -> Cunctation aenv (Array sh c) zipWithD tp f cc1 cc0 -- Two stepper functions identically accessing the same array can be kept in -- stepping form. This might yield a simpler final term. @@ -1188,11 +1125,11 @@ zipWithD tp f cc1 cc0 $ Yield (ArrayR shr tp) (mkIntersect shr sh1 sh0) (combine f f1 f0) where - combine :: forall acc aenv a b c e. Kit acc - => PreFun acc aenv (a -> b -> c) - -> PreFun acc aenv (e -> a) - -> PreFun acc aenv (e -> b) - -> PreFun acc aenv (e -> c) + combine :: forall aenv a b c e. + Fun aenv (a -> b -> c) + -> Fun aenv (e -> a) + -> Fun aenv (e -> b) + -> Fun aenv (e -> c) combine c ixa ixb | Lam lhs1 (Body ixa') <- ixa -- else the skolem 'e' will escape , Lam lhs2 (Body ixb') <- ixb @@ -1379,7 +1316,7 @@ aletD' embedAcc elimAcc (LeftHandSideSingle ArrayR{}) (Embed env1 cc1) (Embed en -- eliminate :: forall aenv aenv' sh e brrs. Extend ArrayR acc aenv aenv' - -> Cunctation acc aenv' (Array sh e) + -> Cunctation aenv' (Array sh e) -> acc (aenv', Array sh e) brrs -> Embed acc aenv brrs eliminate env1 cc1 body @@ -1391,7 +1328,7 @@ aletD' embedAcc elimAcc (LeftHandSideSingle ArrayR{}) (Embed env1 cc1) (Embed en bnd :: PreOpenAcc acc aenv' (Array sh e) bnd = compute cc1 - elim :: ArrayR (Array sh e) -> PreExp acc aenv' sh -> PreFun acc aenv' (sh -> e) -> Embed acc aenv brrs + elim :: ArrayR (Array sh e) -> Exp aenv' sh -> Fun aenv' (sh -> e) -> Embed acc aenv brrs elim r sh1 f1 | sh1' <- weaken (weakenSucc' weakenId) sh1 , f1' <- weaken (weakenSucc' weakenId) f1 @@ -1408,9 +1345,9 @@ aletD' embedAcc elimAcc (LeftHandSideSingle ArrayR{}) (Embed env1 cc1) (Embed en -- things, but that is limited in what it looks for. -- replaceE :: forall env aenv sh e t. - PreOpenExp acc env aenv sh -> PreOpenFun acc env aenv (sh -> e) -> ArrayVar aenv (Array sh e) - -> PreOpenExp acc env aenv t - -> PreOpenExp acc env aenv t + OpenExp env aenv sh -> OpenFun env aenv (sh -> e) -> ArrayVar aenv (Array sh e) + -> OpenExp env aenv t + -> OpenExp env aenv t replaceE sh' f' avar@(Var (ArrayR shr _) _) exp = case exp of Let lhs x y -> let k = weakenWithLHS lhs @@ -1433,16 +1370,16 @@ aletD' embedAcc elimAcc (LeftHandSideSingle ArrayR{}) (Embed env1 cc1) (Embed en Coerce t1 t2 e -> Coerce t1 t2 (cvtE e) Shape a - | Just Refl <- match a a' -> Stats.substitution "replaceE/shape" sh' + | Just Refl <- match a avar -> Stats.substitution "replaceE/shape" sh' | otherwise -> exp Index a sh - | Just Refl <- match a a' + | Just Refl <- match a avar , Lam lhs (Body b) <- f' -> Stats.substitution "replaceE/!" . cvtE $ Let lhs sh b | otherwise -> Index a (cvtE sh) LinearIndex a i - | Just Refl <- match a a' + | Just Refl <- match a avar , Lam lhs (Body b) <- f' -> Stats.substitution "replaceE/!!" . cvtE $ Let lhs @@ -1451,16 +1388,13 @@ aletD' embedAcc elimAcc (LeftHandSideSingle ArrayR{}) (Embed env1 cc1) (Embed en | otherwise -> LinearIndex a (cvtE i) where - a' :: acc aenv (Array sh e) - a' = avarIn avar - - cvtE :: PreOpenExp acc env aenv s -> PreOpenExp acc env aenv s + cvtE :: OpenExp env aenv s -> OpenExp env aenv s cvtE = replaceE sh' f' avar replaceF :: forall env aenv sh e t. - PreOpenExp acc env aenv sh -> PreOpenFun acc env aenv (sh -> e) -> ArrayVar aenv (Array sh e) - -> PreOpenFun acc env aenv t - -> PreOpenFun acc env aenv t + OpenExp env aenv sh -> OpenFun env aenv (sh -> e) -> ArrayVar aenv (Array sh e) + -> OpenFun env aenv t + -> OpenFun env aenv t replaceF sh' f' avar fun = case fun of Body e -> Body (replaceE sh' f' avar e) @@ -1468,7 +1402,7 @@ aletD' embedAcc elimAcc (LeftHandSideSingle ArrayR{}) (Embed env1 cc1) (Embed en in Lam lhs (replaceF (weakenE k sh') (weakenE k f') avar f) replaceA :: forall aenv sh e a. - PreExp acc aenv sh -> PreFun acc aenv (sh -> e) -> ArrayVar aenv (Array sh e) + Exp aenv sh -> Fun aenv (sh -> e) -> ArrayVar aenv (Array sh e) -> PreOpenAcc acc aenv a -> PreOpenAcc acc aenv a replaceA sh' f' avar pacc = @@ -1521,13 +1455,13 @@ aletD' embedAcc elimAcc (LeftHandSideSingle ArrayR{}) (Embed env1 cc1) (Embed en cvtA :: acc aenv s -> acc aenv s cvtA = kmap (replaceA sh' f' avar) - cvtE :: PreExp acc aenv s -> PreExp acc aenv s + cvtE :: Exp aenv s -> Exp aenv s cvtE = replaceE sh' f' avar - cvtF :: PreFun acc aenv s -> PreFun acc aenv s + cvtF :: Fun aenv s -> Fun aenv s cvtF = replaceF sh' f' avar - cvtB :: PreBoundary acc aenv s -> PreBoundary acc aenv s + cvtB :: Boundary aenv s -> Boundary aenv s cvtB Clamp = Clamp cvtB Mirror = Mirror cvtB Wrap = Wrap @@ -1538,7 +1472,7 @@ aletD' embedAcc elimAcc (LeftHandSideSingle ArrayR{}) (Embed env1 cc1) (Embed en cvtAF = cvt sh' f' avar where cvt :: forall aenv a. - PreExp acc aenv sh -> PreFun acc aenv (sh -> e) -> ArrayVar aenv (Array sh e) + Exp aenv sh -> Fun aenv (sh -> e) -> ArrayVar aenv (Array sh e) -> PreOpenAfun acc aenv a -> PreOpenAfun acc aenv a cvt sh'' f'' avar' (Abody a) = Abody $ kmap (replaceA sh'' f'' avar') a @@ -1602,7 +1536,7 @@ aletD' _ _ lhs (Embed env1 cc1) (Embed env0 cc0) -- acondD :: Kit acc => EmbedAcc acc - -> PreExp acc aenv Bool + -> Exp aenv Bool -> acc aenv arrs -> acc aenv arrs -> Embed acc aenv arrs @@ -1617,53 +1551,50 @@ acondD embedAcc p t e -- Scalar expressions -- ------------------ -identity :: TupleType a -> PreOpenFun acc env aenv (a -> a) +identity :: TupleType a -> OpenFun env aenv (a -> a) identity tp | DeclareVars lhs _ value <- declareVars tp = Lam lhs $ Body $ evars $ value weakenId -toIndex :: Kit acc => ShapeR sh -> PreOpenExp acc env aenv sh -> PreOpenFun acc env aenv (sh -> Int) +toIndex :: ShapeR sh -> OpenExp env aenv sh -> OpenFun env aenv (sh -> Int) toIndex shr sh | DeclareVars lhs k value <- declareVars $ shapeType shr = Lam lhs $ Body $ ToIndex shr (weakenE k sh) $ evars $ value weakenId -fromIndex :: Kit acc => ShapeR sh -> PreOpenExp acc env aenv sh -> PreOpenFun acc env aenv (Int -> sh) +fromIndex :: ShapeR sh -> OpenExp env aenv sh -> OpenFun env aenv (Int -> sh) fromIndex shr sh = Lam (LeftHandSideSingle scalarTypeInt) $ Body $ FromIndex shr (weakenE (weakenSucc' weakenId) sh) $ Evar $ Var scalarTypeInt ZeroIdx -reindex :: Kit acc - => ShapeR sh' - -> PreOpenExp acc env aenv sh' +reindex :: ShapeR sh' + -> OpenExp env aenv sh' -> ShapeR sh - -> PreOpenExp acc env aenv sh - -> PreOpenFun acc env aenv (sh -> sh') + -> OpenExp env aenv sh + -> OpenFun env aenv (sh -> sh') reindex shr' sh' shr sh | Just Refl <- match sh sh' = identity (shapeType shr') | otherwise = fromIndex shr' sh' `compose` toIndex shr sh -extend :: Kit acc - => SliceIndex slix sl co sh - -> PreExp acc aenv slix - -> PreFun acc aenv (sh -> sl) +extend :: SliceIndex slix sl co sh + -> Exp aenv slix + -> Fun aenv (sh -> sl) extend sliceIndex slix | DeclareVars lhs k value <- declareVars $ shapeType $ sliceDomainR sliceIndex = Lam lhs $ Body $ IndexSlice sliceIndex (weakenE k slix) $ evars $ value weakenId -restrict :: Kit acc - => SliceIndex slix sl co sh - -> PreExp acc aenv slix - -> PreFun acc aenv (sl -> sh) +restrict :: SliceIndex slix sl co sh + -> Exp aenv slix + -> Fun aenv (sl -> sh) restrict sliceIndex slix | DeclareVars lhs k value <- declareVars $ shapeType $ sliceShapeR sliceIndex = Lam lhs $ Body $ IndexFull sliceIndex (weakenE k slix) $ evars $ value weakenId -arrayShape :: Kit acc => ArrayVar aenv (Array sh e) -> PreExp acc aenv sh -arrayShape = simplify . Shape . avarIn +arrayShape :: ArrayVar aenv (Array sh e) -> Exp aenv sh +arrayShape = simplify . Shape -indexArray :: Kit acc => ArrayVar aenv (Array sh e) -> PreFun acc aenv (sh -> e) +indexArray :: ArrayVar aenv (Array sh e) -> Fun aenv (sh -> e) indexArray v@(Var (ArrayR shr _) _) | DeclareVars lhs _ value <- declareVars $ shapeType shr - = Lam lhs $ Body $ Index (avarIn v) $ evars $ value weakenId + = Lam lhs $ Body $ Index v $ evars $ value weakenId -linearIndex :: Kit acc => ArrayVar aenv (Array sh e) -> PreFun acc aenv (Int -> e) -linearIndex v = Lam (LeftHandSideSingle scalarTypeInt) $ Body $ LinearIndex (avarIn v) $ Evar $ Var scalarTypeInt ZeroIdx +linearIndex :: ArrayVar aenv (Array sh e) -> Fun aenv (Int -> e) +linearIndex v = Lam (LeftHandSideSingle scalarTypeInt) $ Body $ LinearIndex v $ Evar $ Var scalarTypeInt ZeroIdx diff --git a/src/Data/Array/Accelerate/Trafo/LetSplit.hs b/src/Data/Array/Accelerate/Trafo/LetSplit.hs index e0adebe7f..6294de7ea 100644 --- a/src/Data/Array/Accelerate/Trafo/LetSplit.hs +++ b/src/Data/Array/Accelerate/Trafo/LetSplit.hs @@ -16,7 +16,6 @@ module Data.Array.Accelerate.Trafo.LetSplit ( ) where import Prelude hiding ( exp ) -import Data.Array.Accelerate.Array.Representation import Data.Array.Accelerate.AST import Data.Array.Accelerate.Trafo.Base @@ -32,31 +31,31 @@ travA (Apair a1 a2) = inject $ Apair (convertAcc a1) (conver travA Anil = inject $ Anil travA (Apply repr f a) = inject $ Apply repr (convertAfun f) (convertAcc a) travA (Aforeign repr asm f a) = inject $ Aforeign repr asm (convertAfun f) (convertAcc a) -travA (Acond e a1 a2) = inject $ Acond (travE e) (convertAcc a1) (convertAcc a2) +travA (Acond e a1 a2) = inject $ Acond e (convertAcc a1) (convertAcc a2) travA (Awhile c f a) = inject $ Awhile (convertAfun c) (convertAfun f) (convertAcc a) travA (Use repr arr) = inject $ Use repr arr -travA (Unit tp e) = inject $ Unit tp (travE e) -travA (Reshape shr e a) = inject $ Reshape shr (travE e) a -travA (Generate repr e f) = inject $ Generate repr (travE e) (travF f) -travA (Transform repr sh f g a) = inject $ Transform repr (travE sh) (travF f) (travF g) (convertAcc a) -travA (Replicate slix sl a) = inject $ Replicate slix (travE sl) (convertAcc a) -travA (Slice slix a sl) = inject $ Slice slix (convertAcc a) (travE sl) -travA (Map tp f a) = inject $ Map tp (travF f) (convertAcc a) -travA (ZipWith tp f a1 a2) = inject $ ZipWith tp (travF f) (convertAcc a1) (convertAcc a2) -travA (Fold f e a) = inject $ Fold (travF f) (travE e) (convertAcc a) -travA (Fold1 f a) = inject $ Fold1 (travF f) (convertAcc a) -travA (FoldSeg i f e a s) = inject $ FoldSeg i (travF f) (travE e) (convertAcc a) (convertAcc s) -travA (Fold1Seg i f a s) = inject $ Fold1Seg i (travF f) (convertAcc a) (convertAcc s) -travA (Scanl f e a) = inject $ Scanl (travF f) (travE e) (convertAcc a) -travA (Scanl' f e a) = inject $ Scanl' (travF f) (travE e) (convertAcc a) -travA (Scanl1 f a) = inject $ Scanl1 (travF f) (convertAcc a) -travA (Scanr f e a) = inject $ Scanr (travF f) (travE e) (convertAcc a) -travA (Scanr' f e a) = inject $ Scanr' (travF f) (travE e) (convertAcc a) -travA (Scanr1 f a) = inject $ Scanr1 (travF f) (convertAcc a) -travA (Permute f a1 g a2) = inject $ Permute (travF f) (convertAcc a1) (travF g) (convertAcc a2) -travA (Backpermute shr sh f a) = inject $ Backpermute shr (travE sh) (travF f) (convertAcc a) -travA (Stencil s tp f b a) = inject $ Stencil s tp (travF f) (travB b) (convertAcc a) -travA (Stencil2 s1 s2 tp f b1 a1 b2 a2) = inject $ Stencil2 s1 s2 tp (travF f) (travB b1) (convertAcc a1) (travB b2) (convertAcc a2) +travA (Unit tp e) = inject $ Unit tp e +travA (Reshape shr e a) = inject $ Reshape shr e a +travA (Generate repr e f) = inject $ Generate repr e f +travA (Transform repr sh f g a) = inject $ Transform repr sh f g (convertAcc a) +travA (Replicate slix sl a) = inject $ Replicate slix sl (convertAcc a) +travA (Slice slix a sl) = inject $ Slice slix (convertAcc a) sl +travA (Map tp f a) = inject $ Map tp f (convertAcc a) +travA (ZipWith tp f a1 a2) = inject $ ZipWith tp f (convertAcc a1) (convertAcc a2) +travA (Fold f e a) = inject $ Fold f e (convertAcc a) +travA (Fold1 f a) = inject $ Fold1 f (convertAcc a) +travA (FoldSeg i f e a s) = inject $ FoldSeg i f e (convertAcc a) (convertAcc s) +travA (Fold1Seg i f a s) = inject $ Fold1Seg i f (convertAcc a) (convertAcc s) +travA (Scanl f e a) = inject $ Scanl f e (convertAcc a) +travA (Scanl' f e a) = inject $ Scanl' f e (convertAcc a) +travA (Scanl1 f a) = inject $ Scanl1 f (convertAcc a) +travA (Scanr f e a) = inject $ Scanr f e (convertAcc a) +travA (Scanr' f e a) = inject $ Scanr' f e (convertAcc a) +travA (Scanr1 f a) = inject $ Scanr1 f (convertAcc a) +travA (Permute f a1 g a2) = inject $ Permute f (convertAcc a1) g (convertAcc a2) +travA (Backpermute shr sh f a) = inject $ Backpermute shr sh f (convertAcc a) +travA (Stencil s tp f b a) = inject $ Stencil s tp f b (convertAcc a) +travA (Stencil2 s1 s2 tp f b1 a1 b2 a2) = inject $ Stencil2 s1 s2 tp f b1 (convertAcc a1) b2 (convertAcc a2) travBinding :: Kit acc => ALeftHandSide bnd aenv aenv' -> acc aenv bnd -> acc aenv' a -> acc aenv a travBinding (LeftHandSideWildcard _) _ a = a @@ -65,18 +64,6 @@ travBinding lhs@(LeftHandSidePair l1 l2) bnd a = case extract bnd of Just (Apair b1 b2) -> travBinding l1 b1 $ travBinding l2 (weaken (weakenWithLHS l1) b2) a _ -> inject $ Alet lhs bnd a --- XXX: We assume that any Acc contained in an expression is Avar. --- We thus do not have to descend into expressions. --- This isn't yet enforced using the types however. -travE :: PreExp acc aenv t -> PreExp acc aenv t -travE = id - -travF :: PreFun acc aenv t -> PreFun acc aenv t -travF = id - -travB :: PreBoundary acc aenv (Array sh e) -> PreBoundary acc aenv (Array sh e) -travB = id - convertAfun :: Kit acc => PreOpenAfun acc aenv f -> PreOpenAfun acc aenv f convertAfun (Alam lhs f) = Alam lhs $ convertAfun f convertAfun (Abody a) = Abody $ convertAcc a diff --git a/src/Data/Array/Accelerate/Trafo/Sharing.hs b/src/Data/Array/Accelerate/Trafo/Sharing.hs index f1d60b410..4b4a68977 100644 --- a/src/Data/Array/Accelerate/Trafo/Sharing.hs +++ b/src/Data/Array/Accelerate/Trafo/Sharing.hs @@ -70,8 +70,8 @@ import Data.Array.Accelerate.Array.Representation hiding ((!!)) import Data.Array.Accelerate.Array.Sugar ( Elt, EltRepr, Arrays, ArrRepr, eltType ) import qualified Data.Array.Accelerate.Array.Sugar as Sugar import Data.Array.Accelerate.AST hiding ( PreOpenAcc(..), OpenAcc(..), Acc - , PreOpenExp(..), OpenExp, PreExp, Exp - , PreBoundary(..), Boundary + , OpenExp(..), Exp + , Boundary(..) , showPreAccOp, showPreExpOp, expType, HasArraysRepr(..), arraysRepr ) import qualified Data.Array.Accelerate.AST as AST import Data.Array.Accelerate.Debug.Trace as Debug @@ -537,7 +537,7 @@ convertSharingBoundary -> [StableSharingAcc] -> ShapeR sh -> PreBoundary ScopedAcc ScopedExp (Array sh e) - -> AST.PreBoundary AST.OpenAcc aenv (Array sh e) + -> AST.Boundary aenv (Array sh e) convertSharingBoundary config alyt aenv shr = cvt where cvt :: PreBoundary ScopedAcc ScopedExp (Array sh e) -> AST.Boundary aenv (Array sh e) @@ -573,7 +573,7 @@ convertSharingBoundary config alyt aenv shr = cvt convertFun :: Function f => f -> AST.Fun () (EltReprFunctionR f) convertFun = convertFunWith - $ defaultOptions { options = options defaultOptions \\ [seq_sharing, acc_sharing, float_out_acc] } + $ defaultOptions { options = options defaultOptions \\ [seq_sharing, acc_sharing] } convertFunWith :: Function f => Config -> f -> AST.Fun () (EltReprFunctionR f) convertFunWith config = convertOpenFun config EmptyLayout @@ -633,7 +633,7 @@ convertSmartFun config tp f convertExp :: Exp e -> AST.Exp () (EltRepr e) convertExp = convertExpWith - $ defaultOptions { options = options defaultOptions \\ [seq_sharing, acc_sharing, float_out_acc] } + $ defaultOptions { options = options defaultOptions \\ [seq_sharing, acc_sharing] } convertExpWith :: Config -> Exp e -> AST.Exp () (EltRepr e) convertExpWith config (Exp e) = convertOpenExp config EmptyLayout e @@ -743,9 +743,9 @@ convertSharingExp config lyt alyt env aenv exp@(ScopedExp lams _) = cvt exp While tp p it i -> AST.While (cvtFun1 tp p) (cvtFun1 tp it) (cvt i) PrimConst c -> AST.PrimConst c PrimApp f e -> cvtPrimFun f (cvt e) - Index _ a e -> AST.Index (cvtA a) (cvt e) - LinearIndex _ a i -> AST.LinearIndex (cvtA a) (cvt i) - Shape _ a -> AST.Shape (cvtA a) + Index _ a e -> AST.Index (cvtAvar a) (cvt e) + LinearIndex _ a i -> AST.LinearIndex (cvtAvar a) (cvt i) + Shape _ a -> AST.Shape (cvtAvar a) ShapeSize shr e -> AST.ShapeSize shr (cvt e) Foreign repr ff f e -> AST.Foreign repr ff (convertSmartFun config (expType e) f) (cvt e) Coerce t1 t2 e -> AST.Coerce t1 t2 (cvt e) @@ -760,6 +760,11 @@ convertSharingExp config lyt alyt env aenv exp@(ScopedExp lams _) = cvt exp cvtA :: ScopedAcc a -> AST.OpenAcc aenv a cvtA = convertSharingAcc config alyt aenv + cvtAvar :: ScopedAcc a -> AST.ArrayVar aenv a + cvtAvar a = case cvtA a of + AST.OpenAcc (AST.Avar var) -> var + _ -> $internalError "convertSharingExp" "Expected array computation in expression to be floated out" + cvtFun1 :: TupleType a -> (SmartExp a -> ScopedExp b) -> AST.OpenFun env aenv (a -> b) cvtFun1 tp f | DeclareVars lhs k value <- declareVars tp @@ -2554,7 +2559,7 @@ determineScopesSharingExp config accOccMap expOccMap = scopesExp travA :: (ScopedAcc a -> PreSmartExp ScopedAcc ScopedExp t) -> UnscopedAcc a -> (ScopedExp t, NodeCounts) - travA c acc = maybeFloatOutAcc c acc' accCount + travA c acc = floatOutAcc c acc' accCount where (acc', accCount) = scopesAcc acc @@ -2562,20 +2567,19 @@ determineScopesSharingExp config accOccMap expOccMap = scopesExp -> UnscopedAcc a -> UnscopedExp b -> (ScopedExp t, NodeCounts) - travAE c acc e = maybeFloatOutAcc (`c` e') acc' (accCountA +++ accCountE) + travAE c acc e = floatOutAcc (`c` e') acc' (accCountA +++ accCountE) where (acc', accCountA) = scopesAcc acc (e' , accCountE) = scopesExp e - maybeFloatOutAcc :: (ScopedAcc a -> PreSmartExp ScopedAcc ScopedExp t) + floatOutAcc :: (ScopedAcc a -> PreSmartExp ScopedAcc ScopedExp t) -> ScopedAcc a -> NodeCounts -> (ScopedExp t, NodeCounts) - maybeFloatOutAcc c acc@(ScopedAcc _ (AvarSharing _ _)) accCount -- nothing to float out + floatOutAcc c acc@(ScopedAcc _ (AvarSharing _ _)) accCount -- nothing to float out = reconstruct (c acc) accCount - maybeFloatOutAcc c acc accCount - | float_out_acc `member` options config = reconstruct (c var) ((stableAcc `insertAccNode` noNodeCounts) +++ accCount) - | otherwise = reconstruct (c acc) accCount + floatOutAcc c acc accCount + = reconstruct (c var) ((stableAcc `insertAccNode` noNodeCounts) +++ accCount) where (var, stableAcc) = abstract acc (\(ScopedAcc _ s) -> s) diff --git a/src/Data/Array/Accelerate/Trafo/Shrink.hs b/src/Data/Array/Accelerate/Trafo/Shrink.hs index c43183b7b..a3d04a426 100644 --- a/src/Data/Array/Accelerate/Trafo/Shrink.hs +++ b/src/Data/Array/Accelerate/Trafo/Shrink.hs @@ -67,10 +67,10 @@ class Shrink f where shrink = snd . shrink' -instance Kit acc => Shrink (PreOpenExp acc env aenv e) where +instance Shrink (OpenExp env aenv e) where shrink' = shrinkExp -instance Kit acc => Shrink (PreOpenFun acc env aenv f) where +instance Shrink (OpenFun env aenv f) where shrink' = shrinkFun data VarsRange env = VarsRange !(Exists (Idx env)) !Int !(Maybe RangeTuple) -- rightmost variable, count, tuple @@ -110,10 +110,10 @@ weakenVarsRange lhs (VarsRange ix n t) = VarsRange (go lhs ix) n t go (LeftHandSideSingle _) (Exists ix') = Exists (SuccIdx ix') go (LeftHandSidePair l1 l2) ix' = go l2 $ go l1 ix' -matchEVarsRange :: VarsRange env -> PreOpenExp acc env aenv t -> Bool +matchEVarsRange :: VarsRange env -> OpenExp env aenv t -> Bool matchEVarsRange (VarsRange (Exists first) _ (Just rt)) expr = isJust $ go (idxToInt first) rt expr where - go :: Int -> RangeTuple -> PreOpenExp acc env aenv t -> Maybe Int + go :: Int -> RangeTuple -> OpenExp env aenv t -> Maybe Int go i RTNil Nil = Just i go i RTSingle (Evar (Var _ ix)) | checkIdx i ix = Just (i + 1) @@ -224,7 +224,7 @@ strengthenShrunkLHS _ _ _ = $inter -- instance of beta-reduction to cases where the bound variable is used zero -- (dead-code elimination) or one (linear inlining) times. -- -shrinkExp :: Kit acc => PreOpenExp acc env aenv t -> (Bool, PreOpenExp acc env aenv t) +shrinkExp :: OpenExp env aenv t -> (Bool, OpenExp env aenv t) shrinkExp = Stats.substitution "shrinkE" . first getAny . shrinkE where -- If the bound variable is used at most this many times, it will be inlined @@ -234,7 +234,7 @@ shrinkExp = Stats.substitution "shrinkE" . first getAny . shrinkE lIMIT :: Int lIMIT = 1 - cheap :: PreOpenExp acc env aenv t -> Bool + cheap :: OpenExp env aenv t -> Bool cheap (Evar _) = True cheap (Pair e1 e2) = cheap e1 && cheap e2 cheap Nil = True @@ -244,7 +244,7 @@ shrinkExp = Stats.substitution "shrinkE" . first getAny . shrinkE cheap (Coerce _ _ e) = cheap e cheap _ = False - shrinkE :: Kit acc => PreOpenExp acc env aenv t -> (Any, PreOpenExp acc env aenv t) + shrinkE :: OpenExp env aenv t -> (Any, OpenExp env aenv t) shrinkE exp = case exp of Let (LeftHandSideSingle _) bnd@Evar{} body -> Stats.inline "Var" . yes $ shrinkE (inline body bnd) Let lhs bnd body @@ -303,7 +303,7 @@ shrinkExp = Stats.substitution "shrinkE" . first getAny . shrinkE Foreign repr ff f e -> Foreign repr ff <$> shrinkF f <*> shrinkE e Coerce t1 t2 e -> Coerce t1 t2 <$> shrinkE e - shrinkF :: Kit acc => PreOpenFun acc env aenv t -> (Any, PreOpenFun acc env aenv t) + shrinkF :: OpenFun env aenv t -> (Any, OpenFun env aenv t) shrinkF = first Any . shrinkFun first :: (a -> a') -> (a,b) -> (a',b) @@ -312,7 +312,7 @@ shrinkExp = Stats.substitution "shrinkE" . first getAny . shrinkE yes :: (Any, x) -> (Any, x) yes (_, x) = (Any True, x) -shrinkFun :: Kit acc => PreOpenFun acc env aenv f -> (Bool, PreOpenFun acc env aenv f) +shrinkFun :: OpenFun env aenv f -> (Bool, OpenFun env aenv f) shrinkFun (Lam lhs f) = case lhsVarsRange lhs of Left Refl -> let b' = case lhs of @@ -418,7 +418,7 @@ shrinkPreAcc shrinkAcc reduceAcc = Stats.substitution "shrinkA" shrinkA shrinkCT (SnocAtup t c) = SnocAtup (shrinkCT t) (shrinkC c) --} - shrinkE :: PreOpenExp acc env aenv' t -> PreOpenExp acc env aenv' t + shrinkE :: OpenExp env aenv' t -> OpenExp env aenv' t shrinkE exp = case exp of Let bnd body -> Let (shrinkE bnd) (shrinkE body) Var idx -> Var idx @@ -448,11 +448,11 @@ shrinkPreAcc shrinkAcc reduceAcc = Stats.substitution "shrinkA" shrinkA Foreign ff f e -> Foreign ff (shrinkF f) (shrinkE e) Coerce e -> Coerce (shrinkE e) - shrinkF :: PreOpenFun acc env aenv' f -> PreOpenFun acc env aenv' f + shrinkF :: OpenFun env aenv' f -> OpenFun env aenv' f shrinkF (Lam f) = Lam (shrinkF f) shrinkF (Body b) = Body (shrinkE b) - shrinkT :: Tuple (PreOpenExp acc env aenv') t -> Tuple (PreOpenExp acc env aenv') t + shrinkT :: Tuple (OpenExp env aenv') t -> Tuple (OpenExp env aenv') t shrinkT NilTup = NilTup shrinkT (SnocTup t e) = shrinkT t `SnocTup` shrinkE e @@ -467,10 +467,10 @@ shrinkPreAcc shrinkAcc reduceAcc = Stats.substitution "shrinkA" shrinkA -- Count the number of occurrences an in-scope scalar expression bound at the -- given variable index recursively in a term. -- -usesOfExp :: forall acc env aenv t. VarsRange env -> PreOpenExp acc env aenv t -> Count +usesOfExp :: forall env aenv t. VarsRange env -> OpenExp env aenv t -> Count usesOfExp range = countE where - countE :: PreOpenExp acc env aenv e -> Count + countE :: OpenExp env aenv e -> Count countE exp | matchEVarsRange range exp = Finite 1 countE exp = case exp of Evar v -> case varInRange range v of @@ -499,7 +499,7 @@ usesOfExp range = countE Foreign _ _ _ e -> countE e Coerce _ _ e -> countE e -usesOfFun :: VarsRange env -> PreOpenFun acc env aenv f -> Count +usesOfFun :: VarsRange env -> OpenFun env aenv f -> Count usesOfFun range (Lam lhs f) = usesOfFun (weakenVarsRange lhs range) f usesOfFun range (Body b) = usesOfExp range b @@ -525,7 +525,7 @@ usesOfPreAcc withShape countAcc idx = count count :: PreOpenAcc acc aenv a -> Int count pacc = case pacc of - Avar (Var _ this) -> countIdx this + Avar var -> countAvar var -- Alet lhs bnd body -> countA bnd + countAcc withShape (weakenWithLHS lhs >:> idx) body Apair a1 a2 -> countA a1 + countA a2 @@ -563,7 +563,7 @@ usesOfPreAcc withShape countAcc idx = count Stencil2 _ _ _ f _ a1 _ a2 -> countF f + countA a1 + countA a2 -- Collect s -> countS s - countE :: PreOpenExp acc env aenv e -> Int + countE :: OpenExp env aenv e -> Int countE exp = case exp of Let _ bnd body -> countE bnd + countE body Evar _ -> 0 @@ -581,11 +581,11 @@ usesOfPreAcc withShape countAcc idx = count While p f x -> countF p + countF f + countE x PrimConst _ -> 0 PrimApp _ x -> countE x - Index a sh -> countA a + countE sh - LinearIndex a i -> countA a + countE i + Index a sh -> countAvar a + countE sh + LinearIndex a i -> countAvar a + countE i ShapeSize _ sh -> countE sh Shape a - | withShape -> countA a + | withShape -> countAvar a | otherwise -> 0 Foreign _ _ _ e -> countE e Coerce _ _ e -> countE e @@ -593,13 +593,16 @@ usesOfPreAcc withShape countAcc idx = count countA :: acc aenv a -> Int countA = countAcc withShape idx + countAvar :: ArrayVar aenv a -> Int + countAvar (Var _ this) = countIdx this + countAF :: PreOpenAfun acc aenv' f -> Idx aenv' s -> Int countAF (Alam lhs f) v = countAF f (weakenWithLHS lhs >:> v) countAF (Abody a) v = countAcc withShape v a - countF :: PreOpenFun acc env aenv f -> Int + countF :: OpenFun env aenv f -> Int countF (Lam _ f) = countF f countF (Body b) = countE b diff --git a/src/Data/Array/Accelerate/Trafo/Simplify.hs b/src/Data/Array/Accelerate/Trafo/Simplify.hs index 50c5747f6..ec6f0f928 100644 --- a/src/Data/Array/Accelerate/Trafo/Simplify.hs +++ b/src/Data/Array/Accelerate/Trafo/Simplify.hs @@ -52,10 +52,10 @@ import qualified Data.Array.Accelerate.Debug.Trace as Debug class Simplify f where simplify :: f -> f -instance Kit acc => Simplify (PreFun acc aenv f) where +instance Simplify (Fun aenv f) where simplify = simplifyFun -instance Kit acc => Simplify (PreExp acc aenv e) where +instance Simplify (Exp aenv e) where simplify = simplifyExp @@ -85,10 +85,10 @@ instance Kit acc => Simplify (PreExp acc aenv e) where -- tricky and target-dependent issue by, for now, simply ignoring it. -- localCSE :: (Kit acc, Elt a) - => Gamma acc env env aenv - -> PreOpenExp acc env aenv a - -> PreOpenExp acc (env,a) aenv b - -> Maybe (PreOpenExp acc env aenv b) + => Gamma acc env env aenv + -> OpenExp env aenv a + -> OpenExp (env,a) aenv b + -> Maybe (OpenExp env aenv b) localCSE env bnd body | Just ix <- lookupExp env bnd = Stats.ruleFired "CSE" . Just $ inline body (Var ix) | otherwise = Nothing @@ -101,9 +101,9 @@ localCSE env bnd body -- > let x = e in .. e .. -- globalCSE :: (Kit acc, Elt t) - => Gamma acc env env aenv - -> PreOpenExp acc env aenv t - -> Maybe (PreOpenExp acc env aenv t) + => Gamma acc env env aenv + -> OpenExp env aenv t + -> Maybe (OpenExp env aenv t) globalCSE env exp | Just ix <- lookupExp env exp = Stats.ruleFired "CSE" . Just $ Var ix | otherwise = Nothing @@ -139,10 +139,10 @@ globalCSE env exp -- recoverLoops :: (Kit acc, Elt b) - => Gamma acc env env aenv - -> PreOpenExp acc env aenv a - -> PreOpenExp acc (env,a) aenv b - -> Maybe (PreOpenExp acc env aenv b) + => Gamma acc env env aenv + -> OpenExp env aenv a + -> OpenExp (env,a) aenv b + -> Maybe (OpenExp env aenv b) recoverLoops _ bnd e3 -- To introduce scaler loops, we look for expressions of the form: -- @@ -177,15 +177,15 @@ recoverLoops _ bnd e3 = Nothing where - plus :: PreOpenExp acc env aenv Int -> PreOpenExp acc env aenv Int -> PreOpenExp acc env aenv Int + plus :: OpenExp env aenv Int -> OpenExp env aenv Int -> OpenExp env aenv Int plus x y = PrimApp (PrimAdd numType) $ Tuple $ NilTup `SnocTup` x `SnocTup` y - constant :: Int -> PreOpenExp acc env aenv Int + constant :: Int -> OpenExp env aenv Int constant i = Const ((),i) matchEnvTop :: (Elt s, Elt t) - => PreOpenExp acc (env,s) aenv f - -> PreOpenExp acc (env,t) aenv g + => OpenExp (env,s) aenv f + -> OpenExp (env,t) aenv g -> Maybe (s :=: t) matchEnvTop _ _ = gcast Refl --} @@ -204,13 +204,13 @@ recoverLoops _ bnd e3 -- If we do not want to do inlining, we should remove the environment here. -- simplifyOpenExp - :: forall acc env aenv e. (Kit acc) - => Gamma acc env env aenv - -> PreOpenExp acc env aenv e - -> (Bool, PreOpenExp acc env aenv e) + :: forall env aenv e. + Gamma env env aenv + -> OpenExp env aenv e + -> (Bool, OpenExp env aenv e) simplifyOpenExp env = first getAny . cvtE where - cvtE :: PreOpenExp acc env aenv t -> (Any, PreOpenExp acc env aenv t) + cvtE :: OpenExp env aenv t -> (Any, OpenExp env aenv t) cvtE exp = case exp of Let lhs bnd body -> (u <> v, exp') where @@ -241,17 +241,17 @@ simplifyOpenExp env = first getAny . cvtE While p f x -> While <$> cvtF env p <*> cvtF env f <*> cvtE x Coerce t1 t2 e -> Coerce t1 t2 <$> cvtE e - cvtE' :: Gamma acc env' env' aenv -> PreOpenExp acc env' aenv e' -> (Any, PreOpenExp acc env' aenv e') + cvtE' :: Gamma env' env' aenv -> OpenExp env' aenv e' -> (Any, OpenExp env' aenv e') cvtE' env' = first Any . simplifyOpenExp env' - cvtF :: Gamma acc env' env' aenv -> PreOpenFun acc env' aenv f -> (Any, PreOpenFun acc env' aenv f) + cvtF :: Gamma env' env' aenv -> OpenFun env' aenv f -> (Any, OpenFun env' aenv f) cvtF env' = first Any . simplifyOpenFun env' - cvtLet :: Gamma acc env' env' aenv + cvtLet :: Gamma env' env' aenv -> ELeftHandSide bnd env' env'' - -> PreOpenExp acc env' aenv bnd - -> (Gamma acc env'' env'' aenv -> (Any, PreOpenExp acc env'' aenv t)) - -> (Any, PreOpenExp acc env' aenv t) + -> OpenExp env' aenv bnd + -> (Gamma env'' env'' aenv -> (Any, OpenExp env'' aenv t)) + -> (Any, OpenExp env' aenv t) cvtLet env' lhs@(LeftHandSideSingle _) bnd body = Let lhs bnd <$> body (incExp $ env' `pushExp` bnd) -- Single variable on the LHS, add binding to the environment cvtLet env' (LeftHandSideWildcard _) _ body = body env' -- Binding not used, remove let binding cvtLet env' (LeftHandSidePair l1 l2) (Pair e1 e2) body -- Split binding to multiple bindings @@ -263,10 +263,10 @@ simplifyOpenExp env = first getAny . cvtE -- Simplify conditional expressions, in particular by eliminating branches -- when the predicate is a known constant. -- - cond :: (Any, PreOpenExp acc env aenv Bool) - -> (Any, PreOpenExp acc env aenv t) - -> (Any, PreOpenExp acc env aenv t) - -> (Any, PreOpenExp acc env aenv t) + cond :: (Any, OpenExp env aenv Bool) + -> (Any, OpenExp env aenv t) + -> (Any, OpenExp env aenv t) + -> (Any, OpenExp env aenv t) cond p@(_,p') t@(_,t') e@(_,e') | Const _ True <- p' = Stats.knownBranch "True" (yes t') | Const _ False <- p' = Stats.knownBranch "False" (yes e') @@ -275,16 +275,13 @@ simplifyOpenExp env = first getAny . cvtE -- Shape manipulations -- - shape :: acc aenv (Array sh t) -> (Any, PreOpenExp acc env aenv sh) - shape a - | ArrayR ShapeRz _ <- arrayRepr a + shape :: ArrayVar aenv (Array sh t) -> (Any, OpenExp env aenv sh) + shape (Var (ArrayR ShapeRz _) _) = Stats.ruleFired "shape/Z" $ yes Nil shape a = pure $ Shape a - shapeSize :: ShapeR sh - -> (Any, PreOpenExp acc env aenv sh) - -> (Any, PreOpenExp acc env aenv Int) + shapeSize :: ShapeR sh -> (Any, OpenExp env aenv sh) -> (Any, OpenExp env aenv Int) shapeSize shr (_, sh) | Just c <- extractConstTuple sh = Stats.ruleFired "shapeSize/const" $ yes (Const scalarTypeInt (product (shapeToList shr c))) @@ -292,17 +289,17 @@ simplifyOpenExp env = first getAny . cvtE = ShapeSize shr <$> sh toIndex :: ShapeR sh - -> (Any, PreOpenExp acc env aenv sh) - -> (Any, PreOpenExp acc env aenv sh) - -> (Any, PreOpenExp acc env aenv Int) + -> (Any, OpenExp env aenv sh) + -> (Any, OpenExp env aenv sh) + -> (Any, OpenExp env aenv Int) toIndex _ (_,sh) (_,FromIndex _ sh' ix) | Just Refl <- match sh sh' = Stats.ruleFired "toIndex/fromIndex" $ yes ix toIndex shr sh ix = ToIndex shr <$> sh <*> ix fromIndex :: ShapeR sh - -> (Any, PreOpenExp acc env aenv sh) - -> (Any, PreOpenExp acc env aenv Int) - -> (Any, PreOpenExp acc env aenv sh) + -> (Any, OpenExp env aenv sh) + -> (Any, OpenExp env aenv Int) + -> (Any, OpenExp env aenv sh) fromIndex _ (_,sh) (_,ToIndex _ sh' ix) | Just Refl <- match sh sh' = Stats.ruleFired "fromIndex/toIndex" $ yes ix fromIndex shr sh ix = FromIndex shr <$> sh <*> ix @@ -313,7 +310,7 @@ simplifyOpenExp env = first getAny . cvtE yes :: x -> (Any, x) yes x = (Any True, x) -extractConstTuple :: PreOpenExp acc env aenv t -> Maybe t +extractConstTuple :: OpenExp env aenv t -> Maybe t extractConstTuple Nil = Just () extractConstTuple (Pair e1 e2) = (,) <$> extractConstTuple e1 <*> extractConstTuple e2 extractConstTuple (Const _ c) = Just c @@ -322,16 +319,15 @@ extractConstTuple _ = Nothing -- Simplification for open functions -- simplifyOpenFun - :: Kit acc - => Gamma acc env env aenv - -> PreOpenFun acc env aenv f - -> (Bool, PreOpenFun acc env aenv f) + :: Gamma env env aenv + -> OpenFun env aenv f + -> (Bool, OpenFun env aenv f) simplifyOpenFun env (Body e) = Body <$> simplifyOpenExp env e simplifyOpenFun env (Lam lhs f) = Lam lhs <$> simplifyOpenFun env' f where env' = lhsExpr lhs env -lhsExpr :: Kit acc => ELeftHandSide t env env' -> Gamma acc env env aenv -> Gamma acc env' env' aenv +lhsExpr :: ELeftHandSide t env env' -> Gamma env env aenv -> Gamma env' env' aenv lhsExpr (LeftHandSideWildcard _) env = env lhsExpr (LeftHandSideSingle tp) env = incExp env `pushExp` Evar (Var tp ZeroIdx) lhsExpr (LeftHandSidePair l1 l2) env = lhsExpr l2 $ lhsExpr l1 env @@ -339,10 +335,10 @@ lhsExpr (LeftHandSidePair l1 l2) env = lhsExpr l2 $ lhsExpr l1 env -- Simplify closed expressions and functions. The process is applied -- repeatedly until no more changes are made. -- -simplifyExp :: Kit acc => PreExp acc aenv t -> PreExp acc aenv t +simplifyExp :: Exp aenv t -> Exp aenv t simplifyExp = iterate summariseOpenExp (simplifyOpenExp EmptyExp) -simplifyFun :: Kit acc => PreFun acc aenv f -> PreFun acc aenv f +simplifyFun :: Fun aenv f -> Fun aenv f simplifyFun = iterate summariseOpenFun (simplifyOpenFun EmptyExp) @@ -442,19 +438,19 @@ ops = lens _ops (\Stats{..} v -> Stats { _ops = v, ..}) {-# INLINE vars #-} {-# INLINE ops #-} -summariseOpenFun :: PreOpenFun acc env aenv f -> Stats +summariseOpenFun :: OpenFun env aenv f -> Stats summariseOpenFun (Body e) = summariseOpenExp e & terms +~ 1 summariseOpenFun (Lam _ f) = summariseOpenFun f & terms +~ 1 & binders +~ 1 -summariseOpenExp :: PreOpenExp acc env aenv t -> Stats +summariseOpenExp :: OpenExp env aenv t -> Stats summariseOpenExp = (terms +~ 1) . goE where zero = Stats 0 0 0 0 0 - travE :: PreOpenExp acc env aenv t -> Stats + travE :: OpenExp env aenv t -> Stats travE = summariseOpenExp - travF :: PreOpenFun acc env aenv t -> Stats + travF :: OpenFun env aenv t -> Stats travF = summariseOpenFun travA :: acc aenv a -> Stats @@ -498,7 +494,7 @@ summariseOpenExp = (terms +~ 1) . goE -- travVectorType (Vector16Type t) = travSingleType t & types +~ 1 -- The scrutinee has already been counted - goE :: PreOpenExp acc env aenv t -> Stats + goE :: OpenExp env aenv t -> Stats goE exp = case exp of Let _ bnd body -> travE bnd +++ travE body & binders +~ 1 diff --git a/src/Data/Array/Accelerate/Trafo/Substitution.hs b/src/Data/Array/Accelerate/Trafo/Substitution.hs index 8b179049d..e4e3c7249 100644 --- a/src/Data/Array/Accelerate/Trafo/Substitution.hs +++ b/src/Data/Array/Accelerate/Trafo/Substitution.hs @@ -37,6 +37,7 @@ module Data.Array.Accelerate.Trafo.Substitution ( -- ** Rebuilding terms RebuildAcc, Rebuildable(..), RebuildableAcc, RebuildableExp(..), rebuildWeakenVar, rebuildLHS, + OpenAccFun(..), OpenAccExp(..), -- ** Checks isIdentity, isIdentityIndexing, extractExpVars, @@ -101,34 +102,32 @@ bindingIsTrivial lhs vars = Just Refl bindingIsTrivial _ _ = Nothing -isIdentity :: PreOpenFun acc env aenv (a -> b) -> Maybe (a :~: b) +isIdentity :: OpenFun env aenv (a -> b) -> Maybe (a :~: b) isIdentity (Lam lhs (Body (extractExpVars -> Just vars))) = bindingIsTrivial lhs vars isIdentity _ = Nothing -- Detects whether the function is of the form \ix -> a ! ix -isIdentityIndexing :: PreOpenFun acc env aenv (a -> b) -> Maybe (acc aenv (Array a b)) +isIdentityIndexing :: OpenFun env aenv (a -> b) -> Maybe (ArrayVar aenv (Array a b)) isIdentityIndexing (Lam lhs (Body body)) - | Index a ix <- body - , Just vars <- extractExpVars ix - , Just Refl <- bindingIsTrivial lhs vars - = Just a + | Index avar ix <- body + , Just vars <- extractExpVars ix + , Just Refl <- bindingIsTrivial lhs vars + = Just avar isIdentityIndexing _ = Nothing -- | Replace the first variable with the given expression. The environment -- shrinks. -- -inline :: RebuildableAcc acc - => PreOpenExp acc (env, s) aenv t - -> PreOpenExp acc env aenv s - -> PreOpenExp acc env aenv t +inline :: OpenExp (env, s) aenv t + -> OpenExp env aenv s + -> OpenExp env aenv t inline f g = Stats.substitution "inline" $ rebuildE (subTop g) f -inlineVars :: forall acc env env' aenv t1 t2. - RebuildableAcc acc - => ELeftHandSide t1 env env' - -> PreOpenExp acc env' aenv t2 - -> PreOpenExp acc env aenv t1 - -> Maybe (PreOpenExp acc env aenv t2) +inlineVars :: forall env env' aenv t1 t2. + ELeftHandSide t1 env env' + -> OpenExp env' aenv t2 + -> OpenExp env aenv t1 + -> Maybe (OpenExp env aenv t2) inlineVars lhsBound expr bound | Just vars <- lhsFullVars lhsBound = substitute (strengthenWithLHS lhsBound) weakenId vars expr where @@ -136,8 +135,8 @@ inlineVars lhsBound expr bound env1 :?> env2 -> env :> env2 -> ExpVars env1 t1 - -> PreOpenExp acc env1 aenv t - -> Maybe (PreOpenExp acc env2 aenv t) + -> OpenExp env1 aenv t + -> Maybe (OpenExp env2 aenv t) substitute _ k2 vars (extractExpVars -> Just vars') | Just Refl <- matchVars vars vars' = Just $ weakenE k2 bound substitute k1 k2 vars e = case e of @@ -167,18 +166,18 @@ inlineVars lhsBound expr bound Coerce t1 t2 e1 -> Coerce t1 t2 <$> travE e1 where - travE :: PreOpenExp acc env1 aenv s -> Maybe (PreOpenExp acc env2 aenv s) + travE :: OpenExp env1 aenv s -> Maybe (OpenExp env2 aenv s) travE = substitute k1 k2 vars - travF :: PreOpenFun acc env1 aenv s -> Maybe (PreOpenFun acc env2 aenv s) + travF :: OpenFun env1 aenv s -> Maybe (OpenFun env2 aenv s) travF = substituteF k1 k2 vars substituteF :: forall env1 env2 t. env1 :?> env2 -> env :> env2 -> ExpVars env1 t1 - -> PreOpenFun acc env1 aenv t - -> Maybe (PreOpenFun acc env2 aenv t) + -> OpenFun env1 aenv t + -> Maybe (OpenFun env2 aenv t) substituteF k1 k2 vars (Body e) = Body <$> substitute k1 k2 vars e substituteF k1 k2 vars (Lam lhs f) | Exists lhs' <- rebuildLHS lhs = Lam lhs' <$> substituteF (strengthenAfter lhs lhs' k1) (weakenWithLHS lhs' .> k2) (weakenWithLHS lhs `weaken` vars) f @@ -189,44 +188,45 @@ inlineVars _ _ _ = Nothing -- | Replace an expression that uses the top environment variable with another. -- The result of the first is let bound into the second. -- -{- substitute' :: RebuildableAcc acc - => PreOpenExp acc (env, b) aenv c - -> PreOpenExp acc (env, a) aenv b - -> PreOpenExp acc (env, a) aenv c +{- substitute' :: OpenExp (env, b) aenv c + -> OpenExp (env, a) aenv b + -> OpenExp (env, a) aenv c substitute' f g | Stats.substitution "substitute" False = undefined | isIdentity f = g -- don't rebind an identity function | isIdentity g = f | otherwise = Let g $ rebuildE split f where - split :: Idx (env,b) c -> PreOpenExp acc ((env,a),b) aenv c + split :: Idx (env,b) c -> OpenExp ((env,a),b) aenv c split ZeroIdx = Var ZeroIdx split (SuccIdx ix) = Var (SuccIdx (SuccIdx ix)) -substitute :: RebuildableAcc acc - => LeftHandSide b env envb - -> PreOpenExp acc envb c +substitute :: LeftHandSide b env envb + -> OpenExp envb c -> LeftHandSide a env enva - -> PreOpenExp acc enva b + -> OpenExp enva b -} -- | Composition of unary functions. -- -compose :: RebuildableAcc acc - => PreOpenFun acc env aenv (b -> c) - -> PreOpenFun acc env aenv (a -> b) - -> PreOpenFun acc env aenv (a -> c) +compose :: OpenFun env aenv (b -> c) + -> OpenFun env aenv (a -> b) + -> OpenFun env aenv (a -> c) compose f@(Lam lhsB (Body c)) g@(Lam lhsA (Body b)) | Stats.substitution "compose" False = undefined | Just Refl <- isIdentity f = g -- don't rebind an identity function | Just Refl <- isIdentity g = f | Exists lhsB' <- rebuildLHS lhsB - = Lam lhsA $ Body $ Let lhsB' b (weakenE (sinkWithLHS lhsB lhsB' $ weakenWithLHS lhsA) c) + = Lam lhsA + $ Body + $ Let lhsB' b + $ weakenE (sinkWithLHS lhsB lhsB' $ weakenWithLHS lhsA) c -- = Stats.substitution "compose" . Lam lhs2 . Body $ substitute' f g -compose _ _ = error "compose: impossible evaluation" +compose _ + _ = error "compose: impossible evaluation" -subTop :: PreOpenExp acc env aenv s -> ExpVar (env, s) t -> PreOpenExp acc env aenv t +subTop :: OpenExp env aenv s -> ExpVar (env, s) t -> OpenExp env aenv t subTop s (Var _ ZeroIdx ) = s subTop _ (Var tp (SuccIdx ix)) = Evar $ Var tp ix @@ -269,13 +269,13 @@ class Rebuildable f where class RebuildableExp f where {-# MINIMAL rebuildPartialE #-} rebuildPartialE :: (Applicative f', SyntacticExp fe) - => (forall e'. ExpVar env e' -> f' (fe (AccClo (f env)) env' aenv e')) - -> f env aenv e + => (forall e'. ExpVar env e' -> f' (fe env' aenv e')) + -> f env aenv e -> f' (f env' aenv e) {-# INLINEABLE rebuildE #-} rebuildE :: SyntacticExp fe - => (forall e'. ExpVar env e' -> fe (AccClo (f env)) env' aenv e') + => (forall e'. ExpVar env e' -> fe env' aenv e') -> f env aenv e -> f env' aenv e rebuildE v = runIdentity . rebuildPartialE (Identity . v) @@ -284,17 +284,25 @@ class RebuildableExp f where -- type RebuildableAcc acc = (Rebuildable acc, AccClo acc ~ acc) +-- Wrappers which add the 'acc' type argument +-- +data OpenAccExp (acc :: Type -> Type -> Type) env aenv a where + OpenAccExp :: { unOpenAccExp :: OpenExp env aenv a } -> OpenAccExp acc env aenv a + +data OpenAccFun (acc :: Type -> Type -> Type) env aenv a where + OpenAccFun :: { unOpenAccFun :: OpenFun env aenv a } -> OpenAccFun acc env aenv a + -- We can use the same plumbing to rebuildPartial all the things we want to rebuild. -- -instance RebuildableAcc acc => Rebuildable (PreOpenExp acc env) where - type AccClo (PreOpenExp acc env) = acc +instance Rebuildable (OpenAccExp acc env) where + type AccClo (OpenAccExp acc env) = acc {-# INLINEABLE rebuildPartial #-} - rebuildPartial x = Stats.substitution "rebuild" $ rebuildPreOpenExp rebuildPartial (pure . IE) x + rebuildPartial v (OpenAccExp e) = OpenAccExp <$> Stats.substitution "rebuild" (rebuildOpenExp (pure . IE) (reindexAvar v) e) -instance RebuildableAcc acc => Rebuildable (PreOpenFun acc env) where - type AccClo (PreOpenFun acc env) = acc +instance Rebuildable (OpenAccFun acc env) where + type AccClo (OpenAccFun acc env) = acc {-# INLINEABLE rebuildPartial #-} - rebuildPartial x = Stats.substitution "rebuild" $ rebuildFun rebuildPartial (pure . IE) x + rebuildPartial v (OpenAccFun f) = OpenAccFun <$> Stats.substitution "rebuild" (rebuildFun (pure . IE) (reindexAvar v) f) instance RebuildableAcc acc => Rebuildable (PreOpenAcc acc) where type AccClo (PreOpenAcc acc) = acc @@ -311,13 +319,13 @@ instance Rebuildable OpenAcc where {-# INLINEABLE rebuildPartial #-} rebuildPartial x = Stats.substitution "rebuild" $ rebuildOpenAcc x -instance RebuildableAcc acc => RebuildableExp (PreOpenExp acc) where +instance RebuildableExp OpenExp where {-# INLINEABLE rebuildPartialE #-} - rebuildPartialE v x = Stats.substitution "rebuild" $ rebuildPreOpenExp rebuildPartial v (pure . IA) x + rebuildPartialE v x = Stats.substitution "rebuild" $ rebuildOpenExp v (ReindexAvar pure) x -instance RebuildableAcc acc => RebuildableExp (PreOpenFun acc) where +instance RebuildableExp OpenFun where {-# INLINEABLE rebuildPartialE #-} - rebuildPartialE v x = Stats.substitution "rebuild" $ rebuildFun rebuildPartial v (pure . IA) x + rebuildPartialE v x = Stats.substitution "rebuild" $ rebuildFun v (ReindexAvar pure) x -- NOTE: [Weakening] -- @@ -363,7 +371,7 @@ instance Sink (Vars s) where rebuildWeakenVar :: env :> env' -> ArrayVar env (Array sh e) -> PreOpenAcc acc env' (Array sh e) rebuildWeakenVar k (Var s idx) = Avar $ Var s $ k >:> idx -rebuildWeakenEvar :: env :> env' -> ExpVar env t -> PreOpenExp acc env' aenv t +rebuildWeakenEvar :: env :> env' -> ExpVar env t -> OpenExp env' aenv t rebuildWeakenEvar k (Var s idx) = Evar $ Var s $ k >:> idx instance RebuildableAcc acc => Sink (PreOpenAcc acc) where @@ -374,15 +382,15 @@ instance RebuildableAcc acc => Sink (PreOpenAfun acc) where {-# INLINEABLE weaken #-} weaken k = Stats.substitution "weaken" . rebuildA (rebuildWeakenVar k) -instance RebuildableAcc acc => Sink (PreOpenExp acc env) where +instance Sink (OpenExp env) where {-# INLINEABLE weaken #-} - weaken k = Stats.substitution "weaken" . rebuildA (rebuildWeakenVar k) + weaken k = Stats.substitution "weaken" . runIdentity . rebuildOpenExp (Identity . Evar) (ReindexAvar (Identity . weaken k)) -instance RebuildableAcc acc => Sink (PreOpenFun acc env) where +instance Sink (OpenFun env) where {-# INLINEABLE weaken #-} - weaken k = Stats.substitution "weaken" . rebuildA (rebuildWeakenVar k) + weaken k = Stats.substitution "weaken" . runIdentity . rebuildFun (Identity . Evar) (ReindexAvar (Identity . weaken k)) -instance RebuildableAcc acc => Sink (PreBoundary acc) where +instance Sink Boundary where {-# INLINEABLE weaken #-} weaken k bndy = case bndy of @@ -412,11 +420,11 @@ class SinkExp f where -- default weakenE :: RebuildableExp f => env :> env' -> f env aenv t -> f env' aenv t -- weakenE v = Stats.substitution "weakenE" . rebuildE (IE . v) -instance RebuildableAcc acc => SinkExp (PreOpenExp acc) where +instance SinkExp OpenExp where {-# INLINEABLE weakenE #-} weakenE v = Stats.substitution "weakenE" . rebuildE (rebuildWeakenEvar v) -instance RebuildableAcc acc => SinkExp (PreOpenFun acc) where +instance SinkExp OpenFun where {-# INLINEABLE weakenE #-} weakenE v = Stats.substitution "weakenE" . rebuildE (rebuildWeakenEvar v) @@ -447,7 +455,7 @@ strengthenE k x = Stats.substitution "strengthenE" $ rebuildPartialE @f @Maybe @ strengthenWithLHS :: LeftHandSide s t env1 env2 -> env2 :?> env1 strengthenWithLHS (LeftHandSideWildcard _) = Just -strengthenWithLHS (LeftHandSideSingle _) = \ix -> case ix of +strengthenWithLHS (LeftHandSideSingle _) = \ix -> case ix of ZeroIdx -> Nothing SuccIdx i -> Just i strengthenWithLHS (LeftHandSidePair l1 l2) = strengthenWithLHS l2 >=> strengthenWithLHS l1 @@ -471,57 +479,51 @@ strengthenAfter _ _ _ = error "Substitution.strengthenAfter: left hand sides do -- SEE: [Weakening] -- class SyntacticExp f where - varIn :: ExpVar env t -> f acc env aenv t - expOut :: f acc env aenv t -> PreOpenExp acc env aenv t - weakenExp :: RebuildAcc acc -> f acc env aenv t -> f acc (env, s) aenv t - -- weakenExpAcc :: RebuildAcc acc -> f acc env aenv t -> f acc env (aenv, s) t + varIn :: ExpVar env t -> f env aenv t + expOut :: f env aenv t -> OpenExp env aenv t + weakenExp :: f env aenv t -> f (env, s) aenv t -newtype IdxE (acc :: Type -> Type -> Type) env aenv t = IE { unIE :: ExpVar env t } +newtype IdxE env aenv t = IE { unIE :: ExpVar env t } instance SyntacticExp IdxE where varIn = IE expOut = Evar . unIE - weakenExp _ (IE (Var tp ix)) = IE $ Var tp $ SuccIdx ix - -- weakenExpAcc _ = IE . unIE + weakenExp (IE (Var tp ix)) = IE $ Var tp $ SuccIdx ix -instance SyntacticExp PreOpenExp where +instance SyntacticExp OpenExp where varIn = Evar expOut = id - weakenExp k = runIdentity . rebuildPreOpenExp k (Identity . weakenExp k . IE) (Identity . IA) - -- weakenExpAcc k = runIdentity . rebuildPreOpenExp k (Identity . IE) (Identity . weakenAcc k . IA) + weakenExp = runIdentity . rebuildOpenExp (Identity . weakenExp . IE) (ReindexAvar Identity) {-# INLINEABLE shiftE #-} shiftE :: (Applicative f, SyntacticExp fe) - => RebuildAcc acc - -> RebuildEvar f fe acc env env' aenv - -> RebuildEvar f fe acc (env, s) (env', s) aenv -shiftE _ _ (Var tp ZeroIdx) = pure $ varIn (Var tp ZeroIdx) -shiftE k v (Var tp (SuccIdx ix)) = weakenExp k <$> v (Var tp ix) + => RebuildEvar f fe env env' aenv + -> RebuildEvar f fe (env, s) (env', s) aenv +shiftE _ (Var tp ZeroIdx) = pure $ varIn (Var tp ZeroIdx) +shiftE v (Var tp (SuccIdx ix)) = weakenExp <$> v (Var tp ix) {-# INLINEABLE shiftE' #-} shiftE' :: (Applicative f, SyntacticExp fa) => ELeftHandSide t env1 env1' -> ELeftHandSide t env2 env2' - -> RebuildAcc acc - -> RebuildEvar f fa acc env1 env2 aenv - -> RebuildEvar f fa acc env1' env2' aenv -shiftE' (LeftHandSideWildcard _) (LeftHandSideWildcard _) _ v = v -shiftE' (LeftHandSideSingle _) (LeftHandSideSingle _) k v = shiftE k v -shiftE' (LeftHandSidePair a1 b1) (LeftHandSidePair a2 b2) k v = shiftE' b1 b2 k $ shiftE' a1 a2 k v -shiftE' _ _ _ _ = error "Substitution: left hand sides do not match" + -> RebuildEvar f fa env1 env2 aenv + -> RebuildEvar f fa env1' env2' aenv +shiftE' (LeftHandSideWildcard _) (LeftHandSideWildcard _) v = v +shiftE' (LeftHandSideSingle _) (LeftHandSideSingle _) v = shiftE v +shiftE' (LeftHandSidePair a1 b1) (LeftHandSidePair a2 b2) v = shiftE' b1 b2 $ shiftE' a1 a2 v +shiftE' _ _ _ = error "Substitution: left hand sides do not match" -{-# INLINEABLE rebuildPreOpenExp #-} -rebuildPreOpenExp - :: (Applicative f, SyntacticExp fe, SyntacticAcc fa) - => RebuildAcc acc - -> RebuildEvar f fe acc env env' aenv' - -> RebuildAvar f fa acc aenv aenv' - -> PreOpenExp acc env aenv t - -> f (PreOpenExp acc env' aenv' t) -rebuildPreOpenExp k v av exp = +{-# INLINEABLE rebuildOpenExp #-} +rebuildOpenExp + :: (Applicative f, SyntacticExp fe) + => RebuildEvar f fe env env' aenv' + -> ReindexAvar f aenv aenv' + -> OpenExp env aenv t + -> f (OpenExp env' aenv' t) +rebuildOpenExp v av@(ReindexAvar reindex) exp = case exp of Const t c -> pure $ Const t c PrimConst c -> pure $ PrimConst c @@ -529,39 +531,38 @@ rebuildPreOpenExp k v av exp = Evar var -> expOut <$> v var Let lhs a b | Exists lhs' <- rebuildLHS lhs - -> Let lhs' <$> rebuildPreOpenExp k v av a <*> rebuildPreOpenExp k (shiftE' lhs lhs' k v) av b - Pair e1 e2 -> Pair <$> rebuildPreOpenExp k v av e1 <*> rebuildPreOpenExp k v av e2 + -> Let lhs' <$> rebuildOpenExp v av a <*> rebuildOpenExp (shiftE' lhs lhs' v) av b + Pair e1 e2 -> Pair <$> rebuildOpenExp v av e1 <*> rebuildOpenExp v av e2 Nil -> pure Nil - VecPack vec e -> VecPack vec <$> rebuildPreOpenExp k v av e - VecUnpack vec e -> VecUnpack vec <$> rebuildPreOpenExp k v av e - IndexSlice x ix sh -> IndexSlice x <$> rebuildPreOpenExp k v av ix <*> rebuildPreOpenExp k v av sh - IndexFull x ix sl -> IndexFull x <$> rebuildPreOpenExp k v av ix <*> rebuildPreOpenExp k v av sl - ToIndex shr sh ix -> ToIndex shr <$> rebuildPreOpenExp k v av sh <*> rebuildPreOpenExp k v av ix - FromIndex shr sh ix -> FromIndex shr <$> rebuildPreOpenExp k v av sh <*> rebuildPreOpenExp k v av ix - Cond p t e -> Cond <$> rebuildPreOpenExp k v av p <*> rebuildPreOpenExp k v av t <*> rebuildPreOpenExp k v av e - While p f x -> While <$> rebuildFun k v av p <*> rebuildFun k v av f <*> rebuildPreOpenExp k v av x - PrimApp f x -> PrimApp f <$> rebuildPreOpenExp k v av x - Index a sh -> Index <$> k av a <*> rebuildPreOpenExp k v av sh - LinearIndex a i -> LinearIndex <$> k av a <*> rebuildPreOpenExp k v av i - Shape a -> Shape <$> k av a - ShapeSize shr sh -> ShapeSize shr <$> rebuildPreOpenExp k v av sh - Foreign tp ff f e -> Foreign tp ff f <$> rebuildPreOpenExp k v av e - Coerce t1 t2 e -> Coerce t1 t2 <$> rebuildPreOpenExp k v av e + VecPack vec e -> VecPack vec <$> rebuildOpenExp v av e + VecUnpack vec e -> VecUnpack vec <$> rebuildOpenExp v av e + IndexSlice x ix sh -> IndexSlice x <$> rebuildOpenExp v av ix <*> rebuildOpenExp v av sh + IndexFull x ix sl -> IndexFull x <$> rebuildOpenExp v av ix <*> rebuildOpenExp v av sl + ToIndex shr sh ix -> ToIndex shr <$> rebuildOpenExp v av sh <*> rebuildOpenExp v av ix + FromIndex shr sh ix -> FromIndex shr <$> rebuildOpenExp v av sh <*> rebuildOpenExp v av ix + Cond p t e -> Cond <$> rebuildOpenExp v av p <*> rebuildOpenExp v av t <*> rebuildOpenExp v av e + While p f x -> While <$> rebuildFun v av p <*> rebuildFun v av f <*> rebuildOpenExp v av x + PrimApp f x -> PrimApp f <$> rebuildOpenExp v av x + Index a sh -> Index <$> reindex a <*> rebuildOpenExp v av sh + LinearIndex a i -> LinearIndex <$> reindex a <*> rebuildOpenExp v av i + Shape a -> Shape <$> reindex a + ShapeSize shr sh -> ShapeSize shr <$> rebuildOpenExp v av sh + Foreign tp ff f e -> Foreign tp ff f <$> rebuildOpenExp v av e + Coerce t1 t2 e -> Coerce t1 t2 <$> rebuildOpenExp v av e {-# INLINEABLE rebuildFun #-} rebuildFun - :: (Applicative f, SyntacticExp fe, SyntacticAcc fa) - => RebuildAcc acc - -> RebuildEvar f fe acc env env' aenv' - -> RebuildAvar f fa acc aenv aenv' - -> PreOpenFun acc env aenv t - -> f (PreOpenFun acc env' aenv' t) -rebuildFun k v av fun = + :: (Applicative f, SyntacticExp fe) + => RebuildEvar f fe env env' aenv' + -> ReindexAvar f aenv aenv' + -> OpenFun env aenv t + -> f (OpenFun env' aenv' t) +rebuildFun v av fun = case fun of - Body e -> Body <$> rebuildPreOpenExp k v av e + Body e -> Body <$> rebuildOpenExp v av e Lam lhs f | Exists lhs' <- rebuildLHS lhs - -> Lam lhs' <$> rebuildFun k (shiftE' lhs lhs' k v) av f + -> Lam lhs' <$> rebuildFun (shiftE' lhs lhs' v) av f -- The array environment -- ----------------- @@ -592,16 +593,34 @@ instance SyntacticAcc PreOpenAcc where type RebuildAvar f (fa :: (Type -> Type -> Type) -> Type -> Type -> Type) acc aenv aenv' = forall sh e. ArrayVar aenv (Array sh e) -> f (fa acc aenv' (Array sh e)) -type RebuildEvar f fe (acc :: Type -> Type -> Type) env env' aenv' = - forall t'. ExpVar env t' -> f (fe acc env' aenv' t') +type RebuildEvar f fe env env' aenv' = + forall t'. ExpVar env t' -> f (fe env' aenv' t') + +newtype ReindexAvar f aenv aenv' = + ReindexAvar (forall sh e. ArrayVar aenv (Array sh e) -> f (ArrayVar aenv' (Array sh e))) + +reindexAvar + :: forall f fa acc aenv aenv'. + (Applicative f, SyntacticAcc fa) + => RebuildAvar f fa acc aenv aenv' + -> ReindexAvar f aenv aenv' +reindexAvar v = ReindexAvar f where + f :: forall sh e. ArrayVar aenv (Array sh e) -> f (ArrayVar aenv' (Array sh e)) + f var = g <$> v var + + g :: fa acc aenv' (Array sh e) -> ArrayVar aenv' (Array sh e) + g fa = case accOut fa of + Avar var' -> var' + _ -> $internalError "reindexAvar" "An Avar which was used in an Exp was mapped to an array term other than Avar. This mapping is invalid as an Exp can only contain array variables." + {-# INLINEABLE shiftA #-} shiftA :: (Applicative f, SyntacticAcc fa) => RebuildAcc acc -> RebuildAvar f fa acc aenv aenv' - -> ArrayVar (aenv, s) (Array sh e) - -> f (fa acc (aenv', s) (Array sh e)) + -> ArrayVar (aenv, s) (Array sh e) + -> f (fa acc (aenv', s) (Array sh e)) shiftA _ _ (Var s ZeroIdx) = pure $ avarIn $ Var s ZeroIdx shiftA k v (Var s (SuccIdx ix)) = weakenAcc k <$> v (Var s ix) @@ -640,33 +659,35 @@ rebuildPreOpenAcc k av acc = Apair as bs -> Apair <$> k av as <*> k av bs Anil -> pure Anil Apply repr f a -> Apply repr <$> rebuildAfun k av f <*> k av a - Acond p t e -> Acond <$> rebuildPreOpenExp k (pure . IE) av p <*> k av t <*> k av e + Acond p t e -> Acond <$> rebuildOpenExp (pure . IE) av' p <*> k av t <*> k av e Awhile p f a -> Awhile <$> rebuildAfun k av p <*> rebuildAfun k av f <*> k av a - Unit tp e -> Unit tp <$> rebuildPreOpenExp k (pure . IE) av e - Reshape shr e a -> Reshape shr <$> rebuildPreOpenExp k (pure . IE) av e <*> k av a - Generate repr e f -> Generate repr <$> rebuildPreOpenExp k (pure . IE) av e <*> rebuildFun k (pure . IE) av f - Transform repr sh ix f a -> Transform repr <$> rebuildPreOpenExp k (pure . IE) av sh <*> rebuildFun k (pure . IE) av ix <*> rebuildFun k (pure . IE) av f <*> k av a - Replicate sl slix a -> Replicate sl <$> rebuildPreOpenExp k (pure . IE) av slix <*> k av a - Slice sl a slix -> Slice sl <$> k av a <*> rebuildPreOpenExp k (pure . IE) av slix - Map tp f a -> Map tp <$> rebuildFun k (pure . IE) av f <*> k av a - ZipWith tp f a1 a2 -> ZipWith tp <$> rebuildFun k (pure . IE) av f <*> k av a1 <*> k av a2 - Fold f z a -> Fold <$> rebuildFun k (pure . IE) av f <*> rebuildPreOpenExp k (pure . IE) av z <*> k av a - Fold1 f a -> Fold1 <$> rebuildFun k (pure . IE) av f <*> k av a - FoldSeg itp f z a s -> FoldSeg itp <$> rebuildFun k (pure . IE) av f <*> rebuildPreOpenExp k (pure . IE) av z <*> k av a <*> k av s - Fold1Seg itp f a s -> Fold1Seg itp <$> rebuildFun k (pure . IE) av f <*> k av a <*> k av s - Scanl f z a -> Scanl <$> rebuildFun k (pure . IE) av f <*> rebuildPreOpenExp k (pure . IE) av z <*> k av a - Scanl' f z a -> Scanl' <$> rebuildFun k (pure . IE) av f <*> rebuildPreOpenExp k (pure . IE) av z <*> k av a - Scanl1 f a -> Scanl1 <$> rebuildFun k (pure . IE) av f <*> k av a - Scanr f z a -> Scanr <$> rebuildFun k (pure . IE) av f <*> rebuildPreOpenExp k (pure . IE) av z <*> k av a - Scanr' f z a -> Scanr' <$> rebuildFun k (pure . IE) av f <*> rebuildPreOpenExp k (pure . IE) av z <*> k av a - Scanr1 f a -> Scanr1 <$> rebuildFun k (pure . IE) av f <*> k av a - Permute f1 a1 f2 a2 -> Permute <$> rebuildFun k (pure . IE) av f1 <*> k av a1 <*> rebuildFun k (pure . IE) av f2 <*> k av a2 - Backpermute shr sh f a -> Backpermute shr <$> rebuildPreOpenExp k (pure . IE) av sh <*> rebuildFun k (pure . IE) av f <*> k av a - Stencil sr tp f b a -> Stencil sr tp <$> rebuildFun k (pure . IE) av f <*> rebuildBoundary k av b <*> k av a + Unit tp e -> Unit tp <$> rebuildOpenExp (pure . IE) av' e + Reshape shr e a -> Reshape shr <$> rebuildOpenExp (pure . IE) av' e <*> k av a + Generate repr e f -> Generate repr <$> rebuildOpenExp (pure . IE) av' e <*> rebuildFun (pure . IE) av' f + Transform repr sh ix f a -> Transform repr <$> rebuildOpenExp (pure . IE) av' sh <*> rebuildFun (pure . IE) av' ix <*> rebuildFun (pure . IE) av' f <*> k av a + Replicate sl slix a -> Replicate sl <$> rebuildOpenExp (pure . IE) av' slix <*> k av a + Slice sl a slix -> Slice sl <$> k av a <*> rebuildOpenExp (pure . IE) av' slix + Map tp f a -> Map tp <$> rebuildFun (pure . IE) av' f <*> k av a + ZipWith tp f a1 a2 -> ZipWith tp <$> rebuildFun (pure . IE) av' f <*> k av a1 <*> k av a2 + Fold f z a -> Fold <$> rebuildFun (pure . IE) av' f <*> rebuildOpenExp (pure . IE) av' z <*> k av a + Fold1 f a -> Fold1 <$> rebuildFun (pure . IE) av' f <*> k av a + FoldSeg itp f z a s -> FoldSeg itp <$> rebuildFun (pure . IE) av' f <*> rebuildOpenExp (pure . IE) av' z <*> k av a <*> k av s + Fold1Seg itp f a s -> Fold1Seg itp <$> rebuildFun (pure . IE) av' f <*> k av a <*> k av s + Scanl f z a -> Scanl <$> rebuildFun (pure . IE) av' f <*> rebuildOpenExp (pure . IE) av' z <*> k av a + Scanl' f z a -> Scanl' <$> rebuildFun (pure . IE) av' f <*> rebuildOpenExp (pure . IE) av' z <*> k av a + Scanl1 f a -> Scanl1 <$> rebuildFun (pure . IE) av' f <*> k av a + Scanr f z a -> Scanr <$> rebuildFun (pure . IE) av' f <*> rebuildOpenExp (pure . IE) av' z <*> k av a + Scanr' f z a -> Scanr' <$> rebuildFun (pure . IE) av' f <*> rebuildOpenExp (pure . IE) av' z <*> k av a + Scanr1 f a -> Scanr1 <$> rebuildFun (pure . IE) av' f <*> k av a + Permute f1 a1 f2 a2 -> Permute <$> rebuildFun (pure . IE) av' f1 <*> k av a1 <*> rebuildFun (pure . IE) av' f2 <*> k av a2 + Backpermute shr sh f a -> Backpermute shr <$> rebuildOpenExp (pure . IE) av' sh <*> rebuildFun (pure . IE) av' f <*> k av a + Stencil sr tp f b a -> Stencil sr tp <$> rebuildFun (pure . IE) av' f <*> rebuildBoundary av' b <*> k av a Stencil2 s1 s2 tp f b1 a1 b2 a2 - -> Stencil2 s1 s2 tp <$> rebuildFun k (pure . IE) av f <*> rebuildBoundary k av b1 <*> k av a1 <*> rebuildBoundary k av b2 <*> k av a2 + -> Stencil2 s1 s2 tp <$> rebuildFun (pure . IE) av' f <*> rebuildBoundary av' b1 <*> k av a1 <*> rebuildBoundary av' b2 <*> k av a2 Aforeign repr ff afun as -> Aforeign repr ff afun <$> k av as -- Collect seq -> Collect <$> rebuildSeq k av seq + where + av' = reindexAvar av {-# INLINEABLE rebuildAfun #-} rebuildAfun @@ -703,18 +724,17 @@ rebuildLHS (LeftHandSidePair as bs) {-# INLINEABLE rebuildBoundary #-} rebuildBoundary - :: (Applicative f, SyntacticAcc fa) - => RebuildAcc acc - -> RebuildAvar f fa acc aenv aenv' - -> PreBoundary acc aenv t - -> f (PreBoundary acc aenv' t) -rebuildBoundary k av bndy = + :: Applicative f + => ReindexAvar f aenv aenv' + -> Boundary aenv t + -> f (Boundary aenv' t) +rebuildBoundary av bndy = case bndy of Clamp -> pure Clamp Mirror -> pure Mirror Wrap -> pure Wrap Constant v -> pure (Constant v) - Function f -> Function <$> rebuildFun k (pure . IE) av f + Function f -> Function <$> rebuildFun (pure . IE) av f {-- {-# INLINEABLE rebuildSeq #-} @@ -743,7 +763,7 @@ rebuildP k v p = MapSeq f x -> MapSeq <$> rebuildAfun k v f <*> pure x ChunkedMapSeq f x -> ChunkedMapSeq <$> rebuildAfun k v f <*> pure x ZipWithSeq f x y -> ZipWithSeq <$> rebuildAfun k v f <*> pure x <*> pure y - ScanSeq f e x -> ScanSeq <$> rebuildFun k (pure . IE) v f <*> rebuildPreOpenExp k (pure . IE) v e <*> pure x + ScanSeq f e x -> ScanSeq <$> rebuildFun (pure . IE) v f <*> rebuildOpenExp (pure . IE) v e <*> pure x {-# INLINEABLE rebuildC #-} rebuildC :: forall acc fa f aenv aenv' senv a. (SyntacticAcc fa, Applicative f) @@ -753,7 +773,7 @@ rebuildC :: forall acc fa f aenv aenv' senv a. (SyntacticAcc fa, Applicative f) -> f (Consumer acc aenv' senv a) rebuildC k v c = case c of - FoldSeq f e x -> FoldSeq <$> rebuildFun k (pure . IE) v f <*> rebuildPreOpenExp k (pure . IE) v e <*> pure x + FoldSeq f e x -> FoldSeq <$> rebuildFun (pure . IE) v f <*> rebuildOpenExp (pure . IE) v e <*> pure x FoldSeqFlatten f acc x -> FoldSeqFlatten <$> rebuildAfun k v f <*> k v acc <*> pure x Stuple t -> Stuple <$> rebuildT t where @@ -762,7 +782,7 @@ rebuildC k v c = rebuildT (SnocAtup t s) = SnocAtup <$> (rebuildT t) <*> (rebuildC k v s) --} -extractExpVars :: PreOpenExp acc env aenv a -> Maybe (ExpVars env a) +extractExpVars :: OpenExp env aenv a -> Maybe (ExpVars env a) extractExpVars Nil = Just VarsNil extractExpVars (Pair e1 e2) = VarsPair <$> extractExpVars e1 <*> extractExpVars e2 extractExpVars (Evar v) = Just $ VarsSingle v