{-# LANGUAGE TypeFamilies #-}
module Futhark.Optimise.DoubleBuffer (doubleBufferGPU, doubleBufferMC) where
import Control.Monad.Reader
import Control.Monad.State
import Data.Bifunctor
import Data.List qualified as L
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.Construct
import Futhark.IR.GPUMem as GPU
import Futhark.IR.MCMem as MC
import Futhark.IR.Mem.LMAD qualified as LMAD
import Futhark.Pass
import Futhark.Pass.ExplicitAllocations.GPU ()
import Futhark.Transform.Substitute
import Futhark.Util (mapAccumLM)
type OptimiseLoop rep =
Pat (LetDec rep) ->
[(FParam rep, SubExp)] ->
Body rep ->
DoubleBufferM
rep
( Stms rep,
Pat (LetDec rep),
[(FParam rep, SubExp)],
Body rep
)
type OptimiseOp rep =
Op rep -> DoubleBufferM rep (Op rep)
data Env rep = Env
{ forall rep. Env rep -> Scope rep
envScope :: Scope rep,
forall rep. Env rep -> OptimiseLoop rep
envOptimiseLoop :: OptimiseLoop rep,
forall rep. Env rep -> OptimiseOp rep
envOptimiseOp :: OptimiseOp rep
}
newtype DoubleBufferM rep a = DoubleBufferM
{ forall rep a.
DoubleBufferM rep a -> ReaderT (Env rep) (State VNameSource) a
runDoubleBufferM :: ReaderT (Env rep) (State VNameSource) a
}
deriving ((forall a b.
(a -> b) -> DoubleBufferM rep a -> DoubleBufferM rep b)
-> (forall a b. a -> DoubleBufferM rep b -> DoubleBufferM rep a)
-> Functor (DoubleBufferM rep)
forall a b. a -> DoubleBufferM rep b -> DoubleBufferM rep a
forall a b. (a -> b) -> DoubleBufferM rep a -> DoubleBufferM rep b
forall rep a b. a -> DoubleBufferM rep b -> DoubleBufferM rep a
forall rep a b.
(a -> b) -> DoubleBufferM rep a -> DoubleBufferM rep b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall rep a b.
(a -> b) -> DoubleBufferM rep a -> DoubleBufferM rep b
fmap :: forall a b. (a -> b) -> DoubleBufferM rep a -> DoubleBufferM rep b
$c<$ :: forall rep a b. a -> DoubleBufferM rep b -> DoubleBufferM rep a
<$ :: forall a b. a -> DoubleBufferM rep b -> DoubleBufferM rep a
Functor, Functor (DoubleBufferM rep)
Functor (DoubleBufferM rep) =>
(forall a. a -> DoubleBufferM rep a)
-> (forall a b.
DoubleBufferM rep (a -> b)
-> DoubleBufferM rep a -> DoubleBufferM rep b)
-> (forall a b c.
(a -> b -> c)
-> DoubleBufferM rep a
-> DoubleBufferM rep b
-> DoubleBufferM rep c)
-> (forall a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b)
-> (forall a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep a)
-> Applicative (DoubleBufferM rep)
forall rep. Functor (DoubleBufferM rep)
forall a. a -> DoubleBufferM rep a
forall rep a. a -> DoubleBufferM rep a
forall a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep a
forall a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b
forall a b.
DoubleBufferM rep (a -> b)
-> DoubleBufferM rep a -> DoubleBufferM rep b
forall rep a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep a
forall rep a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b
forall rep a b.
DoubleBufferM rep (a -> b)
-> DoubleBufferM rep a -> DoubleBufferM rep b
forall a b c.
(a -> b -> c)
-> DoubleBufferM rep a
-> DoubleBufferM rep b
-> DoubleBufferM rep c
forall rep a b c.
(a -> b -> c)
-> DoubleBufferM rep a
-> DoubleBufferM rep b
-> DoubleBufferM rep c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall rep a. a -> DoubleBufferM rep a
pure :: forall a. a -> DoubleBufferM rep a
$c<*> :: forall rep a b.
DoubleBufferM rep (a -> b)
-> DoubleBufferM rep a -> DoubleBufferM rep b
<*> :: forall a b.
DoubleBufferM rep (a -> b)
-> DoubleBufferM rep a -> DoubleBufferM rep b
$cliftA2 :: forall rep a b c.
(a -> b -> c)
-> DoubleBufferM rep a
-> DoubleBufferM rep b
-> DoubleBufferM rep c
liftA2 :: forall a b c.
(a -> b -> c)
-> DoubleBufferM rep a
-> DoubleBufferM rep b
-> DoubleBufferM rep c
$c*> :: forall rep a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b
*> :: forall a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b
$c<* :: forall rep a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep a
<* :: forall a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep a
Applicative, Applicative (DoubleBufferM rep)
Applicative (DoubleBufferM rep) =>
(forall a b.
DoubleBufferM rep a
-> (a -> DoubleBufferM rep b) -> DoubleBufferM rep b)
-> (forall a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b)
-> (forall a. a -> DoubleBufferM rep a)
-> Monad (DoubleBufferM rep)
forall rep. Applicative (DoubleBufferM rep)
forall a. a -> DoubleBufferM rep a
forall rep a. a -> DoubleBufferM rep a
forall a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b
forall a b.
DoubleBufferM rep a
-> (a -> DoubleBufferM rep b) -> DoubleBufferM rep b
forall rep a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b
forall rep a b.
DoubleBufferM rep a
-> (a -> DoubleBufferM rep b) -> DoubleBufferM rep b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall rep a b.
DoubleBufferM rep a
-> (a -> DoubleBufferM rep b) -> DoubleBufferM rep b
>>= :: forall a b.
DoubleBufferM rep a
-> (a -> DoubleBufferM rep b) -> DoubleBufferM rep b
$c>> :: forall rep a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b
>> :: forall a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b
$creturn :: forall rep a. a -> DoubleBufferM rep a
return :: forall a. a -> DoubleBufferM rep a
Monad, MonadReader (Env rep), Monad (DoubleBufferM rep)
DoubleBufferM rep VNameSource
Monad (DoubleBufferM rep) =>
DoubleBufferM rep VNameSource
-> (VNameSource -> DoubleBufferM rep ())
-> MonadFreshNames (DoubleBufferM rep)
VNameSource -> DoubleBufferM rep ()
forall rep. Monad (DoubleBufferM rep)
forall rep. DoubleBufferM rep VNameSource
forall rep. VNameSource -> DoubleBufferM rep ()
forall (m :: * -> *).
Monad m =>
m VNameSource -> (VNameSource -> m ()) -> MonadFreshNames m
$cgetNameSource :: forall rep. DoubleBufferM rep VNameSource
getNameSource :: DoubleBufferM rep VNameSource
$cputNameSource :: forall rep. VNameSource -> DoubleBufferM rep ()
putNameSource :: VNameSource -> DoubleBufferM rep ()
MonadFreshNames)
instance (ASTRep rep) => HasScope rep (DoubleBufferM rep) where
askScope :: DoubleBufferM rep (Scope rep)
askScope = (Env rep -> Scope rep) -> DoubleBufferM rep (Scope rep)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep -> Scope rep
forall rep. Env rep -> Scope rep
envScope
instance (ASTRep rep) => LocalScope rep (DoubleBufferM rep) where
localScope :: forall a. Scope rep -> DoubleBufferM rep a -> DoubleBufferM rep a
localScope Scope rep
scope = (Env rep -> Env rep) -> DoubleBufferM rep a -> DoubleBufferM rep a
forall a.
(Env rep -> Env rep) -> DoubleBufferM rep a -> DoubleBufferM rep a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env rep -> Env rep)
-> DoubleBufferM rep a -> DoubleBufferM rep a)
-> (Env rep -> Env rep)
-> DoubleBufferM rep a
-> DoubleBufferM rep a
forall a b. (a -> b) -> a -> b
$ \Env rep
env -> Env rep
env {envScope = envScope env <> scope}
optimiseBody :: (ASTRep rep) => Body rep -> DoubleBufferM rep (Body rep)
optimiseBody :: forall rep. ASTRep rep => Body rep -> DoubleBufferM rep (Body rep)
optimiseBody Body rep
body = do
stms' <- [Stm rep] -> DoubleBufferM rep (Stms rep)
forall rep. ASTRep rep => [Stm rep] -> DoubleBufferM rep (Stms rep)
optimiseStms ([Stm rep] -> DoubleBufferM rep (Stms rep))
-> [Stm rep] -> DoubleBufferM rep (Stms rep)
forall a b. (a -> b) -> a -> b
$ Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms rep -> [Stm rep]) -> Stms rep -> [Stm rep]
forall a b. (a -> b) -> a -> b
$ Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms Body rep
body
pure $ body {bodyStms = stms'}
optimiseStms :: (ASTRep rep) => [Stm rep] -> DoubleBufferM rep (Stms rep)
optimiseStms :: forall rep. ASTRep rep => [Stm rep] -> DoubleBufferM rep (Stms rep)
optimiseStms [] = Stms rep -> DoubleBufferM rep (Stms rep)
forall a. a -> DoubleBufferM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms rep
forall a. Monoid a => a
mempty
optimiseStms (Stm rep
e : [Stm rep]
es) = do
e_es <- Stm rep -> DoubleBufferM rep (Stms rep)
forall rep. ASTRep rep => Stm rep -> DoubleBufferM rep (Stms rep)
optimiseStm Stm rep
e
es' <- localScope (castScope $ scopeOf e_es) $ optimiseStms es
pure $ e_es <> es'
optimiseStm :: forall rep. (ASTRep rep) => Stm rep -> DoubleBufferM rep (Stms rep)
optimiseStm :: forall rep. ASTRep rep => Stm rep -> DoubleBufferM rep (Stms rep)
optimiseStm (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (Loop [(FParam rep, SubExp)]
merge LoopForm
form Body rep
body)) = do
body' <-
Scope rep
-> DoubleBufferM rep (Body rep) -> DoubleBufferM rep (Body rep)
forall a. Scope rep -> DoubleBufferM rep a -> DoubleBufferM rep a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (LoopForm -> Scope rep
forall rep. LoopForm -> Scope rep
scopeOfLoopForm LoopForm
form Scope rep -> Scope rep -> Scope rep
forall a. Semigroup a => a -> a -> a
<> [FParam rep] -> Scope rep
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams (((FParam rep, SubExp) -> FParam rep)
-> [(FParam rep, SubExp)] -> [FParam rep]
forall a b. (a -> b) -> [a] -> [b]
map (FParam rep, SubExp) -> FParam rep
forall a b. (a, b) -> a
fst [(FParam rep, SubExp)]
merge)) (DoubleBufferM rep (Body rep) -> DoubleBufferM rep (Body rep))
-> DoubleBufferM rep (Body rep) -> DoubleBufferM rep (Body rep)
forall a b. (a -> b) -> a -> b
$
Body rep -> DoubleBufferM rep (Body rep)
forall rep. ASTRep rep => Body rep -> DoubleBufferM rep (Body rep)
optimiseBody Body rep
body
opt_loop <- asks envOptimiseLoop
(stms, pat', merge', body'') <- opt_loop pat merge body'
pure $ stms <> oneStm (Let pat' aux $ Loop merge' form body'')
optimiseStm (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux Exp rep
e) = do
onOp <- (Env rep -> OpC rep rep -> DoubleBufferM rep (OpC rep rep))
-> DoubleBufferM
rep (OpC rep rep -> DoubleBufferM rep (OpC rep rep))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep -> OpC rep rep -> DoubleBufferM rep (OpC rep rep)
forall rep. Env rep -> OptimiseOp rep
envOptimiseOp
oneStm . Let pat aux <$> mapExpM (optimise onOp) e
where
optimise :: (OpC rep rep -> DoubleBufferM rep (OpC rep rep))
-> Mapper rep rep (DoubleBufferM rep)
optimise OpC rep rep -> DoubleBufferM rep (OpC rep rep)
onOp =
(forall rep (m :: * -> *). Monad m => Mapper rep rep m
identityMapper @rep)
{ mapOnBody = \Scope rep
_ Body rep
x ->
Body rep -> DoubleBufferM rep (Body rep)
forall rep. ASTRep rep => Body rep -> DoubleBufferM rep (Body rep)
optimiseBody Body rep
x :: DoubleBufferM rep (Body rep),
mapOnOp = onOp
}
optimiseGPUOp :: OptimiseOp GPUMem
optimiseGPUOp :: OptimiseOp GPUMem
optimiseGPUOp (Inner (SegOp SegOp SegLevel GPUMem
op)) =
(Env GPUMem -> Env GPUMem)
-> DoubleBufferM GPUMem (Op GPUMem)
-> DoubleBufferM GPUMem (Op GPUMem)
forall a.
(Env GPUMem -> Env GPUMem)
-> DoubleBufferM GPUMem a -> DoubleBufferM GPUMem a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local Env GPUMem -> Env GPUMem
forall {rep} {inner :: * -> *}.
(ExpDec rep ~ (), BodyDec rep ~ (), LetDec rep ~ LetDecMem,
FParamInfo rep ~ FParamMem, LParamInfo rep ~ LetDecMem,
RetType rep ~ RetTypeMem, BranchType rep ~ BranchTypeMem,
OpC rep ~ MemOp inner, OpReturns inner, RephraseOp inner,
Ord (inner rep), Show (inner rep), Rename (inner rep),
Substitute (inner rep), Pretty (inner rep), FreeIn (inner rep),
BuilderOps rep) =>
Env rep -> Env rep
inSegOp (DoubleBufferM GPUMem (Op GPUMem)
-> DoubleBufferM GPUMem (Op GPUMem))
-> DoubleBufferM GPUMem (Op GPUMem)
-> DoubleBufferM GPUMem (Op GPUMem)
forall a b. (a -> b) -> a -> b
$ HostOp NoOp GPUMem -> MemOp (HostOp NoOp) GPUMem
forall (inner :: * -> *) rep. inner rep -> MemOp inner rep
Inner (HostOp NoOp GPUMem -> MemOp (HostOp NoOp) GPUMem)
-> (SegOp SegLevel GPUMem -> HostOp NoOp GPUMem)
-> SegOp SegLevel GPUMem
-> MemOp (HostOp NoOp) GPUMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp SegLevel GPUMem -> HostOp NoOp GPUMem
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp (SegOp SegLevel GPUMem -> MemOp (HostOp NoOp) GPUMem)
-> DoubleBufferM GPUMem (SegOp SegLevel GPUMem)
-> DoubleBufferM GPUMem (MemOp (HostOp NoOp) GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper SegLevel GPUMem GPUMem (DoubleBufferM GPUMem)
-> SegOp SegLevel GPUMem
-> DoubleBufferM GPUMem (SegOp SegLevel GPUMem)
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper SegLevel GPUMem GPUMem (DoubleBufferM GPUMem)
forall {lvl}. SegOpMapper lvl GPUMem GPUMem (DoubleBufferM GPUMem)
mapper SegOp SegLevel GPUMem
op
where
mapper :: SegOpMapper lvl GPUMem GPUMem (DoubleBufferM GPUMem)
mapper =
SegOpMapper lvl (ZonkAny 0) (ZonkAny 0) (DoubleBufferM GPUMem)
forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper
{ mapOnSegOpLambda = optimiseLambda,
mapOnSegOpBody = optimiseKernelBody
}
inSegOp :: Env rep -> Env rep
inSegOp Env rep
env = Env rep
env {envOptimiseLoop = optimiseLoop}
optimiseGPUOp Op GPUMem
op = MemOp (HostOp NoOp) GPUMem
-> DoubleBufferM GPUMem (MemOp (HostOp NoOp) GPUMem)
forall a. a -> DoubleBufferM GPUMem a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Op GPUMem
MemOp (HostOp NoOp) GPUMem
op
optimiseMCOp :: OptimiseOp MCMem
optimiseMCOp :: OptimiseOp MCMem
optimiseMCOp (Inner (ParOp Maybe (SegOp () MCMem)
par_op SegOp () MCMem
op)) =
(Env MCMem -> Env MCMem)
-> DoubleBufferM MCMem (Op MCMem) -> DoubleBufferM MCMem (Op MCMem)
forall a.
(Env MCMem -> Env MCMem)
-> DoubleBufferM MCMem a -> DoubleBufferM MCMem a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local Env MCMem -> Env MCMem
forall {rep} {inner :: * -> *}.
(ExpDec rep ~ (), BodyDec rep ~ (), LetDec rep ~ LetDecMem,
FParamInfo rep ~ FParamMem, LParamInfo rep ~ LetDecMem,
RetType rep ~ RetTypeMem, BranchType rep ~ BranchTypeMem,
OpC rep ~ MemOp inner, OpReturns inner, RephraseOp inner,
Ord (inner rep), Show (inner rep), Rename (inner rep),
Substitute (inner rep), Pretty (inner rep), FreeIn (inner rep),
BuilderOps rep) =>
Env rep -> Env rep
inSegOp (DoubleBufferM MCMem (Op MCMem) -> DoubleBufferM MCMem (Op MCMem))
-> DoubleBufferM MCMem (Op MCMem) -> DoubleBufferM MCMem (Op MCMem)
forall a b. (a -> b) -> a -> b
$
MCOp NoOp MCMem -> MemOp (MCOp NoOp) MCMem
forall (inner :: * -> *) rep. inner rep -> MemOp inner rep
Inner
(MCOp NoOp MCMem -> MemOp (MCOp NoOp) MCMem)
-> DoubleBufferM MCMem (MCOp NoOp MCMem)
-> DoubleBufferM MCMem (MemOp (MCOp NoOp) MCMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Maybe (SegOp () MCMem) -> SegOp () MCMem -> MCOp NoOp MCMem
forall (op :: * -> *) rep.
Maybe (SegOp () rep) -> SegOp () rep -> MCOp op rep
ParOp (Maybe (SegOp () MCMem) -> SegOp () MCMem -> MCOp NoOp MCMem)
-> DoubleBufferM MCMem (Maybe (SegOp () MCMem))
-> DoubleBufferM MCMem (SegOp () MCMem -> MCOp NoOp MCMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SegOp () MCMem -> DoubleBufferM MCMem (SegOp () MCMem))
-> Maybe (SegOp () MCMem)
-> DoubleBufferM MCMem (Maybe (SegOp () MCMem))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Maybe a -> f (Maybe b)
traverse (SegOpMapper () MCMem MCMem (DoubleBufferM MCMem)
-> SegOp () MCMem -> DoubleBufferM MCMem (SegOp () MCMem)
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper () MCMem MCMem (DoubleBufferM MCMem)
forall {lvl}. SegOpMapper lvl MCMem MCMem (DoubleBufferM MCMem)
mapper) Maybe (SegOp () MCMem)
par_op DoubleBufferM MCMem (SegOp () MCMem -> MCOp NoOp MCMem)
-> DoubleBufferM MCMem (SegOp () MCMem)
-> DoubleBufferM MCMem (MCOp NoOp MCMem)
forall a b.
DoubleBufferM MCMem (a -> b)
-> DoubleBufferM MCMem a -> DoubleBufferM MCMem b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper () MCMem MCMem (DoubleBufferM MCMem)
-> SegOp () MCMem -> DoubleBufferM MCMem (SegOp () MCMem)
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper () MCMem MCMem (DoubleBufferM MCMem)
forall {lvl}. SegOpMapper lvl MCMem MCMem (DoubleBufferM MCMem)
mapper SegOp () MCMem
op)
where
mapper :: SegOpMapper lvl MCMem MCMem (DoubleBufferM MCMem)
mapper =
SegOpMapper lvl (ZonkAny 1) (ZonkAny 1) (DoubleBufferM MCMem)
forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper
{ mapOnSegOpLambda = optimiseLambda,
mapOnSegOpBody = optimiseKernelBody
}
inSegOp :: Env rep -> Env rep
inSegOp Env rep
env = Env rep
env {envOptimiseLoop = optimiseLoop}
optimiseMCOp Op MCMem
op = MemOp (MCOp NoOp) MCMem
-> DoubleBufferM MCMem (MemOp (MCOp NoOp) MCMem)
forall a. a -> DoubleBufferM MCMem a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Op MCMem
MemOp (MCOp NoOp) MCMem
op
optimiseKernelBody ::
(ASTRep rep) =>
KernelBody rep ->
DoubleBufferM rep (KernelBody rep)
optimiseKernelBody :: forall rep.
ASTRep rep =>
KernelBody rep -> DoubleBufferM rep (KernelBody rep)
optimiseKernelBody KernelBody rep
kbody = do
stms' <- [Stm rep] -> DoubleBufferM rep (Stms rep)
forall rep. ASTRep rep => [Stm rep] -> DoubleBufferM rep (Stms rep)
optimiseStms ([Stm rep] -> DoubleBufferM rep (Stms rep))
-> [Stm rep] -> DoubleBufferM rep (Stms rep)
forall a b. (a -> b) -> a -> b
$ Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms rep -> [Stm rep]) -> Stms rep -> [Stm rep]
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> Stms rep
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody rep
kbody
pure $ kbody {kernelBodyStms = stms'}
optimiseLambda ::
(ASTRep rep) =>
Lambda rep ->
DoubleBufferM rep (Lambda rep)
optimiseLambda :: forall rep.
ASTRep rep =>
Lambda rep -> DoubleBufferM rep (Lambda rep)
optimiseLambda Lambda rep
lam = do
body <- Scope rep
-> DoubleBufferM rep (Body rep) -> DoubleBufferM rep (Body rep)
forall a. Scope rep -> DoubleBufferM rep a -> DoubleBufferM rep a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Scope rep -> Scope rep
forall fromrep torep.
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope (Scope rep -> Scope rep) -> Scope rep -> Scope rep
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Scope rep
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Lambda rep
lam) (DoubleBufferM rep (Body rep) -> DoubleBufferM rep (Body rep))
-> DoubleBufferM rep (Body rep) -> DoubleBufferM rep (Body rep)
forall a b. (a -> b) -> a -> b
$ Body rep -> DoubleBufferM rep (Body rep)
forall rep. ASTRep rep => Body rep -> DoubleBufferM rep (Body rep)
optimiseBody (Body rep -> DoubleBufferM rep (Body rep))
-> Body rep -> DoubleBufferM rep (Body rep)
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam
pure lam {lambdaBody = body}
type Constraints rep inner =
( Mem rep inner,
BuilderOps rep,
ExpDec rep ~ (),
BodyDec rep ~ (),
LetDec rep ~ LetDecMem
)
extractAllocOf :: (Constraints rep inner) => Names -> VName -> Stms rep -> Maybe (Stm rep, Stms rep)
Names
bound VName
needle Stms rep
stms = do
(stm, stms') <- Stms rep -> Maybe (Stm rep, Stms rep)
forall rep. Stms rep -> Maybe (Stm rep, Stms rep)
stmsHead Stms rep
stms
case stm of
Let (Pat [PatElem (LetDec rep)
pe]) StmAux (ExpDec rep)
_ (Op (Alloc SubExp
size Space
_))
| PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
PatElem LetDecMem
pe VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
needle,
SubExp -> Bool
invariant SubExp
size ->
(Stm rep, Stms rep) -> Maybe (Stm rep, Stms rep)
forall a. a -> Maybe a
Just (Stm rep
stm, Stms rep
stms')
Stm rep
_ ->
let bound' :: Names
bound' = [VName] -> Names
namesFromList (Pat LetDecMem -> [VName]
forall dec. Pat dec -> [VName]
patNames (Stm rep -> Pat (LetDec rep)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm rep
stm)) Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
bound
in (Stms rep -> Stms rep)
-> (Stm rep, Stms rep) -> (Stm rep, Stms rep)
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (Stm rep -> Stms rep
forall rep. Stm rep -> Stms rep
oneStm Stm rep
stm Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<>) ((Stm rep, Stms rep) -> (Stm rep, Stms rep))
-> Maybe (Stm rep, Stms rep) -> Maybe (Stm rep, Stms rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Names -> VName -> Stms rep -> Maybe (Stm rep, Stms rep)
forall rep (inner :: * -> *).
Constraints rep inner =>
Names -> VName -> Stms rep -> Maybe (Stm rep, Stms rep)
extractAllocOf Names
bound' VName
needle Stms rep
stms'
where
invariant :: SubExp -> Bool
invariant Constant {} = Bool
True
invariant (Var VName
v) = VName
v VName -> Names -> Bool
`notNameIn` Names
bound
isArrayIn :: VName -> Param FParamMem -> Bool
isArrayIn :: VName -> Param FParamMem -> Bool
isArrayIn VName
x (Param Attrs
_ VName
_ (MemArray PrimType
_ ShapeBase SubExp
_ Uniqueness
_ (ArrayIn VName
y LMAD
_))) = VName
x VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
y
isArrayIn VName
_ Param FParamMem
_ = Bool
False
doubleBufferSpace :: Space -> Bool
doubleBufferSpace :: Space -> Bool
doubleBufferSpace ScalarSpace {} = Bool
False
doubleBufferSpace Space
_ = Bool
True
optimiseLoop :: (Constraints rep inner) => OptimiseLoop rep
optimiseLoop :: forall rep (inner :: * -> *).
Constraints rep inner =>
OptimiseLoop rep
optimiseLoop (Pat [PatElem (LetDec rep)]
pes) [(FParam rep, SubExp)]
merge body :: Body rep
body@(Body BodyDec rep
_ Stms rep
body_stms Result
body_res) = do
((pat', merge', body'), outer_stms) <- Builder rep (Pat LetDecMem, [(Param FParamMem, SubExp)], Body rep)
-> DoubleBufferM
rep
((Pat LetDecMem, [(Param FParamMem, SubExp)], Body rep), Stms rep)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder rep (Pat LetDecMem, [(Param FParamMem, SubExp)], Body rep)
-> DoubleBufferM
rep
((Pat LetDecMem, [(Param FParamMem, SubExp)], Body rep), Stms rep))
-> Builder
rep (Pat LetDecMem, [(Param FParamMem, SubExp)], Body rep)
-> DoubleBufferM
rep
((Pat LetDecMem, [(Param FParamMem, SubExp)], Body rep), Stms rep)
forall a b. (a -> b) -> a -> b
$ do
((param_changes, body_stms'), (pes', merge', body_res')) <-
([([PatElem LetDecMem], [(Param FParamMem, SubExp)], Result)]
-> ([[PatElem LetDecMem]], [[(Param FParamMem, SubExp)]],
[Result]))
-> (((Param FParamMem, SubExp) -> (Param FParamMem, SubExp),
Stms rep),
[([PatElem LetDecMem], [(Param FParamMem, SubExp)], Result)])
-> (((Param FParamMem, SubExp) -> (Param FParamMem, SubExp),
Stms rep),
([[PatElem LetDecMem]], [[(Param FParamMem, SubExp)]], [Result]))
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second [([PatElem LetDecMem], [(Param FParamMem, SubExp)], Result)]
-> ([[PatElem LetDecMem]], [[(Param FParamMem, SubExp)]], [Result])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ((((Param FParamMem, SubExp) -> (Param FParamMem, SubExp),
Stms rep),
[([PatElem LetDecMem], [(Param FParamMem, SubExp)], Result)])
-> (((Param FParamMem, SubExp) -> (Param FParamMem, SubExp),
Stms rep),
([[PatElem LetDecMem]], [[(Param FParamMem, SubExp)]], [Result])))
-> BuilderT
rep
(State VNameSource)
(((Param FParamMem, SubExp) -> (Param FParamMem, SubExp),
Stms rep),
[([PatElem LetDecMem], [(Param FParamMem, SubExp)], Result)])
-> BuilderT
rep
(State VNameSource)
(((Param FParamMem, SubExp) -> (Param FParamMem, SubExp),
Stms rep),
([[PatElem LetDecMem]], [[(Param FParamMem, SubExp)]], [Result]))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (((Param FParamMem, SubExp) -> (Param FParamMem, SubExp), Stms rep)
-> (PatElem LetDecMem, (Param FParamMem, SubExp), SubExpRes)
-> BuilderT
rep
(State VNameSource)
(((Param FParamMem, SubExp) -> (Param FParamMem, SubExp),
Stms rep),
([PatElem LetDecMem], [(Param FParamMem, SubExp)], Result)))
-> ((Param FParamMem, SubExp) -> (Param FParamMem, SubExp),
Stms rep)
-> [(PatElem LetDecMem, (Param FParamMem, SubExp), SubExpRes)]
-> BuilderT
rep
(State VNameSource)
(((Param FParamMem, SubExp) -> (Param FParamMem, SubExp),
Stms rep),
[([PatElem LetDecMem], [(Param FParamMem, SubExp)], Result)])
forall (m :: * -> *) (t :: * -> *) acc x y.
(Monad m, Traversable t) =>
(acc -> x -> m (acc, y)) -> acc -> t x -> m (acc, t y)
mapAccumLM ((Param FParamMem, SubExp) -> (Param FParamMem, SubExp), Stms rep)
-> (PatElem LetDecMem, (Param FParamMem, SubExp), SubExpRes)
-> BuilderT
rep
(State VNameSource)
(((Param FParamMem, SubExp) -> (Param FParamMem, SubExp),
Stms rep),
([PatElem LetDecMem], [(Param FParamMem, SubExp)], Result))
check ((Param FParamMem, SubExp) -> (Param FParamMem, SubExp)
forall a. a -> a
id, Stms rep
body_stms) ([PatElem LetDecMem]
-> [(Param FParamMem, SubExp)]
-> Result
-> [(PatElem LetDecMem, (Param FParamMem, SubExp), SubExpRes)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElem (LetDec rep)]
[PatElem LetDecMem]
pes [(FParam rep, SubExp)]
[(Param FParamMem, SubExp)]
merge Result
body_res)
pure
( Pat $ mconcat pes',
map param_changes $ mconcat merge',
Body () body_stms' $ mconcat body_res'
)
pure (outer_stms, pat', merge', body')
where
bound_in_loop :: Names
bound_in_loop =
[VName] -> Names
namesFromList (((Param FParamMem, SubExp) -> VName)
-> [(Param FParamMem, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param FParamMem -> VName
forall dec. Param dec -> VName
paramName (Param FParamMem -> VName)
-> ((Param FParamMem, SubExp) -> Param FParamMem)
-> (Param FParamMem, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst) [(FParam rep, SubExp)]
[(Param FParamMem, SubExp)]
merge) Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Body rep -> Names
forall rep. Body rep -> Names
boundInBody Body rep
body
findLmadOfArray :: VName -> Maybe LMAD
findLmadOfArray VName
v = [LMAD] -> Maybe LMAD
forall a. [a] -> Maybe a
listToMaybe ([LMAD] -> Maybe LMAD)
-> ([Stm rep] -> [LMAD]) -> [Stm rep] -> Maybe LMAD
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm rep -> Maybe LMAD) -> [Stm rep] -> [LMAD]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Stm rep -> Maybe LMAD
onStm ([Stm rep] -> Maybe LMAD) -> [Stm rep] -> Maybe LMAD
forall a b. (a -> b) -> a -> b
$ Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms rep
body_stms
where
onStm :: Stm rep -> Maybe LMAD
onStm = [LMAD] -> Maybe LMAD
forall a. [a] -> Maybe a
listToMaybe ([LMAD] -> Maybe LMAD)
-> (Stm rep -> [LMAD]) -> Stm rep -> Maybe LMAD
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElem LetDecMem -> Maybe LMAD) -> [PatElem LetDecMem] -> [LMAD]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe PatElem LetDecMem -> Maybe LMAD
onPatElem ([PatElem LetDecMem] -> [LMAD])
-> (Stm rep -> [PatElem LetDecMem]) -> Stm rep -> [LMAD]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat LetDecMem -> [PatElem LetDecMem]
forall dec. Pat dec -> [PatElem dec]
patElems (Pat LetDecMem -> [PatElem LetDecMem])
-> (Stm rep -> Pat LetDecMem) -> Stm rep -> [PatElem LetDecMem]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> Pat (LetDec rep)
Stm rep -> Pat LetDecMem
forall rep. Stm rep -> Pat (LetDec rep)
stmPat
onPatElem :: PatElem LetDecMem -> Maybe LMAD
onPatElem (PatElem VName
pe_v (MemArray PrimType
_ ShapeBase SubExp
_ NoUniqueness
_ (ArrayIn VName
_ LMAD
lmad)))
| VName
v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
pe_v,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Names
bound_in_loop Names -> Names -> Bool
`namesIntersect` LMAD -> Names
forall a. FreeIn a => a -> Names
freeIn LMAD
lmad =
LMAD -> Maybe LMAD
forall a. a -> Maybe a
Just LMAD
lmad
onPatElem PatElem LetDecMem
_ = Maybe LMAD
forall a. Maybe a
Nothing
changeParam :: a -> (a, b) -> (a, b) -> (a, b)
changeParam a
p_needle (a, b)
new (a
p, b
p_initial) =
if a
p a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
p_needle then (a, b)
new else (a
p, b
p_initial)
check :: ((Param FParamMem, SubExp) -> (Param FParamMem, SubExp), Stms rep)
-> (PatElem LetDecMem, (Param FParamMem, SubExp), SubExpRes)
-> BuilderT
rep
(State VNameSource)
(((Param FParamMem, SubExp) -> (Param FParamMem, SubExp),
Stms rep),
([PatElem LetDecMem], [(Param FParamMem, SubExp)], Result))
check ((Param FParamMem, SubExp) -> (Param FParamMem, SubExp)
param_changes, Stms rep
body_stms') (PatElem LetDecMem
pe, (Param FParamMem
param, SubExp
arg), SubExpRes
res)
| Mem Space
space <- Param FParamMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param FParamMem
param,
Space -> Bool
doubleBufferSpace Space
space,
Var VName
arg_v <- SubExp
arg,
[((Param FParamMem
arr_param, Var VName
arr_param_initial), Var VName
arr_v)] <-
(((Param FParamMem, SubExp), SubExp) -> Bool)
-> [((Param FParamMem, SubExp), SubExp)]
-> [((Param FParamMem, SubExp), SubExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter
(VName -> Param FParamMem -> Bool
isArrayIn (Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
param) (Param FParamMem -> Bool)
-> (((Param FParamMem, SubExp), SubExp) -> Param FParamMem)
-> ((Param FParamMem, SubExp), SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst ((Param FParamMem, SubExp) -> Param FParamMem)
-> (((Param FParamMem, SubExp), SubExp)
-> (Param FParamMem, SubExp))
-> ((Param FParamMem, SubExp), SubExp)
-> Param FParamMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Param FParamMem, SubExp), SubExp) -> (Param FParamMem, SubExp)
forall a b. (a, b) -> a
fst)
([(Param FParamMem, SubExp)]
-> [SubExp] -> [((Param FParamMem, SubExp), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(FParam rep, SubExp)]
[(Param FParamMem, SubExp)]
merge ([SubExp] -> [((Param FParamMem, SubExp), SubExp)])
-> [SubExp] -> [((Param FParamMem, SubExp), SubExp)]
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
body_res),
MemArray PrimType
pt ShapeBase SubExp
shape Uniqueness
_ (ArrayIn VName
_ LMAD
param_lmad) <- Param FParamMem -> FParamMem
forall dec. Param dec -> dec
paramDec Param FParamMem
arr_param,
Var VName
arr_mem_out <- SubExpRes -> SubExp
resSubExp SubExpRes
res,
Just LMAD
arr_lmad <- VName -> Maybe LMAD
findLmadOfArray VName
arr_v,
Just (Stm rep
arr_mem_out_alloc, Stms rep
body_stms'') <-
Names -> VName -> Stms rep -> Maybe (Stm rep, Stms rep)
forall rep (inner :: * -> *).
Constraints rep inner =>
Names -> VName -> Stms rep -> Maybe (Stm rep, Stms rep)
extractAllocOf Names
bound_in_loop VName
arr_mem_out Stms rep
body_stms' = do
num_bytes <-
String
-> Exp (Rep (BuilderT rep (State VNameSource)))
-> BuilderT rep (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"num_bytes" (Exp rep -> BuilderT rep (State VNameSource) SubExp)
-> BuilderT rep (State VNameSource) (Exp rep)
-> BuilderT rep (State VNameSource) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
rep
(State VNameSource)
(Exp (Rep (BuilderT rep (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (PrimType -> TPrimExp Int64 VName
forall a. Num a => PrimType -> a
primByteSize PrimType
pt TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* (TPrimExp Int64 VName
1 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ LMAD -> TPrimExp Int64 VName
forall num.
Pretty num =>
LMAD (TPrimExp Int64 num) -> TPrimExp Int64 num
LMAD.range LMAD
arr_lmad))
arr_mem_in <-
letExp (baseString arg_v <> "_in") $ Op $ Alloc num_bytes space
addStm arr_mem_out_alloc
pe_unused <-
PatElem
<$> newVName (baseString (patElemName pe) <> "_unused")
<*> pure (MemMem space)
param_out <-
newParam (baseString (paramName param) <> "_out") (MemMem space)
arr_v_copy <- newVName $ baseString arr_v <> "_db_copy"
let arr_initial_info =
PrimType
-> ShapeBase SubExp -> NoUniqueness -> MemBind -> LetDecMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase SubExp
shape NoUniqueness
NoUniqueness (MemBind -> LetDecMem) -> MemBind -> LetDecMem
forall a b. (a -> b) -> a -> b
$ VName -> LMAD -> MemBind
ArrayIn VName
arr_mem_in LMAD
arr_lmad
arr_initial_pe =
VName -> LetDecMem -> PatElem LetDecMem
forall dec. VName -> dec -> PatElem dec
PatElem VName
arr_v_copy LetDecMem
arr_initial_info
addStm . Let (Pat [arr_initial_pe]) (defAux ()) . BasicOp $
Replicate mempty (Var arr_param_initial)
let arr_param' =
Attrs -> VName -> FParamMem -> Param FParamMem
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty (Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
arr_param) (FParamMem -> Param FParamMem) -> FParamMem -> Param FParamMem
forall a b. (a -> b) -> a -> b
$
PrimType -> ShapeBase SubExp -> Uniqueness -> MemBind -> FParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase SubExp
shape Uniqueness
Unique (VName -> LMAD -> MemBind
ArrayIn (Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
param) LMAD
param_lmad)
let mkUpdate VName
lmad_v =
case (((Param FParamMem, SubExp), SubExpRes) -> Bool)
-> [((Param FParamMem, SubExp), SubExpRes)]
-> Maybe ((Param FParamMem, SubExp), SubExpRes)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
L.find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
lmad_v) (VName -> Bool)
-> (((Param FParamMem, SubExp), SubExpRes) -> VName)
-> ((Param FParamMem, SubExp), SubExpRes)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param FParamMem -> VName
forall dec. Param dec -> VName
paramName (Param FParamMem -> VName)
-> (((Param FParamMem, SubExp), SubExpRes) -> Param FParamMem)
-> ((Param FParamMem, SubExp), SubExpRes)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst ((Param FParamMem, SubExp) -> Param FParamMem)
-> (((Param FParamMem, SubExp), SubExpRes)
-> (Param FParamMem, SubExp))
-> ((Param FParamMem, SubExp), SubExpRes)
-> Param FParamMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Param FParamMem, SubExp), SubExpRes) -> (Param FParamMem, SubExp)
forall a b. (a, b) -> a
fst) ([((Param FParamMem, SubExp), SubExpRes)]
-> Maybe ((Param FParamMem, SubExp), SubExpRes))
-> [((Param FParamMem, SubExp), SubExpRes)]
-> Maybe ((Param FParamMem, SubExp), SubExpRes)
forall a b. (a -> b) -> a -> b
$
[(Param FParamMem, SubExp)]
-> Result -> [((Param FParamMem, SubExp), SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(FParam rep, SubExp)]
[(Param FParamMem, SubExp)]
merge Result
body_res of
Maybe ((Param FParamMem, SubExp), SubExpRes)
Nothing -> (Param FParamMem, SubExp) -> (Param FParamMem, SubExp)
forall a. a -> a
id
Just ((Param FParamMem
p, SubExp
_), SubExpRes
p_res) -> Param FParamMem
-> (Param FParamMem, SubExp)
-> (Param FParamMem, SubExp)
-> (Param FParamMem, SubExp)
forall {a} {b}. Eq a => a -> (a, b) -> (a, b) -> (a, b)
changeParam Param FParamMem
p (Param FParamMem
p, SubExpRes -> SubExp
resSubExp SubExpRes
p_res)
updateLmadParam =
(((Param FParamMem, SubExp) -> (Param FParamMem, SubExp))
-> ((Param FParamMem, SubExp) -> (Param FParamMem, SubExp))
-> (Param FParamMem, SubExp)
-> (Param FParamMem, SubExp))
-> ((Param FParamMem, SubExp) -> (Param FParamMem, SubExp))
-> [(Param FParamMem, SubExp) -> (Param FParamMem, SubExp)]
-> (Param FParamMem, SubExp)
-> (Param FParamMem, SubExp)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ((Param FParamMem, SubExp) -> (Param FParamMem, SubExp))
-> ((Param FParamMem, SubExp) -> (Param FParamMem, SubExp))
-> (Param FParamMem, SubExp)
-> (Param FParamMem, SubExp)
forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) (Param FParamMem, SubExp) -> (Param FParamMem, SubExp)
forall a. a -> a
id ([(Param FParamMem, SubExp) -> (Param FParamMem, SubExp)]
-> (Param FParamMem, SubExp) -> (Param FParamMem, SubExp))
-> [(Param FParamMem, SubExp) -> (Param FParamMem, SubExp)]
-> (Param FParamMem, SubExp)
-> (Param FParamMem, SubExp)
forall a b. (a -> b) -> a -> b
$ (VName -> (Param FParamMem, SubExp) -> (Param FParamMem, SubExp))
-> [VName]
-> [(Param FParamMem, SubExp) -> (Param FParamMem, SubExp)]
forall a b. (a -> b) -> [a] -> [b]
map VName -> (Param FParamMem, SubExp) -> (Param FParamMem, SubExp)
mkUpdate ([VName]
-> [(Param FParamMem, SubExp) -> (Param FParamMem, SubExp)])
-> [VName]
-> [(Param FParamMem, SubExp) -> (Param FParamMem, SubExp)]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ LMAD -> Names
forall a. FreeIn a => a -> Names
freeIn LMAD
param_lmad
pure
( ( updateLmadParam
. changeParam arr_param (arr_param', Var arr_v_copy)
. param_changes,
substituteNames (M.singleton arr_mem_out (paramName param_out)) body_stms''
),
( [pe, pe_unused],
[(param, Var arr_mem_in), (param_out, Var arr_mem_out)],
[ res {resSubExp = Var $ paramName param_out},
subExpRes $ Var $ paramName param
]
)
)
| Bool
otherwise =
(((Param FParamMem, SubExp) -> (Param FParamMem, SubExp),
Stms rep),
([PatElem LetDecMem], [(Param FParamMem, SubExp)], Result))
-> BuilderT
rep
(State VNameSource)
(((Param FParamMem, SubExp) -> (Param FParamMem, SubExp),
Stms rep),
([PatElem LetDecMem], [(Param FParamMem, SubExp)], Result))
forall a. a -> BuilderT rep (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( ((Param FParamMem, SubExp) -> (Param FParamMem, SubExp)
param_changes, Stms rep
body_stms'),
([PatElem LetDecMem
pe], [(Param FParamMem
param, SubExp
arg)], [SubExpRes
res])
)
doubleBuffer :: (Mem rep inner) => String -> String -> OptimiseOp rep -> Pass rep rep
doubleBuffer :: forall rep (inner :: * -> *).
Mem rep inner =>
String -> String -> OptimiseOp rep -> Pass rep rep
doubleBuffer String
name String
desc OptimiseOp rep
onOp =
Pass
{ passName :: String
passName = String
name,
passDescription :: String
passDescription = String
desc,
passFunction :: Prog rep -> PassM (Prog rep)
passFunction = (Scope rep -> Stms rep -> PassM (Stms rep))
-> Prog rep -> PassM (Prog rep)
forall rep.
(Scope rep -> Stms rep -> PassM (Stms rep))
-> Prog rep -> PassM (Prog rep)
intraproceduralTransformation Scope rep -> Stms rep -> PassM (Stms rep)
optimise
}
where
optimise :: Scope rep -> Stms rep -> PassM (Stms rep)
optimise Scope rep
scope Stms rep
stms = (VNameSource -> (Stms rep, VNameSource)) -> PassM (Stms rep)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Stms rep, VNameSource)) -> PassM (Stms rep))
-> (VNameSource -> (Stms rep, VNameSource)) -> PassM (Stms rep)
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
let m :: ReaderT (Env rep) (State VNameSource) (Stms rep)
m =
DoubleBufferM rep (Stms rep)
-> ReaderT (Env rep) (State VNameSource) (Stms rep)
forall rep a.
DoubleBufferM rep a -> ReaderT (Env rep) (State VNameSource) a
runDoubleBufferM (DoubleBufferM rep (Stms rep)
-> ReaderT (Env rep) (State VNameSource) (Stms rep))
-> DoubleBufferM rep (Stms rep)
-> ReaderT (Env rep) (State VNameSource) (Stms rep)
forall a b. (a -> b) -> a -> b
$ Scope rep
-> DoubleBufferM rep (Stms rep) -> DoubleBufferM rep (Stms rep)
forall a. Scope rep -> DoubleBufferM rep a -> DoubleBufferM rep a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope rep
scope (DoubleBufferM rep (Stms rep) -> DoubleBufferM rep (Stms rep))
-> DoubleBufferM rep (Stms rep) -> DoubleBufferM rep (Stms rep)
forall a b. (a -> b) -> a -> b
$ [Stm rep] -> DoubleBufferM rep (Stms rep)
forall rep. ASTRep rep => [Stm rep] -> DoubleBufferM rep (Stms rep)
optimiseStms ([Stm rep] -> DoubleBufferM rep (Stms rep))
-> [Stm rep] -> DoubleBufferM rep (Stms rep)
forall a b. (a -> b) -> a -> b
$ Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms rep
stms
in State VNameSource (Stms rep)
-> VNameSource -> (Stms rep, VNameSource)
forall s a. State s a -> s -> (a, s)
runState (ReaderT (Env rep) (State VNameSource) (Stms rep)
-> Env rep -> State VNameSource (Stms rep)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Env rep) (State VNameSource) (Stms rep)
m Env rep
env) VNameSource
src
env :: Env rep
env = Scope rep -> OptimiseLoop rep -> OptimiseOp rep -> Env rep
forall rep.
Scope rep -> OptimiseLoop rep -> OptimiseOp rep -> Env rep
Env Scope rep
forall a. Monoid a => a
mempty OptimiseLoop rep
Pat (LetDec rep)
-> [(Param FParamMem, SubExp)]
-> Body rep
-> DoubleBufferM
rep
(Stms rep, Pat (LetDec rep), [(Param FParamMem, SubExp)], Body rep)
forall {f :: * -> *} {a} {b} {c} {d}.
(Applicative f, Monoid a) =>
b -> c -> d -> f (a, b, c, d)
doNotTouchLoop OptimiseOp rep
onOp
doNotTouchLoop :: b -> c -> d -> f (a, b, c, d)
doNotTouchLoop b
pat c
merge d
body = (a, b, c, d) -> f (a, b, c, d)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
forall a. Monoid a => a
mempty, b
pat, c
merge, d
body)
doubleBufferGPU :: Pass GPUMem GPUMem
doubleBufferGPU :: Pass GPUMem GPUMem
doubleBufferGPU =
String -> String -> OptimiseOp GPUMem -> Pass GPUMem GPUMem
forall rep (inner :: * -> *).
Mem rep inner =>
String -> String -> OptimiseOp rep -> Pass rep rep
doubleBuffer
String
"Double buffer GPU"
String
"Double buffer memory in sequential loops (GPU rep)."
OptimiseOp GPUMem
optimiseGPUOp
doubleBufferMC :: Pass MCMem MCMem
doubleBufferMC :: Pass MCMem MCMem
doubleBufferMC =
String -> String -> OptimiseOp MCMem -> Pass MCMem MCMem
forall rep (inner :: * -> *).
Mem rep inner =>
String -> String -> OptimiseOp rep -> Pass rep rep
doubleBuffer
String
"Double buffer MC"
String
"Double buffer memory in sequential loops (MC rep)."
OptimiseOp MCMem
optimiseMCOp