{-# LANGUAGE TypeFamilies #-}
module Futhark.Pass.ExtractKernels.Intrablock (intrablockParallelise) where
import Control.Monad
import Control.Monad.RWS
import Control.Monad.Trans.Maybe
import Data.Map.Strict qualified as M
import Data.Set qualified as S
import Futhark.Analysis.PrimExp.Convert
import Futhark.IR.GPU hiding (HistOp)
import Futhark.IR.GPU.Op qualified as GPU
import Futhark.IR.SOACS
import Futhark.MonadFreshNames
import Futhark.Pass.ExtractKernels.BlockedKernel
import Futhark.Pass.ExtractKernels.DistributeNests
import Futhark.Pass.ExtractKernels.Distribution
import Futhark.Pass.ExtractKernels.ToGPU
import Futhark.Tools
import Futhark.Transform.FirstOrderTransform qualified as FOT
import Futhark.Util.Log
import Prelude hiding (log)
intrablockParallelise ::
(MonadFreshNames m, LocalScope GPU m) =>
KernelNest ->
Lambda SOACS ->
m
( Maybe
( (SubExp, SubExp),
SubExp,
Log,
Stms GPU,
Stms GPU
)
)
intrablockParallelise :: forall (m :: * -> *).
(MonadFreshNames m, LocalScope GPU m) =>
KernelNest
-> Lambda SOACS
-> m (Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU))
intrablockParallelise KernelNest
knest Lambda SOACS
lam = MaybeT m ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
-> m (Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU))
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT (MaybeT m ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
-> m (Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)))
-> MaybeT m ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
-> m (Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU))
forall a b. (a -> b) -> a -> b
$ do
(ispace, inps) <- m ([(VName, SubExp)], [KernelInput])
-> MaybeT m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *) a. Monad m => m a -> MaybeT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m ([(VName, SubExp)], [KernelInput])
-> MaybeT m ([(VName, SubExp)], [KernelInput]))
-> m ([(VName, SubExp)], [KernelInput])
-> MaybeT m ([(VName, SubExp)], [KernelInput])
forall a b. (a -> b) -> a -> b
$ KernelNest -> m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
knest
(num_tblocks, w_stms) <-
lift $
runBuilder $
letSubExp "intra_num_tblocks"
=<< foldBinOp (Mul Int64 OverflowUndef) (intConst Int64 1) (map snd ispace)
let body = Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam
tblock_size <- newVName "computed_tblock_size"
(wss_min, wss_avail, log, kbody) <-
lift . localScope (scopeOfLParams $ lambdaParams lam) $
intrablockParalleliseBody body
outside_scope <- lift askScope
let available VName
v =
VName
v VName -> Scope GPU -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.member` Scope GPU
outside_scope
Bool -> Bool -> Bool
&& VName
v VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` (KernelInput -> VName) -> [KernelInput] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map KernelInput -> VName
kernelInputName [KernelInput]
inps
unless (all available $ namesToList $ freeIn (wss_min ++ wss_avail)) $
fail "Irregular parallelism"
((intra_avail_par, kspace, read_input_stms), prelude_stms) <- lift $
runBuilder $ do
let foldBinOp' BinOp
_ [] = SubExp -> m (Exp (Rep m))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> m (Exp (Rep m))) -> SubExp -> m (Exp (Rep m))
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1
foldBinOp' BinOp
bop (SubExp
x : [SubExp]
xs) = BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp BinOp
bop SubExp
x [SubExp]
xs
ws_min <-
mapM (letSubExp "one_intra_par_min" <=< foldBinOp' (Mul Int64 OverflowUndef)) $
filter (not . null) wss_min
ws_avail <-
mapM (letSubExp "one_intra_par_avail" <=< foldBinOp' (Mul Int64 OverflowUndef)) $
filter (not . null) wss_avail
intra_avail_par <-
letSubExp "intra_avail_par" =<< foldBinOp' (SMin Int64) ws_avail
letBindNames [tblock_size]
=<< if null ws_min
then
eBinOp
(SMin Int64)
(eSubExp =<< letSubExp "max_tblock_size" (Op $ SizeOp $ GetSizeMax SizeThreadBlock))
(eSubExp intra_avail_par)
else foldBinOp' (SMax Int64) ws_min
let inputIsUsed KernelInput
input = KernelInput -> VName
kernelInputName KernelInput
input VName -> Names -> Bool
`nameIn` Body SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn Body SOACS
body
used_inps = (KernelInput -> Bool) -> [KernelInput] -> [KernelInput]
forall a. (a -> Bool) -> [a] -> [a]
filter KernelInput -> Bool
inputIsUsed [KernelInput]
inps
addStms w_stms
read_input_stms <- runBuilder_ $ mapM readGroupKernelInput used_inps
space <- SegSpace <$> newVName "phys_tblock_id" <*> pure ispace
pure (intra_avail_par, space, read_input_stms)
let kbody' = KernelBody GPU
kbody {kernelBodyStms = read_input_stms <> kernelBodyStms kbody}
let nested_pat = LoopNesting -> Pat Type
loopNestingPat LoopNesting
first_nest
rts = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map ([(VName, SubExp)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(VName, SubExp)]
ispace Int -> Type -> Type
forall u. Int -> TypeBase Shape u -> TypeBase Shape u
`stripArray`) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Pat Type -> [Type]
forall dec. Typed dec => Pat dec -> [Type]
patTypes Pat Type
nested_pat
grid = Count NumBlocks SubExp -> Count BlockSize SubExp -> KernelGrid
KernelGrid (SubExp -> Count NumBlocks SubExp
forall {k} (u :: k) e. e -> Count u e
Count SubExp
num_tblocks) (SubExp -> Count BlockSize SubExp
forall {k} (u :: k) e. e -> Count u e
Count (SubExp -> Count BlockSize SubExp)
-> SubExp -> Count BlockSize SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
tblock_size)
lvl = SegVirt -> Maybe KernelGrid -> SegLevel
SegBlock SegVirt
SegNoVirt (KernelGrid -> Maybe KernelGrid
forall a. a -> Maybe a
Just KernelGrid
grid)
kstm =
Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec GPU)
nested_pat StmAux ()
StmAux (ExpDec GPU)
aux (Exp GPU -> Stm GPU) -> Exp GPU -> Stm GPU
forall a b. (a -> b) -> a -> b
$ Op GPU -> Exp GPU
forall rep. Op rep -> Exp rep
Op (Op GPU -> Exp GPU) -> Op GPU -> Exp GPU
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPU -> HostOp SOAC GPU
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp (SegOp SegLevel GPU -> HostOp SOAC GPU)
-> SegOp SegLevel GPU -> HostOp SOAC GPU
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace -> [Type] -> KernelBody GPU -> SegOp SegLevel GPU
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
lvl SegSpace
kspace [Type]
rts KernelBody GPU
kbody'
let intra_min_par = SubExp
intra_avail_par
pure
( (intra_min_par, intra_avail_par),
Var tblock_size,
log,
prelude_stms,
oneStm kstm
)
where
first_nest :: LoopNesting
first_nest = KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
knest
aux :: StmAux ()
aux = LoopNesting -> StmAux ()
loopNestingAux LoopNesting
first_nest
readGroupKernelInput ::
(DistRep (Rep m), MonadBuilder m) =>
KernelInput ->
m ()
readGroupKernelInput :: forall (m :: * -> *).
(DistRep (Rep m), MonadBuilder m) =>
KernelInput -> m ()
readGroupKernelInput KernelInput
inp
| Array {} <- KernelInput -> Type
kernelInputType KernelInput
inp = do
v <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> m VName) -> String -> m VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString (VName -> String) -> VName -> String
forall a b. (a -> b) -> a -> b
$ KernelInput -> VName
kernelInputName KernelInput
inp
readKernelInput inp {kernelInputName = v}
letBindNames [kernelInputName inp] $ BasicOp $ Replicate mempty $ Var v
| Bool
otherwise =
KernelInput -> m ()
forall (m :: * -> *).
(DistRep (Rep m), MonadBuilder m) =>
KernelInput -> m ()
readKernelInput KernelInput
inp
data IntraAcc = IntraAcc
{ IntraAcc -> Set [SubExp]
accMinPar :: S.Set [SubExp],
IntraAcc -> Set [SubExp]
accAvailPar :: S.Set [SubExp],
IntraAcc -> Log
accLog :: Log
}
instance Semigroup IntraAcc where
IntraAcc Set [SubExp]
min_x Set [SubExp]
avail_x Log
log_x <> :: IntraAcc -> IntraAcc -> IntraAcc
<> IntraAcc Set [SubExp]
min_y Set [SubExp]
avail_y Log
log_y =
Set [SubExp] -> Set [SubExp] -> Log -> IntraAcc
IntraAcc (Set [SubExp]
min_x Set [SubExp] -> Set [SubExp] -> Set [SubExp]
forall a. Semigroup a => a -> a -> a
<> Set [SubExp]
min_y) (Set [SubExp]
avail_x Set [SubExp] -> Set [SubExp] -> Set [SubExp]
forall a. Semigroup a => a -> a -> a
<> Set [SubExp]
avail_y) (Log
log_x Log -> Log -> Log
forall a. Semigroup a => a -> a -> a
<> Log
log_y)
instance Monoid IntraAcc where
mempty :: IntraAcc
mempty = Set [SubExp] -> Set [SubExp] -> Log -> IntraAcc
IntraAcc Set [SubExp]
forall a. Monoid a => a
mempty Set [SubExp]
forall a. Monoid a => a
mempty Log
forall a. Monoid a => a
mempty
type IntrablockM =
BuilderT GPU (RWS () IntraAcc VNameSource)
instance MonadLogger IntrablockM where
addLog :: Log -> IntrablockM ()
addLog Log
log = IntraAcc -> IntrablockM ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell IntraAcc
forall a. Monoid a => a
mempty {accLog = log}
runIntrablockM ::
(MonadFreshNames m, HasScope GPU m) =>
IntrablockM () ->
m (IntraAcc, Stms GPU)
runIntrablockM :: forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
IntrablockM () -> m (IntraAcc, Stms GPU)
runIntrablockM IntrablockM ()
m = do
scope <- Scope GPU -> Scope GPU
forall fromrep torep.
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope (Scope GPU -> Scope GPU) -> m (Scope GPU) -> m (Scope GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (Scope GPU)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
modifyNameSource $ \VNameSource
src ->
let (((), Stms GPU
kstms), VNameSource
src', IntraAcc
acc) = RWS () IntraAcc VNameSource ((), Stms GPU)
-> () -> VNameSource -> (((), Stms GPU), VNameSource, IntraAcc)
forall r w s a. RWS r w s a -> r -> s -> (a, s, w)
runRWS (IntrablockM ()
-> Scope GPU -> RWS () IntraAcc VNameSource ((), Stms GPU)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT IntrablockM ()
m Scope GPU
scope) () VNameSource
src
in ((IntraAcc
acc, Stms GPU
kstms), VNameSource
src')
parallelMin :: [SubExp] -> IntrablockM ()
parallelMin :: [SubExp] -> IntrablockM ()
parallelMin [SubExp]
ws =
IntraAcc -> IntrablockM ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell
IntraAcc
forall a. Monoid a => a
mempty
{ accMinPar = S.singleton ws,
accAvailPar = S.singleton ws
}
intrablockBody :: Body SOACS -> IntrablockM (Body GPU)
intrablockBody :: Body SOACS -> IntrablockM (Body GPU)
intrablockBody Body SOACS
body = do
stms <- IntrablockM ()
-> BuilderT
GPU
(RWS () IntraAcc VNameSource)
(Stms (Rep (BuilderT GPU (RWS () IntraAcc VNameSource))))
forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ (IntrablockM ()
-> BuilderT
GPU
(RWS () IntraAcc VNameSource)
(Stms (Rep (BuilderT GPU (RWS () IntraAcc VNameSource)))))
-> IntrablockM ()
-> BuilderT
GPU
(RWS () IntraAcc VNameSource)
(Stms (Rep (BuilderT GPU (RWS () IntraAcc VNameSource))))
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> IntrablockM ()
intrablockStms (Stms SOACS -> IntrablockM ()) -> Stms SOACS -> IntrablockM ()
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Stms SOACS
forall rep. Body rep -> Stms rep
bodyStms Body SOACS
body
pure $ mkBody stms $ bodyResult body
intrablockLambda :: Lambda SOACS -> IntrablockM (Lambda GPU)
intrablockLambda :: Lambda SOACS -> IntrablockM (Lambda GPU)
intrablockLambda Lambda SOACS
lam =
[LParam (Rep (BuilderT GPU (RWS () IntraAcc VNameSource)))]
-> BuilderT GPU (RWS () IntraAcc VNameSource) Result
-> BuilderT
GPU
(RWS () IntraAcc VNameSource)
(Lambda (Rep (BuilderT GPU (RWS () IntraAcc VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam) (BuilderT GPU (RWS () IntraAcc VNameSource) Result
-> BuilderT
GPU
(RWS () IntraAcc VNameSource)
(Lambda (Rep (BuilderT GPU (RWS () IntraAcc VNameSource)))))
-> BuilderT GPU (RWS () IntraAcc VNameSource) Result
-> BuilderT
GPU
(RWS () IntraAcc VNameSource)
(Lambda (Rep (BuilderT GPU (RWS () IntraAcc VNameSource))))
forall a b. (a -> b) -> a -> b
$
Body (Rep (BuilderT GPU (RWS () IntraAcc VNameSource)))
-> BuilderT GPU (RWS () IntraAcc VNameSource) Result
Body GPU -> BuilderT GPU (RWS () IntraAcc VNameSource) Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (Body GPU -> BuilderT GPU (RWS () IntraAcc VNameSource) Result)
-> IntrablockM (Body GPU)
-> BuilderT GPU (RWS () IntraAcc VNameSource) Result
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Body SOACS -> IntrablockM (Body GPU)
intrablockBody (Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam)
intrablockWithAccInput :: WithAccInput SOACS -> IntrablockM (WithAccInput GPU)
intrablockWithAccInput :: WithAccInput SOACS -> IntrablockM (WithAccInput GPU)
intrablockWithAccInput (Shape
shape, [VName]
arrs, Maybe (Lambda SOACS, [SubExp])
Nothing) =
WithAccInput GPU -> IntrablockM (WithAccInput GPU)
forall a. a -> BuilderT GPU (RWS () IntraAcc VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Shape
shape, [VName]
arrs, Maybe (Lambda GPU, [SubExp])
forall a. Maybe a
Nothing)
intrablockWithAccInput (Shape
shape, [VName]
arrs, Just (Lambda SOACS
lam, [SubExp]
nes)) = do
lam' <- Lambda SOACS -> IntrablockM (Lambda GPU)
intrablockLambda Lambda SOACS
lam
pure (shape, arrs, Just (lam', nes))
intrablockStm :: Stm SOACS -> IntrablockM ()
intrablockStm :: Stm SOACS -> IntrablockM ()
intrablockStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux Exp SOACS
e) = do
scope <- BuilderT GPU (RWS () IntraAcc VNameSource) (Scope GPU)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
let lvl = SegVirt -> SegLevel
SegThreadInBlock SegVirt
SegNoVirt
case e of
Loop [(FParam SOACS, SubExp)]
merge LoopForm
form Body SOACS
loopbody ->
Scope GPU -> IntrablockM () -> IntrablockM ()
forall a.
Scope GPU
-> BuilderT GPU (RWS () IntraAcc VNameSource) a
-> BuilderT GPU (RWS () IntraAcc VNameSource) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (LoopForm -> Scope GPU
forall rep. LoopForm -> Scope rep
scopeOfLoopForm LoopForm
form Scope GPU -> Scope GPU -> Scope GPU
forall a. Semigroup a => a -> a -> a
<> [Param (FParamInfo GPU)] -> Scope GPU
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams (((Param (FParamInfo GPU), SubExp) -> Param (FParamInfo GPU))
-> [(Param (FParamInfo GPU), SubExp)] -> [Param (FParamInfo GPU)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (FParamInfo GPU), SubExp) -> Param (FParamInfo GPU)
forall a b. (a, b) -> a
fst [(FParam SOACS, SubExp)]
[(Param (FParamInfo GPU), SubExp)]
merge)) (IntrablockM () -> IntrablockM ())
-> IntrablockM () -> IntrablockM ()
forall a b. (a -> b) -> a -> b
$ do
loopbody' <- Body SOACS -> IntrablockM (Body GPU)
intrablockBody Body SOACS
loopbody
certifying (stmAuxCerts aux) . letBind pat $
Loop merge form loopbody'
Match [SubExp]
cond [Case (Body SOACS)]
cases Body SOACS
defbody MatchDec (BranchType SOACS)
ifdec -> do
cases' <- (Case (Body SOACS)
-> BuilderT GPU (RWS () IntraAcc VNameSource) (Case (Body GPU)))
-> [Case (Body SOACS)]
-> BuilderT GPU (RWS () IntraAcc VNameSource) [Case (Body GPU)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((Body SOACS -> IntrablockM (Body GPU))
-> Case (Body SOACS)
-> BuilderT GPU (RWS () IntraAcc VNameSource) (Case (Body GPU))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Case a -> f (Case b)
traverse Body SOACS -> IntrablockM (Body GPU)
intrablockBody) [Case (Body SOACS)]
cases
defbody' <- intrablockBody defbody
certifying (stmAuxCerts aux) . letBind pat $
Match cond cases' defbody' ifdec
WithAcc [WithAccInput SOACS]
inputs Lambda SOACS
lam -> do
inputs' <- (WithAccInput SOACS -> IntrablockM (WithAccInput GPU))
-> [WithAccInput SOACS]
-> BuilderT GPU (RWS () IntraAcc VNameSource) [WithAccInput GPU]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM WithAccInput SOACS -> IntrablockM (WithAccInput GPU)
intrablockWithAccInput [WithAccInput SOACS]
inputs
lam' <- intrablockLambda lam
certifying (stmAuxCerts aux) . letBind pat $ WithAcc inputs' lam'
Op Op SOACS
soac
| Attr
"sequential_outer" Attr -> Attrs -> Bool
`inAttrs` StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
StmAux (ExpDec SOACS)
aux ->
Stms SOACS -> IntrablockM ()
intrablockStms (Stms SOACS -> IntrablockM ())
-> (Stms SOACS -> Stms SOACS) -> Stms SOACS -> IntrablockM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm SOACS -> Stm SOACS) -> Stms SOACS -> Stms SOACS
forall a b. (a -> b) -> Seq a -> Seq b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certs -> Stm SOACS -> Stm SOACS
forall rep. Certs -> Stm rep -> Stm rep
certify (StmAux () -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux))
(Stms SOACS -> IntrablockM ())
-> BuilderT GPU (RWS () IntraAcc VNameSource) (Stms SOACS)
-> IntrablockM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Builder SOACS ()
-> BuilderT GPU (RWS () IntraAcc VNameSource) (Stms SOACS)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Pat (LetDec (Rep (BuilderT SOACS (State VNameSource))))
-> SOAC (Rep (BuilderT SOACS (State VNameSource)))
-> Builder SOACS ()
forall (m :: * -> *).
Transformer m =>
Pat (LetDec (Rep m)) -> SOAC (Rep m) -> m ()
FOT.transformSOAC Pat (LetDec (Rep (BuilderT SOACS (State VNameSource))))
Pat (LetDec SOACS)
pat Op SOACS
SOAC (Rep (BuilderT SOACS (State VNameSource)))
soac)
Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form)
| Just Lambda SOACS
lam <- ScremaForm SOACS -> Maybe (Lambda SOACS)
forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm SOACS
form -> do
let loopnest :: LoopNesting
loopnest = Pat Type
-> StmAux () -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting Pat Type
Pat (LetDec SOACS)
pat StmAux ()
StmAux (ExpDec SOACS)
aux SubExp
w ([(Param Type, VName)] -> LoopNesting)
-> [(Param Type, VName)] -> LoopNesting
forall a b. (a -> b) -> a -> b
$ [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam) [VName]
arrs
env :: DistEnv GPU (BuilderT GPU (RWS () IntraAcc VNameSource))
env =
DistEnv
{ distNest :: Nestings
distNest =
Nesting -> Nestings
singleNesting (Nesting -> Nestings) -> Nesting -> Nestings
forall a b. (a -> b) -> a -> b
$ Names -> LoopNesting -> Nesting
Nesting Names
forall a. Monoid a => a
mempty LoopNesting
loopnest,
distScope :: Scope GPU
distScope =
Pat Type -> Scope GPU
forall rep dec. (LetDec rep ~ dec) => Pat dec -> Scope rep
scopeOfPat Pat Type
Pat (LetDec SOACS)
pat
Scope GPU -> Scope GPU -> Scope GPU
forall a. Semigroup a => a -> a -> a
<> Scope SOACS -> Scope GPU
scopeForGPU (Lambda SOACS -> Scope SOACS
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Lambda SOACS
lam)
Scope GPU -> Scope GPU -> Scope GPU
forall a. Semigroup a => a -> a -> a
<> Scope GPU
scope,
distOnInnerMap :: MapLoop
-> DistAcc GPU
-> DistNestT
GPU (BuilderT GPU (RWS () IntraAcc VNameSource)) (DistAcc GPU)
distOnInnerMap =
MapLoop
-> DistAcc GPU
-> DistNestT
GPU (BuilderT GPU (RWS () IntraAcc VNameSource)) (DistAcc GPU)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
distributeMap,
distOnTopLevelStms :: Stms SOACS
-> DistNestT
GPU (BuilderT GPU (RWS () IntraAcc VNameSource)) (Stms GPU)
distOnTopLevelStms =
BuilderT GPU (RWS () IntraAcc VNameSource) (Stms GPU)
-> DistNestT
GPU (BuilderT GPU (RWS () IntraAcc VNameSource)) (Stms GPU)
forall rep (m :: * -> *) a.
(LocalScope rep m, DistRep rep) =>
m a -> DistNestT rep m a
liftInner (BuilderT GPU (RWS () IntraAcc VNameSource) (Stms GPU)
-> DistNestT
GPU (BuilderT GPU (RWS () IntraAcc VNameSource)) (Stms GPU))
-> (Stms SOACS
-> BuilderT GPU (RWS () IntraAcc VNameSource) (Stms GPU))
-> Stms SOACS
-> DistNestT
GPU (BuilderT GPU (RWS () IntraAcc VNameSource)) (Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IntrablockM ()
-> BuilderT
GPU
(RWS () IntraAcc VNameSource)
(Stms (Rep (BuilderT GPU (RWS () IntraAcc VNameSource))))
IntrablockM ()
-> BuilderT GPU (RWS () IntraAcc VNameSource) (Stms GPU)
forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ (IntrablockM ()
-> BuilderT GPU (RWS () IntraAcc VNameSource) (Stms GPU))
-> (Stms SOACS -> IntrablockM ())
-> Stms SOACS
-> BuilderT GPU (RWS () IntraAcc VNameSource) (Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms SOACS -> IntrablockM ()
intrablockStms,
distSegLevel :: MkSegLevel GPU (BuilderT GPU (RWS () IntraAcc VNameSource))
distSegLevel = \[SubExp]
minw String
_ ThreadRecommendation
_ -> do
IntrablockM ()
-> BuilderT GPU (BuilderT GPU (RWS () IntraAcc VNameSource)) ()
forall (m :: * -> *) a. Monad m => m a -> BuilderT GPU m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IntrablockM ()
-> BuilderT GPU (BuilderT GPU (RWS () IntraAcc VNameSource)) ())
-> IntrablockM ()
-> BuilderT GPU (BuilderT GPU (RWS () IntraAcc VNameSource)) ()
forall a b. (a -> b) -> a -> b
$ [SubExp] -> IntrablockM ()
parallelMin [SubExp]
minw
SegLevel
-> BuilderT
GPU (BuilderT GPU (RWS () IntraAcc VNameSource)) SegLevel
forall a.
a -> BuilderT GPU (BuilderT GPU (RWS () IntraAcc VNameSource)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure SegLevel
lvl,
distOnSOACSStms :: Stm SOACS -> BuilderT GPU (State VNameSource) (Stms GPU)
distOnSOACSStms =
Stms GPU -> BuilderT GPU (State VNameSource) (Stms GPU)
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU -> BuilderT GPU (State VNameSource) (Stms GPU))
-> (Stm SOACS -> Stms GPU)
-> Stm SOACS
-> BuilderT GPU (State VNameSource) (Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm GPU -> Stms GPU
forall rep. Stm rep -> Stms rep
oneStm (Stm GPU -> Stms GPU)
-> (Stm SOACS -> Stm GPU) -> Stm SOACS -> Stms GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm SOACS -> Stm GPU
soacsStmToGPU,
distOnSOACSLambda :: Lambda SOACS -> Builder GPU (Lambda GPU)
distOnSOACSLambda =
Lambda GPU -> Builder GPU (Lambda GPU)
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda GPU -> Builder GPU (Lambda GPU))
-> (Lambda SOACS -> Lambda GPU)
-> Lambda SOACS
-> Builder GPU (Lambda GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda SOACS -> Lambda GPU
soacsLambdaToGPU
}
acc :: DistAcc GPU
acc =
DistAcc
{ distTargets :: Targets
distTargets = Target -> Targets
singleTarget (Pat Type
Pat (LetDec SOACS)
pat, Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult (Body SOACS -> Result) -> Body SOACS -> Result
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam),
distStms :: Stms GPU
distStms = Stms GPU
forall a. Monoid a => a
mempty
}
Stms (Rep (BuilderT GPU (RWS () IntraAcc VNameSource)))
-> IntrablockM ()
Stms GPU -> IntrablockM ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms
(Stms GPU -> IntrablockM ())
-> BuilderT GPU (RWS () IntraAcc VNameSource) (Stms GPU)
-> IntrablockM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< DistEnv GPU (BuilderT GPU (RWS () IntraAcc VNameSource))
-> DistNestT
GPU (BuilderT GPU (RWS () IntraAcc VNameSource)) (DistAcc GPU)
-> BuilderT GPU (RWS () IntraAcc VNameSource) (Stms GPU)
forall (m :: * -> *) rep.
(MonadLogger m, DistRep rep) =>
DistEnv rep m -> DistNestT rep m (DistAcc rep) -> m (Stms rep)
runDistNestT DistEnv GPU (BuilderT GPU (RWS () IntraAcc VNameSource))
env (DistAcc GPU
-> Stms SOACS
-> DistNestT
GPU (BuilderT GPU (RWS () IntraAcc VNameSource)) (DistAcc GPU)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> Stms SOACS -> DistNestT rep m (DistAcc rep)
distributeMapBodyStms DistAcc GPU
acc (Body SOACS -> Stms SOACS
forall rep. Body rep -> Stms rep
bodyStms (Body SOACS -> Stms SOACS) -> Body SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam))
Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form)
| Just ([Scan SOACS]
scans, Lambda SOACS
mapfun) <- ScremaForm SOACS -> Maybe ([Scan SOACS], Lambda SOACS)
forall rep. ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
isScanomapSOAC ScremaForm SOACS
form,
Scan Lambda SOACS
scanfun [SubExp]
nes <- [Scan SOACS] -> Scan SOACS
forall rep. Buildable rep => [Scan rep] -> Scan rep
singleScan [Scan SOACS]
scans -> do
let scanfun' :: Lambda GPU
scanfun' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
scanfun
mapfun' :: Lambda GPU
mapfun' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
mapfun
Certs -> IntrablockM () -> IntrablockM ()
forall a.
Certs
-> BuilderT GPU (RWS () IntraAcc VNameSource) a
-> BuilderT GPU (RWS () IntraAcc VNameSource) a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (StmAux () -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) (IntrablockM () -> IntrablockM ())
-> IntrablockM () -> IntrablockM ()
forall a b. (a -> b) -> a -> b
$
Stms (Rep (BuilderT GPU (RWS () IntraAcc VNameSource)))
-> IntrablockM ()
Stms GPU -> IntrablockM ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (Stms GPU -> IntrablockM ())
-> BuilderT GPU (RWS () IntraAcc VNameSource) (Stms GPU)
-> IntrablockM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOpLevel GPU
-> Pat (LetDec GPU)
-> Certs
-> SubExp
-> [SegBinOp GPU]
-> Lambda GPU
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> BuilderT GPU (RWS () IntraAcc VNameSource) (Stms GPU)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep
-> Pat (LetDec rep)
-> Certs
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms rep)
segScan SegOpLevel GPU
SegLevel
lvl Pat (LetDec SOACS)
Pat (LetDec GPU)
pat Certs
forall a. Monoid a => a
mempty SubExp
w [Commutativity -> Lambda GPU -> [SubExp] -> Shape -> SegBinOp GPU
forall rep.
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp Commutativity
Noncommutative Lambda GPU
scanfun' [SubExp]
nes Shape
forall a. Monoid a => a
mempty] Lambda GPU
mapfun' [VName]
arrs [] []
[SubExp] -> IntrablockM ()
parallelMin [SubExp
w]
Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form)
| Just ([Reduce SOACS]
reds, Lambda SOACS
map_lam) <- ScremaForm SOACS -> Maybe ([Reduce SOACS], Lambda SOACS)
forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm SOACS
form -> do
let onReduce :: Reduce SOACS -> SegBinOp GPU
onReduce (Reduce Commutativity
comm Lambda SOACS
red_lam [SubExp]
nes) =
Commutativity -> Lambda GPU -> [SubExp] -> Shape -> SegBinOp GPU
forall rep.
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp Commutativity
comm (Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
red_lam) [SubExp]
nes Shape
forall a. Monoid a => a
mempty
reds' :: [SegBinOp GPU]
reds' = (Reduce SOACS -> SegBinOp GPU) -> [Reduce SOACS] -> [SegBinOp GPU]
forall a b. (a -> b) -> [a] -> [b]
map Reduce SOACS -> SegBinOp GPU
onReduce [Reduce SOACS]
reds
map_lam' :: Lambda GPU
map_lam' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
map_lam
Certs -> IntrablockM () -> IntrablockM ()
forall a.
Certs
-> BuilderT GPU (RWS () IntraAcc VNameSource) a
-> BuilderT GPU (RWS () IntraAcc VNameSource) a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (StmAux () -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) (IntrablockM () -> IntrablockM ())
-> IntrablockM () -> IntrablockM ()
forall a b. (a -> b) -> a -> b
$
Stms (Rep (BuilderT GPU (RWS () IntraAcc VNameSource)))
-> IntrablockM ()
Stms GPU -> IntrablockM ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (Stms GPU -> IntrablockM ())
-> BuilderT GPU (RWS () IntraAcc VNameSource) (Stms GPU)
-> IntrablockM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOpLevel GPU
-> Pat (LetDec GPU)
-> Certs
-> SubExp
-> [SegBinOp GPU]
-> Lambda GPU
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> BuilderT GPU (RWS () IntraAcc VNameSource) (Stms GPU)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep
-> Pat (LetDec rep)
-> Certs
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms rep)
segRed SegOpLevel GPU
SegLevel
lvl Pat (LetDec SOACS)
Pat (LetDec GPU)
pat Certs
forall a. Monoid a => a
mempty SubExp
w [SegBinOp GPU]
reds' Lambda GPU
map_lam' [VName]
arrs [] []
[SubExp] -> IntrablockM ()
parallelMin [SubExp
w]
Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form) ->
(Stm SOACS -> IntrablockM ()) -> Stms SOACS -> IntrablockM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm SOACS -> IntrablockM ()
intrablockStm (Stms SOACS -> IntrablockM ())
-> (((), Stms SOACS) -> Stms SOACS)
-> ((), Stms SOACS)
-> IntrablockM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm SOACS -> Stm SOACS) -> Stms SOACS -> Stms SOACS
forall a b. (a -> b) -> Seq a -> Seq b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certs -> Stm SOACS -> Stm SOACS
forall rep. Certs -> Stm rep -> Stm rep
certify (StmAux () -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux)) (Stms SOACS -> Stms SOACS)
-> (((), Stms SOACS) -> Stms SOACS)
-> ((), Stms SOACS)
-> Stms SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((), Stms SOACS) -> Stms SOACS
forall a b. (a, b) -> b
snd
(((), Stms SOACS) -> IntrablockM ())
-> BuilderT GPU (RWS () IntraAcc VNameSource) ((), Stms SOACS)
-> IntrablockM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BuilderT SOACS (BuilderT GPU (RWS () IntraAcc VNameSource)) ()
-> Scope SOACS
-> BuilderT GPU (RWS () IntraAcc VNameSource) ((), Stms SOACS)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT (Pat
(LetDec
(Rep
(BuilderT SOACS (BuilderT GPU (RWS () IntraAcc VNameSource)))))
-> SubExp
-> ScremaForm
(Rep (BuilderT SOACS (BuilderT GPU (RWS () IntraAcc VNameSource))))
-> [VName]
-> BuilderT SOACS (BuilderT GPU (RWS () IntraAcc VNameSource)) ()
forall (m :: * -> *).
(MonadBuilder m, Op (Rep m) ~ SOAC (Rep m), Buildable (Rep m)) =>
Pat (LetDec (Rep m))
-> SubExp -> ScremaForm (Rep m) -> [VName] -> m ()
dissectScrema Pat
(LetDec
(Rep
(BuilderT SOACS (BuilderT GPU (RWS () IntraAcc VNameSource)))))
Pat (LetDec SOACS)
pat SubExp
w ScremaForm
(Rep (BuilderT SOACS (BuilderT GPU (RWS () IntraAcc VNameSource))))
ScremaForm SOACS
form [VName]
arrs) (Scope GPU -> Scope SOACS
scopeForSOACs Scope GPU
scope)
Op (Hist SubExp
w [VName]
arrs [HistOp SOACS]
ops Lambda SOACS
bucket_fun) -> do
ops' <- [HistOp SOACS]
-> (HistOp SOACS
-> BuilderT GPU (RWS () IntraAcc VNameSource) (HistOp GPU))
-> BuilderT GPU (RWS () IntraAcc VNameSource) [HistOp GPU]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp SOACS]
ops ((HistOp SOACS
-> BuilderT GPU (RWS () IntraAcc VNameSource) (HistOp GPU))
-> BuilderT GPU (RWS () IntraAcc VNameSource) [HistOp GPU])
-> (HistOp SOACS
-> BuilderT GPU (RWS () IntraAcc VNameSource) (HistOp GPU))
-> BuilderT GPU (RWS () IntraAcc VNameSource) [HistOp GPU]
forall a b. (a -> b) -> a -> b
$ \(HistOp Shape
num_bins SubExp
rf [VName]
dests [SubExp]
nes Lambda SOACS
op) -> do
(op', nes', shape) <- Lambda SOACS
-> [SubExp]
-> BuilderT
GPU (RWS () IntraAcc VNameSource) (Lambda SOACS, [SubExp], Shape)
forall (m :: * -> *).
MonadBuilder m =>
Lambda SOACS -> [SubExp] -> m (Lambda SOACS, [SubExp], Shape)
determineReduceOp Lambda SOACS
op [SubExp]
nes
let op'' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
op'
pure $ GPU.HistOp num_bins rf dests nes' shape op''
let bucket_fun' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
bucket_fun
certifying (stmAuxCerts aux) $
addStms =<< segHist lvl pat w [] [] ops' bucket_fun' arrs
parallelMin [w]
Op (Stream SubExp
w [VName]
arrs [SubExp]
accs Lambda SOACS
lam)
| LParam SOACS
chunk_size_param : [LParam SOACS]
_ <- Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam -> do
types <- (Scope GPU -> Scope SOACS)
-> BuilderT GPU (RWS () IntraAcc VNameSource) (Scope SOACS)
forall a.
(Scope GPU -> a) -> BuilderT GPU (RWS () IntraAcc VNameSource) a
forall rep (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope Scope GPU -> Scope SOACS
forall fromrep torep.
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope
((), stream_stms) <-
runBuilderT (sequentialStreamWholeArray pat w accs lam arrs) types
let replace (Var VName
v) | VName
v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
LParam SOACS
chunk_size_param = SubExp
w
replace SubExp
se = SubExp
se
replaceSets (IntraAcc Set [SubExp]
x Set [SubExp]
y Log
log) =
Set [SubExp] -> Set [SubExp] -> Log -> IntraAcc
IntraAcc (([SubExp] -> [SubExp]) -> Set [SubExp] -> Set [SubExp]
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map ((SubExp -> SubExp) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> SubExp
replace) Set [SubExp]
x) (([SubExp] -> [SubExp]) -> Set [SubExp] -> Set [SubExp]
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map ((SubExp -> SubExp) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> SubExp
replace) Set [SubExp]
y) Log
log
censor replaceSets $ intrablockStms stream_stms
Op (Scatter SubExp
w [VName]
ivs ScatterSpec VName
dests Lambda SOACS
lam) -> do
write_i <- String -> BuilderT GPU (RWS () IntraAcc VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"write_i"
space <- mkSegSpace [(write_i, w)]
let lam' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
lam
grouped = ScatterSpec VName
-> Result -> [(Shape, VName, [(Result, SubExpRes)])]
forall array a.
ScatterSpec array -> [a] -> [(Shape, array, [([a], a)])]
groupScatterResults ScatterSpec VName
dests (Result -> [(Shape, VName, [(Result, SubExpRes)])])
-> Result -> [(Shape, VName, [(Result, SubExpRes)])]
forall a b. (a -> b) -> a -> b
$ Body GPU -> Result
forall rep. Body rep -> Result
bodyResult (Body GPU -> Result) -> Body GPU -> Result
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
lam'
(_, dest_arrs, _) = unzip3 grouped
dest_ts <- mapM lookupType dest_arrs
let krets = do
(a_t, (_a_w, a, is_vs)) <- [Type]
-> [(Shape, VName, [(Result, SubExpRes)])]
-> [(Type, (Shape, VName, [(Result, SubExpRes)]))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
dest_ts [(Shape, VName, [(Result, SubExpRes)])]
grouped
let cs =
((Result, SubExpRes) -> Certs) -> [(Result, SubExpRes)] -> Certs
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap ((SubExpRes -> Certs) -> Result -> Certs
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap SubExpRes -> Certs
resCerts (Result -> Certs)
-> ((Result, SubExpRes) -> Result) -> (Result, SubExpRes) -> Certs
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Result, SubExpRes) -> Result
forall a b. (a, b) -> a
fst) [(Result, SubExpRes)]
is_vs
Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> ((Result, SubExpRes) -> Certs) -> [(Result, SubExpRes)] -> Certs
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (SubExpRes -> Certs
resCerts (SubExpRes -> Certs)
-> ((Result, SubExpRes) -> SubExpRes)
-> (Result, SubExpRes)
-> Certs
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Result, SubExpRes) -> SubExpRes
forall a b. (a, b) -> b
snd) [(Result, SubExpRes)]
is_vs
is_vs' = do
(is, v) <- [(Result, SubExpRes)]
is_vs
pure
( fullSlice a_t $ map (DimFix . resSubExp) is,
resSubExp v
)
pure $ WriteReturns cs a is_vs'
inputs = do
(p, p_a) <- [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda GPU -> [Param (LParamInfo GPU)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
lam') [VName]
ivs
pure $ KernelInput (paramName p) (paramType p) p_a [Var write_i]
kstms <- runBuilder_ $
localScope (scopeOfSegSpace space) $ do
mapM_ readKernelInput inputs
addStms $ bodyStms $ lambdaBody lam'
certifying (stmAuxCerts aux) $ do
let body = BodyDec GPU -> Stms GPU -> [KernelResult] -> KernelBody GPU
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
kstms [KernelResult]
krets
letBind pat $ Op $ SegOp $ SegMap lvl space (patTypes pat) body
parallelMin [w]
Exp SOACS
_ ->
Stm (Rep (BuilderT GPU (RWS () IntraAcc VNameSource)))
-> IntrablockM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep (BuilderT GPU (RWS () IntraAcc VNameSource)))
-> IntrablockM ())
-> Stm (Rep (BuilderT GPU (RWS () IntraAcc VNameSource)))
-> IntrablockM ()
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Stm GPU
soacsStmToGPU Stm SOACS
stm
intrablockStms :: Stms SOACS -> IntrablockM ()
intrablockStms :: Stms SOACS -> IntrablockM ()
intrablockStms = (Stm SOACS -> IntrablockM ()) -> Stms SOACS -> IntrablockM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm SOACS -> IntrablockM ()
intrablockStm
intrablockParalleliseBody ::
(MonadFreshNames m, HasScope GPU m) =>
Body SOACS ->
m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
intrablockParalleliseBody :: forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
Body SOACS -> m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
intrablockParalleliseBody Body SOACS
body = do
(IntraAcc min_ws avail_ws log, kstms) <-
IntrablockM () -> m (IntraAcc, Stms GPU)
forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
IntrablockM () -> m (IntraAcc, Stms GPU)
runIntrablockM (IntrablockM () -> m (IntraAcc, Stms GPU))
-> IntrablockM () -> m (IntraAcc, Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> IntrablockM ()
intrablockStms (Stms SOACS -> IntrablockM ()) -> Stms SOACS -> IntrablockM ()
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Stms SOACS
forall rep. Body rep -> Stms rep
bodyStms Body SOACS
body
pure
( S.toList min_ws,
S.toList avail_ws,
log,
KernelBody () kstms $ map ret $ bodyResult body
)
where
ret :: SubExpRes -> KernelResult
ret (SubExpRes Certs
cs SubExp
se) = ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify Certs
cs SubExp
se