{-# LANGUAGE TypeFamilies #-}
module Futhark.Pass.ExpandAllocations (expandAllocations) where
import Control.Monad
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
import Data.Bifunctor
import Data.Either (rights)
import Data.List (find, foldl')
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Sequence qualified as Seq
import Futhark.Analysis.Alias as Alias
import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.Error
import Futhark.IR
import Futhark.IR.GPU.Simplify qualified as GPU
import Futhark.IR.GPUMem
import Futhark.IR.Mem.LMAD qualified as LMAD
import Futhark.MonadFreshNames
import Futhark.Optimise.Simplify.Rep (addScopeWisdom)
import Futhark.Pass
import Futhark.Pass.ExplicitAllocations.GPU (explicitAllocationsInStms)
import Futhark.Pass.ExtractKernels.BlockedKernel (nonSegRed)
import Futhark.Pass.ExtractKernels.ToGPU (segThread)
import Futhark.Tools
import Futhark.Transform.CopyPropagate (copyPropagateInFun)
import Futhark.Transform.Rename (renameStm)
import Futhark.Transform.Substitute
import Futhark.Util (mapAccumLM)
import Prelude hiding (quot)
expandAllocations :: Pass GPUMem GPUMem
expandAllocations :: Pass GPUMem GPUMem
expandAllocations =
String
-> String
-> (Prog GPUMem -> PassM (Prog GPUMem))
-> Pass GPUMem GPUMem
forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass String
"expand allocations" String
"Expand allocations" ((Prog GPUMem -> PassM (Prog GPUMem)) -> Pass GPUMem GPUMem)
-> (Prog GPUMem -> PassM (Prog GPUMem)) -> Pass GPUMem GPUMem
forall a b. (a -> b) -> a -> b
$
\Prog GPUMem
prog -> do
consts' <-
(VNameSource -> (Stms GPUMem, VNameSource)) -> PassM (Stms GPUMem)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Stms GPUMem, VNameSource))
-> PassM (Stms GPUMem))
-> (VNameSource -> (Stms GPUMem, VNameSource))
-> PassM (Stms GPUMem)
forall a b. (a -> b) -> a -> b
$
Either String (Stms GPUMem, VNameSource)
-> (Stms GPUMem, VNameSource)
forall a. Either String a -> a
limitationOnLeft
(Either String (Stms GPUMem, VNameSource)
-> (Stms GPUMem, VNameSource))
-> (VNameSource -> Either String (Stms GPUMem, VNameSource))
-> VNameSource
-> (Stms GPUMem, VNameSource)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StateT VNameSource (Either String) (Stms GPUMem)
-> VNameSource -> Either String (Stms GPUMem, VNameSource)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
-> Scope GPUMem -> StateT VNameSource (Either String) (Stms GPUMem)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Stms GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
transformStms (Prog GPUMem -> Stms GPUMem
forall rep. Prog rep -> Stms rep
progConsts Prog GPUMem
prog)) Scope GPUMem
forall a. Monoid a => a
mempty)
funs' <- mapM (transformFunDef $ scopeOf consts') (progFuns prog)
pure $ prog {progConsts = consts', progFuns = funs'}
type ExpandM = ReaderT (Scope GPUMem) (StateT VNameSource (Either String))
limitationOnLeft :: Either String a -> a
limitationOnLeft :: forall a. Either String a -> a
limitationOnLeft = (String -> a) -> (a -> a) -> Either String a -> a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either String -> a
forall a. String -> a
compilerLimitationS a -> a
forall a. a -> a
id
transformFunDef ::
Scope GPUMem ->
FunDef GPUMem ->
PassM (FunDef GPUMem)
transformFunDef :: Scope GPUMem -> FunDef GPUMem -> PassM (FunDef GPUMem)
transformFunDef Scope GPUMem
scope FunDef GPUMem
fundec = do
body' <- (VNameSource -> (Body GPUMem, VNameSource)) -> PassM (Body GPUMem)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Body GPUMem, VNameSource))
-> PassM (Body GPUMem))
-> (VNameSource -> (Body GPUMem, VNameSource))
-> PassM (Body GPUMem)
forall a b. (a -> b) -> a -> b
$ Either String (Body GPUMem, VNameSource)
-> (Body GPUMem, VNameSource)
forall a. Either String a -> a
limitationOnLeft (Either String (Body GPUMem, VNameSource)
-> (Body GPUMem, VNameSource))
-> (VNameSource -> Either String (Body GPUMem, VNameSource))
-> VNameSource
-> (Body GPUMem, VNameSource)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StateT VNameSource (Either String) (Body GPUMem)
-> VNameSource -> Either String (Body GPUMem, VNameSource)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
-> Scope GPUMem -> StateT VNameSource (Either String) (Body GPUMem)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
m Scope GPUMem
forall a. Monoid a => a
mempty)
copyPropagateInFun
simpleGPUMem
(ST.fromScope (addScopeWisdom scope))
fundec {funDefBody = body'}
where
m :: ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
m =
Scope GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
forall a.
Scope GPUMem
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPUMem
scope (ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem))
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
forall a b. (a -> b) -> a -> b
$
FunDef GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf FunDef GPUMem
fundec (ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem))
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
forall a b. (a -> b) -> a -> b
$
Body GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
transformBody (Body GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem))
-> Body GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
forall a b. (a -> b) -> a -> b
$
FunDef GPUMem -> Body GPUMem
forall rep. FunDef rep -> Body rep
funDefBody FunDef GPUMem
fundec
transformBody :: Body GPUMem -> ExpandM (Body GPUMem)
transformBody :: Body GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
transformBody (Body () Stms GPUMem
stms Result
res) = BodyDec GPUMem -> Stms GPUMem -> Result -> Body GPUMem
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () (Stms GPUMem -> Result -> Body GPUMem)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either String))
(Result -> Body GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stms GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
transformStms Stms GPUMem
stms ReaderT
(Scope GPUMem)
(StateT VNameSource (Either String))
(Result -> Body GPUMem)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) Result
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
forall a b.
ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (a -> b)
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) Result
forall a.
a -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res
transformLambda :: Lambda GPUMem -> ExpandM (Lambda GPUMem)
transformLambda :: Lambda GPUMem -> ExpandM (Lambda GPUMem)
transformLambda (Lambda [LParam GPUMem]
params [TypeBase (ShapeBase SubExp) NoUniqueness]
ret Body GPUMem
body) =
[LParam GPUMem]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Body GPUMem
-> Lambda GPUMem
forall rep.
[LParam rep]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Body rep
-> Lambda rep
Lambda [LParam GPUMem]
params [TypeBase (ShapeBase SubExp) NoUniqueness]
ret
(Body GPUMem -> Lambda GPUMem)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
-> ExpandM (Lambda GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Scope GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
forall a.
Scope GPUMem
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([LParam GPUMem] -> Scope GPUMem
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [LParam GPUMem]
params) (Body GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
transformBody Body GPUMem
body)
transformStms :: Stms GPUMem -> ExpandM (Stms GPUMem)
transformStms :: Stms GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
transformStms Stms GPUMem
stms =
Stms GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms GPUMem
stms (ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem))
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall a b. (a -> b) -> a -> b
$ [Stms GPUMem] -> Stms GPUMem
forall a. Monoid a => [a] -> a
mconcat ([Stms GPUMem] -> Stms GPUMem)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) [Stms GPUMem]
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem))
-> [Stm GPUMem]
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) [Stms GPUMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Stm GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
transformStm (Stms GPUMem -> [Stm GPUMem]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms GPUMem
stms)
transformStm :: Stm GPUMem -> ExpandM (Stms GPUMem)
transformStm :: Stm GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
transformStm (Let Pat (LetDec GPUMem)
pat StmAux (ExpDec GPUMem)
aux (Match [SubExp]
cond [Case (Body GPUMem)]
cases Body GPUMem
defbody (MatchDec [BranchType GPUMem]
ts MatchSort
MatchEquiv))) = do
let onCase :: Case (Body GPUMem)
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either String))
(Either String (Case (Body GPUMem)))
onCase (Case [Maybe PrimValue]
vs Body GPUMem
body) =
(Case (Body GPUMem) -> Either String (Case (Body GPUMem))
forall a b. b -> Either a b
Right (Case (Body GPUMem) -> Either String (Case (Body GPUMem)))
-> (Body GPUMem -> Case (Body GPUMem))
-> Body GPUMem
-> Either String (Case (Body GPUMem))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe PrimValue] -> Body GPUMem -> Case (Body GPUMem)
forall body. [Maybe PrimValue] -> body -> Case body
Case [Maybe PrimValue]
vs (Body GPUMem -> Either String (Case (Body GPUMem)))
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either String))
(Either String (Case (Body GPUMem)))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
transformBody Body GPUMem
body) ReaderT
(Scope GPUMem)
(StateT VNameSource (Either String))
(Either String (Case (Body GPUMem)))
-> (String
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either String))
(Either String (Case (Body GPUMem))))
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either String))
(Either String (Case (Body GPUMem)))
forall a.
ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
-> (String
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a)
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
`catchError` (Either String (Case (Body GPUMem))
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either String))
(Either String (Case (Body GPUMem)))
forall a.
a -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either String (Case (Body GPUMem))
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either String))
(Either String (Case (Body GPUMem))))
-> (String -> Either String (Case (Body GPUMem)))
-> String
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either String))
(Either String (Case (Body GPUMem)))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Either String (Case (Body GPUMem))
forall a b. a -> Either a b
Left)
cases' <- [Either String (Case (Body GPUMem))] -> [Case (Body GPUMem)]
forall a b. [Either a b] -> [b]
rights ([Either String (Case (Body GPUMem))] -> [Case (Body GPUMem)])
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either String))
[Either String (Case (Body GPUMem))]
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either String))
[Case (Body GPUMem)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Case (Body GPUMem)
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either String))
(Either String (Case (Body GPUMem))))
-> [Case (Body GPUMem)]
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either String))
[Either String (Case (Body GPUMem))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Case (Body GPUMem)
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either String))
(Either String (Case (Body GPUMem)))
onCase [Case (Body GPUMem)]
cases
defbody' <- (Right <$> transformBody defbody) `catchError` (pure . Left)
case (cases', defbody') of
([], Left String
e) ->
String
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall a.
String
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError String
e
(Case (Body GPUMem)
_ : [Case (Body GPUMem)]
_, Left String
_) ->
Stms GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall a.
a -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem))
-> Stms GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall a b. (a -> b) -> a -> b
$ Stm GPUMem -> Stms GPUMem
forall rep. Stm rep -> Stms rep
oneStm (Stm GPUMem -> Stms GPUMem) -> Stm GPUMem -> Stms GPUMem
forall a b. (a -> b) -> a -> b
$ Pat (LetDec GPUMem)
-> StmAux (ExpDec GPUMem) -> Exp GPUMem -> Stm GPUMem
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPUMem)
pat StmAux (ExpDec GPUMem)
aux (Exp GPUMem -> Stm GPUMem) -> Exp GPUMem -> Stm GPUMem
forall a b. (a -> b) -> a -> b
$ [SubExp]
-> [Case (Body GPUMem)]
-> Body GPUMem
-> MatchDec (BranchType GPUMem)
-> Exp GPUMem
forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond ([Case (Body GPUMem)] -> [Case (Body GPUMem)]
forall a. HasCallStack => [a] -> [a]
init [Case (Body GPUMem)]
cases') (Case (Body GPUMem) -> Body GPUMem
forall body. Case body -> body
caseBody (Case (Body GPUMem) -> Body GPUMem)
-> Case (Body GPUMem) -> Body GPUMem
forall a b. (a -> b) -> a -> b
$ [Case (Body GPUMem)] -> Case (Body GPUMem)
forall a. HasCallStack => [a] -> a
last [Case (Body GPUMem)]
cases') ([BranchTypeMem] -> MatchSort -> MatchDec BranchTypeMem
forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [BranchType GPUMem]
[BranchTypeMem]
ts MatchSort
MatchEquiv)
([Case (Body GPUMem)]
_, Right Body GPUMem
defbody'') ->
Stms GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall a.
a -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem))
-> Stms GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall a b. (a -> b) -> a -> b
$ Stm GPUMem -> Stms GPUMem
forall rep. Stm rep -> Stms rep
oneStm (Stm GPUMem -> Stms GPUMem) -> Stm GPUMem -> Stms GPUMem
forall a b. (a -> b) -> a -> b
$ Pat (LetDec GPUMem)
-> StmAux (ExpDec GPUMem) -> Exp GPUMem -> Stm GPUMem
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPUMem)
pat StmAux (ExpDec GPUMem)
aux (Exp GPUMem -> Stm GPUMem) -> Exp GPUMem -> Stm GPUMem
forall a b. (a -> b) -> a -> b
$ [SubExp]
-> [Case (Body GPUMem)]
-> Body GPUMem
-> MatchDec (BranchType GPUMem)
-> Exp GPUMem
forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond [Case (Body GPUMem)]
cases' Body GPUMem
defbody'' ([BranchTypeMem] -> MatchSort -> MatchDec BranchTypeMem
forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [BranchType GPUMem]
[BranchTypeMem]
ts MatchSort
MatchEquiv)
transformStm (Let Pat (LetDec GPUMem)
pat StmAux (ExpDec GPUMem)
aux Exp GPUMem
e) = do
(stms, e') <- Exp GPUMem -> ExpandM (Stms GPUMem, Exp GPUMem)
transformExp (Exp GPUMem -> ExpandM (Stms GPUMem, Exp GPUMem))
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Exp GPUMem)
-> ExpandM (Stms GPUMem, Exp GPUMem)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Mapper
GPUMem
GPUMem
(ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
-> Exp GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Exp GPUMem)
forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper
GPUMem
GPUMem
(ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
transform Exp GPUMem
e
pure $ stms <> oneStm (Let pat aux e')
where
transform :: Mapper
GPUMem
GPUMem
(ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
transform =
Mapper
GPUMem
GPUMem
(ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
forall rep (m :: * -> *). Monad m => Mapper rep rep m
identityMapper
{ mapOnBody = \Scope GPUMem
scope -> Scope GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
forall a.
Scope GPUMem
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPUMem
scope (ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem))
-> (Body GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem))
-> Body GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
transformBody
}
transformExp :: Exp GPUMem -> ExpandM (Stms GPUMem, Exp GPUMem)
transformExp :: Exp GPUMem -> ExpandM (Stms GPUMem, Exp GPUMem)
transformExp (Op (Inner (SegOp (SegMap SegLevel
lvl SegSpace
space [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody)))) = do
(alloc_stms, (lvl', _, kbody')) <- SegLevel
-> SegSpace
-> [Lambda GPUMem]
-> KernelBody GPUMem
-> ExpandM
(Stms GPUMem, (SegLevel, [Lambda GPUMem], KernelBody GPUMem))
transformScanRed SegLevel
lvl SegSpace
space [] KernelBody GPUMem
kbody
pure
( alloc_stms,
Op $ Inner $ SegOp $ SegMap lvl' space ts kbody'
)
transformExp (Op (Inner (SegOp (SegRed SegLevel
lvl SegSpace
space [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody [SegBinOp GPUMem]
reds)))) = do
(alloc_stms, (lvl', lams, kbody')) <-
SegLevel
-> SegSpace
-> [Lambda GPUMem]
-> KernelBody GPUMem
-> ExpandM
(Stms GPUMem, (SegLevel, [Lambda GPUMem], KernelBody GPUMem))
transformScanRed SegLevel
lvl SegSpace
space ((SegBinOp GPUMem -> Lambda GPUMem)
-> [SegBinOp GPUMem] -> [Lambda GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda [SegBinOp GPUMem]
reds) KernelBody GPUMem
kbody
let reds' = (SegBinOp GPUMem -> Lambda GPUMem -> SegBinOp GPUMem)
-> [SegBinOp GPUMem] -> [Lambda GPUMem] -> [SegBinOp GPUMem]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\SegBinOp GPUMem
red Lambda GPUMem
lam -> SegBinOp GPUMem
red {segBinOpLambda = lam}) [SegBinOp GPUMem]
reds [Lambda GPUMem]
lams
pure
( alloc_stms,
Op $ Inner $ SegOp $ SegRed lvl' space ts kbody' reds'
)
transformExp (Op (Inner (SegOp (SegScan SegLevel
lvl SegSpace
space [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody [SegBinOp GPUMem]
scans)))) = do
(alloc_stms, (lvl', lams, kbody')) <-
SegLevel
-> SegSpace
-> [Lambda GPUMem]
-> KernelBody GPUMem
-> ExpandM
(Stms GPUMem, (SegLevel, [Lambda GPUMem], KernelBody GPUMem))
transformScanRed SegLevel
lvl SegSpace
space ((SegBinOp GPUMem -> Lambda GPUMem)
-> [SegBinOp GPUMem] -> [Lambda GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda [SegBinOp GPUMem]
scans) KernelBody GPUMem
kbody
let scans' = (SegBinOp GPUMem -> Lambda GPUMem -> SegBinOp GPUMem)
-> [SegBinOp GPUMem] -> [Lambda GPUMem] -> [SegBinOp GPUMem]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\SegBinOp GPUMem
red Lambda GPUMem
lam -> SegBinOp GPUMem
red {segBinOpLambda = lam}) [SegBinOp GPUMem]
scans [Lambda GPUMem]
lams
pure
( alloc_stms,
Op $ Inner $ SegOp $ SegScan lvl' space ts kbody' scans'
)
transformExp (Op (Inner (SegOp (SegHist SegLevel
lvl SegSpace
space [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody [HistOp GPUMem]
ops)))) = do
(alloc_stms, (lvl', lams', kbody')) <- SegLevel
-> SegSpace
-> [Lambda GPUMem]
-> KernelBody GPUMem
-> ExpandM
(Stms GPUMem, (SegLevel, [Lambda GPUMem], KernelBody GPUMem))
transformScanRed SegLevel
lvl SegSpace
space [Lambda GPUMem]
lams KernelBody GPUMem
kbody
let ops' = (HistOp GPUMem -> Lambda GPUMem -> HistOp GPUMem)
-> [HistOp GPUMem] -> [Lambda GPUMem] -> [HistOp GPUMem]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith HistOp GPUMem -> Lambda GPUMem -> HistOp GPUMem
forall {rep} {rep}. HistOp rep -> Lambda rep -> HistOp rep
onOp [HistOp GPUMem]
ops [Lambda GPUMem]
lams'
pure
( alloc_stms,
Op $ Inner $ SegOp $ SegHist lvl' space ts kbody' ops'
)
where
lams :: [Lambda GPUMem]
lams = (HistOp GPUMem -> Lambda GPUMem)
-> [HistOp GPUMem] -> [Lambda GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp [HistOp GPUMem]
ops
onOp :: HistOp rep -> Lambda rep -> HistOp rep
onOp HistOp rep
op Lambda rep
lam = HistOp rep
op {histOp = lam}
transformExp (WithAcc [WithAccInput GPUMem]
inputs Lambda GPUMem
lam) = do
lam' <- Lambda GPUMem -> ExpandM (Lambda GPUMem)
transformLambda Lambda GPUMem
lam
(input_alloc_stms, inputs') <- mapAndUnzipM onInput inputs
pure
( mconcat input_alloc_stms,
WithAcc inputs' lam'
)
where
onInput :: (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b))
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either String))
(Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
onInput (ShapeBase SubExp
shape, b
arrs, Maybe (Lambda GPUMem, b)
Nothing) =
(Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either String))
(Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
forall a.
a -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPUMem
forall a. Monoid a => a
mempty, (ShapeBase SubExp
shape, b
arrs, Maybe (Lambda GPUMem, b)
forall a. Maybe a
Nothing))
onInput (ShapeBase SubExp
shape, b
arrs, Just (Lambda GPUMem
op_lam, b
nes)) = do
bound_outside <- (Scope GPUMem -> Names)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) Names
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((Scope GPUMem -> Names)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) Names)
-> (Scope GPUMem -> Names)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) Names
forall a b. (a -> b) -> a -> b
$ [VName] -> Names
namesFromList ([VName] -> Names)
-> (Scope GPUMem -> [VName]) -> Scope GPUMem -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope GPUMem -> [VName]
forall k a. Map k a -> [k]
M.keys
let
lvl = SegVirt -> Maybe KernelGrid -> SegLevel
SegThread SegVirt
SegNoVirt Maybe KernelGrid
forall a. Maybe a
Nothing
(op_lam', lam_allocs) =
extractLambdaAllocations (lvl, [0]) bound_outside mempty op_lam
variantAlloc ((SegLevel, [TPrimExp Int64 VName])
_, Var VName
v, Space
_) = VName
v VName -> Names -> Bool
`notNameIn` Names
bound_outside
variantAlloc ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
_ = Bool
False
(variant_allocs, invariant_allocs) = M.partition variantAlloc lam_allocs
case M.elems variant_allocs of
((SegLevel, [TPrimExp Int64 VName])
_, SubExp
v, Space
_) : [((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)]
_ ->
String
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) ()
forall a.
String
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) ())
-> String
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) ()
forall a b. (a -> b) -> a -> b
$
String
"Cannot handle un-sliceable allocation size: "
String -> String -> String
forall a. [a] -> [a] -> [a]
++ SubExp -> String
forall a. Pretty a => a -> String
prettyString SubExp
v
String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"\nLikely cause: irregular nested operations inside accumulator update operator."
[] ->
()
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) ()
forall a.
a -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
let num_is = ShapeBase SubExp -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
shape
is = (Param (MemInfo SubExp NoUniqueness MemBind) -> VName)
-> [Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName ([Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName])
-> [Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName]
forall a b. (a -> b) -> a -> b
$ Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a. Int -> [a] -> [a]
take Int
num_is ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)])
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
op_lam
(alloc_stms, alloc_offsets) <-
genericExpandedInvariantAllocations (const $ const (shape, map le64 is)) invariant_allocs
scope <- askScope
let scope' = Lambda GPUMem -> Scope GPUMem
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Lambda GPUMem
op_lam Scope GPUMem -> Scope GPUMem -> Scope GPUMem
forall a. Semigroup a => a -> a -> a
<> Scope GPUMem
scope Scope GPUMem -> Scope GPUMem -> Scope GPUMem
forall a. Semigroup a => a -> a -> a
<> Stms GPUMem -> Scope GPUMem
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms GPUMem
alloc_stms
either throwError pure <=< runOffsetM scope' $ do
op_lam'' <- offsetMemoryInLambda alloc_offsets op_lam'
pure (alloc_stms, (shape, arrs, Just (op_lam'', nes)))
transformExp Exp GPUMem
e =
(Stms GPUMem, Exp GPUMem) -> ExpandM (Stms GPUMem, Exp GPUMem)
forall a.
a -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPUMem
forall a. Monoid a => a
mempty, Exp GPUMem
e)
ensureGridKnown :: SegLevel -> ExpandM (Stms GPUMem, SegLevel, KernelGrid)
ensureGridKnown :: SegLevel -> ExpandM (Stms GPUMem, SegLevel, KernelGrid)
ensureGridKnown SegLevel
lvl =
case SegLevel
lvl of
SegThread SegVirt
_ (Just KernelGrid
grid) -> (Stms GPUMem, SegLevel, KernelGrid)
-> ExpandM (Stms GPUMem, SegLevel, KernelGrid)
forall a.
a -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPUMem
forall a. Monoid a => a
mempty, SegLevel
lvl, KernelGrid
grid)
SegBlock SegVirt
_ (Just KernelGrid
grid) -> (Stms GPUMem, SegLevel, KernelGrid)
-> ExpandM (Stms GPUMem, SegLevel, KernelGrid)
forall a.
a -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPUMem
forall a. Monoid a => a
mempty, SegLevel
lvl, KernelGrid
grid)
SegThread SegVirt
virt Maybe KernelGrid
Nothing -> (Maybe KernelGrid -> SegLevel)
-> ExpandM (Stms GPUMem, SegLevel, KernelGrid)
forall {somerep} {rep} {op :: * -> *} {m :: * -> *} {b}.
(FParamInfo somerep ~ FParamInfo rep, OpC rep ~ MemOp (HostOp op),
LetDec somerep ~ LetDec rep, LParamInfo somerep ~ LParamInfo rep,
HasScope somerep m, MonadFreshNames m, BuilderOps rep, IsOp op,
RephraseOp op, Pretty (op rep), Rename (op rep),
Substitute (op rep), FreeIn (op rep), Show (op rep),
Ord (op rep)) =>
(Maybe KernelGrid -> b) -> m (Stms rep, b, KernelGrid)
mkGrid (SegVirt -> Maybe KernelGrid -> SegLevel
SegThread SegVirt
virt)
SegBlock SegVirt
virt Maybe KernelGrid
Nothing -> (Maybe KernelGrid -> SegLevel)
-> ExpandM (Stms GPUMem, SegLevel, KernelGrid)
forall {somerep} {rep} {op :: * -> *} {m :: * -> *} {b}.
(FParamInfo somerep ~ FParamInfo rep, OpC rep ~ MemOp (HostOp op),
LetDec somerep ~ LetDec rep, LParamInfo somerep ~ LParamInfo rep,
HasScope somerep m, MonadFreshNames m, BuilderOps rep, IsOp op,
RephraseOp op, Pretty (op rep), Rename (op rep),
Substitute (op rep), FreeIn (op rep), Show (op rep),
Ord (op rep)) =>
(Maybe KernelGrid -> b) -> m (Stms rep, b, KernelGrid)
mkGrid (SegVirt -> Maybe KernelGrid -> SegLevel
SegBlock SegVirt
virt)
SegThreadInBlock {} -> String -> ExpandM (Stms GPUMem, SegLevel, KernelGrid)
forall a. HasCallStack => String -> a
error String
"ensureGridKnown: SegThreadInBlock"
where
mkGrid :: (Maybe KernelGrid -> b) -> m (Stms rep, b, KernelGrid)
mkGrid Maybe KernelGrid -> b
f = do
(grid, stms) <-
Builder rep KernelGrid -> m (KernelGrid, Stms rep)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder rep KernelGrid -> m (KernelGrid, Stms rep))
-> Builder rep KernelGrid -> m (KernelGrid, Stms rep)
forall a b. (a -> b) -> a -> b
$
Count NumBlocks SubExp -> Count BlockSize SubExp -> KernelGrid
KernelGrid
(Count NumBlocks SubExp -> Count BlockSize SubExp -> KernelGrid)
-> BuilderT rep (State VNameSource) (Count NumBlocks SubExp)
-> BuilderT
rep (State VNameSource) (Count BlockSize SubExp -> KernelGrid)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> Count NumBlocks SubExp
forall {k} (u :: k) e. e -> Count u e
Count (SubExp -> Count NumBlocks SubExp)
-> BuilderT rep (State VNameSource) SubExp
-> BuilderT rep (State VNameSource) (Count NumBlocks SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> SizeClass -> BuilderT rep (State VNameSource) SubExp
forall {m :: * -> *} {op :: * -> *}.
(OpC (Rep m) ~ MemOp (HostOp op), MonadBuilder m,
Pretty (op (Rep m)), Rename (op (Rep m)), Substitute (op (Rep m)),
FreeIn (op (Rep m)), IsOp op, RephraseOp op, Show (op (Rep m)),
Ord (op (Rep m))) =>
String -> SizeClass -> m SubExp
getSize String
"num_tblocks" SizeClass
SizeGrid)
BuilderT
rep (State VNameSource) (Count BlockSize SubExp -> KernelGrid)
-> BuilderT rep (State VNameSource) (Count BlockSize SubExp)
-> Builder rep KernelGrid
forall a b.
BuilderT rep (State VNameSource) (a -> b)
-> BuilderT rep (State VNameSource) a
-> BuilderT rep (State VNameSource) b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> Count BlockSize SubExp
forall {k} (u :: k) e. e -> Count u e
Count (SubExp -> Count BlockSize SubExp)
-> BuilderT rep (State VNameSource) SubExp
-> BuilderT rep (State VNameSource) (Count BlockSize SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> SizeClass -> BuilderT rep (State VNameSource) SubExp
forall {m :: * -> *} {op :: * -> *}.
(OpC (Rep m) ~ MemOp (HostOp op), MonadBuilder m,
Pretty (op (Rep m)), Rename (op (Rep m)), Substitute (op (Rep m)),
FreeIn (op (Rep m)), IsOp op, RephraseOp op, Show (op (Rep m)),
Ord (op (Rep m))) =>
String -> SizeClass -> m SubExp
getSize String
"tblock_size" SizeClass
SizeThreadBlock)
pure (stms, f $ Just grid, grid)
getSize :: String -> SizeClass -> m SubExp
getSize String
desc SizeClass
size_class = do
size_key <- String -> Name
nameFromString (String -> Name) -> (VName -> String) -> VName -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> String
forall a. Pretty a => a -> String
prettyString (VName -> Name) -> m VName -> m Name
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
desc
letSubExp desc $ Op $ Inner $ SizeOp $ GetSize size_key size_class
transformScanRed ::
SegLevel ->
SegSpace ->
[Lambda GPUMem] ->
KernelBody GPUMem ->
ExpandM (Stms GPUMem, (SegLevel, [Lambda GPUMem], KernelBody GPUMem))
transformScanRed :: SegLevel
-> SegSpace
-> [Lambda GPUMem]
-> KernelBody GPUMem
-> ExpandM
(Stms GPUMem, (SegLevel, [Lambda GPUMem], KernelBody GPUMem))
transformScanRed SegLevel
lvl SegSpace
space [Lambda GPUMem]
ops KernelBody GPUMem
kbody = do
bound_outside <- (Scope GPUMem -> Names)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) Names
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((Scope GPUMem -> Names)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) Names)
-> (Scope GPUMem -> Names)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) Names
forall a b. (a -> b) -> a -> b
$ [VName] -> Names
namesFromList ([VName] -> Names)
-> (Scope GPUMem -> [VName]) -> Scope GPUMem -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope GPUMem -> [VName]
forall k a. Map k a -> [k]
M.keys
let user = (SegLevel
lvl, [VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (VName -> TPrimExp Int64 VName) -> VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space])
(kbody', kbody_allocs) =
extractKernelBodyAllocations user bound_outside bound_in_kernel kbody
(ops', ops_allocs) =
unzip $ map (extractLambdaAllocations user bound_outside mempty) ops
variantAlloc ((SegLevel, [TPrimExp Int64 VName])
_, Var VName
v, Space
_) = VName
v VName -> Names -> Bool
`notNameIn` Names
bound_outside
variantAlloc ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
_ = Bool
False
(variant_allocs, invariant_allocs) =
M.partition variantAlloc $ kbody_allocs <> mconcat ops_allocs
badVariant ((SegLevel, [TPrimExp Int64 VName])
_, Var VName
v, Space
_) = VName
v VName -> Names -> Bool
`notNameIn` Names
bound_in_kernel
badVariant ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
_ = Bool
False
case find badVariant $ M.elems variant_allocs of
Just ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
v ->
String
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) ()
forall a.
String
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) ())
-> String
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) ()
forall a b. (a -> b) -> a -> b
$
String
"Cannot handle un-sliceable allocation size: "
String -> String -> String
forall a. [a] -> [a] -> [a]
++ ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> String
forall a. Pretty a => a -> String
prettyString ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
v
String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"\nLikely cause: irregular nested operations inside parallel constructs."
Maybe ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
Nothing ->
()
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) ()
forall a.
a -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
case lvl of
SegBlock {}
| Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Extraction -> Bool
forall a. Map VName a -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Extraction
variant_allocs ->
String
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) ()
forall a.
String
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError String
"Cannot handle invariant allocations in SegBlock."
SegLevel
_ ->
()
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) ()
forall a.
a -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
if null variant_allocs && null invariant_allocs
then pure (mempty, (lvl, ops, kbody))
else do
(lvl_stms, lvl', grid) <- ensureGridKnown lvl
allocsForBody variant_allocs invariant_allocs grid space kbody kbody' $
\RebaseMap
offsets Stms GPUMem
alloc_stms KernelBody GPUMem
kbody'' -> do
ops'' <- [Lambda GPUMem]
-> (Lambda GPUMem -> OffsetM (Lambda GPUMem))
-> OffsetM [Lambda GPUMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Lambda GPUMem]
ops' ((Lambda GPUMem -> OffsetM (Lambda GPUMem))
-> OffsetM [Lambda GPUMem])
-> (Lambda GPUMem -> OffsetM (Lambda GPUMem))
-> OffsetM [Lambda GPUMem]
forall a b. (a -> b) -> a -> b
$ \Lambda GPUMem
op' ->
Scope GPUMem -> OffsetM (Lambda GPUMem) -> OffsetM (Lambda GPUMem)
forall a. Scope GPUMem -> OffsetM a -> OffsetM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Lambda GPUMem -> Scope GPUMem
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Lambda GPUMem
op') (OffsetM (Lambda GPUMem) -> OffsetM (Lambda GPUMem))
-> OffsetM (Lambda GPUMem) -> OffsetM (Lambda GPUMem)
forall a b. (a -> b) -> a -> b
$ RebaseMap -> Lambda GPUMem -> OffsetM (Lambda GPUMem)
offsetMemoryInLambda RebaseMap
offsets Lambda GPUMem
op'
pure (lvl_stms <> alloc_stms, (lvl', ops'', kbody''))
where
bound_in_kernel :: Names
bound_in_kernel =
[VName] -> Names
namesFromList (Map VName (NameInfo (ZonkAny 3)) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (NameInfo (ZonkAny 3)) -> [VName])
-> Map VName (NameInfo (ZonkAny 3)) -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> Map VName (NameInfo (ZonkAny 3))
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space)
Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> KernelBody GPUMem -> Names
boundInKernelBody KernelBody GPUMem
kbody
boundInKernelBody :: KernelBody GPUMem -> Names
boundInKernelBody :: KernelBody GPUMem -> Names
boundInKernelBody = [VName] -> Names
namesFromList ([VName] -> Names)
-> (KernelBody GPUMem -> [VName]) -> KernelBody GPUMem -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope GPUMem -> [VName]
forall k a. Map k a -> [k]
M.keys (Scope GPUMem -> [VName])
-> (KernelBody GPUMem -> Scope GPUMem)
-> KernelBody GPUMem
-> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms GPUMem -> Scope GPUMem
forall rep a. Scoped rep a => a -> Scope rep
scopeOf (Stms GPUMem -> Scope GPUMem)
-> (KernelBody GPUMem -> Stms GPUMem)
-> KernelBody GPUMem
-> Scope GPUMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms
addStmsToKernelBody :: Stms GPUMem -> KernelBody GPUMem -> KernelBody GPUMem
addStmsToKernelBody :: Stms GPUMem -> KernelBody GPUMem -> KernelBody GPUMem
addStmsToKernelBody Stms GPUMem
stms KernelBody GPUMem
kbody =
KernelBody GPUMem
kbody {kernelBodyStms = stms <> kernelBodyStms kbody}
allocsForBody ::
Extraction ->
Extraction ->
KernelGrid ->
SegSpace ->
KernelBody GPUMem ->
KernelBody GPUMem ->
(RebaseMap -> Stms GPUMem -> KernelBody GPUMem -> OffsetM b) ->
ExpandM b
allocsForBody :: forall b.
Extraction
-> Extraction
-> KernelGrid
-> SegSpace
-> KernelBody GPUMem
-> KernelBody GPUMem
-> (RebaseMap -> Stms GPUMem -> KernelBody GPUMem -> OffsetM b)
-> ExpandM b
allocsForBody Extraction
variant_allocs Extraction
invariant_allocs KernelGrid
grid SegSpace
space KernelBody GPUMem
kbody KernelBody GPUMem
kbody' RebaseMap -> Stms GPUMem -> KernelBody GPUMem -> OffsetM b
m = do
(alloc_offsets, alloc_stms) <-
KernelGrid
-> SegSpace
-> Stms GPUMem
-> Extraction
-> Extraction
-> ExpandM (RebaseMap, Stms GPUMem)
memoryRequirements
KernelGrid
grid
SegSpace
space
(KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody)
Extraction
variant_allocs
Extraction
invariant_allocs
let (alloc_stms_dev, alloc_stms_shared) =
Seq.partition (not . isSharedAlloc) alloc_stms
scope <- askScope
let scope' = SegSpace -> Scope GPUMem
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space Scope GPUMem -> Scope GPUMem -> Scope GPUMem
forall a. Semigroup a => a -> a -> a
<> Scope GPUMem
scope Scope GPUMem -> Scope GPUMem -> Scope GPUMem
forall a. Semigroup a => a -> a -> a
<> Stms GPUMem -> Scope GPUMem
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms GPUMem
alloc_stms
either throwError pure <=< runOffsetM scope' $ do
kbody'' <-
addStmsToKernelBody alloc_stms_shared
<$> offsetMemoryInKernelBody alloc_offsets kbody'
m alloc_offsets alloc_stms_dev kbody''
where
isSharedAlloc :: Stm rep -> Bool
isSharedAlloc (Let Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ (Op (Alloc SubExp
_ (Space String
"shared")))) = Bool
True
isSharedAlloc Stm rep
_ = Bool
False
memoryRequirements ::
KernelGrid ->
SegSpace ->
Stms GPUMem ->
Extraction ->
Extraction ->
ExpandM (RebaseMap, Stms GPUMem)
memoryRequirements :: KernelGrid
-> SegSpace
-> Stms GPUMem
-> Extraction
-> Extraction
-> ExpandM (RebaseMap, Stms GPUMem)
memoryRequirements KernelGrid
grid SegSpace
space Stms GPUMem
kstms Extraction
variant_allocs Extraction
invariant_allocs = do
(num_threads, num_threads_stms) <-
Builder GPUMem SubExp
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either String))
(SubExp, Stms GPUMem)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder GPUMem SubExp
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either String))
(SubExp, Stms GPUMem))
-> (BasicOp -> Builder GPUMem SubExp)
-> BasicOp
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either String))
(SubExp, Stms GPUMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String
-> Exp (Rep (BuilderT GPUMem (State VNameSource)))
-> Builder GPUMem SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"num_threads" (Exp GPUMem -> Builder GPUMem SubExp)
-> (BasicOp -> Exp GPUMem) -> BasicOp -> Builder GPUMem SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp GPUMem
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either String))
(SubExp, Stms GPUMem))
-> BasicOp
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either String))
(SubExp, Stms GPUMem)
forall a b. (a -> b) -> a -> b
$
BinOp -> SubExp -> SubExp -> BasicOp
BinOp
(IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef)
(Count NumBlocks SubExp -> SubExp
forall {k} (u :: k) e. Count u e -> e
unCount (Count NumBlocks SubExp -> SubExp)
-> Count NumBlocks SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ KernelGrid -> Count NumBlocks SubExp
gridNumBlocks KernelGrid
grid)
(Count BlockSize SubExp -> SubExp
forall {k} (u :: k) e. Count u e -> e
unCount (Count BlockSize SubExp -> SubExp)
-> Count BlockSize SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ KernelGrid -> Count BlockSize SubExp
gridBlockSize KernelGrid
grid)
(invariant_alloc_stms, invariant_alloc_offsets) <-
inScopeOf num_threads_stms $
expandedInvariantAllocations
num_threads
(gridNumBlocks grid)
(gridBlockSize grid)
invariant_allocs
(variant_alloc_stms, variant_alloc_offsets) <-
inScopeOf num_threads_stms $
expandedVariantAllocations
num_threads
space
kstms
variant_allocs
pure
( invariant_alloc_offsets <> variant_alloc_offsets,
num_threads_stms <> invariant_alloc_stms <> variant_alloc_stms
)
type Exp64 = TPrimExp Int64 VName
type User = (SegLevel, [Exp64])
type = M.Map VName (User, SubExp, Space)
extractKernelBodyAllocations ::
User ->
Names ->
Names ->
KernelBody GPUMem ->
( KernelBody GPUMem,
Extraction
)
extractKernelBodyAllocations :: (SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> KernelBody GPUMem
-> (KernelBody GPUMem, Extraction)
extractKernelBodyAllocations (SegLevel, [TPrimExp Int64 VName])
lvl Names
bound_outside Names
bound_kernel =
(SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> (KernelBody GPUMem -> Stms GPUMem)
-> (Stms GPUMem -> KernelBody GPUMem -> KernelBody GPUMem)
-> KernelBody GPUMem
-> (KernelBody GPUMem, Extraction)
forall body.
(SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> (body -> Stms GPUMem)
-> (Stms GPUMem -> body -> body)
-> body
-> (body, Extraction)
extractGenericBodyAllocations (SegLevel, [TPrimExp Int64 VName])
lvl Names
bound_outside Names
bound_kernel KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms ((Stms GPUMem -> KernelBody GPUMem -> KernelBody GPUMem)
-> KernelBody GPUMem -> (KernelBody GPUMem, Extraction))
-> (Stms GPUMem -> KernelBody GPUMem -> KernelBody GPUMem)
-> KernelBody GPUMem
-> (KernelBody GPUMem, Extraction)
forall a b. (a -> b) -> a -> b
$
\Stms GPUMem
stms KernelBody GPUMem
kbody -> KernelBody GPUMem
kbody {kernelBodyStms = stms}
extractBodyAllocations ::
User ->
Names ->
Names ->
Body GPUMem ->
(Body GPUMem, Extraction)
extractBodyAllocations :: (SegLevel, [TPrimExp Int64 VName])
-> Names -> Names -> Body GPUMem -> (Body GPUMem, Extraction)
extractBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel =
(SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> (Body GPUMem -> Stms GPUMem)
-> (Stms GPUMem -> Body GPUMem -> Body GPUMem)
-> Body GPUMem
-> (Body GPUMem, Extraction)
forall body.
(SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> (body -> Stms GPUMem)
-> (Stms GPUMem -> body -> body)
-> body
-> (body, Extraction)
extractGenericBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel Body GPUMem -> Stms GPUMem
forall rep. Body rep -> Stms rep
bodyStms ((Stms GPUMem -> Body GPUMem -> Body GPUMem)
-> Body GPUMem -> (Body GPUMem, Extraction))
-> (Stms GPUMem -> Body GPUMem -> Body GPUMem)
-> Body GPUMem
-> (Body GPUMem, Extraction)
forall a b. (a -> b) -> a -> b
$
\Stms GPUMem
stms Body GPUMem
body -> Body GPUMem
body {bodyStms = stms}
extractLambdaAllocations ::
User ->
Names ->
Names ->
Lambda GPUMem ->
(Lambda GPUMem, Extraction)
(SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel Lambda GPUMem
lam =
(Lambda GPUMem
lam {lambdaBody = body'}, Extraction
allocs)
where
(Body GPUMem
body', Extraction
allocs) =
(SegLevel, [TPrimExp Int64 VName])
-> Names -> Names -> Body GPUMem -> (Body GPUMem, Extraction)
extractBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel (Body GPUMem -> (Body GPUMem, Extraction))
-> Body GPUMem -> (Body GPUMem, Extraction)
forall a b. (a -> b) -> a -> b
$
Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
lam
extractGenericBodyAllocations ::
User ->
Names ->
Names ->
(body -> Stms GPUMem) ->
(Stms GPUMem -> body -> body) ->
body ->
( body,
Extraction
)
extractGenericBodyAllocations :: forall body.
(SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> (body -> Stms GPUMem)
-> (Stms GPUMem -> body -> body)
-> body
-> (body, Extraction)
extractGenericBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel body -> Stms GPUMem
get_stms Stms GPUMem -> body -> body
set_stms body
body =
let bound_kernel' :: Names
bound_kernel' = Names
bound_kernel Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Stms GPUMem -> Names
forall rep. Stms rep -> Names
boundByStms (body -> Stms GPUMem
get_stms body
body)
([Stm GPUMem]
stms, Extraction
allocs) =
Writer Extraction [Stm GPUMem] -> ([Stm GPUMem], Extraction)
forall w a. Writer w a -> (a, w)
runWriter (Writer Extraction [Stm GPUMem] -> ([Stm GPUMem], Extraction))
-> (WriterT Extraction Identity [Maybe (Stm GPUMem)]
-> Writer Extraction [Stm GPUMem])
-> WriterT Extraction Identity [Maybe (Stm GPUMem)]
-> ([Stm GPUMem], Extraction)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Maybe (Stm GPUMem)] -> [Stm GPUMem])
-> WriterT Extraction Identity [Maybe (Stm GPUMem)]
-> Writer Extraction [Stm GPUMem]
forall a b.
(a -> b)
-> WriterT Extraction Identity a -> WriterT Extraction Identity b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Maybe (Stm GPUMem)] -> [Stm GPUMem]
forall a. [Maybe a] -> [a]
catMaybes (WriterT Extraction Identity [Maybe (Stm GPUMem)]
-> ([Stm GPUMem], Extraction))
-> WriterT Extraction Identity [Maybe (Stm GPUMem)]
-> ([Stm GPUMem], Extraction)
forall a b. (a -> b) -> a -> b
$
(Stm GPUMem -> WriterT Extraction Identity (Maybe (Stm GPUMem)))
-> [Stm GPUMem] -> WriterT Extraction Identity [Maybe (Stm GPUMem)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> Stm GPUMem
-> WriterT Extraction Identity (Maybe (Stm GPUMem))
extractStmAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel') ([Stm GPUMem] -> WriterT Extraction Identity [Maybe (Stm GPUMem)])
-> [Stm GPUMem] -> WriterT Extraction Identity [Maybe (Stm GPUMem)]
forall a b. (a -> b) -> a -> b
$
Stms GPUMem -> [Stm GPUMem]
forall rep. Stms rep -> [Stm rep]
stmsToList (body -> Stms GPUMem
get_stms body
body)
in (Stms GPUMem -> body -> body
set_stms ([Stm GPUMem] -> Stms GPUMem
forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm GPUMem]
stms) body
body, Extraction
allocs)
expandable :: User -> Space -> Bool
expandable :: (SegLevel, [TPrimExp Int64 VName]) -> Space -> Bool
expandable (SegBlock {}, [TPrimExp Int64 VName]
_) (Space String
"shared") = Bool
False
expandable (SegLevel, [TPrimExp Int64 VName])
_ ScalarSpace {} = Bool
False
expandable (SegLevel, [TPrimExp Int64 VName])
_ Space
_ = Bool
True
notScalar :: Space -> Bool
notScalar :: Space -> Bool
notScalar ScalarSpace {} = Bool
False
notScalar Space
_ = Bool
True
extractStmAllocations ::
User ->
Names ->
Names ->
Stm GPUMem ->
Writer Extraction (Maybe (Stm GPUMem))
(SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel (Let (Pat [PatElem (LetDec GPUMem)
patElem]) StmAux (ExpDec GPUMem)
_ (Op (Alloc SubExp
size Space
space)))
| (SegLevel, [TPrimExp Int64 VName]) -> Space -> Bool
expandable (SegLevel, [TPrimExp Int64 VName])
user Space
space Bool -> Bool -> Bool
&& SubExp -> Bool
expandableSize SubExp
size
Bool -> Bool -> Bool
|| (SubExp -> Bool
boundInKernel SubExp
size Bool -> Bool -> Bool
&& Space -> Bool
notScalar Space
space) = do
Extraction -> WriterT Extraction Identity ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (Extraction -> WriterT Extraction Identity ())
-> Extraction -> WriterT Extraction Identity ()
forall a b. (a -> b) -> a -> b
$ VName
-> ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
-> Extraction
forall k a. k -> a -> Map k a
M.singleton (PatElem (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
PatElem (MemInfo SubExp NoUniqueness MemBind)
patElem) ((SegLevel, [TPrimExp Int64 VName])
user, SubExp
size, Space
space)
Maybe (Stm GPUMem)
-> WriterT Extraction Identity (Maybe (Stm GPUMem))
forall a. a -> WriterT Extraction Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Stm GPUMem)
forall a. Maybe a
Nothing
where
expandableSize :: SubExp -> Bool
expandableSize (Var VName
v) = VName
v VName -> Names -> Bool
`nameIn` Names
bound_outside Bool -> Bool -> Bool
|| VName
v VName -> Names -> Bool
`nameIn` Names
bound_kernel
expandableSize Constant {} = Bool
True
boundInKernel :: SubExp -> Bool
boundInKernel (Var VName
v) = VName
v VName -> Names -> Bool
`nameIn` Names
bound_kernel
boundInKernel Constant {} = Bool
False
extractStmAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel Stm GPUMem
stm = do
e <- Mapper GPUMem GPUMem (WriterT Extraction Identity)
-> Exp GPUMem -> WriterT Extraction Identity (Exp GPUMem)
forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM ((SegLevel, [TPrimExp Int64 VName])
-> Mapper GPUMem GPUMem (WriterT Extraction Identity)
expMapper (SegLevel, [TPrimExp Int64 VName])
user) (Exp GPUMem -> WriterT Extraction Identity (Exp GPUMem))
-> Exp GPUMem -> WriterT Extraction Identity (Exp GPUMem)
forall a b. (a -> b) -> a -> b
$ Stm GPUMem -> Exp GPUMem
forall rep. Stm rep -> Exp rep
stmExp Stm GPUMem
stm
pure $ Just $ stm {stmExp = e}
where
expMapper :: (SegLevel, [TPrimExp Int64 VName])
-> Mapper GPUMem GPUMem (WriterT Extraction Identity)
expMapper (SegLevel, [TPrimExp Int64 VName])
user' =
(forall rep (m :: * -> *). Monad m => Mapper rep rep m
identityMapper @GPUMem)
{ mapOnBody = const $ onBody user',
mapOnOp = onOp user'
}
onBody :: (SegLevel, [TPrimExp Int64 VName])
-> Body GPUMem -> WriterT Extraction Identity (Body GPUMem)
onBody (SegLevel, [TPrimExp Int64 VName])
user' Body GPUMem
body = do
let (Body GPUMem
body', Extraction
allocs) = (SegLevel, [TPrimExp Int64 VName])
-> Names -> Names -> Body GPUMem -> (Body GPUMem, Extraction)
extractBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user' Names
bound_outside Names
bound_kernel Body GPUMem
body
Extraction -> WriterT Extraction Identity ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Extraction
allocs
Body GPUMem -> WriterT Extraction Identity (Body GPUMem)
forall a. a -> WriterT Extraction Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Body GPUMem
body'
onOp :: (SegLevel, [TPrimExp Int64 VName])
-> MemOp (HostOp NoOp) GPUMem
-> WriterT Extraction Identity (MemOp (HostOp NoOp) GPUMem)
onOp (SegLevel
_, [TPrimExp Int64 VName]
user_ids) (Inner (SegOp SegOp SegLevel GPUMem
op)) =
HostOp NoOp GPUMem -> MemOp (HostOp NoOp) GPUMem
forall (inner :: * -> *) rep. inner rep -> MemOp inner rep
Inner (HostOp NoOp GPUMem -> MemOp (HostOp NoOp) GPUMem)
-> (SegOp SegLevel GPUMem -> HostOp NoOp GPUMem)
-> SegOp SegLevel GPUMem
-> MemOp (HostOp NoOp) GPUMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp SegLevel GPUMem -> HostOp NoOp GPUMem
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp (SegOp SegLevel GPUMem -> MemOp (HostOp NoOp) GPUMem)
-> WriterT Extraction Identity (SegOp SegLevel GPUMem)
-> WriterT Extraction Identity (MemOp (HostOp NoOp) GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper SegLevel GPUMem GPUMem (WriterT Extraction Identity)
-> SegOp SegLevel GPUMem
-> WriterT Extraction Identity (SegOp SegLevel GPUMem)
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM ((SegLevel, [TPrimExp Int64 VName])
-> SegOpMapper SegLevel GPUMem GPUMem (WriterT Extraction Identity)
opMapper (SegLevel, [TPrimExp Int64 VName])
user'') SegOp SegLevel GPUMem
op
where
user'' :: (SegLevel, [TPrimExp Int64 VName])
user'' =
(SegOp SegLevel GPUMem -> SegLevel
forall lvl rep. SegOp lvl rep -> lvl
segLevel SegOp SegLevel GPUMem
op, [TPrimExp Int64 VName]
user_ids [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (SegSpace -> VName
segFlat (SegOp SegLevel GPUMem -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp SegLevel GPUMem
op))])
onOp (SegLevel, [TPrimExp Int64 VName])
_ MemOp (HostOp NoOp) GPUMem
op = MemOp (HostOp NoOp) GPUMem
-> WriterT Extraction Identity (MemOp (HostOp NoOp) GPUMem)
forall a. a -> WriterT Extraction Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure MemOp (HostOp NoOp) GPUMem
op
opMapper :: (SegLevel, [TPrimExp Int64 VName])
-> SegOpMapper SegLevel GPUMem GPUMem (WriterT Extraction Identity)
opMapper (SegLevel, [TPrimExp Int64 VName])
user' =
SegOpMapper
SegLevel (ZonkAny 2) (ZonkAny 2) (WriterT Extraction Identity)
forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper
{ mapOnSegOpLambda = onLambda user',
mapOnSegOpBody = onKernelBody user'
}
onKernelBody :: (SegLevel, [TPrimExp Int64 VName])
-> KernelBody GPUMem
-> WriterT Extraction Identity (KernelBody GPUMem)
onKernelBody (SegLevel, [TPrimExp Int64 VName])
user' KernelBody GPUMem
body = do
let (KernelBody GPUMem
body', Extraction
allocs) =
(SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> KernelBody GPUMem
-> (KernelBody GPUMem, Extraction)
extractKernelBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user' Names
bound_outside Names
bound_kernel KernelBody GPUMem
body
Extraction -> WriterT Extraction Identity ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Extraction
allocs
KernelBody GPUMem
-> WriterT Extraction Identity (KernelBody GPUMem)
forall a. a -> WriterT Extraction Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure KernelBody GPUMem
body'
onLambda :: (SegLevel, [TPrimExp Int64 VName])
-> Lambda GPUMem -> WriterT Extraction Identity (Lambda GPUMem)
onLambda (SegLevel, [TPrimExp Int64 VName])
user' Lambda GPUMem
lam = do
body <- (SegLevel, [TPrimExp Int64 VName])
-> Body GPUMem -> WriterT Extraction Identity (Body GPUMem)
onBody (SegLevel, [TPrimExp Int64 VName])
user' (Body GPUMem -> WriterT Extraction Identity (Body GPUMem))
-> Body GPUMem -> WriterT Extraction Identity (Body GPUMem)
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
lam
pure lam {lambdaBody = body}
genericExpandedInvariantAllocations ::
(User -> Space -> (Shape, [Exp64])) -> Extraction -> ExpandM (Stms GPUMem, RebaseMap)
genericExpandedInvariantAllocations :: ((SegLevel, [TPrimExp Int64 VName])
-> Space -> (ShapeBase SubExp, [TPrimExp Int64 VName]))
-> Extraction -> ExpandM (Stms GPUMem, RebaseMap)
genericExpandedInvariantAllocations (SegLevel, [TPrimExp Int64 VName])
-> Space -> (ShapeBase SubExp, [TPrimExp Int64 VName])
getNumUsers Extraction
invariant_allocs = do
(rebases, alloc_stms) <- Builder GPUMem [RebaseMap]
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either String))
([RebaseMap], Stms GPUMem)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder GPUMem [RebaseMap]
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either String))
([RebaseMap], Stms GPUMem))
-> Builder GPUMem [RebaseMap]
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either String))
([RebaseMap], Stms GPUMem)
forall a b. (a -> b) -> a -> b
$ ((VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))
-> BuilderT GPUMem (State VNameSource) RebaseMap)
-> [(VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))]
-> Builder GPUMem [RebaseMap]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))
-> BuilderT GPUMem (State VNameSource) RebaseMap
expand ([(VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))]
-> Builder GPUMem [RebaseMap])
-> [(VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))]
-> Builder GPUMem [RebaseMap]
forall a b. (a -> b) -> a -> b
$ Extraction
-> [(VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))]
forall k a. Map k a -> [(k, a)]
M.toList Extraction
invariant_allocs
pure (alloc_stms, mconcat rebases)
where
expand :: (VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))
-> BuilderT GPUMem (State VNameSource) RebaseMap
expand (VName
mem, ((SegLevel, [TPrimExp Int64 VName])
user, SubExp
per_thread_size, Space
space)) = do
let num_users :: ShapeBase SubExp
num_users = (ShapeBase SubExp, [TPrimExp Int64 VName]) -> ShapeBase SubExp
forall a b. (a, b) -> a
fst ((ShapeBase SubExp, [TPrimExp Int64 VName]) -> ShapeBase SubExp)
-> (ShapeBase SubExp, [TPrimExp Int64 VName]) -> ShapeBase SubExp
forall a b. (a -> b) -> a -> b
$ (SegLevel, [TPrimExp Int64 VName])
-> Space -> (ShapeBase SubExp, [TPrimExp Int64 VName])
getNumUsers (SegLevel, [TPrimExp Int64 VName])
user Space
space
allocpat :: Pat (MemInfo SubExp NoUniqueness MemBind)
allocpat = [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> Pat (MemInfo SubExp NoUniqueness MemBind)
forall dec. [PatElem dec] -> Pat dec
Pat [VName
-> MemInfo SubExp NoUniqueness MemBind
-> PatElem (MemInfo SubExp NoUniqueness MemBind)
forall dec. VName -> dec -> PatElem dec
PatElem VName
mem (MemInfo SubExp NoUniqueness MemBind
-> PatElem (MemInfo SubExp NoUniqueness MemBind))
-> MemInfo SubExp NoUniqueness MemBind
-> PatElem (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$ Space -> MemInfo SubExp NoUniqueness MemBind
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space]
total_size <-
String
-> Exp (Rep (BuilderT GPUMem (State VNameSource)))
-> BuilderT GPUMem (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"total_size" (Exp GPUMem -> BuilderT GPUMem (State VNameSource) VName)
-> ([TPrimExp Int64 VName]
-> BuilderT GPUMem (State VNameSource) (Exp GPUMem))
-> [TPrimExp Int64 VName]
-> BuilderT GPUMem (State VNameSource) VName
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< TPrimExp Int64 VName
-> BuilderT
GPUMem
(State VNameSource)
(Exp (Rep (BuilderT GPUMem (State VNameSource))))
TPrimExp Int64 VName
-> BuilderT GPUMem (State VNameSource) (Exp GPUMem)
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (TPrimExp Int64 VName
-> BuilderT GPUMem (State VNameSource) (Exp GPUMem))
-> ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName]
-> BuilderT GPUMem (State VNameSource) (Exp GPUMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName]
-> BuilderT GPUMem (State VNameSource) VName)
-> [TPrimExp Int64 VName]
-> BuilderT GPUMem (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$
SubExp -> TPrimExp Int64 VName
pe64 SubExp
per_thread_size TPrimExp Int64 VName
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. a -> [a] -> [a]
: (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
num_users)
letBind allocpat $ Op $ Alloc (Var total_size) space
pure $ M.singleton mem $ newBase user space
newBaseThread :: (SegLevel, [TPrimExp Int64 VName])
-> Space
-> [TPrimExp Int64 VName]
-> (TPrimExp Int64 VName, TPrimExp Int64 VName)
newBaseThread (SegLevel, [TPrimExp Int64 VName])
user Space
space [TPrimExp Int64 VName]
_old_shape =
let (ShapeBase SubExp
users_shape, [TPrimExp Int64 VName]
user_ids) = (SegLevel, [TPrimExp Int64 VName])
-> Space -> (ShapeBase SubExp, [TPrimExp Int64 VName])
getNumUsers (SegLevel, [TPrimExp Int64 VName])
user Space
space
dims :: [TPrimExp Int64 VName]
dims = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
users_shape)
in ( [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [TPrimExp Int64 VName]
dims [TPrimExp Int64 VName]
user_ids,
[TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims
)
newBase :: (SegLevel, [TPrimExp Int64 VName])
-> Space
-> [TPrimExp Int64 VName]
-> (TPrimExp Int64 VName, TPrimExp Int64 VName)
newBase user :: (SegLevel, [TPrimExp Int64 VName])
user@(SegThreadInBlock {}, [TPrimExp Int64 VName]
_) Space
space = (SegLevel, [TPrimExp Int64 VName])
-> Space
-> [TPrimExp Int64 VName]
-> (TPrimExp Int64 VName, TPrimExp Int64 VName)
newBaseThread (SegLevel, [TPrimExp Int64 VName])
user Space
space
newBase user :: (SegLevel, [TPrimExp Int64 VName])
user@(SegThread {}, [TPrimExp Int64 VName]
_) Space
space = (SegLevel, [TPrimExp Int64 VName])
-> Space
-> [TPrimExp Int64 VName]
-> (TPrimExp Int64 VName, TPrimExp Int64 VName)
newBaseThread (SegLevel, [TPrimExp Int64 VName])
user Space
space
newBase user :: (SegLevel, [TPrimExp Int64 VName])
user@(SegBlock {}, [TPrimExp Int64 VName]
_) Space
space = \[TPrimExp Int64 VName]
_old_shape ->
let (ShapeBase SubExp
users_shape, [TPrimExp Int64 VName]
user_ids) = (SegLevel, [TPrimExp Int64 VName])
-> Space -> (ShapeBase SubExp, [TPrimExp Int64 VName])
getNumUsers (SegLevel, [TPrimExp Int64 VName])
user Space
space
dims :: [TPrimExp Int64 VName]
dims = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
users_shape)
in ( [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [TPrimExp Int64 VName]
dims [TPrimExp Int64 VName]
user_ids,
[TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims
)
expandedInvariantAllocations ::
SubExp ->
Count NumBlocks SubExp ->
Count BlockSize SubExp ->
Extraction ->
ExpandM (Stms GPUMem, RebaseMap)
expandedInvariantAllocations :: SubExp
-> Count NumBlocks SubExp
-> Count BlockSize SubExp
-> Extraction
-> ExpandM (Stms GPUMem, RebaseMap)
expandedInvariantAllocations SubExp
num_threads (Count SubExp
num_tblocks) (Count SubExp
tblock_size) =
((SegLevel, [TPrimExp Int64 VName])
-> Space -> (ShapeBase SubExp, [TPrimExp Int64 VName]))
-> Extraction -> ExpandM (Stms GPUMem, RebaseMap)
genericExpandedInvariantAllocations (SegLevel, [TPrimExp Int64 VName])
-> Space -> (ShapeBase SubExp, [TPrimExp Int64 VName])
getNumUsers
where
getNumUsers :: (SegLevel, [TPrimExp Int64 VName])
-> Space -> (ShapeBase SubExp, [TPrimExp Int64 VName])
getNumUsers (SegThread {}, [TPrimExp Int64 VName
gtid]) Space
_ = ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
num_threads], [TPrimExp Int64 VName
gtid])
getNumUsers (SegThread {}, [TPrimExp Int64 VName
gid, TPrimExp Int64 VName
ltid]) Space
_ = ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
num_tblocks, SubExp
tblock_size], [TPrimExp Int64 VName
gid, TPrimExp Int64 VName
ltid])
getNumUsers (SegThreadInBlock {}, [TPrimExp Int64 VName
gtid]) Space
_ = ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
num_threads], [TPrimExp Int64 VName
gtid])
getNumUsers (SegThreadInBlock {}, [TPrimExp Int64 VName
_gid, TPrimExp Int64 VName
ltid]) (Space String
"shared") =
([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
tblock_size], [TPrimExp Int64 VName
ltid])
getNumUsers (SegThreadInBlock {}, [TPrimExp Int64 VName
gid, TPrimExp Int64 VName
ltid]) (Space String
"device") =
([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
num_tblocks, SubExp
tblock_size], [TPrimExp Int64 VName
gid, TPrimExp Int64 VName
ltid])
getNumUsers (SegBlock {}, [TPrimExp Int64 VName
gid]) Space
_ = ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
num_tblocks], [TPrimExp Int64 VName
gid])
getNumUsers (SegLevel, [TPrimExp Int64 VName])
user Space
space = String -> (ShapeBase SubExp, [TPrimExp Int64 VName])
forall a. HasCallStack => String -> a
error (String -> (ShapeBase SubExp, [TPrimExp Int64 VName]))
-> String -> (ShapeBase SubExp, [TPrimExp Int64 VName])
forall a b. (a -> b) -> a -> b
$ String
"getNumUsers: unhandled " String -> String -> String
forall a. [a] -> [a] -> [a]
++ ((SegLevel, [TPrimExp Int64 VName]), Space) -> String
forall a. Show a => a -> String
show ((SegLevel, [TPrimExp Int64 VName])
user, Space
space)
expandedVariantAllocations ::
SubExp ->
SegSpace ->
Stms GPUMem ->
Extraction ->
ExpandM (Stms GPUMem, RebaseMap)
expandedVariantAllocations :: SubExp
-> SegSpace
-> Stms GPUMem
-> Extraction
-> ExpandM (Stms GPUMem, RebaseMap)
expandedVariantAllocations SubExp
_ SegSpace
_ Stms GPUMem
_ Extraction
variant_allocs
| Extraction -> Bool
forall a. Map VName a -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Extraction
variant_allocs = (Stms GPUMem, RebaseMap) -> ExpandM (Stms GPUMem, RebaseMap)
forall a.
a -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPUMem
forall a. Monoid a => a
mempty, RebaseMap
forall a. Monoid a => a
mempty)
expandedVariantAllocations SubExp
num_threads SegSpace
kspace Stms GPUMem
kstms Extraction
variant_allocs = do
let sizes_to_blocks :: [(SubExp, [(VName, Space)])]
sizes_to_blocks = Extraction -> [(SubExp, [(VName, Space)])]
removeCommonSizes Extraction
variant_allocs
variant_sizes :: [SubExp]
variant_sizes = ((SubExp, [(VName, Space)]) -> SubExp)
-> [(SubExp, [(VName, Space)])] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, [(VName, Space)]) -> SubExp
forall a b. (a, b) -> a
fst [(SubExp, [(VName, Space)])]
sizes_to_blocks
(slice_stms, offsets, size_sums) <-
SubExp
-> [SubExp]
-> SegSpace
-> Stms GPUMem
-> ExpandM (Stms GPU, [VName], [VName])
sliceKernelSizes SubExp
num_threads [SubExp]
variant_sizes SegSpace
kspace Stms GPUMem
kstms
slice_stms_tmp <- simplifyStms =<< explicitAllocationsInStms slice_stms
slice_stms' <- transformStms slice_stms_tmp
let variant_allocs' :: [(VName, (SubExp, SubExp, Space))]
variant_allocs' =
[[(VName, (SubExp, SubExp, Space))]]
-> [(VName, (SubExp, SubExp, Space))]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[(VName, (SubExp, SubExp, Space))]]
-> [(VName, (SubExp, SubExp, Space))])
-> [[(VName, (SubExp, SubExp, Space))]]
-> [(VName, (SubExp, SubExp, Space))]
forall a b. (a -> b) -> a -> b
$
([(VName, Space)]
-> (VName, VName) -> [(VName, (SubExp, SubExp, Space))])
-> [[(VName, Space)]]
-> [(VName, VName)]
-> [[(VName, (SubExp, SubExp, Space))]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith [(VName, Space)]
-> (VName, VName) -> [(VName, (SubExp, SubExp, Space))]
forall {a} {c}.
[(a, c)] -> (VName, VName) -> [(a, (SubExp, SubExp, c))]
memInfo (((SubExp, [(VName, Space)]) -> [(VName, Space)])
-> [(SubExp, [(VName, Space)])] -> [[(VName, Space)]]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, [(VName, Space)]) -> [(VName, Space)]
forall a b. (a, b) -> b
snd [(SubExp, [(VName, Space)])]
sizes_to_blocks) ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
offsets [VName]
size_sums)
memInfo [(a, c)]
blocks (VName
offset, VName
total_size) =
[(a
mem, (VName -> SubExp
Var VName
offset, VName -> SubExp
Var VName
total_size, c
space)) | (a
mem, c
space) <- [(a, c)]
blocks]
(alloc_stms, rebases) <- mapAndUnzipM expand variant_allocs'
pure (slice_stms' <> stmsFromList alloc_stms, mconcat rebases)
where
expand :: (VName, (SubExp, SubExp, Space))
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either String))
(Stm GPUMem, RebaseMap)
expand (VName
mem, (SubExp
_offset, SubExp
total_size, Space
space)) = do
let allocpat :: Pat (MemInfo SubExp NoUniqueness MemBind)
allocpat = [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> Pat (MemInfo SubExp NoUniqueness MemBind)
forall dec. [PatElem dec] -> Pat dec
Pat [VName
-> MemInfo SubExp NoUniqueness MemBind
-> PatElem (MemInfo SubExp NoUniqueness MemBind)
forall dec. VName -> dec -> PatElem dec
PatElem VName
mem (MemInfo SubExp NoUniqueness MemBind
-> PatElem (MemInfo SubExp NoUniqueness MemBind))
-> MemInfo SubExp NoUniqueness MemBind
-> PatElem (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$ Space -> MemInfo SubExp NoUniqueness MemBind
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space]
(Stm GPUMem, RebaseMap)
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either String))
(Stm GPUMem, RebaseMap)
forall a.
a -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( Pat (LetDec GPUMem)
-> StmAux (ExpDec GPUMem) -> Exp GPUMem -> Stm GPUMem
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPUMem)
Pat (MemInfo SubExp NoUniqueness MemBind)
allocpat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp GPUMem -> Stm GPUMem) -> Exp GPUMem -> Stm GPUMem
forall a b. (a -> b) -> a -> b
$ Op GPUMem -> Exp GPUMem
forall rep. Op rep -> Exp rep
Op (Op GPUMem -> Exp GPUMem) -> Op GPUMem -> Exp GPUMem
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp (HostOp NoOp) GPUMem
forall (inner :: * -> *) rep. SubExp -> Space -> MemOp inner rep
Alloc SubExp
total_size Space
space,
VName
-> ([TPrimExp Int64 VName]
-> (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> RebaseMap
forall k a. k -> a -> Map k a
M.singleton VName
mem [TPrimExp Int64 VName]
-> (TPrimExp Int64 VName, TPrimExp Int64 VName)
newBase
)
num_threads' :: TPrimExp Int64 VName
num_threads' = SubExp -> TPrimExp Int64 VName
pe64 SubExp
num_threads
gtid :: TPrimExp Int64 VName
gtid = VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (VName -> TPrimExp Int64 VName) -> VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
kspace
newBase :: [TPrimExp Int64 VName]
-> (TPrimExp Int64 VName, TPrimExp Int64 VName)
newBase [TPrimExp Int64 VName]
_old_shape =
(TPrimExp Int64 VName
gtid, TPrimExp Int64 VName
num_threads')
type Expansion = (Exp64, Exp64)
type RebaseMap = M.Map VName ([Exp64] -> Expansion)
newtype OffsetM a
= OffsetM (BuilderT GPUMem (StateT VNameSource (Either String)) a)
deriving
( Functor OffsetM
Functor OffsetM =>
(forall a. a -> OffsetM a)
-> (forall a b. OffsetM (a -> b) -> OffsetM a -> OffsetM b)
-> (forall a b c.
(a -> b -> c) -> OffsetM a -> OffsetM b -> OffsetM c)
-> (forall a b. OffsetM a -> OffsetM b -> OffsetM b)
-> (forall a b. OffsetM a -> OffsetM b -> OffsetM a)
-> Applicative OffsetM
forall a. a -> OffsetM a
forall a b. OffsetM a -> OffsetM b -> OffsetM a
forall a b. OffsetM a -> OffsetM b -> OffsetM b
forall a b. OffsetM (a -> b) -> OffsetM a -> OffsetM b
forall a b c. (a -> b -> c) -> OffsetM a -> OffsetM b -> OffsetM c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall a. a -> OffsetM a
pure :: forall a. a -> OffsetM a
$c<*> :: forall a b. OffsetM (a -> b) -> OffsetM a -> OffsetM b
<*> :: forall a b. OffsetM (a -> b) -> OffsetM a -> OffsetM b
$cliftA2 :: forall a b c. (a -> b -> c) -> OffsetM a -> OffsetM b -> OffsetM c
liftA2 :: forall a b c. (a -> b -> c) -> OffsetM a -> OffsetM b -> OffsetM c
$c*> :: forall a b. OffsetM a -> OffsetM b -> OffsetM b
*> :: forall a b. OffsetM a -> OffsetM b -> OffsetM b
$c<* :: forall a b. OffsetM a -> OffsetM b -> OffsetM a
<* :: forall a b. OffsetM a -> OffsetM b -> OffsetM a
Applicative,
(forall a b. (a -> b) -> OffsetM a -> OffsetM b)
-> (forall a b. a -> OffsetM b -> OffsetM a) -> Functor OffsetM
forall a b. a -> OffsetM b -> OffsetM a
forall a b. (a -> b) -> OffsetM a -> OffsetM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> OffsetM a -> OffsetM b
fmap :: forall a b. (a -> b) -> OffsetM a -> OffsetM b
$c<$ :: forall a b. a -> OffsetM b -> OffsetM a
<$ :: forall a b. a -> OffsetM b -> OffsetM a
Functor,
Applicative OffsetM
Applicative OffsetM =>
(forall a b. OffsetM a -> (a -> OffsetM b) -> OffsetM b)
-> (forall a b. OffsetM a -> OffsetM b -> OffsetM b)
-> (forall a. a -> OffsetM a)
-> Monad OffsetM
forall a. a -> OffsetM a
forall a b. OffsetM a -> OffsetM b -> OffsetM b
forall a b. OffsetM a -> (a -> OffsetM b) -> OffsetM b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall a b. OffsetM a -> (a -> OffsetM b) -> OffsetM b
>>= :: forall a b. OffsetM a -> (a -> OffsetM b) -> OffsetM b
$c>> :: forall a b. OffsetM a -> OffsetM b -> OffsetM b
>> :: forall a b. OffsetM a -> OffsetM b -> OffsetM b
$creturn :: forall a. a -> OffsetM a
return :: forall a. a -> OffsetM a
Monad,
HasScope GPUMem,
LocalScope GPUMem,
MonadError String,
Monad OffsetM
OffsetM VNameSource
Monad OffsetM =>
OffsetM VNameSource
-> (VNameSource -> OffsetM ()) -> MonadFreshNames OffsetM
VNameSource -> OffsetM ()
forall (m :: * -> *).
Monad m =>
m VNameSource -> (VNameSource -> m ()) -> MonadFreshNames m
$cgetNameSource :: OffsetM VNameSource
getNameSource :: OffsetM VNameSource
$cputNameSource :: VNameSource -> OffsetM ()
putNameSource :: VNameSource -> OffsetM ()
MonadFreshNames
)
instance MonadBuilder OffsetM where
type Rep OffsetM = GPUMem
mkExpDecM :: Pat (LetDec (Rep OffsetM))
-> Exp (Rep OffsetM) -> OffsetM (ExpDec (Rep OffsetM))
mkExpDecM Pat (LetDec (Rep OffsetM))
pat Exp (Rep OffsetM)
e = BuilderT
GPUMem (StateT VNameSource (Either String)) (ExpDec (Rep OffsetM))
-> OffsetM (ExpDec (Rep OffsetM))
forall a.
BuilderT GPUMem (StateT VNameSource (Either String)) a -> OffsetM a
OffsetM (BuilderT
GPUMem (StateT VNameSource (Either String)) (ExpDec (Rep OffsetM))
-> OffsetM (ExpDec (Rep OffsetM)))
-> BuilderT
GPUMem (StateT VNameSource (Either String)) (ExpDec (Rep OffsetM))
-> OffsetM (ExpDec (Rep OffsetM))
forall a b. (a -> b) -> a -> b
$ Pat
(LetDec
(Rep (BuilderT GPUMem (StateT VNameSource (Either String)))))
-> Exp (Rep (BuilderT GPUMem (StateT VNameSource (Either String))))
-> BuilderT
GPUMem
(StateT VNameSource (Either String))
(ExpDec
(Rep (BuilderT GPUMem (StateT VNameSource (Either String)))))
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m (ExpDec (Rep m))
mkExpDecM Pat
(LetDec
(Rep (BuilderT GPUMem (StateT VNameSource (Either String)))))
Pat (LetDec (Rep OffsetM))
pat Exp (Rep (BuilderT GPUMem (StateT VNameSource (Either String))))
Exp (Rep OffsetM)
e
mkBodyM :: Stms (Rep OffsetM) -> Result -> OffsetM (Body (Rep OffsetM))
mkBodyM Stms (Rep OffsetM)
stms Result
res = BuilderT
GPUMem (StateT VNameSource (Either String)) (Body (Rep OffsetM))
-> OffsetM (Body (Rep OffsetM))
forall a.
BuilderT GPUMem (StateT VNameSource (Either String)) a -> OffsetM a
OffsetM (BuilderT
GPUMem (StateT VNameSource (Either String)) (Body (Rep OffsetM))
-> OffsetM (Body (Rep OffsetM)))
-> BuilderT
GPUMem (StateT VNameSource (Either String)) (Body (Rep OffsetM))
-> OffsetM (Body (Rep OffsetM))
forall a b. (a -> b) -> a -> b
$ Stms (Rep (BuilderT GPUMem (StateT VNameSource (Either String))))
-> Result
-> BuilderT
GPUMem
(StateT VNameSource (Either String))
(Body (Rep (BuilderT GPUMem (StateT VNameSource (Either String)))))
forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM Stms (Rep (BuilderT GPUMem (StateT VNameSource (Either String))))
Stms (Rep OffsetM)
stms Result
res
mkLetNamesM :: [VName] -> Exp (Rep OffsetM) -> OffsetM (Stm (Rep OffsetM))
mkLetNamesM [VName]
pat Exp (Rep OffsetM)
e = BuilderT
GPUMem (StateT VNameSource (Either String)) (Stm (Rep OffsetM))
-> OffsetM (Stm (Rep OffsetM))
forall a.
BuilderT GPUMem (StateT VNameSource (Either String)) a -> OffsetM a
OffsetM (BuilderT
GPUMem (StateT VNameSource (Either String)) (Stm (Rep OffsetM))
-> OffsetM (Stm (Rep OffsetM)))
-> BuilderT
GPUMem (StateT VNameSource (Either String)) (Stm (Rep OffsetM))
-> OffsetM (Stm (Rep OffsetM))
forall a b. (a -> b) -> a -> b
$ [VName]
-> Exp (Rep (BuilderT GPUMem (StateT VNameSource (Either String))))
-> BuilderT
GPUMem
(StateT VNameSource (Either String))
(Stm (Rep (BuilderT GPUMem (StateT VNameSource (Either String)))))
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m (Stm (Rep m))
mkLetNamesM [VName]
pat Exp (Rep (BuilderT GPUMem (StateT VNameSource (Either String))))
Exp (Rep OffsetM)
e
addStms :: Stms (Rep OffsetM) -> OffsetM ()
addStms = BuilderT GPUMem (StateT VNameSource (Either String)) ()
-> OffsetM ()
forall a.
BuilderT GPUMem (StateT VNameSource (Either String)) a -> OffsetM a
OffsetM (BuilderT GPUMem (StateT VNameSource (Either String)) ()
-> OffsetM ())
-> (Stms GPUMem
-> BuilderT GPUMem (StateT VNameSource (Either String)) ())
-> Stms GPUMem
-> OffsetM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms (Rep (BuilderT GPUMem (StateT VNameSource (Either String))))
-> BuilderT GPUMem (StateT VNameSource (Either String)) ()
Stms GPUMem
-> BuilderT GPUMem (StateT VNameSource (Either String)) ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms
collectStms :: forall a. OffsetM a -> OffsetM (a, Stms (Rep OffsetM))
collectStms (OffsetM BuilderT GPUMem (StateT VNameSource (Either String)) a
m) = BuilderT
GPUMem (StateT VNameSource (Either String)) (a, Stms (Rep OffsetM))
-> OffsetM (a, Stms (Rep OffsetM))
forall a.
BuilderT GPUMem (StateT VNameSource (Either String)) a -> OffsetM a
OffsetM (BuilderT
GPUMem (StateT VNameSource (Either String)) (a, Stms (Rep OffsetM))
-> OffsetM (a, Stms (Rep OffsetM)))
-> BuilderT
GPUMem (StateT VNameSource (Either String)) (a, Stms (Rep OffsetM))
-> OffsetM (a, Stms (Rep OffsetM))
forall a b. (a -> b) -> a -> b
$ BuilderT GPUMem (StateT VNameSource (Either String)) a
-> BuilderT
GPUMem
(StateT VNameSource (Either String))
(a,
Stms (Rep (BuilderT GPUMem (StateT VNameSource (Either String)))))
forall a.
BuilderT GPUMem (StateT VNameSource (Either String)) a
-> BuilderT
GPUMem
(StateT VNameSource (Either String))
(a,
Stms (Rep (BuilderT GPUMem (StateT VNameSource (Either String)))))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms BuilderT GPUMem (StateT VNameSource (Either String)) a
m
runOffsetM ::
(MonadFreshNames m) =>
Scope GPUMem ->
OffsetM a ->
m (Either String a)
runOffsetM :: forall (m :: * -> *) a.
MonadFreshNames m =>
Scope GPUMem -> OffsetM a -> m (Either String a)
runOffsetM Scope GPUMem
scope (OffsetM BuilderT GPUMem (StateT VNameSource (Either String)) a
m) = (VNameSource -> (Either String a, VNameSource))
-> m (Either String a)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Either String a, VNameSource))
-> m (Either String a))
-> (VNameSource -> (Either String a, VNameSource))
-> m (Either String a)
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
case StateT VNameSource (Either String) (a, Stms GPUMem)
-> VNameSource -> Either String ((a, Stms GPUMem), VNameSource)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (BuilderT GPUMem (StateT VNameSource (Either String)) a
-> Scope GPUMem
-> StateT VNameSource (Either String) (a, Stms GPUMem)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT BuilderT GPUMem (StateT VNameSource (Either String)) a
m Scope GPUMem
scope) VNameSource
src of
Left String
e -> (String -> Either String a
forall a b. a -> Either a b
Left String
e, VNameSource
src)
Right ((a, Stms GPUMem)
x, VNameSource
src') -> (a -> Either String a
forall a b. b -> Either a b
Right ((a, Stms GPUMem) -> a
forall a b. (a, b) -> a
fst (a, Stms GPUMem)
x), VNameSource
src')
lookupNewBase :: VName -> [Exp64] -> RebaseMap -> Maybe Expansion
lookupNewBase :: VName
-> [TPrimExp Int64 VName]
-> RebaseMap
-> Maybe (TPrimExp Int64 VName, TPrimExp Int64 VName)
lookupNewBase VName
name [TPrimExp Int64 VName]
x RebaseMap
offsets =
(([TPrimExp Int64 VName]
-> (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> [TPrimExp Int64 VName]
-> (TPrimExp Int64 VName, TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName]
x) (([TPrimExp Int64 VName]
-> (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> Maybe
([TPrimExp Int64 VName]
-> (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> Maybe (TPrimExp Int64 VName, TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName
-> RebaseMap
-> Maybe
([TPrimExp Int64 VName]
-> (TPrimExp Int64 VName, TPrimExp Int64 VName))
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name RebaseMap
offsets
offsetMemoryInKernelBody :: RebaseMap -> KernelBody GPUMem -> OffsetM (KernelBody GPUMem)
offsetMemoryInKernelBody :: RebaseMap -> KernelBody GPUMem -> OffsetM (KernelBody GPUMem)
offsetMemoryInKernelBody RebaseMap
offsets KernelBody GPUMem
kbody = do
stms' <-
OffsetM () -> OffsetM (Stms (Rep OffsetM))
forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ (OffsetM () -> OffsetM (Stms (Rep OffsetM)))
-> OffsetM () -> OffsetM (Stms (Rep OffsetM))
forall a b. (a -> b) -> a -> b
$
(Stm GPUMem -> OffsetM ()) -> Stms GPUMem -> OffsetM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Stm (Rep OffsetM) -> OffsetM ()
Stm GPUMem -> OffsetM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm GPUMem -> OffsetM ())
-> (Stm GPUMem -> OffsetM (Stm GPUMem)) -> Stm GPUMem -> OffsetM ()
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< RebaseMap -> Stm GPUMem -> OffsetM (Stm GPUMem)
offsetMemoryInStm RebaseMap
offsets) (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody)
pure kbody {kernelBodyStms = stms'}
offsetMemoryInBody :: RebaseMap -> Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBody :: RebaseMap -> Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBody RebaseMap
offsets (Body BodyDec GPUMem
_ Stms GPUMem
stms Result
res) = do
OffsetM Result -> OffsetM (Body (Rep OffsetM))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (OffsetM Result -> OffsetM (Body (Rep OffsetM)))
-> OffsetM Result -> OffsetM (Body (Rep OffsetM))
forall a b. (a -> b) -> a -> b
$ do
(Stm GPUMem -> OffsetM ()) -> Stms GPUMem -> OffsetM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Stm (Rep OffsetM) -> OffsetM ()
Stm GPUMem -> OffsetM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm GPUMem -> OffsetM ())
-> (Stm GPUMem -> OffsetM (Stm GPUMem)) -> Stm GPUMem -> OffsetM ()
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< RebaseMap -> Stm GPUMem -> OffsetM (Stm GPUMem)
offsetMemoryInStm RebaseMap
offsets) Stms GPUMem
stms
Result -> OffsetM Result
forall a. a -> OffsetM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res
argsContext :: [SubExp] -> OffsetM [SubExp]
argsContext :: [SubExp] -> OffsetM [SubExp]
argsContext = ([[SubExp]] -> [SubExp]) -> OffsetM [[SubExp]] -> OffsetM [SubExp]
forall a b. (a -> b) -> OffsetM a -> OffsetM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [[SubExp]] -> [SubExp]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (OffsetM [[SubExp]] -> OffsetM [SubExp])
-> ([SubExp] -> OffsetM [[SubExp]]) -> [SubExp] -> OffsetM [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp -> OffsetM [SubExp]) -> [SubExp] -> OffsetM [[SubExp]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> OffsetM [SubExp]
forall {m :: * -> *} {inner :: * -> *}.
(RetType (Rep m) ~ RetTypeMem,
LParamInfo (Rep m) ~ MemInfo SubExp NoUniqueness MemBind,
BranchType (Rep m) ~ BranchTypeMem,
FParamInfo (Rep m) ~ MemInfo SubExp Uniqueness MemBind,
OpC (Rep m) ~ MemOp inner, OpReturns inner, MonadBuilder m,
RephraseOp inner, HasLetDecMem (LetDec (Rep m)),
Pretty (inner (Rep m)), Rename (inner (Rep m)),
Substitute (inner (Rep m)), FreeIn (inner (Rep m)),
Show (inner (Rep m)), Ord (inner (Rep m))) =>
SubExp -> m [SubExp]
resCtx
where
resCtx :: SubExp -> m [SubExp]
resCtx SubExp
se = do
v_t <- SubExp -> m (MemInfo SubExp NoUniqueness MemBind)
forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Mem rep inner) =>
SubExp -> m (MemInfo SubExp NoUniqueness MemBind)
subExpMemInfo SubExp
se
case v_t of
MemArray PrimType
_ ShapeBase SubExp
_ NoUniqueness
_ (ArrayIn VName
mem LMAD
lmad) -> do
ctxs <- (TPrimExp Int64 VName -> m SubExp)
-> [TPrimExp Int64 VName] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"ctx" (Exp (Rep m) -> m SubExp)
-> (TPrimExp Int64 VName -> m (Exp (Rep m)))
-> TPrimExp Int64 VName
-> m SubExp
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< TPrimExp Int64 VName -> m (Exp (Rep m))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp) (LMAD -> [TPrimExp Int64 VName]
forall a. LMAD a -> [a]
LMAD.existentialized LMAD
lmad)
pure $ Var mem : ctxs
MemInfo SubExp NoUniqueness MemBind
_ -> [SubExp] -> m [SubExp]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
offsetMemoryInBodyReturnCtx :: RebaseMap -> Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBodyReturnCtx :: RebaseMap -> Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBodyReturnCtx RebaseMap
offsets (Body BodyDec GPUMem
_ Stms GPUMem
stms Result
res) = do
OffsetM Result -> OffsetM (Body (Rep OffsetM))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (OffsetM Result -> OffsetM (Body (Rep OffsetM)))
-> OffsetM Result -> OffsetM (Body (Rep OffsetM))
forall a b. (a -> b) -> a -> b
$ do
(Stm GPUMem -> OffsetM ()) -> Stms GPUMem -> OffsetM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Stm (Rep OffsetM) -> OffsetM ()
Stm GPUMem -> OffsetM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm GPUMem -> OffsetM ())
-> (Stm GPUMem -> OffsetM (Stm GPUMem)) -> Stm GPUMem -> OffsetM ()
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< RebaseMap -> Stm GPUMem -> OffsetM (Stm GPUMem)
offsetMemoryInStm RebaseMap
offsets) Stms GPUMem
stms
ctx <- [SubExp] -> OffsetM [SubExp]
argsContext ([SubExp] -> OffsetM [SubExp]) -> [SubExp] -> OffsetM [SubExp]
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
res
pure $ res <> subExpsRes ctx
lmadFrom :: LMAD.Shape num -> [num] -> LMAD.LMAD num
lmadFrom :: forall num. Shape num -> Shape num -> LMAD num
lmadFrom Shape num
shape Shape num
xs =
num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD.LMAD (Shape num -> num
forall a. HasCallStack => [a] -> a
head Shape num
xs) ([LMADDim num] -> LMAD num) -> [LMADDim num] -> LMAD num
forall a b. (a -> b) -> a -> b
$ (num -> num -> LMADDim num)
-> Shape num -> Shape num -> [LMADDim num]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith num -> num -> LMADDim num
forall num. num -> num -> LMADDim num
LMAD.LMADDim (Int -> Shape num -> Shape num
forall a. Int -> [a] -> [a]
drop Int
1 Shape num
xs) Shape num
shape
addPatternContext :: Pat LetDecMem -> OffsetM (Pat LetDecMem)
addPatternContext :: Pat (MemInfo SubExp NoUniqueness MemBind)
-> OffsetM (Pat (MemInfo SubExp NoUniqueness MemBind))
addPatternContext (Pat [PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes) = Scope GPUMem
-> OffsetM (Pat (MemInfo SubExp NoUniqueness MemBind))
-> OffsetM (Pat (MemInfo SubExp NoUniqueness MemBind))
forall a. Scope GPUMem -> OffsetM a -> OffsetM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Pat (MemInfo SubExp NoUniqueness MemBind) -> Scope GPUMem
forall rep dec. (LetDec rep ~ dec) => Pat dec -> Scope rep
scopeOfPat ([PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> Pat (MemInfo SubExp NoUniqueness MemBind)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes)) (OffsetM (Pat (MemInfo SubExp NoUniqueness MemBind))
-> OffsetM (Pat (MemInfo SubExp NoUniqueness MemBind)))
-> OffsetM (Pat (MemInfo SubExp NoUniqueness MemBind))
-> OffsetM (Pat (MemInfo SubExp NoUniqueness MemBind))
forall a b. (a -> b) -> a -> b
$ do
(pes_ctx, pes') <- ([PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> PatElem (MemInfo SubExp NoUniqueness MemBind)
-> OffsetM
([PatElem (MemInfo SubExp NoUniqueness MemBind)],
PatElem (MemInfo SubExp NoUniqueness MemBind)))
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> OffsetM
([PatElem (MemInfo SubExp NoUniqueness MemBind)],
[PatElem (MemInfo SubExp NoUniqueness MemBind)])
forall (m :: * -> *) (t :: * -> *) acc x y.
(Monad m, Traversable t) =>
(acc -> x -> m (acc, y)) -> acc -> t x -> m (acc, t y)
mapAccumLM [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> PatElem (MemInfo SubExp NoUniqueness MemBind)
-> OffsetM
([PatElem (MemInfo SubExp NoUniqueness MemBind)],
PatElem (MemInfo SubExp NoUniqueness MemBind))
forall {rep} {inner :: * -> *} {m :: * -> *} {d} {u} {ret} {d} {u}.
(BranchType rep ~ BranchTypeMem,
LParamInfo rep ~ MemInfo SubExp NoUniqueness MemBind,
FParamInfo rep ~ MemInfo SubExp Uniqueness MemBind,
RetType rep ~ RetTypeMem, OpC rep ~ MemOp inner,
HasLetDecMem (LetDec rep), ASTRep rep, OpReturns inner,
RephraseOp inner, HasScope rep m, MonadFreshNames m,
Pretty (inner rep), Rename (inner rep), Show (inner rep),
Ord (inner rep), Substitute (inner rep), FreeIn (inner rep)) =>
[PatElem (MemInfo d u ret)]
-> PatElem (MemInfo d u MemBind)
-> m ([PatElem (MemInfo d u ret)], PatElem (MemInfo d u MemBind))
onType [] [PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes
pure $ Pat $ pes' <> pes_ctx
where
onType :: [PatElem (MemInfo d u ret)]
-> PatElem (MemInfo d u MemBind)
-> m ([PatElem (MemInfo d u ret)], PatElem (MemInfo d u MemBind))
onType
[PatElem (MemInfo d u ret)]
acc
(PatElem VName
pe_v (MemArray PrimType
pt ShapeBase d
pe_shape u
pe_u (ArrayIn VName
pe_mem LMAD
lmad))) = do
space <- VName -> m Space
forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
pe_mem
pe_mem' <- newVName $ baseString pe_mem <> "_ext"
let num_exts = [TPrimExp Int64 VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LMAD -> [TPrimExp Int64 VName]
forall a. LMAD a -> [a]
LMAD.existentialized LMAD
lmad)
lmad_exts <-
replicateM num_exts $
PatElem <$> newVName "ext" <*> pure (MemPrim int64)
let pe_lmad' = [TPrimExp Int64 VName] -> [TPrimExp Int64 VName] -> LMAD
forall num. Shape num -> Shape num -> LMAD num
lmadFrom (LMAD -> [TPrimExp Int64 VName]
forall a. LMAD a -> [a]
LMAD.shape LMAD
lmad) ([TPrimExp Int64 VName] -> LMAD) -> [TPrimExp Int64 VName] -> LMAD
forall a b. (a -> b) -> a -> b
$ (PatElem (MemInfo d u ret) -> TPrimExp Int64 VName)
-> [PatElem (MemInfo d u ret)] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (VName -> TPrimExp Int64 VName)
-> (PatElem (MemInfo d u ret) -> VName)
-> PatElem (MemInfo d u ret)
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem (MemInfo d u ret) -> VName
forall dec. PatElem dec -> VName
patElemName) [PatElem (MemInfo d u ret)]
lmad_exts
pure
( acc ++ PatElem pe_mem' (MemMem space) : lmad_exts,
PatElem pe_v $ MemArray pt pe_shape pe_u $ ArrayIn pe_mem' pe_lmad'
)
onType [PatElem (MemInfo d u ret)]
acc PatElem (MemInfo d u MemBind)
t = ([PatElem (MemInfo d u ret)], PatElem (MemInfo d u MemBind))
-> m ([PatElem (MemInfo d u ret)], PatElem (MemInfo d u MemBind))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([PatElem (MemInfo d u ret)]
acc, PatElem (MemInfo d u MemBind)
t)
addParamsContext :: [Param FParamMem] -> OffsetM [Param FParamMem]
addParamsContext :: [Param (MemInfo SubExp Uniqueness MemBind)]
-> OffsetM [Param (MemInfo SubExp Uniqueness MemBind)]
addParamsContext [Param (MemInfo SubExp Uniqueness MemBind)]
ps = Scope GPUMem
-> OffsetM [Param (MemInfo SubExp Uniqueness MemBind)]
-> OffsetM [Param (MemInfo SubExp Uniqueness MemBind)]
forall a. Scope GPUMem -> OffsetM a -> OffsetM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param (MemInfo SubExp Uniqueness MemBind)] -> Scope GPUMem
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param (MemInfo SubExp Uniqueness MemBind)]
ps) (OffsetM [Param (MemInfo SubExp Uniqueness MemBind)]
-> OffsetM [Param (MemInfo SubExp Uniqueness MemBind)])
-> OffsetM [Param (MemInfo SubExp Uniqueness MemBind)]
-> OffsetM [Param (MemInfo SubExp Uniqueness MemBind)]
forall a b. (a -> b) -> a -> b
$ do
(ps_ctx, ps') <- ([Param (MemInfo SubExp Uniqueness MemBind)]
-> Param (MemInfo SubExp Uniqueness MemBind)
-> OffsetM
([Param (MemInfo SubExp Uniqueness MemBind)],
Param (MemInfo SubExp Uniqueness MemBind)))
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> OffsetM
([Param (MemInfo SubExp Uniqueness MemBind)],
[Param (MemInfo SubExp Uniqueness MemBind)])
forall (m :: * -> *) (t :: * -> *) acc x y.
(Monad m, Traversable t) =>
(acc -> x -> m (acc, y)) -> acc -> t x -> m (acc, t y)
mapAccumLM [Param (MemInfo SubExp Uniqueness MemBind)]
-> Param (MemInfo SubExp Uniqueness MemBind)
-> OffsetM
([Param (MemInfo SubExp Uniqueness MemBind)],
Param (MemInfo SubExp Uniqueness MemBind))
forall {rep} {inner :: * -> *} {m :: * -> *} {d} {u} {ret} {d} {u}.
(BranchType rep ~ BranchTypeMem,
LParamInfo rep ~ MemInfo SubExp NoUniqueness MemBind,
FParamInfo rep ~ MemInfo SubExp Uniqueness MemBind,
RetType rep ~ RetTypeMem, OpC rep ~ MemOp inner,
HasLetDecMem (LetDec rep), ASTRep rep, OpReturns inner,
RephraseOp inner, HasScope rep m, MonadFreshNames m,
Pretty (inner rep), Rename (inner rep), Show (inner rep),
Ord (inner rep), Substitute (inner rep), FreeIn (inner rep)) =>
[Param (MemInfo d u ret)]
-> Param (MemInfo d u MemBind)
-> m ([Param (MemInfo d u ret)], Param (MemInfo d u MemBind))
onType [] [Param (MemInfo SubExp Uniqueness MemBind)]
ps
pure $ ps' <> ps_ctx
where
onType :: [Param (MemInfo d u ret)]
-> Param (MemInfo d u MemBind)
-> m ([Param (MemInfo d u ret)], Param (MemInfo d u MemBind))
onType [Param (MemInfo d u ret)]
acc (Param Attrs
attr VName
v (MemArray PrimType
pt ShapeBase d
shape u
u (ArrayIn VName
mem LMAD
lmad))) = do
space <- VName -> m Space
forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
mem
mem' <- newVName $ baseString mem <> "_ext"
let num_exts = [TPrimExp Int64 VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LMAD -> [TPrimExp Int64 VName]
forall a. LMAD a -> [a]
LMAD.existentialized LMAD
lmad)
lmad_exts <-
replicateM num_exts $
Param mempty <$> newVName "ext" <*> pure (MemPrim int64)
let lmad' = [TPrimExp Int64 VName] -> [TPrimExp Int64 VName] -> LMAD
forall num. Shape num -> Shape num -> LMAD num
lmadFrom (LMAD -> [TPrimExp Int64 VName]
forall a. LMAD a -> [a]
LMAD.shape LMAD
lmad) ([TPrimExp Int64 VName] -> LMAD) -> [TPrimExp Int64 VName] -> LMAD
forall a b. (a -> b) -> a -> b
$ (Param (MemInfo d u ret) -> TPrimExp Int64 VName)
-> [Param (MemInfo d u ret)] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (VName -> TPrimExp Int64 VName)
-> (Param (MemInfo d u ret) -> VName)
-> Param (MemInfo d u ret)
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (MemInfo d u ret) -> VName
forall dec. Param dec -> VName
paramName) [Param (MemInfo d u ret)]
lmad_exts
pure
( acc ++ Param mempty mem' (MemMem space) : lmad_exts,
Param attr v $ MemArray pt shape u $ ArrayIn mem' lmad'
)
onType [Param (MemInfo d u ret)]
acc Param (MemInfo d u MemBind)
t = ([Param (MemInfo d u ret)], Param (MemInfo d u MemBind))
-> m ([Param (MemInfo d u ret)], Param (MemInfo d u MemBind))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Param (MemInfo d u ret)]
acc, Param (MemInfo d u MemBind)
t)
offsetBranch ::
Pat LetDecMem ->
[BranchTypeMem] ->
OffsetM (Pat LetDecMem, [BranchTypeMem])
offsetBranch :: Pat (MemInfo SubExp NoUniqueness MemBind)
-> [BranchTypeMem]
-> OffsetM
(Pat (MemInfo SubExp NoUniqueness MemBind), [BranchTypeMem])
offsetBranch (Pat [PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes) [BranchTypeMem]
ts = do
((pes_ctx, ts_ctx), (pes', ts')) <-
([(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)]
-> ([PatElem (MemInfo SubExp NoUniqueness MemBind)],
[BranchTypeMem]))
-> ([(PatElem (MemInfo SubExp NoUniqueness MemBind),
BranchTypeMem)]
-> ([PatElem (MemInfo SubExp NoUniqueness MemBind)],
[BranchTypeMem]))
-> ([(PatElem (MemInfo SubExp NoUniqueness MemBind),
BranchTypeMem)],
[(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)])
-> (([PatElem (MemInfo SubExp NoUniqueness MemBind)],
[BranchTypeMem]),
([PatElem (MemInfo SubExp NoUniqueness MemBind)], [BranchTypeMem]))
forall a b c d. (a -> b) -> (c -> d) -> (a, c) -> (b, d)
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap [(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)]
-> ([PatElem (MemInfo SubExp NoUniqueness MemBind)],
[BranchTypeMem])
forall a b. [(a, b)] -> ([a], [b])
unzip [(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)]
-> ([PatElem (MemInfo SubExp NoUniqueness MemBind)],
[BranchTypeMem])
forall a b. [(a, b)] -> ([a], [b])
unzip (([(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)],
[(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)])
-> (([PatElem (MemInfo SubExp NoUniqueness MemBind)],
[BranchTypeMem]),
([PatElem (MemInfo SubExp NoUniqueness MemBind)],
[BranchTypeMem])))
-> OffsetM
([(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)],
[(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)])
-> OffsetM
(([PatElem (MemInfo SubExp NoUniqueness MemBind)],
[BranchTypeMem]),
([PatElem (MemInfo SubExp NoUniqueness MemBind)], [BranchTypeMem]))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)]
-> (PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)
-> OffsetM
([(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)],
(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)))
-> [(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)]
-> [(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)]
-> OffsetM
([(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)],
[(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)])
forall (m :: * -> *) (t :: * -> *) acc x y.
(Monad m, Traversable t) =>
(acc -> x -> m (acc, y)) -> acc -> t x -> m (acc, t y)
mapAccumLM [(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)]
-> (PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)
-> OffsetM
([(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)],
(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem))
onType [] ([PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> [BranchTypeMem]
-> [(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes [BranchTypeMem]
ts)
pure (Pat $ pes' <> pes_ctx, ts' <> ts_ctx)
where
onType :: [(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)]
-> (PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)
-> OffsetM
([(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)],
(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem))
onType
[(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)]
acc
( PatElem VName
pe_v (MemArray PrimType
_ ShapeBase SubExp
pe_shape NoUniqueness
pe_u (ArrayIn VName
pe_mem LMAD
pe_lmad)),
MemArray PrimType
pt ShapeBase ExtSize
shape NoUniqueness
u MemReturn
meminfo
) = do
(space, lmad) <- case MemReturn
meminfo of
ReturnsInBlock VName
mem ExtLMAD
lmad -> do
space <- VName -> OffsetM Space
forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
mem
pure (space, lmad)
ReturnsNewBlock Space
space Int
_ ExtLMAD
lmad ->
(Space, ExtLMAD) -> OffsetM (Space, ExtLMAD)
forall a. a -> OffsetM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Space
space, ExtLMAD
lmad)
pe_mem' <- newVName $ baseString pe_mem <> "_ext"
let start = [BranchTypeMem] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [BranchTypeMem]
ts Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)]
-> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)]
acc
num_exts = [TPrimExp Int64 (Ext VName)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (ExtLMAD -> [TPrimExp Int64 (Ext VName)]
forall a. LMAD a -> [a]
LMAD.existentialized ExtLMAD
lmad)
ext (Free SubExp
se) = VName -> Ext VName
forall a. a -> Ext a
Free (VName -> Ext VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> TPrimExp Int64 VName
pe64 SubExp
se
ext (Ext Int
i) = Ext VName -> TPrimExp Int64 (Ext VName)
forall a. a -> TPrimExp Int64 a
le64 (Int -> Ext VName
forall a. Int -> Ext a
Ext Int
i)
lmad_exts <-
replicateM num_exts $
PatElem <$> newVName "ext" <*> pure (MemPrim int64)
let pe_lmad' = [TPrimExp Int64 VName] -> [TPrimExp Int64 VName] -> LMAD
forall num. Shape num -> Shape num -> LMAD num
lmadFrom (LMAD -> [TPrimExp Int64 VName]
forall a. LMAD a -> [a]
LMAD.shape LMAD
pe_lmad) ([TPrimExp Int64 VName] -> LMAD) -> [TPrimExp Int64 VName] -> LMAD
forall a b. (a -> b) -> a -> b
$ (PatElem (MemInfo SubExp NoUniqueness MemBind)
-> TPrimExp Int64 VName)
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (VName -> TPrimExp Int64 VName)
-> (PatElem (MemInfo SubExp NoUniqueness MemBind) -> VName)
-> PatElem (MemInfo SubExp NoUniqueness MemBind)
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. PatElem dec -> VName
patElemName) [PatElem (MemInfo SubExp NoUniqueness MemBind)]
lmad_exts
pure
( acc
++ (PatElem pe_mem' $ MemMem space, MemMem space)
: map (,MemPrim int64) lmad_exts,
( PatElem pe_v $ MemArray pt pe_shape pe_u $ ArrayIn pe_mem' pe_lmad',
MemArray pt shape u . ReturnsNewBlock space start . fmap ext $
LMAD.mkExistential (shapeDims shape) (1 + start)
)
)
onType [(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)]
acc (PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)
t = ([(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)],
(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem))
-> OffsetM
([(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)],
(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem))
forall a. a -> OffsetM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)]
acc, (PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)
t)
offsetMemoryInPat :: RebaseMap -> Pat LetDecMem -> [ExpReturns] -> Pat LetDecMem
offsetMemoryInPat :: RebaseMap
-> Pat (MemInfo SubExp NoUniqueness MemBind)
-> [ExpReturns]
-> Pat (MemInfo SubExp NoUniqueness MemBind)
offsetMemoryInPat RebaseMap
offsets (Pat [PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes) [ExpReturns]
rets = do
[PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> Pat (MemInfo SubExp NoUniqueness MemBind)
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> Pat (MemInfo SubExp NoUniqueness MemBind))
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> Pat (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$ (PatElem (MemInfo SubExp NoUniqueness MemBind)
-> ExpReturns -> PatElem (MemInfo SubExp NoUniqueness MemBind))
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> [ExpReturns]
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PatElem (MemInfo SubExp NoUniqueness MemBind)
-> ExpReturns -> PatElem (MemInfo SubExp NoUniqueness MemBind)
onPE [PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes [ExpReturns]
rets
where
onPE :: PatElem (MemInfo SubExp NoUniqueness MemBind)
-> ExpReturns -> PatElem (MemInfo SubExp NoUniqueness MemBind)
onPE
(PatElem VName
name (MemArray PrimType
pt ShapeBase SubExp
shape NoUniqueness
u (ArrayIn VName
mem LMAD
_)))
(MemArray PrimType
_ ShapeBase ExtSize
_ NoUniqueness
_ Maybe MemReturn
info)
| Just ExtLMAD
lmad <- Maybe MemReturn -> Maybe ExtLMAD
getLMAD Maybe MemReturn
info =
VName
-> MemInfo SubExp NoUniqueness MemBind
-> PatElem (MemInfo SubExp NoUniqueness MemBind)
forall dec. VName -> dec -> PatElem dec
PatElem VName
name (MemInfo SubExp NoUniqueness MemBind
-> PatElem (MemInfo SubExp NoUniqueness MemBind))
-> (LMAD -> MemInfo SubExp NoUniqueness MemBind)
-> LMAD
-> PatElem (MemInfo SubExp NoUniqueness MemBind)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType
-> ShapeBase SubExp
-> NoUniqueness
-> MemBind
-> MemInfo SubExp NoUniqueness MemBind
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase SubExp
shape NoUniqueness
u (MemBind -> MemInfo SubExp NoUniqueness MemBind)
-> (LMAD -> MemBind) -> LMAD -> MemInfo SubExp NoUniqueness MemBind
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> LMAD -> MemBind
ArrayIn VName
mem (LMAD -> PatElem (MemInfo SubExp NoUniqueness MemBind))
-> LMAD -> PatElem (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$
(TPrimExp Int64 (Ext VName) -> TPrimExp Int64 VName)
-> ExtLMAD -> LMAD
forall a b. (a -> b) -> LMAD a -> LMAD b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Ext VName -> VName)
-> TPrimExp Int64 (Ext VName) -> TPrimExp Int64 VName
forall a b. (a -> b) -> TPrimExp Int64 a -> TPrimExp Int64 b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Ext VName -> VName
unExt) ExtLMAD
lmad
onPE PatElem (MemInfo SubExp NoUniqueness MemBind)
pe ExpReturns
_ =
RebaseMap
-> MemInfo SubExp NoUniqueness MemBind
-> MemInfo SubExp NoUniqueness MemBind
forall u. RebaseMap -> MemBound u -> MemBound u
offsetMemoryInMemBound RebaseMap
offsets (MemInfo SubExp NoUniqueness MemBind
-> MemInfo SubExp NoUniqueness MemBind)
-> PatElem (MemInfo SubExp NoUniqueness MemBind)
-> PatElem (MemInfo SubExp NoUniqueness MemBind)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> PatElem (MemInfo SubExp NoUniqueness MemBind)
pe
unExt :: Ext VName -> VName
unExt (Ext Int
i) = PatElem (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. PatElem dec -> VName
patElemName ([PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> Int -> PatElem (MemInfo SubExp NoUniqueness MemBind)
forall a. HasCallStack => [a] -> Int -> a
!! Int
i)
unExt (Free VName
v) = VName
v
getLMAD :: Maybe MemReturn -> Maybe ExtLMAD
getLMAD (Just (ReturnsNewBlock Space
_ Int
_ ExtLMAD
lmad)) = ExtLMAD -> Maybe ExtLMAD
forall a. a -> Maybe a
Just ExtLMAD
lmad
getLMAD (Just (ReturnsInBlock VName
_ ExtLMAD
lmad)) = ExtLMAD -> Maybe ExtLMAD
forall a. a -> Maybe a
Just ExtLMAD
lmad
getLMAD Maybe MemReturn
_ = Maybe ExtLMAD
forall a. Maybe a
Nothing
offsetMemoryInParam :: RebaseMap -> Param (MemBound u) -> Param (MemBound u)
offsetMemoryInParam :: forall u. RebaseMap -> Param (MemBound u) -> Param (MemBound u)
offsetMemoryInParam RebaseMap
offsets = (MemBound u -> MemBound u)
-> Param (MemBound u) -> Param (MemBound u)
forall a b. (a -> b) -> Param a -> Param b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((MemBound u -> MemBound u)
-> Param (MemBound u) -> Param (MemBound u))
-> (MemBound u -> MemBound u)
-> Param (MemBound u)
-> Param (MemBound u)
forall a b. (a -> b) -> a -> b
$ RebaseMap -> MemBound u -> MemBound u
forall u. RebaseMap -> MemBound u -> MemBound u
offsetMemoryInMemBound RebaseMap
offsets
offsetMemoryInMemBound :: RebaseMap -> MemBound u -> MemBound u
offsetMemoryInMemBound :: forall u. RebaseMap -> MemBound u -> MemBound u
offsetMemoryInMemBound RebaseMap
offsets (MemArray PrimType
pt ShapeBase SubExp
shape u
u (ArrayIn VName
mem LMAD
lmad))
| Just (TPrimExp Int64 VName
o, TPrimExp Int64 VName
p) <- VName
-> [TPrimExp Int64 VName]
-> RebaseMap
-> Maybe (TPrimExp Int64 VName, TPrimExp Int64 VName)
lookupNewBase VName
mem (LMAD -> [TPrimExp Int64 VName]
forall a. LMAD a -> [a]
LMAD.shape LMAD
lmad) RebaseMap
offsets =
PrimType
-> ShapeBase SubExp -> u -> MemBind -> MemInfo SubExp u MemBind
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase SubExp
shape u
u (MemBind -> MemInfo SubExp u MemBind)
-> MemBind -> MemInfo SubExp u MemBind
forall a b. (a -> b) -> a -> b
$ VName -> LMAD -> MemBind
ArrayIn VName
mem (LMAD -> MemBind) -> LMAD -> MemBind
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> TPrimExp Int64 VName -> LMAD -> LMAD
forall num. IntegralExp num => num -> num -> LMAD num -> LMAD num
LMAD.expand TPrimExp Int64 VName
o TPrimExp Int64 VName
p LMAD
lmad
offsetMemoryInMemBound RebaseMap
_ MemInfo SubExp u MemBind
info = MemInfo SubExp u MemBind
info
offsetMemoryInBodyReturns :: RebaseMap -> BodyReturns -> BodyReturns
offsetMemoryInBodyReturns :: RebaseMap -> BranchTypeMem -> BranchTypeMem
offsetMemoryInBodyReturns RebaseMap
offsets (MemArray PrimType
pt ShapeBase ExtSize
shape NoUniqueness
u (ReturnsInBlock VName
mem ExtLMAD
lmad))
| Just LMAD
lmad' <- ExtLMAD -> Maybe LMAD
isStaticLMAD ExtLMAD
lmad,
Just (TPrimExp Int64 VName
o, TPrimExp Int64 VName
p) <- VName
-> [TPrimExp Int64 VName]
-> RebaseMap
-> Maybe (TPrimExp Int64 VName, TPrimExp Int64 VName)
lookupNewBase VName
mem (LMAD -> [TPrimExp Int64 VName]
forall a. LMAD a -> [a]
LMAD.shape LMAD
lmad') RebaseMap
offsets =
PrimType
-> ShapeBase ExtSize -> NoUniqueness -> MemReturn -> BranchTypeMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase ExtSize
shape NoUniqueness
u (MemReturn -> BranchTypeMem) -> MemReturn -> BranchTypeMem
forall a b. (a -> b) -> a -> b
$
VName -> ExtLMAD -> MemReturn
ReturnsInBlock VName
mem (ExtLMAD -> MemReturn) -> ExtLMAD -> MemReturn
forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 (Ext VName)
-> TPrimExp Int64 (Ext VName) -> ExtLMAD -> ExtLMAD
forall num. IntegralExp num => num -> num -> LMAD num -> LMAD num
LMAD.expand (VName -> Ext VName
forall a. a -> Ext a
Free (VName -> Ext VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TPrimExp Int64 VName
o) ((VName -> Ext VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName)
forall a b. (a -> b) -> TPrimExp Int64 a -> TPrimExp Int64 b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> Ext VName
forall a. a -> Ext a
Free TPrimExp Int64 VName
p) ExtLMAD
lmad
offsetMemoryInBodyReturns RebaseMap
_ BranchTypeMem
br = BranchTypeMem
br
offsetMemoryInLambda :: RebaseMap -> Lambda GPUMem -> OffsetM (Lambda GPUMem)
offsetMemoryInLambda :: RebaseMap -> Lambda GPUMem -> OffsetM (Lambda GPUMem)
offsetMemoryInLambda RebaseMap
offsets Lambda GPUMem
lam = do
body <- Lambda GPUMem -> OffsetM (Body GPUMem) -> OffsetM (Body GPUMem)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Lambda GPUMem
lam (OffsetM (Body GPUMem) -> OffsetM (Body GPUMem))
-> OffsetM (Body GPUMem) -> OffsetM (Body GPUMem)
forall a b. (a -> b) -> a -> b
$ RebaseMap -> Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBody RebaseMap
offsets (Body GPUMem -> OffsetM (Body GPUMem))
-> Body GPUMem -> OffsetM (Body GPUMem)
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
lam
let params = (Param (MemInfo SubExp NoUniqueness MemBind)
-> Param (MemInfo SubExp NoUniqueness MemBind))
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a b. (a -> b) -> [a] -> [b]
map (RebaseMap
-> Param (MemInfo SubExp NoUniqueness MemBind)
-> Param (MemInfo SubExp NoUniqueness MemBind)
forall u. RebaseMap -> Param (MemBound u) -> Param (MemBound u)
offsetMemoryInParam RebaseMap
offsets) ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)])
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam
pure $ lam {lambdaBody = body, lambdaParams = params}
offsetMemoryInLoopParams ::
RebaseMap ->
[(FParam GPUMem, SubExp)] ->
(RebaseMap -> [(FParam GPUMem, SubExp)] -> OffsetM a) ->
OffsetM a
offsetMemoryInLoopParams :: forall a.
RebaseMap
-> [(FParam GPUMem, SubExp)]
-> (RebaseMap -> [(FParam GPUMem, SubExp)] -> OffsetM a)
-> OffsetM a
offsetMemoryInLoopParams RebaseMap
offsets [(FParam GPUMem, SubExp)]
merge RebaseMap -> [(FParam GPUMem, SubExp)] -> OffsetM a
f = do
let ([Param (MemInfo SubExp Uniqueness MemBind)]
params, [SubExp]
args) = [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> ([Param (MemInfo SubExp Uniqueness MemBind)], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(FParam GPUMem, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
merge
params' <- [Param (MemInfo SubExp Uniqueness MemBind)]
-> OffsetM [Param (MemInfo SubExp Uniqueness MemBind)]
addParamsContext [Param (MemInfo SubExp Uniqueness MemBind)]
params
args' <- (args <>) <$> argsContext args
f offsets' $ zip params' args'
where
offsets' :: RebaseMap
offsets' = RebaseMap -> RebaseMap
extend RebaseMap
offsets
extend :: RebaseMap -> RebaseMap
extend RebaseMap
rm = (RebaseMap
-> (Param (MemInfo SubExp Uniqueness MemBind), SubExp)
-> RebaseMap)
-> RebaseMap
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> RebaseMap
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' RebaseMap
-> (Param (MemInfo SubExp Uniqueness MemBind), SubExp) -> RebaseMap
forall {a} {dec}. Map VName a -> (Param dec, SubExp) -> Map VName a
onParamArg RebaseMap
rm [(FParam GPUMem, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
merge
onParamArg :: Map VName a -> (Param dec, SubExp) -> Map VName a
onParamArg Map VName a
rm (Param dec
param, Var VName
arg)
| Just a
x <- VName -> Map VName a -> Maybe a
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
arg Map VName a
rm =
VName -> a -> Map VName a -> Map VName a
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
param) a
x Map VName a
rm
onParamArg Map VName a
rm (Param dec, SubExp)
_ = Map VName a
rm
offsetMemoryInExp :: RebaseMap -> Exp GPUMem -> OffsetM (Exp GPUMem)
offsetMemoryInExp :: RebaseMap -> Exp GPUMem -> OffsetM (Exp GPUMem)
offsetMemoryInExp RebaseMap
offsets = Mapper GPUMem GPUMem OffsetM -> Exp GPUMem -> OffsetM (Exp GPUMem)
forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper GPUMem GPUMem OffsetM
recurse
where
recurse :: Mapper GPUMem GPUMem OffsetM
recurse =
(forall rep (m :: * -> *). Monad m => Mapper rep rep m
identityMapper @GPUMem)
{ mapOnBody = \Scope GPUMem
bscope -> Scope GPUMem -> OffsetM (Body GPUMem) -> OffsetM (Body GPUMem)
forall a. Scope GPUMem -> OffsetM a -> OffsetM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPUMem
bscope (OffsetM (Body GPUMem) -> OffsetM (Body GPUMem))
-> (Body GPUMem -> OffsetM (Body GPUMem))
-> Body GPUMem
-> OffsetM (Body GPUMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RebaseMap -> Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBody RebaseMap
offsets,
mapOnBranchType = pure . offsetMemoryInBodyReturns offsets,
mapOnOp = onOp
}
onOp :: MemOp (HostOp NoOp) GPUMem -> OffsetM (MemOp (HostOp NoOp) GPUMem)
onOp (Inner (SegOp SegOp SegLevel GPUMem
op)) =
HostOp NoOp GPUMem -> MemOp (HostOp NoOp) GPUMem
forall (inner :: * -> *) rep. inner rep -> MemOp inner rep
Inner (HostOp NoOp GPUMem -> MemOp (HostOp NoOp) GPUMem)
-> (SegOp SegLevel GPUMem -> HostOp NoOp GPUMem)
-> SegOp SegLevel GPUMem
-> MemOp (HostOp NoOp) GPUMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp SegLevel GPUMem -> HostOp NoOp GPUMem
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp
(SegOp SegLevel GPUMem -> MemOp (HostOp NoOp) GPUMem)
-> OffsetM (SegOp SegLevel GPUMem)
-> OffsetM (MemOp (HostOp NoOp) GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Scope GPUMem
-> OffsetM (SegOp SegLevel GPUMem)
-> OffsetM (SegOp SegLevel GPUMem)
forall a. Scope GPUMem -> OffsetM a -> OffsetM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (SegSpace -> Scope GPUMem
forall rep. SegSpace -> Scope rep
scopeOfSegSpace (SegOp SegLevel GPUMem -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp SegLevel GPUMem
op)) (SegOpMapper SegLevel GPUMem GPUMem OffsetM
-> SegOp SegLevel GPUMem -> OffsetM (SegOp SegLevel GPUMem)
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper SegLevel GPUMem GPUMem OffsetM
segOpMapper SegOp SegLevel GPUMem
op)
where
segOpMapper :: SegOpMapper SegLevel GPUMem GPUMem OffsetM
segOpMapper =
SegOpMapper SegLevel (ZonkAny 1) (ZonkAny 1) OffsetM
forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper
{ mapOnSegOpBody = offsetMemoryInKernelBody offsets,
mapOnSegOpLambda = offsetMemoryInLambda offsets
}
onOp MemOp (HostOp NoOp) GPUMem
op = MemOp (HostOp NoOp) GPUMem -> OffsetM (MemOp (HostOp NoOp) GPUMem)
forall a. a -> OffsetM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure MemOp (HostOp NoOp) GPUMem
op
offsetMemoryInStm :: RebaseMap -> Stm GPUMem -> OffsetM (Stm GPUMem)
offsetMemoryInStm :: RebaseMap -> Stm GPUMem -> OffsetM (Stm GPUMem)
offsetMemoryInStm RebaseMap
offsets (Let Pat (LetDec GPUMem)
pat StmAux (ExpDec GPUMem)
dec (Match [SubExp]
cond [Case (Body GPUMem)]
cases Body GPUMem
defbody (MatchDec [BranchType GPUMem]
ts MatchSort
kind))) = do
cases' <- [Case (Body GPUMem)]
-> (Case (Body GPUMem) -> OffsetM (Case (Body GPUMem)))
-> OffsetM [Case (Body GPUMem)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Case (Body GPUMem)]
cases ((Case (Body GPUMem) -> OffsetM (Case (Body GPUMem)))
-> OffsetM [Case (Body GPUMem)])
-> (Case (Body GPUMem) -> OffsetM (Case (Body GPUMem)))
-> OffsetM [Case (Body GPUMem)]
forall a b. (a -> b) -> a -> b
$ \(Case [Maybe PrimValue]
vs Body GPUMem
body) ->
[Maybe PrimValue] -> Body GPUMem -> Case (Body GPUMem)
forall body. [Maybe PrimValue] -> body -> Case body
Case [Maybe PrimValue]
vs (Body GPUMem -> Case (Body GPUMem))
-> OffsetM (Body GPUMem) -> OffsetM (Case (Body GPUMem))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> RebaseMap -> Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBodyReturnCtx RebaseMap
offsets Body GPUMem
body
defbody' <- offsetMemoryInBodyReturnCtx offsets defbody
(pat', ts') <- offsetBranch pat ts
pure $ Let pat' dec $ Match cond cases' defbody' $ MatchDec ts' kind
offsetMemoryInStm RebaseMap
offsets (Let Pat (LetDec GPUMem)
pat StmAux (ExpDec GPUMem)
dec (Loop [(FParam GPUMem, SubExp)]
merge LoopForm
form Body GPUMem
body)) = do
loop' <-
RebaseMap
-> [(FParam GPUMem, SubExp)]
-> (RebaseMap -> [(FParam GPUMem, SubExp)] -> OffsetM (Exp GPUMem))
-> OffsetM (Exp GPUMem)
forall a.
RebaseMap
-> [(FParam GPUMem, SubExp)]
-> (RebaseMap -> [(FParam GPUMem, SubExp)] -> OffsetM a)
-> OffsetM a
offsetMemoryInLoopParams RebaseMap
offsets [(FParam GPUMem, SubExp)]
merge ((RebaseMap -> [(FParam GPUMem, SubExp)] -> OffsetM (Exp GPUMem))
-> OffsetM (Exp GPUMem))
-> (RebaseMap -> [(FParam GPUMem, SubExp)] -> OffsetM (Exp GPUMem))
-> OffsetM (Exp GPUMem)
forall a b. (a -> b) -> a -> b
$ \RebaseMap
offsets' [(FParam GPUMem, SubExp)]
merge' -> do
body' <-
Scope GPUMem -> OffsetM (Body GPUMem) -> OffsetM (Body GPUMem)
forall a. Scope GPUMem -> OffsetM a -> OffsetM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope
([FParam GPUMem] -> Scope GPUMem
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams (((FParam GPUMem, SubExp) -> FParam GPUMem)
-> [(FParam GPUMem, SubExp)] -> [FParam GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map (FParam GPUMem, SubExp) -> FParam GPUMem
forall a b. (a, b) -> a
fst [(FParam GPUMem, SubExp)]
merge') Scope GPUMem -> Scope GPUMem -> Scope GPUMem
forall a. Semigroup a => a -> a -> a
<> LoopForm -> Scope GPUMem
forall rep. LoopForm -> Scope rep
scopeOfLoopForm LoopForm
form)
(RebaseMap -> Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBodyReturnCtx RebaseMap
offsets' Body GPUMem
body)
pure $ Loop merge' form body'
pat' <- addPatternContext pat
pure $ Let pat' dec loop'
offsetMemoryInStm RebaseMap
offsets (Let Pat (LetDec GPUMem)
pat StmAux (ExpDec GPUMem)
dec Exp GPUMem
e) = do
e' <- RebaseMap -> Exp GPUMem -> OffsetM (Exp GPUMem)
offsetMemoryInExp RebaseMap
offsets Exp GPUMem
e
pat' <-
offsetMemoryInPat offsets pat
<$> ( maybe (throwError "offsetMemoryInStm: ill-typed") pure
=<< expReturns e'
)
scope <- askScope
rts <-
maybe (throwError "offsetMemoryInStm: ill-typed") pure $
runReader (expReturns e') scope
let pat'' = [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> Pat (MemInfo SubExp NoUniqueness MemBind)
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> Pat (MemInfo SubExp NoUniqueness MemBind))
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> Pat (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$ (PatElem (MemInfo SubExp NoUniqueness MemBind)
-> ExpReturns -> PatElem (MemInfo SubExp NoUniqueness MemBind))
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> [ExpReturns]
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PatElem (MemInfo SubExp NoUniqueness MemBind)
-> ExpReturns -> PatElem (MemInfo SubExp NoUniqueness MemBind)
forall {d} {u} {d} {u}.
PatElem (MemInfo d u MemBind)
-> MemInfo d u (Maybe MemReturn) -> PatElem (MemInfo d u MemBind)
pick (Pat (MemInfo SubExp NoUniqueness MemBind)
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (MemInfo SubExp NoUniqueness MemBind)
pat') [ExpReturns]
rts
pure $ Let pat'' dec e'
where
pick :: PatElem (MemInfo d u MemBind)
-> MemInfo d u (Maybe MemReturn) -> PatElem (MemInfo d u MemBind)
pick
(PatElem VName
name (MemArray PrimType
pt ShapeBase d
s u
u MemBind
_ret))
(MemArray PrimType
_ ShapeBase d
_ u
_ (Just (ReturnsInBlock VName
m ExtLMAD
extlmad)))
| Just LMAD
lmad <- ExtLMAD -> Maybe LMAD
instantiateLMAD ExtLMAD
extlmad =
VName -> MemInfo d u MemBind -> PatElem (MemInfo d u MemBind)
forall dec. VName -> dec -> PatElem dec
PatElem VName
name (PrimType -> ShapeBase d -> u -> MemBind -> MemInfo d u MemBind
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase d
s u
u (VName -> LMAD -> MemBind
ArrayIn VName
m LMAD
lmad))
pick PatElem (MemInfo d u MemBind)
p MemInfo d u (Maybe MemReturn)
_ = PatElem (MemInfo d u MemBind)
p
instantiateLMAD :: ExtLMAD -> Maybe LMAD
instantiateLMAD :: ExtLMAD -> Maybe LMAD
instantiateLMAD = (TPrimExp Int64 (Ext VName) -> Maybe (TPrimExp Int64 VName))
-> ExtLMAD -> Maybe LMAD
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> LMAD a -> f (LMAD b)
traverse ((Ext VName -> Maybe VName)
-> TPrimExp Int64 (Ext VName) -> Maybe (TPrimExp Int64 VName)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> TPrimExp Int64 a -> f (TPrimExp Int64 b)
traverse Ext VName -> Maybe VName
forall {a}. Ext a -> Maybe a
inst)
where
inst :: Ext a -> Maybe a
inst Ext {} = Maybe a
forall a. Maybe a
Nothing
inst (Free a
x) = a -> Maybe a
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x
unAllocGPUStms :: Stms GPUMem -> Either String (Stms GPU.GPU)
unAllocGPUStms :: Stms GPUMem -> Either String (Stms GPU)
unAllocGPUStms = Bool -> Stms GPUMem -> Either String (Stms GPU)
unAllocStms Bool
False
where
unAllocBody :: Body GPUMem -> Either String (Body GPU)
unAllocBody (Body BodyDec GPUMem
dec Stms GPUMem
stms Result
res) =
BodyDec GPU -> Stms GPU -> Result -> Body GPU
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec GPU
BodyDec GPUMem
dec (Stms GPU -> Result -> Body GPU)
-> Either String (Stms GPU) -> Either String (Result -> Body GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Bool -> Stms GPUMem -> Either String (Stms GPU)
unAllocStms Bool
True Stms GPUMem
stms Either String (Result -> Body GPU)
-> Either String Result -> Either String (Body GPU)
forall a b.
Either String (a -> b) -> Either String a -> Either String b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> Either String Result
forall a. a -> Either String a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res
unAllocKernelBody :: KernelBody GPUMem -> Either String (KernelBody GPU)
unAllocKernelBody (KernelBody BodyDec GPUMem
dec Stms GPUMem
stms [KernelResult]
res) =
BodyDec GPU -> Stms GPU -> [KernelResult] -> KernelBody GPU
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec GPU
BodyDec GPUMem
dec (Stms GPU -> [KernelResult] -> KernelBody GPU)
-> Either String (Stms GPU)
-> Either String ([KernelResult] -> KernelBody GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Bool -> Stms GPUMem -> Either String (Stms GPU)
unAllocStms Bool
True Stms GPUMem
stms Either String ([KernelResult] -> KernelBody GPU)
-> Either String [KernelResult] -> Either String (KernelBody GPU)
forall a b.
Either String (a -> b) -> Either String a -> Either String b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [KernelResult] -> Either String [KernelResult]
forall a. a -> Either String a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [KernelResult]
res
unAllocStms :: Bool -> Stms GPUMem -> Either String (Stms GPU)
unAllocStms Bool
nested = (Stm GPUMem -> Either String (Stm GPU))
-> Stms GPUMem -> Either String (Stms GPU)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Seq a -> m (Seq b)
mapM (Bool -> Stm GPUMem -> Either String (Stm GPU)
unAllocStm Bool
nested)
unAllocStm :: Bool -> Stm GPUMem -> Either String (Stm GPU)
unAllocStm Bool
nested stm :: Stm GPUMem
stm@(Let Pat (LetDec GPUMem)
pat StmAux (ExpDec GPUMem)
dec (Op Alloc {}))
| Bool
nested =
String -> Either String (Stm GPU)
forall a. String -> Either String a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String -> Either String (Stm GPU))
-> String -> Either String (Stm GPU)
forall a b. (a -> b) -> a -> b
$ String
"Cannot handle nested allocation: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Stm GPUMem -> String
forall a. Pretty a => a -> String
prettyString Stm GPUMem
stm
| Bool
otherwise =
Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
-> StmAux () -> Exp GPU -> Stm GPU
Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let
(Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
-> StmAux () -> Exp GPU -> Stm GPU)
-> Either String (Pat (TypeBase (ShapeBase SubExp) NoUniqueness))
-> Either String (StmAux () -> Exp GPU -> Stm GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pat (MemInfo SubExp NoUniqueness MemBind)
-> Either String (Pat (TypeBase (ShapeBase SubExp) NoUniqueness))
forall {d} {u} {ret} {a}.
Pat (MemInfo d u ret) -> Either a (Pat (TypeBase (ShapeBase d) u))
unAllocPat Pat (LetDec GPUMem)
Pat (MemInfo SubExp NoUniqueness MemBind)
pat
Either String (StmAux () -> Exp GPU -> Stm GPU)
-> Either String (StmAux ()) -> Either String (Exp GPU -> Stm GPU)
forall a b.
Either String (a -> b) -> Either String a -> Either String b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> StmAux () -> Either String (StmAux ())
forall a. a -> Either String a
forall (f :: * -> *) a. Applicative f => a -> f a
pure StmAux ()
StmAux (ExpDec GPUMem)
dec
Either String (Exp GPU -> Stm GPU)
-> Either String (Exp GPU) -> Either String (Stm GPU)
forall a b.
Either String (a -> b) -> Either String a -> Either String b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp GPU -> Either String (Exp GPU)
forall a. a -> Either String a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
UnitValue))
unAllocStm Bool
_ (Let Pat (LetDec GPUMem)
pat StmAux (ExpDec GPUMem)
dec Exp GPUMem
e) =
Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
-> StmAux () -> Exp GPU -> Stm GPU
Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
-> StmAux () -> Exp GPU -> Stm GPU)
-> Either String (Pat (TypeBase (ShapeBase SubExp) NoUniqueness))
-> Either String (StmAux () -> Exp GPU -> Stm GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pat (MemInfo SubExp NoUniqueness MemBind)
-> Either String (Pat (TypeBase (ShapeBase SubExp) NoUniqueness))
forall {d} {u} {ret} {a}.
Pat (MemInfo d u ret) -> Either a (Pat (TypeBase (ShapeBase d) u))
unAllocPat Pat (LetDec GPUMem)
Pat (MemInfo SubExp NoUniqueness MemBind)
pat Either String (StmAux () -> Exp GPU -> Stm GPU)
-> Either String (StmAux ()) -> Either String (Exp GPU -> Stm GPU)
forall a b.
Either String (a -> b) -> Either String a -> Either String b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> StmAux () -> Either String (StmAux ())
forall a. a -> Either String a
forall (f :: * -> *) a. Applicative f => a -> f a
pure StmAux ()
StmAux (ExpDec GPUMem)
dec Either String (Exp GPU -> Stm GPU)
-> Either String (Exp GPU) -> Either String (Stm GPU)
forall a b.
Either String (a -> b) -> Either String a -> Either String b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Mapper GPUMem GPU (Either String)
-> Exp GPUMem -> Either String (Exp GPU)
forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper GPUMem GPU (Either String)
unAlloc' Exp GPUMem
e
unAllocLambda :: Lambda GPUMem -> Either String (Lambda GPU)
unAllocLambda (Lambda [LParam GPUMem]
params [TypeBase (ShapeBase SubExp) NoUniqueness]
ret Body GPUMem
body) =
[LParam GPU]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Body GPU
-> Lambda GPU
forall rep.
[LParam rep]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Body rep
-> Lambda rep
Lambda ((Param (MemInfo SubExp NoUniqueness MemBind)
-> Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a b. (a -> b) -> [a] -> [b]
map Param (MemInfo SubExp NoUniqueness MemBind)
-> Param (TypeBase (ShapeBase SubExp) NoUniqueness)
forall {d} {u} {ret}.
Param (MemInfo d u ret) -> Param (TypeBase (ShapeBase d) u)
unParam [LParam GPUMem]
[Param (MemInfo SubExp NoUniqueness MemBind)]
params) [TypeBase (ShapeBase SubExp) NoUniqueness]
ret (Body GPU -> Lambda GPU)
-> Either String (Body GPU) -> Either String (Lambda GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body GPUMem -> Either String (Body GPU)
unAllocBody Body GPUMem
body
unAllocPat :: Pat (MemInfo d u ret) -> Either a (Pat (TypeBase (ShapeBase d) u))
unAllocPat (Pat [PatElem (MemInfo d u ret)]
pes) =
[PatElem (TypeBase (ShapeBase d) u)]
-> Pat (TypeBase (ShapeBase d) u)
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem (TypeBase (ShapeBase d) u)]
-> Pat (TypeBase (ShapeBase d) u))
-> Either a [PatElem (TypeBase (ShapeBase d) u)]
-> Either a (Pat (TypeBase (ShapeBase d) u))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (PatElem (MemInfo d u ret)
-> Either a (PatElem (TypeBase (ShapeBase d) u)))
-> [PatElem (MemInfo d u ret)]
-> Either a [PatElem (TypeBase (ShapeBase d) u)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((MemInfo d u ret -> Either a (TypeBase (ShapeBase d) u))
-> PatElem (MemInfo d u ret)
-> Either a (PatElem (TypeBase (ShapeBase d) u))
forall (m :: * -> *) from to.
Monad m =>
(from -> m to) -> PatElem from -> m (PatElem to)
rephrasePatElem (TypeBase (ShapeBase d) u -> Either a (TypeBase (ShapeBase d) u)
forall a b. b -> Either a b
Right (TypeBase (ShapeBase d) u -> Either a (TypeBase (ShapeBase d) u))
-> (MemInfo d u ret -> TypeBase (ShapeBase d) u)
-> MemInfo d u ret
-> Either a (TypeBase (ShapeBase d) u)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemInfo d u ret -> TypeBase (ShapeBase d) u
forall d u ret. MemInfo d u ret -> TypeBase (ShapeBase d) u
unMem)) [PatElem (MemInfo d u ret)]
pes
unAllocOp :: MemOp (HostOp NoOp) GPUMem -> Either String (HostOp SOAC GPU)
unAllocOp Alloc {} = String -> Either String (HostOp SOAC GPU)
forall a b. a -> Either a b
Left String
"unAllocOp: unhandled Alloc"
unAllocOp (Inner OtherOp {}) = String -> Either String (HostOp SOAC GPU)
forall a b. a -> Either a b
Left String
"unAllocOp: unhandled OtherOp"
unAllocOp (Inner GPUBody {}) = String -> Either String (HostOp SOAC GPU)
forall a b. a -> Either a b
Left String
"unAllocOp: unhandled GPUBody"
unAllocOp (Inner (SizeOp SizeOp
op)) = HostOp SOAC GPU -> Either String (HostOp SOAC GPU)
forall a. a -> Either String a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (HostOp SOAC GPU -> Either String (HostOp SOAC GPU))
-> HostOp SOAC GPU -> Either String (HostOp SOAC GPU)
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp SOAC GPU
forall (op :: * -> *) rep. SizeOp -> HostOp op rep
SizeOp SizeOp
op
unAllocOp (Inner (SegOp SegOp SegLevel GPUMem
op)) = SegOp SegLevel GPU -> HostOp SOAC GPU
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp (SegOp SegLevel GPU -> HostOp SOAC GPU)
-> Either String (SegOp SegLevel GPU)
-> Either String (HostOp SOAC GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper SegLevel GPUMem GPU (Either String)
-> SegOp SegLevel GPUMem -> Either String (SegOp SegLevel GPU)
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper SegLevel GPUMem GPU (Either String)
mapper SegOp SegLevel GPUMem
op
where
mapper :: SegOpMapper SegLevel GPUMem GPU (Either String)
mapper =
SegOpMapper SegLevel (ZonkAny 0) (ZonkAny 0) (Either String)
forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper
{ mapOnSegOpLambda = unAllocLambda,
mapOnSegOpBody = unAllocKernelBody
}
unParam :: Param (MemInfo d u ret) -> Param (TypeBase (ShapeBase d) u)
unParam = (MemInfo d u ret -> TypeBase (ShapeBase d) u)
-> Param (MemInfo d u ret) -> Param (TypeBase (ShapeBase d) u)
forall a b. (a -> b) -> Param a -> Param b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap MemInfo d u ret -> TypeBase (ShapeBase d) u
forall d u ret. MemInfo d u ret -> TypeBase (ShapeBase d) u
unMem
unT :: MemInfo d u ret -> Either a (TypeBase (ShapeBase d) u)
unT = TypeBase (ShapeBase d) u -> Either a (TypeBase (ShapeBase d) u)
forall a b. b -> Either a b
Right (TypeBase (ShapeBase d) u -> Either a (TypeBase (ShapeBase d) u))
-> (MemInfo d u ret -> TypeBase (ShapeBase d) u)
-> MemInfo d u ret
-> Either a (TypeBase (ShapeBase d) u)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemInfo d u ret -> TypeBase (ShapeBase d) u
forall d u ret. MemInfo d u ret -> TypeBase (ShapeBase d) u
unMem
unAlloc' :: Mapper GPUMem GPU (Either String)
unAlloc' =
Mapper
{ mapOnBody :: Scope GPU -> Body GPUMem -> Either String (Body GPU)
mapOnBody = (Body GPUMem -> Either String (Body GPU))
-> Scope GPU -> Body GPUMem -> Either String (Body GPU)
forall a b. a -> b -> a
const Body GPUMem -> Either String (Body GPU)
unAllocBody,
mapOnRetType :: RetType GPUMem -> Either String (RetType GPU)
mapOnRetType = RetType GPUMem -> Either String (RetType GPU)
RetTypeMem -> Either String DeclExtType
forall {d} {u} {ret} {a}.
MemInfo d u ret -> Either a (TypeBase (ShapeBase d) u)
unT,
mapOnBranchType :: BranchType GPUMem -> Either String (BranchType GPU)
mapOnBranchType = BranchType GPUMem -> Either String (BranchType GPU)
BranchTypeMem -> Either String ExtType
forall {d} {u} {ret} {a}.
MemInfo d u ret -> Either a (TypeBase (ShapeBase d) u)
unT,
mapOnFParam :: FParam GPUMem -> Either String (FParam GPU)
mapOnFParam = Param (TypeBase (ShapeBase SubExp) Uniqueness)
-> Either String (Param (TypeBase (ShapeBase SubExp) Uniqueness))
forall a b. b -> Either a b
Right (Param (TypeBase (ShapeBase SubExp) Uniqueness)
-> Either String (Param (TypeBase (ShapeBase SubExp) Uniqueness)))
-> (Param (MemInfo SubExp Uniqueness MemBind)
-> Param (TypeBase (ShapeBase SubExp) Uniqueness))
-> Param (MemInfo SubExp Uniqueness MemBind)
-> Either String (Param (TypeBase (ShapeBase SubExp) Uniqueness))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (MemInfo SubExp Uniqueness MemBind)
-> Param (TypeBase (ShapeBase SubExp) Uniqueness)
forall {d} {u} {ret}.
Param (MemInfo d u ret) -> Param (TypeBase (ShapeBase d) u)
unParam,
mapOnLParam :: LParam GPUMem -> Either String (LParam GPU)
mapOnLParam = Param (TypeBase (ShapeBase SubExp) NoUniqueness)
-> Either String (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
forall a b. b -> Either a b
Right (Param (TypeBase (ShapeBase SubExp) NoUniqueness)
-> Either
String (Param (TypeBase (ShapeBase SubExp) NoUniqueness)))
-> (Param (MemInfo SubExp NoUniqueness MemBind)
-> Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> Param (MemInfo SubExp NoUniqueness MemBind)
-> Either String (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (MemInfo SubExp NoUniqueness MemBind)
-> Param (TypeBase (ShapeBase SubExp) NoUniqueness)
forall {d} {u} {ret}.
Param (MemInfo d u ret) -> Param (TypeBase (ShapeBase d) u)
unParam,
mapOnOp :: Op GPUMem -> Either String (Op GPU)
mapOnOp = Op GPUMem -> Either String (Op GPU)
MemOp (HostOp NoOp) GPUMem -> Either String (HostOp SOAC GPU)
unAllocOp,
mapOnSubExp :: SubExp -> Either String SubExp
mapOnSubExp = SubExp -> Either String SubExp
forall a b. b -> Either a b
Right,
mapOnVName :: VName -> Either String VName
mapOnVName = VName -> Either String VName
forall a b. b -> Either a b
Right
}
unMem :: MemInfo d u ret -> TypeBase (ShapeBase d) u
unMem :: forall d u ret. MemInfo d u ret -> TypeBase (ShapeBase d) u
unMem (MemPrim PrimType
pt) = PrimType -> TypeBase (ShapeBase d) u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
pt
unMem (MemArray PrimType
pt ShapeBase d
shape u
u ret
_) = PrimType -> ShapeBase d -> u -> TypeBase (ShapeBase d) u
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt ShapeBase d
shape u
u
unMem (MemAcc VName
acc ShapeBase SubExp
ispace [TypeBase (ShapeBase SubExp) NoUniqueness]
ts u
u) = VName
-> ShapeBase SubExp
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> u
-> TypeBase (ShapeBase d) u
forall shape u.
VName
-> ShapeBase SubExp
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> u
-> TypeBase shape u
Acc VName
acc ShapeBase SubExp
ispace [TypeBase (ShapeBase SubExp) NoUniqueness]
ts u
u
unMem MemMem {} = PrimType -> TypeBase (ShapeBase d) u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit
unAllocScope :: Scope GPUMem -> Scope GPU.GPU
unAllocScope :: Scope GPUMem -> Scope GPU
unAllocScope = (NameInfo GPUMem -> NameInfo GPU) -> Scope GPUMem -> Scope GPU
forall a b k. (a -> b) -> Map k a -> Map k b
M.map NameInfo GPUMem -> NameInfo GPU
forall {rep} {d} {u} {rep} {ret} {d} {u} {ret} {d} {u} {ret}.
(LetDec rep ~ TypeBase (ShapeBase d) u,
LetDec rep ~ MemInfo d u ret,
FParamInfo rep ~ TypeBase (ShapeBase d) u,
FParamInfo rep ~ MemInfo d u ret,
LParamInfo rep ~ TypeBase (ShapeBase d) u,
LParamInfo rep ~ MemInfo d u ret) =>
NameInfo rep -> NameInfo rep
unInfo
where
unInfo :: NameInfo rep -> NameInfo rep
unInfo (LetName LetDec rep
dec) = LetDec rep -> NameInfo rep
forall rep. LetDec rep -> NameInfo rep
LetName (LetDec rep -> NameInfo rep) -> LetDec rep -> NameInfo rep
forall a b. (a -> b) -> a -> b
$ MemInfo d u ret -> TypeBase (ShapeBase d) u
forall d u ret. MemInfo d u ret -> TypeBase (ShapeBase d) u
unMem LetDec rep
MemInfo d u ret
dec
unInfo (FParamName FParamInfo rep
dec) = FParamInfo rep -> NameInfo rep
forall rep. FParamInfo rep -> NameInfo rep
FParamName (FParamInfo rep -> NameInfo rep) -> FParamInfo rep -> NameInfo rep
forall a b. (a -> b) -> a -> b
$ MemInfo d u ret -> TypeBase (ShapeBase d) u
forall d u ret. MemInfo d u ret -> TypeBase (ShapeBase d) u
unMem FParamInfo rep
MemInfo d u ret
dec
unInfo (LParamName LParamInfo rep
dec) = LParamInfo rep -> NameInfo rep
forall rep. LParamInfo rep -> NameInfo rep
LParamName (LParamInfo rep -> NameInfo rep) -> LParamInfo rep -> NameInfo rep
forall a b. (a -> b) -> a -> b
$ MemInfo d u ret -> TypeBase (ShapeBase d) u
forall d u ret. MemInfo d u ret -> TypeBase (ShapeBase d) u
unMem LParamInfo rep
MemInfo d u ret
dec
unInfo (IndexName IntType
it) = IntType -> NameInfo rep
forall rep. IntType -> NameInfo rep
IndexName IntType
it
removeCommonSizes :: Extraction -> [(SubExp, [(VName, Space)])]
removeCommonSizes :: Extraction -> [(SubExp, [(VName, Space)])]
removeCommonSizes = Map SubExp [(VName, Space)] -> [(SubExp, [(VName, Space)])]
forall k a. Map k a -> [(k, a)]
M.toList (Map SubExp [(VName, Space)] -> [(SubExp, [(VName, Space)])])
-> (Extraction -> Map SubExp [(VName, Space)])
-> Extraction
-> [(SubExp, [(VName, Space)])]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Map SubExp [(VName, Space)]
-> (VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))
-> Map SubExp [(VName, Space)])
-> Map SubExp [(VName, Space)]
-> [(VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))]
-> Map SubExp [(VName, Space)]
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Map SubExp [(VName, Space)]
-> (VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))
-> Map SubExp [(VName, Space)]
forall {k} {a} {b} {a}.
Ord k =>
Map k [(a, b)] -> (a, (a, k, b)) -> Map k [(a, b)]
comb Map SubExp [(VName, Space)]
forall a. Monoid a => a
mempty ([(VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))]
-> Map SubExp [(VName, Space)])
-> (Extraction
-> [(VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))])
-> Extraction
-> Map SubExp [(VName, Space)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Extraction
-> [(VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))]
forall k a. Map k a -> [(k, a)]
M.toList
where
comb :: Map k [(a, b)] -> (a, (a, k, b)) -> Map k [(a, b)]
comb Map k [(a, b)]
m (a
mem, (a
_, k
size, b
space)) = ([(a, b)] -> [(a, b)] -> [(a, b)])
-> k -> [(a, b)] -> Map k [(a, b)] -> Map k [(a, b)]
forall k a. Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
M.insertWith [(a, b)] -> [(a, b)] -> [(a, b)]
forall a. [a] -> [a] -> [a]
(++) k
size [(a
mem, b
space)] Map k [(a, b)]
m
copyConsumed :: (MonadBuilder m, AliasableRep (Rep m)) => Stms (Rep m) -> m (Stms (Rep m))
copyConsumed :: forall (m :: * -> *).
(MonadBuilder m, AliasableRep (Rep m)) =>
Stms (Rep m) -> m (Stms (Rep m))
copyConsumed Stms (Rep m)
stms = do
let consumed :: [VName]
consumed = Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ (AliasTable, Names) -> Names
forall a b. (a, b) -> b
snd ((AliasTable, Names) -> Names) -> (AliasTable, Names) -> Names
forall a b. (a -> b) -> a -> b
$ (Stms (Aliases (Rep m)), (AliasTable, Names))
-> (AliasTable, Names)
forall a b. (a, b) -> b
snd ((Stms (Aliases (Rep m)), (AliasTable, Names))
-> (AliasTable, Names))
-> (Stms (Aliases (Rep m)), (AliasTable, Names))
-> (AliasTable, Names)
forall a b. (a -> b) -> a -> b
$ AliasTable
-> Stms (Rep m) -> (Stms (Aliases (Rep m)), (AliasTable, Names))
forall rep.
AliasableRep rep =>
AliasTable -> Stms rep -> (Stms (Aliases rep), (AliasTable, Names))
Alias.analyseStms AliasTable
forall a. Monoid a => a
mempty Stms (Rep m)
stms
m () -> m (Stms (Rep m))
forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ (m () -> m (Stms (Rep m))) -> m () -> m (Stms (Rep m))
forall a b. (a -> b) -> a -> b
$ do
consumed' <- (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> m VName
forall {m :: * -> *}. MonadBuilder m => VName -> m VName
copy [VName]
consumed
let substs = [(VName, VName)] -> Map VName VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
consumed [VName]
consumed')
addStms $ substituteNames substs stms
where
copy :: VName -> m VName
copy VName
v = String -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
v String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_copy") (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> SubExp -> BasicOp
Replicate ShapeBase SubExp
forall a. Monoid a => a
mempty (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
sliceKernelSizes ::
SubExp ->
[SubExp] ->
SegSpace ->
Stms GPUMem ->
ExpandM (Stms GPU.GPU, [VName], [VName])
sliceKernelSizes :: SubExp
-> [SubExp]
-> SegSpace
-> Stms GPUMem
-> ExpandM (Stms GPU, [VName], [VName])
sliceKernelSizes SubExp
num_threads [SubExp]
sizes SegSpace
space Stms GPUMem
kstms = do
kstms' <- (String
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPU))
-> (Stms GPU
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPU))
-> Either String (Stms GPU)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPU)
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either String
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPU)
forall a.
String
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError Stms GPU
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPU)
forall a.
a -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either String (Stms GPU)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPU))
-> Either String (Stms GPU)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stms GPUMem -> Either String (Stms GPU)
unAllocGPUStms Stms GPUMem
kstms
let num_sizes = [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
sizes
i64s = Int
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a. Int -> a -> [a]
replicate Int
num_sizes (TypeBase (ShapeBase SubExp) NoUniqueness
-> [TypeBase (ShapeBase SubExp) NoUniqueness])
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
kernels_scope <- asks unAllocScope
(max_lam, _) <- flip runBuilderT kernels_scope $ do
xs <- replicateM num_sizes $ newParam "x" (Prim int64)
ys <- replicateM num_sizes $ newParam "y" (Prim int64)
(zs, stms) <- localScope (scopeOfLParams $ xs ++ ys) $
collectStms $
forM (zip xs ys) $ \(Param (TypeBase (ShapeBase SubExp) NoUniqueness)
x, Param (TypeBase (ShapeBase SubExp) NoUniqueness)
y) ->
(SubExp -> SubExpRes)
-> BuilderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
SubExp
-> BuilderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
SubExpRes
forall a b.
(a -> b)
-> BuilderT
GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String))) a
-> BuilderT
GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String))) b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> SubExpRes
subExpRes (BuilderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
SubExp
-> BuilderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
SubExpRes)
-> (BasicOp
-> BuilderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
SubExp)
-> BasicOp
-> BuilderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
SubExpRes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String
-> Exp
(Rep
(BuilderT
GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))
-> BuilderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"z" (Exp GPU
-> BuilderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
SubExp)
-> (BasicOp -> Exp GPU)
-> BasicOp
-> BuilderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp
-> BuilderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
SubExpRes)
-> BasicOp
-> BuilderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
SubExpRes
forall a b. (a -> b) -> a -> b
$
BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMax IntType
Int64) (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase (ShapeBase SubExp) NoUniqueness)
x) (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase (ShapeBase SubExp) NoUniqueness)
y)
pure $ Lambda (xs ++ ys) i64s (mkBody stms zs)
flat_gtid_lparam <- newParam "flat_gtid" (Prim (IntType Int64))
size_lam' <- localScope (scopeOfSegSpace space) . fmap fst . flip runBuilderT kernels_scope $
GPU.simplifyLambda <=< mkLambda [flat_gtid_lparam] $ do
let (kspace_gtids, kspace_dims) = unzip $ unSegSpace space
new_inds =
[TPrimExp Int64 VName]
-> TPrimExp Int64 VName -> [TPrimExp Int64 VName]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex
((SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
kspace_dims)
(SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName) -> SubExp -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase (ShapeBase SubExp) NoUniqueness)
flat_gtid_lparam)
zipWithM_ letBindNames (map pure kspace_gtids) =<< mapM toExp new_inds
mapM_ addStm =<< copyConsumed kstms'
pure $ subExpsRes sizes
((maxes_per_thread, size_sums), slice_stms) <- flip runBuilderT kernels_scope $ do
pat <-
basicPat <$> replicateM num_sizes (newIdent "max_per_thread" $ Prim int64)
w <-
letSubExp "size_slice_w"
=<< foldBinOp (Mul Int64 OverflowUndef) (intConst Int64 1) (segSpaceDims space)
thread_space_iota <-
letExp "thread_space_iota" $
BasicOp $
Iota w (intConst Int64 0) (intConst Int64 1) Int64
let red_op =
Commutativity
-> Lambda GPU -> [SubExp] -> ShapeBase SubExp -> SegBinOp GPU
forall rep.
Commutativity
-> Lambda rep -> [SubExp] -> ShapeBase SubExp -> SegBinOp rep
SegBinOp
Commutativity
Commutative
Lambda GPU
max_lam
(Int -> SubExp -> [SubExp]
forall a. Int -> a -> [a]
replicate Int
num_sizes (SubExp -> [SubExp]) -> SubExp -> [SubExp]
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0)
ShapeBase SubExp
forall a. Monoid a => a
mempty
lvl <- segThread "segred"
addStms
=<< mapM renameStm
=<< nonSegRed lvl pat w [red_op] size_lam' [thread_space_iota]
size_sums <- forM (patNames pat) $ \VName
threads_max ->
String
-> Exp
(Rep
(BuilderT
GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))
-> BuilderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"size_sum" (Exp
(Rep
(BuilderT
GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))
-> BuilderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
VName)
-> Exp
(Rep
(BuilderT
GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))
-> BuilderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
VName
forall a b. (a -> b) -> a -> b
$
BasicOp
-> Exp
(Rep
(BuilderT
GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp
-> Exp
(Rep
(BuilderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either String))))))
-> BasicOp
-> Exp
(Rep
(BuilderT
GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))
forall a b. (a -> b) -> a -> b
$
BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (VName -> SubExp
Var VName
threads_max) SubExp
num_threads
pure (patNames pat, size_sums)
pure (slice_stms, maxes_per_thread, size_sums)