{-# LANGUAGE TypeFamilies #-}

-- | Extract limited nested parallelism for execution inside
-- individual kernel threadblocks.
module Futhark.Pass.ExtractKernels.Intrablock (intrablockParallelise) where

import Control.Monad
import Control.Monad.RWS
import Control.Monad.Trans.Maybe
import Data.Map.Strict qualified as M
import Data.Set qualified as S
import Futhark.Analysis.PrimExp.Convert
import Futhark.IR.GPU hiding (HistOp)
import Futhark.IR.GPU.Op qualified as GPU
import Futhark.IR.SOACS
import Futhark.MonadFreshNames
import Futhark.Pass.ExtractKernels.BlockedKernel
import Futhark.Pass.ExtractKernels.DistributeNests
import Futhark.Pass.ExtractKernels.Distribution
import Futhark.Pass.ExtractKernels.ToGPU
import Futhark.Tools
import Futhark.Transform.FirstOrderTransform qualified as FOT
import Futhark.Util.Log
import Prelude hiding (log)

-- | Convert the statements inside a map nest to kernel statements,
-- attempting to parallelise any remaining (top-level) parallel
-- statements.  Anything that is not a map, scan or reduction will
-- simply be sequentialised.  This includes sequential loops that
-- contain maps, scans or reduction.  In the future, we could probably
-- do something more clever.  Make sure that the amount of parallelism
-- to be exploited does not exceed the group size.  Further, as a hack
-- we also consider the size of all intermediate arrays as
-- "parallelism to be exploited" to avoid exploding shared memory.
--
-- We distinguish between "minimum group size" and "maximum
-- exploitable parallelism".
intrablockParallelise ::
  (MonadFreshNames m, LocalScope GPU m) =>
  KernelNest ->
  Lambda SOACS ->
  m
    ( Maybe
        ( (SubExp, SubExp),
          SubExp,
          Log,
          Stms GPU,
          Stms GPU
        )
    )
intrablockParallelise :: forall (m :: * -> *).
(MonadFreshNames m, LocalScope GPU m) =>
KernelNest
-> Lambda SOACS
-> m (Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU))
intrablockParallelise KernelNest
knest Lambda SOACS
lam = MaybeT m ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
-> m (Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU))
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT (MaybeT m ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
 -> m (Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)))
-> MaybeT m ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
-> m (Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU))
forall a b. (a -> b) -> a -> b
$ do
  (ispace, inps) <- m ([(VName, SubExp)], [KernelInput])
-> MaybeT m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *) a. Monad m => m a -> MaybeT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m ([(VName, SubExp)], [KernelInput])
 -> MaybeT m ([(VName, SubExp)], [KernelInput]))
-> m ([(VName, SubExp)], [KernelInput])
-> MaybeT m ([(VName, SubExp)], [KernelInput])
forall a b. (a -> b) -> a -> b
$ KernelNest -> m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
knest

  (num_tblocks, w_stms) <-
    lift $
      runBuilder $
        letSubExp "intra_num_tblocks"
          =<< foldBinOp (Mul Int64 OverflowUndef) (intConst Int64 1) (map snd ispace)

  let body = Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam

  tblock_size <- newVName "computed_tblock_size"
  (wss_min, wss_avail, log, kbody) <-
    lift . localScope (scopeOfLParams $ lambdaParams lam) $
      intrablockParalleliseBody body

  outside_scope <- lift askScope
  -- outside_scope may also contain the inputs, even though those are
  -- not actually available outside the kernel.
  let available VName
v =
        VName
v VName -> Scope GPU -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.member` Scope GPU
outside_scope
          Bool -> Bool -> Bool
&& VName
v VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` (KernelInput -> VName) -> [KernelInput] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map KernelInput -> VName
kernelInputName [KernelInput]
inps
  unless (all available $ namesToList $ freeIn (wss_min ++ wss_avail)) $
    fail "Irregular parallelism"

  ((intra_avail_par, kspace, read_input_stms), prelude_stms) <- lift $
    runBuilder $ do
      let foldBinOp' BinOp
_ [] = SubExp -> m (Exp (Rep m))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> m (Exp (Rep m))) -> SubExp -> m (Exp (Rep m))
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1
          foldBinOp' BinOp
bop (SubExp
x : [SubExp]
xs) = BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp BinOp
bop SubExp
x [SubExp]
xs
      ws_min <-
        mapM (letSubExp "one_intra_par_min" <=< foldBinOp' (Mul Int64 OverflowUndef)) $
          filter (not . null) wss_min
      ws_avail <-
        mapM (letSubExp "one_intra_par_avail" <=< foldBinOp' (Mul Int64 OverflowUndef)) $
          filter (not . null) wss_avail

      -- The amount of parallelism available *in the worst case* is
      -- equal to the smallest parallel loop, or *at least* 1.
      intra_avail_par <-
        letSubExp "intra_avail_par" =<< foldBinOp' (SMin Int64) ws_avail

      -- The group size is either the maximum of the minimum parallelism
      -- exploited, or the desired parallelism (bounded by the max group
      -- size) in case there is no minimum.
      letBindNames [tblock_size]
        =<< if null ws_min
          then
            eBinOp
              (SMin Int64)
              (eSubExp =<< letSubExp "max_tblock_size" (Op $ SizeOp $ GetSizeMax SizeThreadBlock))
              (eSubExp intra_avail_par)
          else foldBinOp' (SMax Int64) ws_min

      let inputIsUsed KernelInput
input = KernelInput -> VName
kernelInputName KernelInput
input VName -> Names -> Bool
`nameIn` Body SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn Body SOACS
body
          used_inps = (KernelInput -> Bool) -> [KernelInput] -> [KernelInput]
forall a. (a -> Bool) -> [a] -> [a]
filter KernelInput -> Bool
inputIsUsed [KernelInput]
inps

      addStms w_stms
      read_input_stms <- runBuilder_ $ mapM readGroupKernelInput used_inps
      space <- SegSpace <$> newVName "phys_tblock_id" <*> pure ispace
      pure (intra_avail_par, space, read_input_stms)

  let kbody' = KernelBody GPU
kbody {kernelBodyStms = read_input_stms <> kernelBodyStms kbody}

  let nested_pat = LoopNesting -> Pat Type
loopNestingPat LoopNesting
first_nest
      rts = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map ([(VName, SubExp)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(VName, SubExp)]
ispace Int -> Type -> Type
forall u. Int -> TypeBase Shape u -> TypeBase Shape u
`stripArray`) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Pat Type -> [Type]
forall dec. Typed dec => Pat dec -> [Type]
patTypes Pat Type
nested_pat
      grid = Count NumBlocks SubExp -> Count BlockSize SubExp -> KernelGrid
KernelGrid (SubExp -> Count NumBlocks SubExp
forall {k} (u :: k) e. e -> Count u e
Count SubExp
num_tblocks) (SubExp -> Count BlockSize SubExp
forall {k} (u :: k) e. e -> Count u e
Count (SubExp -> Count BlockSize SubExp)
-> SubExp -> Count BlockSize SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
tblock_size)
      lvl = SegVirt -> Maybe KernelGrid -> SegLevel
SegBlock SegVirt
SegNoVirt (KernelGrid -> Maybe KernelGrid
forall a. a -> Maybe a
Just KernelGrid
grid)
      kstm =
        Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec GPU)
nested_pat StmAux ()
StmAux (ExpDec GPU)
aux (Exp GPU -> Stm GPU) -> Exp GPU -> Stm GPU
forall a b. (a -> b) -> a -> b
$ Op GPU -> Exp GPU
forall rep. Op rep -> Exp rep
Op (Op GPU -> Exp GPU) -> Op GPU -> Exp GPU
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPU -> HostOp SOAC GPU
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp (SegOp SegLevel GPU -> HostOp SOAC GPU)
-> SegOp SegLevel GPU -> HostOp SOAC GPU
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace -> [Type] -> KernelBody GPU -> SegOp SegLevel GPU
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
lvl SegSpace
kspace [Type]
rts KernelBody GPU
kbody'

  let intra_min_par = SubExp
intra_avail_par
  pure
    ( (intra_min_par, intra_avail_par),
      Var tblock_size,
      log,
      prelude_stms,
      oneStm kstm
    )
  where
    first_nest :: LoopNesting
first_nest = KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
knest
    aux :: StmAux ()
aux = LoopNesting -> StmAux ()
loopNestingAux LoopNesting
first_nest

readGroupKernelInput ::
  (DistRep (Rep m), MonadBuilder m) =>
  KernelInput ->
  m ()
readGroupKernelInput :: forall (m :: * -> *).
(DistRep (Rep m), MonadBuilder m) =>
KernelInput -> m ()
readGroupKernelInput KernelInput
inp
  | Array {} <- KernelInput -> Type
kernelInputType KernelInput
inp = do
      v <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> m VName) -> String -> m VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString (VName -> String) -> VName -> String
forall a b. (a -> b) -> a -> b
$ KernelInput -> VName
kernelInputName KernelInput
inp
      readKernelInput inp {kernelInputName = v}
      letBindNames [kernelInputName inp] $ BasicOp $ Replicate mempty $ Var v
  | Bool
otherwise =
      KernelInput -> m ()
forall (m :: * -> *).
(DistRep (Rep m), MonadBuilder m) =>
KernelInput -> m ()
readKernelInput KernelInput
inp

data IntraAcc = IntraAcc
  { IntraAcc -> Set [SubExp]
accMinPar :: S.Set [SubExp],
    IntraAcc -> Set [SubExp]
accAvailPar :: S.Set [SubExp],
    IntraAcc -> Log
accLog :: Log
  }

instance Semigroup IntraAcc where
  IntraAcc Set [SubExp]
min_x Set [SubExp]
avail_x Log
log_x <> :: IntraAcc -> IntraAcc -> IntraAcc
<> IntraAcc Set [SubExp]
min_y Set [SubExp]
avail_y Log
log_y =
    Set [SubExp] -> Set [SubExp] -> Log -> IntraAcc
IntraAcc (Set [SubExp]
min_x Set [SubExp] -> Set [SubExp] -> Set [SubExp]
forall a. Semigroup a => a -> a -> a
<> Set [SubExp]
min_y) (Set [SubExp]
avail_x Set [SubExp] -> Set [SubExp] -> Set [SubExp]
forall a. Semigroup a => a -> a -> a
<> Set [SubExp]
avail_y) (Log
log_x Log -> Log -> Log
forall a. Semigroup a => a -> a -> a
<> Log
log_y)

instance Monoid IntraAcc where
  mempty :: IntraAcc
mempty = Set [SubExp] -> Set [SubExp] -> Log -> IntraAcc
IntraAcc Set [SubExp]
forall a. Monoid a => a
mempty Set [SubExp]
forall a. Monoid a => a
mempty Log
forall a. Monoid a => a
mempty

type IntrablockM =
  BuilderT GPU (RWS () IntraAcc VNameSource)

instance MonadLogger IntrablockM where
  addLog :: Log -> IntrablockM ()
addLog Log
log = IntraAcc -> IntrablockM ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell IntraAcc
forall a. Monoid a => a
mempty {accLog = log}

runIntrablockM ::
  (MonadFreshNames m, HasScope GPU m) =>
  IntrablockM () ->
  m (IntraAcc, Stms GPU)
runIntrablockM :: forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
IntrablockM () -> m (IntraAcc, Stms GPU)
runIntrablockM IntrablockM ()
m = do
  scope <- Scope GPU -> Scope GPU
forall fromrep torep.
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope (Scope GPU -> Scope GPU) -> m (Scope GPU) -> m (Scope GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (Scope GPU)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  modifyNameSource $ \VNameSource
src ->
    let (((), Stms GPU
kstms), VNameSource
src', IntraAcc
acc) = RWS () IntraAcc VNameSource ((), Stms GPU)
-> () -> VNameSource -> (((), Stms GPU), VNameSource, IntraAcc)
forall r w s a. RWS r w s a -> r -> s -> (a, s, w)
runRWS (IntrablockM ()
-> Scope GPU -> RWS () IntraAcc VNameSource ((), Stms GPU)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT IntrablockM ()
m Scope GPU
scope) () VNameSource
src
     in ((IntraAcc
acc, Stms GPU
kstms), VNameSource
src')

parallelMin :: [SubExp] -> IntrablockM ()
parallelMin :: [SubExp] -> IntrablockM ()
parallelMin [SubExp]
ws =
  IntraAcc -> IntrablockM ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell
    IntraAcc
forall a. Monoid a => a
mempty
      { accMinPar = S.singleton ws,
        accAvailPar = S.singleton ws
      }

intrablockBody :: Body SOACS -> IntrablockM (Body GPU)
intrablockBody :: Body SOACS -> IntrablockM (Body GPU)
intrablockBody Body SOACS
body = do
  stms <- IntrablockM ()
-> BuilderT
     GPU
     (RWS () IntraAcc VNameSource)
     (Stms (Rep (BuilderT GPU (RWS () IntraAcc VNameSource))))
forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ (IntrablockM ()
 -> BuilderT
      GPU
      (RWS () IntraAcc VNameSource)
      (Stms (Rep (BuilderT GPU (RWS () IntraAcc VNameSource)))))
-> IntrablockM ()
-> BuilderT
     GPU
     (RWS () IntraAcc VNameSource)
     (Stms (Rep (BuilderT GPU (RWS () IntraAcc VNameSource))))
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> IntrablockM ()
intrablockStms (Stms SOACS -> IntrablockM ()) -> Stms SOACS -> IntrablockM ()
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Stms SOACS
forall rep. Body rep -> Stms rep
bodyStms Body SOACS
body
  pure $ mkBody stms $ bodyResult body

intrablockLambda :: Lambda SOACS -> IntrablockM (Lambda GPU)
intrablockLambda :: Lambda SOACS -> IntrablockM (Lambda GPU)
intrablockLambda Lambda SOACS
lam =
  [LParam (Rep (BuilderT GPU (RWS () IntraAcc VNameSource)))]
-> BuilderT GPU (RWS () IntraAcc VNameSource) Result
-> BuilderT
     GPU
     (RWS () IntraAcc VNameSource)
     (Lambda (Rep (BuilderT GPU (RWS () IntraAcc VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam) (BuilderT GPU (RWS () IntraAcc VNameSource) Result
 -> BuilderT
      GPU
      (RWS () IntraAcc VNameSource)
      (Lambda (Rep (BuilderT GPU (RWS () IntraAcc VNameSource)))))
-> BuilderT GPU (RWS () IntraAcc VNameSource) Result
-> BuilderT
     GPU
     (RWS () IntraAcc VNameSource)
     (Lambda (Rep (BuilderT GPU (RWS () IntraAcc VNameSource))))
forall a b. (a -> b) -> a -> b
$
    Body (Rep (BuilderT GPU (RWS () IntraAcc VNameSource)))
-> BuilderT GPU (RWS () IntraAcc VNameSource) Result
Body GPU -> BuilderT GPU (RWS () IntraAcc VNameSource) Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (Body GPU -> BuilderT GPU (RWS () IntraAcc VNameSource) Result)
-> IntrablockM (Body GPU)
-> BuilderT GPU (RWS () IntraAcc VNameSource) Result
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Body SOACS -> IntrablockM (Body GPU)
intrablockBody (Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam)

intrablockWithAccInput :: WithAccInput SOACS -> IntrablockM (WithAccInput GPU)
intrablockWithAccInput :: WithAccInput SOACS -> IntrablockM (WithAccInput GPU)
intrablockWithAccInput (Shape
shape, [VName]
arrs, Maybe (Lambda SOACS, [SubExp])
Nothing) =
  WithAccInput GPU -> IntrablockM (WithAccInput GPU)
forall a. a -> BuilderT GPU (RWS () IntraAcc VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Shape
shape, [VName]
arrs, Maybe (Lambda GPU, [SubExp])
forall a. Maybe a
Nothing)
intrablockWithAccInput (Shape
shape, [VName]
arrs, Just (Lambda SOACS
lam, [SubExp]
nes)) = do
  lam' <- Lambda SOACS -> IntrablockM (Lambda GPU)
intrablockLambda Lambda SOACS
lam
  pure (shape, arrs, Just (lam', nes))

intrablockStm :: Stm SOACS -> IntrablockM ()
intrablockStm :: Stm SOACS -> IntrablockM ()
intrablockStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux Exp SOACS
e) = do
  scope <- BuilderT GPU (RWS () IntraAcc VNameSource) (Scope GPU)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  let lvl = SegVirt -> SegLevel
SegThreadInBlock SegVirt
SegNoVirt

  case e of
    Loop [(FParam SOACS, SubExp)]
merge LoopForm
form Body SOACS
loopbody ->
      Scope GPU -> IntrablockM () -> IntrablockM ()
forall a.
Scope GPU
-> BuilderT GPU (RWS () IntraAcc VNameSource) a
-> BuilderT GPU (RWS () IntraAcc VNameSource) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (LoopForm -> Scope GPU
forall rep. LoopForm -> Scope rep
scopeOfLoopForm LoopForm
form Scope GPU -> Scope GPU -> Scope GPU
forall a. Semigroup a => a -> a -> a
<> [Param (FParamInfo GPU)] -> Scope GPU
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams (((Param (FParamInfo GPU), SubExp) -> Param (FParamInfo GPU))
-> [(Param (FParamInfo GPU), SubExp)] -> [Param (FParamInfo GPU)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (FParamInfo GPU), SubExp) -> Param (FParamInfo GPU)
forall a b. (a, b) -> a
fst [(FParam SOACS, SubExp)]
[(Param (FParamInfo GPU), SubExp)]
merge)) (IntrablockM () -> IntrablockM ())
-> IntrablockM () -> IntrablockM ()
forall a b. (a -> b) -> a -> b
$ do
        loopbody' <- Body SOACS -> IntrablockM (Body GPU)
intrablockBody Body SOACS
loopbody
        certifying (stmAuxCerts aux) . letBind pat $
          Loop merge form loopbody'
    Match [SubExp]
cond [Case (Body SOACS)]
cases Body SOACS
defbody MatchDec (BranchType SOACS)
ifdec -> do
      cases' <- (Case (Body SOACS)
 -> BuilderT GPU (RWS () IntraAcc VNameSource) (Case (Body GPU)))
-> [Case (Body SOACS)]
-> BuilderT GPU (RWS () IntraAcc VNameSource) [Case (Body GPU)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((Body SOACS -> IntrablockM (Body GPU))
-> Case (Body SOACS)
-> BuilderT GPU (RWS () IntraAcc VNameSource) (Case (Body GPU))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Case a -> f (Case b)
traverse Body SOACS -> IntrablockM (Body GPU)
intrablockBody) [Case (Body SOACS)]
cases
      defbody' <- intrablockBody defbody
      certifying (stmAuxCerts aux) . letBind pat $
        Match cond cases' defbody' ifdec
    WithAcc [WithAccInput SOACS]
inputs Lambda SOACS
lam -> do
      inputs' <- (WithAccInput SOACS -> IntrablockM (WithAccInput GPU))
-> [WithAccInput SOACS]
-> BuilderT GPU (RWS () IntraAcc VNameSource) [WithAccInput GPU]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM WithAccInput SOACS -> IntrablockM (WithAccInput GPU)
intrablockWithAccInput [WithAccInput SOACS]
inputs
      lam' <- intrablockLambda lam
      certifying (stmAuxCerts aux) . letBind pat $ WithAcc inputs' lam'
    Op Op SOACS
soac
      | Attr
"sequential_outer" Attr -> Attrs -> Bool
`inAttrs` StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
StmAux (ExpDec SOACS)
aux ->
          Stms SOACS -> IntrablockM ()
intrablockStms (Stms SOACS -> IntrablockM ())
-> (Stms SOACS -> Stms SOACS) -> Stms SOACS -> IntrablockM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm SOACS -> Stm SOACS) -> Stms SOACS -> Stms SOACS
forall a b. (a -> b) -> Seq a -> Seq b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certs -> Stm SOACS -> Stm SOACS
forall rep. Certs -> Stm rep -> Stm rep
certify (StmAux () -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux))
            (Stms SOACS -> IntrablockM ())
-> BuilderT GPU (RWS () IntraAcc VNameSource) (Stms SOACS)
-> IntrablockM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Builder SOACS ()
-> BuilderT GPU (RWS () IntraAcc VNameSource) (Stms SOACS)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Pat (LetDec (Rep (BuilderT SOACS (State VNameSource))))
-> SOAC (Rep (BuilderT SOACS (State VNameSource)))
-> Builder SOACS ()
forall (m :: * -> *).
Transformer m =>
Pat (LetDec (Rep m)) -> SOAC (Rep m) -> m ()
FOT.transformSOAC Pat (LetDec (Rep (BuilderT SOACS (State VNameSource))))
Pat (LetDec SOACS)
pat Op SOACS
SOAC (Rep (BuilderT SOACS (State VNameSource)))
soac)
    Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form)
      | Just Lambda SOACS
lam <- ScremaForm SOACS -> Maybe (Lambda SOACS)
forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm SOACS
form -> do
          let loopnest :: LoopNesting
loopnest = Pat Type
-> StmAux () -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting Pat Type
Pat (LetDec SOACS)
pat StmAux ()
StmAux (ExpDec SOACS)
aux SubExp
w ([(Param Type, VName)] -> LoopNesting)
-> [(Param Type, VName)] -> LoopNesting
forall a b. (a -> b) -> a -> b
$ [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam) [VName]
arrs
              env :: DistEnv GPU (BuilderT GPU (RWS () IntraAcc VNameSource))
env =
                DistEnv
                  { distNest :: Nestings
distNest =
                      Nesting -> Nestings
singleNesting (Nesting -> Nestings) -> Nesting -> Nestings
forall a b. (a -> b) -> a -> b
$ Names -> LoopNesting -> Nesting
Nesting Names
forall a. Monoid a => a
mempty LoopNesting
loopnest,
                    distScope :: Scope GPU
distScope =
                      Pat Type -> Scope GPU
forall rep dec. (LetDec rep ~ dec) => Pat dec -> Scope rep
scopeOfPat Pat Type
Pat (LetDec SOACS)
pat
                        Scope GPU -> Scope GPU -> Scope GPU
forall a. Semigroup a => a -> a -> a
<> Scope SOACS -> Scope GPU
scopeForGPU (Lambda SOACS -> Scope SOACS
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Lambda SOACS
lam)
                        Scope GPU -> Scope GPU -> Scope GPU
forall a. Semigroup a => a -> a -> a
<> Scope GPU
scope,
                    distOnInnerMap :: MapLoop
-> DistAcc GPU
-> DistNestT
     GPU (BuilderT GPU (RWS () IntraAcc VNameSource)) (DistAcc GPU)
distOnInnerMap =
                      MapLoop
-> DistAcc GPU
-> DistNestT
     GPU (BuilderT GPU (RWS () IntraAcc VNameSource)) (DistAcc GPU)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
distributeMap,
                    distOnTopLevelStms :: Stms SOACS
-> DistNestT
     GPU (BuilderT GPU (RWS () IntraAcc VNameSource)) (Stms GPU)
distOnTopLevelStms =
                      BuilderT GPU (RWS () IntraAcc VNameSource) (Stms GPU)
-> DistNestT
     GPU (BuilderT GPU (RWS () IntraAcc VNameSource)) (Stms GPU)
forall rep (m :: * -> *) a.
(LocalScope rep m, DistRep rep) =>
m a -> DistNestT rep m a
liftInner (BuilderT GPU (RWS () IntraAcc VNameSource) (Stms GPU)
 -> DistNestT
      GPU (BuilderT GPU (RWS () IntraAcc VNameSource)) (Stms GPU))
-> (Stms SOACS
    -> BuilderT GPU (RWS () IntraAcc VNameSource) (Stms GPU))
-> Stms SOACS
-> DistNestT
     GPU (BuilderT GPU (RWS () IntraAcc VNameSource)) (Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IntrablockM ()
-> BuilderT
     GPU
     (RWS () IntraAcc VNameSource)
     (Stms (Rep (BuilderT GPU (RWS () IntraAcc VNameSource))))
IntrablockM ()
-> BuilderT GPU (RWS () IntraAcc VNameSource) (Stms GPU)
forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ (IntrablockM ()
 -> BuilderT GPU (RWS () IntraAcc VNameSource) (Stms GPU))
-> (Stms SOACS -> IntrablockM ())
-> Stms SOACS
-> BuilderT GPU (RWS () IntraAcc VNameSource) (Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms SOACS -> IntrablockM ()
intrablockStms,
                    distSegLevel :: MkSegLevel GPU (BuilderT GPU (RWS () IntraAcc VNameSource))
distSegLevel = \[SubExp]
minw String
_ ThreadRecommendation
_ -> do
                      IntrablockM ()
-> BuilderT GPU (BuilderT GPU (RWS () IntraAcc VNameSource)) ()
forall (m :: * -> *) a. Monad m => m a -> BuilderT GPU m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IntrablockM ()
 -> BuilderT GPU (BuilderT GPU (RWS () IntraAcc VNameSource)) ())
-> IntrablockM ()
-> BuilderT GPU (BuilderT GPU (RWS () IntraAcc VNameSource)) ()
forall a b. (a -> b) -> a -> b
$ [SubExp] -> IntrablockM ()
parallelMin [SubExp]
minw
                      SegLevel
-> BuilderT
     GPU (BuilderT GPU (RWS () IntraAcc VNameSource)) SegLevel
forall a.
a -> BuilderT GPU (BuilderT GPU (RWS () IntraAcc VNameSource)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure SegLevel
lvl,
                    distOnSOACSStms :: Stm SOACS -> BuilderT GPU (State VNameSource) (Stms GPU)
distOnSOACSStms =
                      Stms GPU -> BuilderT GPU (State VNameSource) (Stms GPU)
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU -> BuilderT GPU (State VNameSource) (Stms GPU))
-> (Stm SOACS -> Stms GPU)
-> Stm SOACS
-> BuilderT GPU (State VNameSource) (Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm GPU -> Stms GPU
forall rep. Stm rep -> Stms rep
oneStm (Stm GPU -> Stms GPU)
-> (Stm SOACS -> Stm GPU) -> Stm SOACS -> Stms GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm SOACS -> Stm GPU
soacsStmToGPU,
                    distOnSOACSLambda :: Lambda SOACS -> Builder GPU (Lambda GPU)
distOnSOACSLambda =
                      Lambda GPU -> Builder GPU (Lambda GPU)
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda GPU -> Builder GPU (Lambda GPU))
-> (Lambda SOACS -> Lambda GPU)
-> Lambda SOACS
-> Builder GPU (Lambda GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda SOACS -> Lambda GPU
soacsLambdaToGPU
                  }
              acc :: DistAcc GPU
acc =
                DistAcc
                  { distTargets :: Targets
distTargets = Target -> Targets
singleTarget (Pat Type
Pat (LetDec SOACS)
pat, Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult (Body SOACS -> Result) -> Body SOACS -> Result
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam),
                    distStms :: Stms GPU
distStms = Stms GPU
forall a. Monoid a => a
mempty
                  }

          Stms (Rep (BuilderT GPU (RWS () IntraAcc VNameSource)))
-> IntrablockM ()
Stms GPU -> IntrablockM ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms
            (Stms GPU -> IntrablockM ())
-> BuilderT GPU (RWS () IntraAcc VNameSource) (Stms GPU)
-> IntrablockM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< DistEnv GPU (BuilderT GPU (RWS () IntraAcc VNameSource))
-> DistNestT
     GPU (BuilderT GPU (RWS () IntraAcc VNameSource)) (DistAcc GPU)
-> BuilderT GPU (RWS () IntraAcc VNameSource) (Stms GPU)
forall (m :: * -> *) rep.
(MonadLogger m, DistRep rep) =>
DistEnv rep m -> DistNestT rep m (DistAcc rep) -> m (Stms rep)
runDistNestT DistEnv GPU (BuilderT GPU (RWS () IntraAcc VNameSource))
env (DistAcc GPU
-> Stms SOACS
-> DistNestT
     GPU (BuilderT GPU (RWS () IntraAcc VNameSource)) (DistAcc GPU)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> Stms SOACS -> DistNestT rep m (DistAcc rep)
distributeMapBodyStms DistAcc GPU
acc (Body SOACS -> Stms SOACS
forall rep. Body rep -> Stms rep
bodyStms (Body SOACS -> Stms SOACS) -> Body SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam))
    Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form)
      | Just ([Scan SOACS]
scans, Lambda SOACS
mapfun) <- ScremaForm SOACS -> Maybe ([Scan SOACS], Lambda SOACS)
forall rep. ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
isScanomapSOAC ScremaForm SOACS
form,
        -- FIXME: Futhark.CodeGen.ImpGen.GPU.Block.compileGroupOp
        -- cannot handle multiple scan operators yet.
        Scan Lambda SOACS
scanfun [SubExp]
nes <- [Scan SOACS] -> Scan SOACS
forall rep. Buildable rep => [Scan rep] -> Scan rep
singleScan [Scan SOACS]
scans -> do
          let scanfun' :: Lambda GPU
scanfun' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
scanfun
              mapfun' :: Lambda GPU
mapfun' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
mapfun
          Certs -> IntrablockM () -> IntrablockM ()
forall a.
Certs
-> BuilderT GPU (RWS () IntraAcc VNameSource) a
-> BuilderT GPU (RWS () IntraAcc VNameSource) a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (StmAux () -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) (IntrablockM () -> IntrablockM ())
-> IntrablockM () -> IntrablockM ()
forall a b. (a -> b) -> a -> b
$
            Stms (Rep (BuilderT GPU (RWS () IntraAcc VNameSource)))
-> IntrablockM ()
Stms GPU -> IntrablockM ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (Stms GPU -> IntrablockM ())
-> BuilderT GPU (RWS () IntraAcc VNameSource) (Stms GPU)
-> IntrablockM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOpLevel GPU
-> Pat (LetDec GPU)
-> Certs
-> SubExp
-> [SegBinOp GPU]
-> Lambda GPU
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> BuilderT GPU (RWS () IntraAcc VNameSource) (Stms GPU)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep
-> Pat (LetDec rep)
-> Certs
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms rep)
segScan SegOpLevel GPU
SegLevel
lvl Pat (LetDec SOACS)
Pat (LetDec GPU)
pat Certs
forall a. Monoid a => a
mempty SubExp
w [Commutativity -> Lambda GPU -> [SubExp] -> Shape -> SegBinOp GPU
forall rep.
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp Commutativity
Noncommutative Lambda GPU
scanfun' [SubExp]
nes Shape
forall a. Monoid a => a
mempty] Lambda GPU
mapfun' [VName]
arrs [] []
          [SubExp] -> IntrablockM ()
parallelMin [SubExp
w]
    Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form)
      | Just ([Reduce SOACS]
reds, Lambda SOACS
map_lam) <- ScremaForm SOACS -> Maybe ([Reduce SOACS], Lambda SOACS)
forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm SOACS
form -> do
          let onReduce :: Reduce SOACS -> SegBinOp GPU
onReduce (Reduce Commutativity
comm Lambda SOACS
red_lam [SubExp]
nes) =
                Commutativity -> Lambda GPU -> [SubExp] -> Shape -> SegBinOp GPU
forall rep.
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp Commutativity
comm (Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
red_lam) [SubExp]
nes Shape
forall a. Monoid a => a
mempty
              reds' :: [SegBinOp GPU]
reds' = (Reduce SOACS -> SegBinOp GPU) -> [Reduce SOACS] -> [SegBinOp GPU]
forall a b. (a -> b) -> [a] -> [b]
map Reduce SOACS -> SegBinOp GPU
onReduce [Reduce SOACS]
reds
              map_lam' :: Lambda GPU
map_lam' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
map_lam
          Certs -> IntrablockM () -> IntrablockM ()
forall a.
Certs
-> BuilderT GPU (RWS () IntraAcc VNameSource) a
-> BuilderT GPU (RWS () IntraAcc VNameSource) a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (StmAux () -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) (IntrablockM () -> IntrablockM ())
-> IntrablockM () -> IntrablockM ()
forall a b. (a -> b) -> a -> b
$
            Stms (Rep (BuilderT GPU (RWS () IntraAcc VNameSource)))
-> IntrablockM ()
Stms GPU -> IntrablockM ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (Stms GPU -> IntrablockM ())
-> BuilderT GPU (RWS () IntraAcc VNameSource) (Stms GPU)
-> IntrablockM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOpLevel GPU
-> Pat (LetDec GPU)
-> Certs
-> SubExp
-> [SegBinOp GPU]
-> Lambda GPU
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> BuilderT GPU (RWS () IntraAcc VNameSource) (Stms GPU)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep
-> Pat (LetDec rep)
-> Certs
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms rep)
segRed SegOpLevel GPU
SegLevel
lvl Pat (LetDec SOACS)
Pat (LetDec GPU)
pat Certs
forall a. Monoid a => a
mempty SubExp
w [SegBinOp GPU]
reds' Lambda GPU
map_lam' [VName]
arrs [] []
          [SubExp] -> IntrablockM ()
parallelMin [SubExp
w]
    Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form) ->
      -- This screma is too complicated for us to immediately do
      -- anything, so split it up and try again.
      (Stm SOACS -> IntrablockM ()) -> Stms SOACS -> IntrablockM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm SOACS -> IntrablockM ()
intrablockStm (Stms SOACS -> IntrablockM ())
-> (((), Stms SOACS) -> Stms SOACS)
-> ((), Stms SOACS)
-> IntrablockM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm SOACS -> Stm SOACS) -> Stms SOACS -> Stms SOACS
forall a b. (a -> b) -> Seq a -> Seq b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certs -> Stm SOACS -> Stm SOACS
forall rep. Certs -> Stm rep -> Stm rep
certify (StmAux () -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux)) (Stms SOACS -> Stms SOACS)
-> (((), Stms SOACS) -> Stms SOACS)
-> ((), Stms SOACS)
-> Stms SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((), Stms SOACS) -> Stms SOACS
forall a b. (a, b) -> b
snd
        (((), Stms SOACS) -> IntrablockM ())
-> BuilderT GPU (RWS () IntraAcc VNameSource) ((), Stms SOACS)
-> IntrablockM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BuilderT SOACS (BuilderT GPU (RWS () IntraAcc VNameSource)) ()
-> Scope SOACS
-> BuilderT GPU (RWS () IntraAcc VNameSource) ((), Stms SOACS)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT (Pat
  (LetDec
     (Rep
        (BuilderT SOACS (BuilderT GPU (RWS () IntraAcc VNameSource)))))
-> SubExp
-> ScremaForm
     (Rep (BuilderT SOACS (BuilderT GPU (RWS () IntraAcc VNameSource))))
-> [VName]
-> BuilderT SOACS (BuilderT GPU (RWS () IntraAcc VNameSource)) ()
forall (m :: * -> *).
(MonadBuilder m, Op (Rep m) ~ SOAC (Rep m), Buildable (Rep m)) =>
Pat (LetDec (Rep m))
-> SubExp -> ScremaForm (Rep m) -> [VName] -> m ()
dissectScrema Pat
  (LetDec
     (Rep
        (BuilderT SOACS (BuilderT GPU (RWS () IntraAcc VNameSource)))))
Pat (LetDec SOACS)
pat SubExp
w ScremaForm
  (Rep (BuilderT SOACS (BuilderT GPU (RWS () IntraAcc VNameSource))))
ScremaForm SOACS
form [VName]
arrs) (Scope GPU -> Scope SOACS
scopeForSOACs Scope GPU
scope)
    Op (Hist SubExp
w [VName]
arrs [HistOp SOACS]
ops Lambda SOACS
bucket_fun) -> do
      ops' <- [HistOp SOACS]
-> (HistOp SOACS
    -> BuilderT GPU (RWS () IntraAcc VNameSource) (HistOp GPU))
-> BuilderT GPU (RWS () IntraAcc VNameSource) [HistOp GPU]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp SOACS]
ops ((HistOp SOACS
  -> BuilderT GPU (RWS () IntraAcc VNameSource) (HistOp GPU))
 -> BuilderT GPU (RWS () IntraAcc VNameSource) [HistOp GPU])
-> (HistOp SOACS
    -> BuilderT GPU (RWS () IntraAcc VNameSource) (HistOp GPU))
-> BuilderT GPU (RWS () IntraAcc VNameSource) [HistOp GPU]
forall a b. (a -> b) -> a -> b
$ \(HistOp Shape
num_bins SubExp
rf [VName]
dests [SubExp]
nes Lambda SOACS
op) -> do
        (op', nes', shape) <- Lambda SOACS
-> [SubExp]
-> BuilderT
     GPU (RWS () IntraAcc VNameSource) (Lambda SOACS, [SubExp], Shape)
forall (m :: * -> *).
MonadBuilder m =>
Lambda SOACS -> [SubExp] -> m (Lambda SOACS, [SubExp], Shape)
determineReduceOp Lambda SOACS
op [SubExp]
nes
        let op'' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
op'
        pure $ GPU.HistOp num_bins rf dests nes' shape op''

      let bucket_fun' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
bucket_fun
      certifying (stmAuxCerts aux) $
        addStms =<< segHist lvl pat w [] [] ops' bucket_fun' arrs
      parallelMin [w]
    Op (Stream SubExp
w [VName]
arrs [SubExp]
accs Lambda SOACS
lam)
      | LParam SOACS
chunk_size_param : [LParam SOACS]
_ <- Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam -> do
          types <- (Scope GPU -> Scope SOACS)
-> BuilderT GPU (RWS () IntraAcc VNameSource) (Scope SOACS)
forall a.
(Scope GPU -> a) -> BuilderT GPU (RWS () IntraAcc VNameSource) a
forall rep (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope Scope GPU -> Scope SOACS
forall fromrep torep.
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope
          ((), stream_stms) <-
            runBuilderT (sequentialStreamWholeArray pat w accs lam arrs) types
          let replace (Var VName
v) | VName
v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
LParam SOACS
chunk_size_param = SubExp
w
              replace SubExp
se = SubExp
se
              replaceSets (IntraAcc Set [SubExp]
x Set [SubExp]
y Log
log) =
                Set [SubExp] -> Set [SubExp] -> Log -> IntraAcc
IntraAcc (([SubExp] -> [SubExp]) -> Set [SubExp] -> Set [SubExp]
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map ((SubExp -> SubExp) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> SubExp
replace) Set [SubExp]
x) (([SubExp] -> [SubExp]) -> Set [SubExp] -> Set [SubExp]
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map ((SubExp -> SubExp) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> SubExp
replace) Set [SubExp]
y) Log
log
          censor replaceSets $ intrablockStms stream_stms
    Op (Scatter SubExp
w [VName]
ivs ScatterSpec VName
dests Lambda SOACS
lam) -> do
      write_i <- String -> BuilderT GPU (RWS () IntraAcc VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"write_i"
      space <- mkSegSpace [(write_i, w)]

      let lam' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
lam
          grouped = ScatterSpec VName
-> Result -> [(Shape, VName, [(Result, SubExpRes)])]
forall array a.
ScatterSpec array -> [a] -> [(Shape, array, [([a], a)])]
groupScatterResults ScatterSpec VName
dests (Result -> [(Shape, VName, [(Result, SubExpRes)])])
-> Result -> [(Shape, VName, [(Result, SubExpRes)])]
forall a b. (a -> b) -> a -> b
$ Body GPU -> Result
forall rep. Body rep -> Result
bodyResult (Body GPU -> Result) -> Body GPU -> Result
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
lam'
          (_, dest_arrs, _) = unzip3 grouped

      dest_ts <- mapM lookupType dest_arrs

      let krets = do
            (a_t, (_a_w, a, is_vs)) <- [Type]
-> [(Shape, VName, [(Result, SubExpRes)])]
-> [(Type, (Shape, VName, [(Result, SubExpRes)]))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
dest_ts [(Shape, VName, [(Result, SubExpRes)])]
grouped
            let cs =
                  ((Result, SubExpRes) -> Certs) -> [(Result, SubExpRes)] -> Certs
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap ((SubExpRes -> Certs) -> Result -> Certs
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap SubExpRes -> Certs
resCerts (Result -> Certs)
-> ((Result, SubExpRes) -> Result) -> (Result, SubExpRes) -> Certs
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Result, SubExpRes) -> Result
forall a b. (a, b) -> a
fst) [(Result, SubExpRes)]
is_vs
                    Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> ((Result, SubExpRes) -> Certs) -> [(Result, SubExpRes)] -> Certs
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (SubExpRes -> Certs
resCerts (SubExpRes -> Certs)
-> ((Result, SubExpRes) -> SubExpRes)
-> (Result, SubExpRes)
-> Certs
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Result, SubExpRes) -> SubExpRes
forall a b. (a, b) -> b
snd) [(Result, SubExpRes)]
is_vs
                is_vs' = do
                  (is, v) <- [(Result, SubExpRes)]
is_vs
                  pure
                    ( fullSlice a_t $ map (DimFix . resSubExp) is,
                      resSubExp v
                    )
            pure $ WriteReturns cs a is_vs'
          inputs = do
            (p, p_a) <- [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda GPU -> [Param (LParamInfo GPU)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
lam') [VName]
ivs
            pure $ KernelInput (paramName p) (paramType p) p_a [Var write_i]

      kstms <- runBuilder_ $
        localScope (scopeOfSegSpace space) $ do
          mapM_ readKernelInput inputs
          addStms $ bodyStms $ lambdaBody lam'

      certifying (stmAuxCerts aux) $ do
        let body = BodyDec GPU -> Stms GPU -> [KernelResult] -> KernelBody GPU
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
kstms [KernelResult]
krets
        letBind pat $ Op $ SegOp $ SegMap lvl space (patTypes pat) body

      parallelMin [w]
    Exp SOACS
_ ->
      Stm (Rep (BuilderT GPU (RWS () IntraAcc VNameSource)))
-> IntrablockM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep (BuilderT GPU (RWS () IntraAcc VNameSource)))
 -> IntrablockM ())
-> Stm (Rep (BuilderT GPU (RWS () IntraAcc VNameSource)))
-> IntrablockM ()
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Stm GPU
soacsStmToGPU Stm SOACS
stm

intrablockStms :: Stms SOACS -> IntrablockM ()
intrablockStms :: Stms SOACS -> IntrablockM ()
intrablockStms = (Stm SOACS -> IntrablockM ()) -> Stms SOACS -> IntrablockM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm SOACS -> IntrablockM ()
intrablockStm

intrablockParalleliseBody ::
  (MonadFreshNames m, HasScope GPU m) =>
  Body SOACS ->
  m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
intrablockParalleliseBody :: forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
Body SOACS -> m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
intrablockParalleliseBody Body SOACS
body = do
  (IntraAcc min_ws avail_ws log, kstms) <-
    IntrablockM () -> m (IntraAcc, Stms GPU)
forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
IntrablockM () -> m (IntraAcc, Stms GPU)
runIntrablockM (IntrablockM () -> m (IntraAcc, Stms GPU))
-> IntrablockM () -> m (IntraAcc, Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> IntrablockM ()
intrablockStms (Stms SOACS -> IntrablockM ()) -> Stms SOACS -> IntrablockM ()
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Stms SOACS
forall rep. Body rep -> Stms rep
bodyStms Body SOACS
body
  pure
    ( S.toList min_ws,
      S.toList avail_ws,
      log,
      KernelBody () kstms $ map ret $ bodyResult body
    )
  where
    ret :: SubExpRes -> KernelResult
ret (SubExpRes Certs
cs SubExp
se) = ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify Certs
cs SubExp
se