{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}

-- | Kernel extraction.
--
-- In the following, I will use the term "width" to denote the amount
-- of immediate parallelism in a map - that is, the outer size of the
-- array(s) being used as input.
--
-- = Basic Idea
--
-- If we have:
--
-- @
--   map
--     map(f)
--     stms_a...
--     map(g)
-- @
--
-- Then we want to distribute to:
--
-- @
--   map
--     map(f)
--   map
--     stms_a
--   map
--     map(g)
-- @
--
-- But for now only if
--
--  (0) it can be done without creating irregular arrays.
--      Specifically, the size of the arrays created by @map(f)@, by
--      @map(g)@ and whatever is created by @stms_a@ that is also used
--      in @map(g)@, must be invariant to the outermost loop.
--
--  (1) the maps are _balanced_.  That is, the functions @f@ and @g@
--      must do the same amount of work for every iteration.
--
-- The advantage is that the map-nests containing @map(f)@ and
-- @map(g)@ can now be trivially flattened at no cost, thus exposing
-- more parallelism.  Note that the @stms_a@ map constitutes array
-- expansion, which requires additional storage.
--
-- = Distributing Sequential Loops
--
-- As a starting point, sequential loops are treated like scalar
-- expressions.  That is, not distributed.  However, sometimes it can
-- be worthwhile to distribute if they contain a map:
--
-- @
--   map
--     loop
--       map
--     map
-- @
--
-- If we distribute the loop and interchange the outer map into the
-- loop, we get this:
--
-- @
--   loop
--     map
--       map
--   map
--     map
-- @
--
-- Now more parallelism may be available.
--
-- = Unbalanced Maps
--
-- Unbalanced maps will as a rule be sequentialised, but sometimes,
-- there is another way.  Assume we find this:
--
-- @
--   map
--     map(f)
--       map(g)
--     map
-- @
--
-- Presume that @map(f)@ is unbalanced.  By the simple rule above, we
-- would then fully sequentialise it, resulting in this:
--
-- @
--   map
--     loop
--   map
--     map
-- @
--
-- == Balancing by Loop Interchange
--
-- The above is not ideal, as we cannot flatten the @map-loop@ nest,
-- and we are thus limited in the amount of parallelism available.
--
-- But assume now that the width of @map(g)@ is invariant to the outer
-- loop.  Then if possible, we can interchange @map(f)@ and @map(g)@,
-- sequentialise @map(f)@ and distribute, interchanging the outer
-- parallel loop into the sequential loop:
--
-- @
--   loop(f)
--     map
--       map(g)
--   map
--     map
-- @
--
-- After flattening the two nests we can obtain more parallelism.
--
-- When distributing a map, we also need to distribute everything that
-- the map depends on - possibly as its own map.  When distributing a
-- set of scalar bindings, we will need to know which of the binding
-- results are used afterwards.  Hence, we will need to compute usage
-- information.
--
-- = Redomap
--
-- Redomap can be handled much like map.  Distributed loops are
-- distributed as maps, with the parameters corresponding to the
-- neutral elements added to their bodies.  The remaining loop will
-- remain a redomap.  Example:
--
-- @
-- redomap(op,
--         fn (v) =>
--           map(f)
--           map(g),
--         e,a)
-- @
--
-- distributes to
--
-- @
-- let b = map(fn v =>
--               let acc = e
--               map(f),
--               a)
-- redomap(op,
--         fn (v,dist) =>
--           map(g),
--         e,a,b)
-- @
--
-- Note that there may be further kernel extraction opportunities
-- inside the @map(f)@.  The downside of this approach is that the
-- intermediate array (@b@ above) must be written to main memory.  An
-- often better approach is to just turn the entire @redomap@ into a
-- single kernel.
module Futhark.Pass.ExtractKernels (extractKernels) where

import Control.Monad
import Control.Monad.RWS.Strict
import Control.Monad.Reader
import Data.Bifunctor (first)
import Data.Maybe
import Futhark.IR.GPU
import Futhark.IR.SOACS
import Futhark.IR.SOACS.Simplify (simplifyStms)
import Futhark.MonadFreshNames
import Futhark.Pass
import Futhark.Pass.ExtractKernels.BlockedKernel
import Futhark.Pass.ExtractKernels.DistributeNests
import Futhark.Pass.ExtractKernels.Distribution
import Futhark.Pass.ExtractKernels.ISRWIM
import Futhark.Pass.ExtractKernels.Intrablock
import Futhark.Pass.ExtractKernels.StreamKernel
import Futhark.Pass.ExtractKernels.ToGPU
import Futhark.Tools
import Futhark.Transform.FirstOrderTransform qualified as FOT
import Futhark.Transform.Rename
import Futhark.Util.Log
import Prelude hiding (log)

-- | Transform a program using SOACs to a program using explicit
-- kernels, using the kernel extraction transformation.
extractKernels :: Pass SOACS GPU
extractKernels :: Pass SOACS GPU
extractKernels =
  Pass
    { passName :: [Char]
passName = [Char]
"extract kernels",
      passDescription :: [Char]
passDescription = [Char]
"Perform kernel extraction",
      passFunction :: Prog SOACS -> PassM (Prog GPU)
passFunction = Prog SOACS -> PassM (Prog GPU)
transformProg
    }

transformProg :: Prog SOACS -> PassM (Prog GPU)
transformProg :: Prog SOACS -> PassM (Prog GPU)
transformProg Prog SOACS
prog = do
  consts' <- DistribM (Stms GPU) -> PassM (Stms GPU)
forall (m :: * -> *) a.
(MonadLogger m, MonadFreshNames m) =>
DistribM a -> m a
runDistribM (DistribM (Stms GPU) -> PassM (Stms GPU))
-> DistribM (Stms GPU) -> PassM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ KernelPath -> [Stm SOACS] -> DistribM (Stms GPU)
transformStms KernelPath
forall a. Monoid a => a
mempty ([Stm SOACS] -> DistribM (Stms GPU))
-> [Stm SOACS] -> DistribM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms SOACS -> [Stm SOACS]) -> Stms SOACS -> [Stm SOACS]
forall a b. (a -> b) -> a -> b
$ Prog SOACS -> Stms SOACS
forall rep. Prog rep -> Stms rep
progConsts Prog SOACS
prog
  funs' <- mapM (transformFunDef $ scopeOf consts') $ progFuns prog
  pure $
    prog
      { progConsts = consts',
        progFuns = funs'
      }

-- In order to generate more stable threshold names, we keep track of
-- the numbers used for thresholds separately from the ordinary name
-- source,
data State = State
  { State -> VNameSource
stateNameSource :: VNameSource,
    State -> Int
stateThresholdCounter :: Int
  }

newtype DistribM a = DistribM (RWS (Scope GPU) Log State a)
  deriving
    ( (forall a b. (a -> b) -> DistribM a -> DistribM b)
-> (forall a b. a -> DistribM b -> DistribM a) -> Functor DistribM
forall a b. a -> DistribM b -> DistribM a
forall a b. (a -> b) -> DistribM a -> DistribM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> DistribM a -> DistribM b
fmap :: forall a b. (a -> b) -> DistribM a -> DistribM b
$c<$ :: forall a b. a -> DistribM b -> DistribM a
<$ :: forall a b. a -> DistribM b -> DistribM a
Functor,
      Functor DistribM
Functor DistribM =>
(forall a. a -> DistribM a)
-> (forall a b. DistribM (a -> b) -> DistribM a -> DistribM b)
-> (forall a b c.
    (a -> b -> c) -> DistribM a -> DistribM b -> DistribM c)
-> (forall a b. DistribM a -> DistribM b -> DistribM b)
-> (forall a b. DistribM a -> DistribM b -> DistribM a)
-> Applicative DistribM
forall a. a -> DistribM a
forall a b. DistribM a -> DistribM b -> DistribM a
forall a b. DistribM a -> DistribM b -> DistribM b
forall a b. DistribM (a -> b) -> DistribM a -> DistribM b
forall a b c.
(a -> b -> c) -> DistribM a -> DistribM b -> DistribM c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall a. a -> DistribM a
pure :: forall a. a -> DistribM a
$c<*> :: forall a b. DistribM (a -> b) -> DistribM a -> DistribM b
<*> :: forall a b. DistribM (a -> b) -> DistribM a -> DistribM b
$cliftA2 :: forall a b c.
(a -> b -> c) -> DistribM a -> DistribM b -> DistribM c
liftA2 :: forall a b c.
(a -> b -> c) -> DistribM a -> DistribM b -> DistribM c
$c*> :: forall a b. DistribM a -> DistribM b -> DistribM b
*> :: forall a b. DistribM a -> DistribM b -> DistribM b
$c<* :: forall a b. DistribM a -> DistribM b -> DistribM a
<* :: forall a b. DistribM a -> DistribM b -> DistribM a
Applicative,
      Applicative DistribM
Applicative DistribM =>
(forall a b. DistribM a -> (a -> DistribM b) -> DistribM b)
-> (forall a b. DistribM a -> DistribM b -> DistribM b)
-> (forall a. a -> DistribM a)
-> Monad DistribM
forall a. a -> DistribM a
forall a b. DistribM a -> DistribM b -> DistribM b
forall a b. DistribM a -> (a -> DistribM b) -> DistribM b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall a b. DistribM a -> (a -> DistribM b) -> DistribM b
>>= :: forall a b. DistribM a -> (a -> DistribM b) -> DistribM b
$c>> :: forall a b. DistribM a -> DistribM b -> DistribM b
>> :: forall a b. DistribM a -> DistribM b -> DistribM b
$creturn :: forall a. a -> DistribM a
return :: forall a. a -> DistribM a
Monad,
      HasScope GPU,
      LocalScope GPU,
      MonadState State,
      Monad DistribM
Applicative DistribM
(Applicative DistribM, Monad DistribM) =>
(forall a. ToLog a => a -> DistribM ())
-> (Log -> DistribM ()) -> MonadLogger DistribM
Log -> DistribM ()
forall a. ToLog a => a -> DistribM ()
forall (m :: * -> *).
(Applicative m, Monad m) =>
(forall a. ToLog a => a -> m ()) -> (Log -> m ()) -> MonadLogger m
$clogMsg :: forall a. ToLog a => a -> DistribM ()
logMsg :: forall a. ToLog a => a -> DistribM ()
$caddLog :: Log -> DistribM ()
addLog :: Log -> DistribM ()
MonadLogger
    )

instance MonadFreshNames DistribM where
  getNameSource :: DistribM VNameSource
getNameSource = (State -> VNameSource) -> DistribM VNameSource
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets State -> VNameSource
stateNameSource
  putNameSource :: VNameSource -> DistribM ()
putNameSource VNameSource
src = (State -> State) -> DistribM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((State -> State) -> DistribM ())
-> (State -> State) -> DistribM ()
forall a b. (a -> b) -> a -> b
$ \State
s -> State
s {stateNameSource = src}

runDistribM ::
  (MonadLogger m, MonadFreshNames m) =>
  DistribM a ->
  m a
runDistribM :: forall (m :: * -> *) a.
(MonadLogger m, MonadFreshNames m) =>
DistribM a -> m a
runDistribM (DistribM RWS (Scope GPU) Log State a
m) = do
  (x, msgs) <- (VNameSource -> ((a, Log), VNameSource)) -> m (a, Log)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((a, Log), VNameSource)) -> m (a, Log))
-> (VNameSource -> ((a, Log), VNameSource)) -> m (a, Log)
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
    let (a
x, State
s, Log
msgs) = RWS (Scope GPU) Log State a
-> Scope GPU -> State -> (a, State, Log)
forall r w s a. RWS r w s a -> r -> s -> (a, s, w)
runRWS RWS (Scope GPU) Log State a
m Scope GPU
forall a. Monoid a => a
mempty (VNameSource -> Int -> State
State VNameSource
src Int
0)
     in ((a
x, Log
msgs), State -> VNameSource
stateNameSource State
s)
  addLog msgs
  pure x

transformFunDef ::
  (MonadFreshNames m, MonadLogger m) =>
  Scope GPU ->
  FunDef SOACS ->
  m (FunDef GPU)
transformFunDef :: forall (m :: * -> *).
(MonadFreshNames m, MonadLogger m) =>
Scope GPU -> FunDef SOACS -> m (FunDef GPU)
transformFunDef Scope GPU
scope (FunDef Maybe EntryPoint
entry Attrs
attrs Name
name [(RetType SOACS, RetAls)]
rettype [FParam SOACS]
params Body SOACS
body) = DistribM (FunDef GPU) -> m (FunDef GPU)
forall (m :: * -> *) a.
(MonadLogger m, MonadFreshNames m) =>
DistribM a -> m a
runDistribM (DistribM (FunDef GPU) -> m (FunDef GPU))
-> DistribM (FunDef GPU) -> m (FunDef GPU)
forall a b. (a -> b) -> a -> b
$ do
  body' <-
    Scope GPU -> DistribM (Body GPU) -> DistribM (Body GPU)
forall a. Scope GPU -> DistribM a -> DistribM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Scope GPU
scope Scope GPU -> Scope GPU -> Scope GPU
forall a. Semigroup a => a -> a -> a
<> [Param DeclType] -> Scope GPU
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param DeclType]
[FParam SOACS]
params) (DistribM (Body GPU) -> DistribM (Body GPU))
-> DistribM (Body GPU) -> DistribM (Body GPU)
forall a b. (a -> b) -> a -> b
$
      KernelPath -> Body SOACS -> DistribM (Body GPU)
transformBody KernelPath
forall a. Monoid a => a
mempty Body SOACS
body
  pure $ FunDef entry attrs name rettype params body'

type GPUStms = Stms GPU

transformBody :: KernelPath -> Body SOACS -> DistribM (Body GPU)
transformBody :: KernelPath -> Body SOACS -> DistribM (Body GPU)
transformBody KernelPath
path Body SOACS
body = do
  stms <- KernelPath -> [Stm SOACS] -> DistribM (Stms GPU)
transformStms KernelPath
path ([Stm SOACS] -> DistribM (Stms GPU))
-> [Stm SOACS] -> DistribM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms SOACS -> [Stm SOACS]) -> Stms SOACS -> [Stm SOACS]
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Stms SOACS
forall rep. Body rep -> Stms rep
bodyStms Body SOACS
body
  pure $ mkBody stms $ bodyResult body

transformStms :: KernelPath -> [Stm SOACS] -> DistribM GPUStms
transformStms :: KernelPath -> [Stm SOACS] -> DistribM (Stms GPU)
transformStms KernelPath
_ [] =
  Stms GPU -> DistribM (Stms GPU)
forall a. a -> DistribM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms GPU
forall a. Monoid a => a
mempty
transformStms KernelPath
path (Stm SOACS
stm : [Stm SOACS]
stms) =
  Stm SOACS -> DistribM (Maybe (Stms SOACS))
sequentialisedUnbalancedStm Stm SOACS
stm DistribM (Maybe (Stms SOACS))
-> (Maybe (Stms SOACS) -> DistribM (Stms GPU))
-> DistribM (Stms GPU)
forall a b. DistribM a -> (a -> DistribM b) -> DistribM b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Maybe (Stms SOACS)
Nothing -> do
      stm' <- KernelPath -> Stm SOACS -> DistribM (Stms GPU)
transformStm KernelPath
path Stm SOACS
stm
      inScopeOf stm' $
        (stm' <>) <$> transformStms path stms
    Just Stms SOACS
stms' ->
      KernelPath -> [Stm SOACS] -> DistribM (Stms GPU)
transformStms KernelPath
path ([Stm SOACS] -> DistribM (Stms GPU))
-> [Stm SOACS] -> DistribM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms SOACS
stms' [Stm SOACS] -> [Stm SOACS] -> [Stm SOACS]
forall a. Semigroup a => a -> a -> a
<> [Stm SOACS]
stms

unbalancedLambda :: Lambda SOACS -> Bool
unbalancedLambda :: Lambda SOACS -> Bool
unbalancedLambda Lambda SOACS
orig_lam =
  Names -> Body SOACS -> Bool
forall {rep}. (OpC rep ~ SOAC) => Names -> Body rep -> Bool
unbalancedBody ([VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName ([Param Type] -> [VName]) -> [Param Type] -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
orig_lam) (Body SOACS -> Bool) -> Body SOACS -> Bool
forall a b. (a -> b) -> a -> b
$
    Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
orig_lam
  where
    subExpBound :: SubExp -> Names -> Bool
subExpBound (Var VName
i) Names
bound = VName
i VName -> Names -> Bool
`nameIn` Names
bound
    subExpBound (Constant PrimValue
_) Names
_ = Bool
False

    unbalancedBody :: Names -> Body rep -> Bool
unbalancedBody Names
bound Body rep
body =
      (Stm rep -> Bool) -> Seq (Stm rep) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Names -> Exp rep -> Bool
unbalancedStm (Names
bound Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Body rep -> Names
forall rep. Body rep -> Names
boundInBody Body rep
body) (Exp rep -> Bool) -> (Stm rep -> Exp rep) -> Stm rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp) (Seq (Stm rep) -> Bool) -> Seq (Stm rep) -> Bool
forall a b. (a -> b) -> a -> b
$
        Body rep -> Seq (Stm rep)
forall rep. Body rep -> Stms rep
bodyStms Body rep
body

    -- XXX - our notion of balancing is probably still too naive.
    unbalancedStm :: Names -> Exp rep -> Bool
unbalancedStm Names
bound (Op (Stream SubExp
w [VName]
_ [SubExp]
_ Lambda rep
_)) =
      SubExp
w SubExp -> Names -> Bool
`subExpBound` Names
bound
    unbalancedStm Names
bound (Op (Screma SubExp
w [VName]
_ ScremaForm rep
_)) =
      SubExp
w SubExp -> Names -> Bool
`subExpBound` Names
bound
    unbalancedStm Names
_ Op {} =
      Bool
False
    unbalancedStm Names
_ Loop {} = Bool
False
    unbalancedStm Names
bound (WithAcc [WithAccInput rep]
_ Lambda rep
lam) =
      Names -> Body rep -> Bool
unbalancedBody Names
bound (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam)
    unbalancedStm Names
bound (Match [SubExp]
ses [Case (Body rep)]
cases Body rep
defbody MatchDec (BranchType rep)
_) =
      (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (SubExp -> Names -> Bool
`subExpBound` Names
bound) [SubExp]
ses
        Bool -> Bool -> Bool
&& ( (Case (Body rep) -> Bool) -> [Case (Body rep)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Names -> Body rep -> Bool
unbalancedBody Names
bound (Body rep -> Bool)
-> (Case (Body rep) -> Body rep) -> Case (Body rep) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Case (Body rep) -> Body rep
forall body. Case body -> body
caseBody) [Case (Body rep)]
cases
               Bool -> Bool -> Bool
|| Names -> Body rep -> Bool
unbalancedBody Names
bound Body rep
defbody
           )
    unbalancedStm Names
_ (BasicOp BasicOp
_) =
      Bool
False
    unbalancedStm Names
_ Apply {} = Bool
False

sequentialisedUnbalancedStm :: Stm SOACS -> DistribM (Maybe (Stms SOACS))
sequentialisedUnbalancedStm :: Stm SOACS -> DistribM (Maybe (Stms SOACS))
sequentialisedUnbalancedStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
_ (Op soac :: Op SOACS
soac@(Screma SubExp
_ [VName]
_ ScremaForm SOACS
form)))
  | Just ([Reduce SOACS]
_, Lambda SOACS
lam2) <- ScremaForm SOACS -> Maybe ([Reduce SOACS], Lambda SOACS)
forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm SOACS
form,
    Lambda SOACS -> Bool
unbalancedLambda Lambda SOACS
lam2,
    Lambda SOACS -> Bool
lambdaContainsParallelism Lambda SOACS
lam2 = do
      types <- (Scope GPU -> Scope SOACS) -> DistribM (Scope SOACS)
forall a. (Scope GPU -> a) -> DistribM a
forall rep (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope Scope GPU -> Scope SOACS
scopeForSOACs
      Just . snd <$> runBuilderT (FOT.transformSOAC pat soac) types
sequentialisedUnbalancedStm Stm SOACS
_ =
  Maybe (Stms SOACS) -> DistribM (Maybe (Stms SOACS))
forall a. a -> DistribM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Stms SOACS)
forall a. Maybe a
Nothing

cmpSizeLe ::
  String ->
  SizeClass ->
  [SubExp] ->
  DistribM ((SubExp, Name), Stms GPU)
cmpSizeLe :: [Char]
-> SizeClass -> [SubExp] -> DistribM ((SubExp, Name), Stms GPU)
cmpSizeLe [Char]
desc SizeClass
size_class [SubExp]
to_what = do
  x <- (State -> Int) -> DistribM Int
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets State -> Int
stateThresholdCounter
  modify $ \State
s -> State
s {stateThresholdCounter = x + 1}
  let size_key = [Char] -> Name
nameFromString ([Char] -> Name) -> [Char] -> Name
forall a b. (a -> b) -> a -> b
$ [Char]
desc [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
x
  runBuilder $ do
    to_what' <-
      letSubExp "comparatee"
        =<< foldBinOp (Mul Int64 OverflowUndef) (intConst Int64 1) to_what
    cmp_res <- letSubExp desc $ Op $ SizeOp $ CmpSizeLe size_key size_class to_what'
    pure (cmp_res, size_key)

kernelAlternatives ::
  (MonadFreshNames m, HasScope GPU m) =>
  Pat Type ->
  Body GPU ->
  [(SubExp, Body GPU)] ->
  m (Stms GPU)
kernelAlternatives :: forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
Pat Type -> Body GPU -> [(SubExp, Body GPU)] -> m (Stms GPU)
kernelAlternatives Pat Type
pat Body GPU
default_body [] = Builder GPU () -> m (Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Builder GPU () -> m (Stms GPU)) -> Builder GPU () -> m (Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
  ses <- Body (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind Body (Rep (BuilderT GPU (State VNameSource)))
Body GPU
default_body
  forM_ (zip (patNames pat) ses) $ \(VName
name, SubExpRes Certs
cs SubExp
se) ->
    Certs -> Builder GPU () -> Builder GPU ()
forall a.
Certs
-> BuilderT GPU (State VNameSource) a
-> BuilderT GPU (State VNameSource) a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (Builder GPU () -> Builder GPU ())
-> Builder GPU () -> Builder GPU ()
forall a b. (a -> b) -> a -> b
$ [VName]
-> Exp (Rep (BuilderT GPU (State VNameSource))) -> Builder GPU ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
name] (Exp (Rep (BuilderT GPU (State VNameSource))) -> Builder GPU ())
-> Exp (Rep (BuilderT GPU (State VNameSource))) -> Builder GPU ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource))))
-> BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
kernelAlternatives Pat Type
pat Body GPU
default_body ((SubExp
cond, Body GPU
alt) : [(SubExp, Body GPU)]
alts) = Builder GPU () -> m (Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Builder GPU () -> m (Stms GPU)) -> Builder GPU () -> m (Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
  alts_pat <- ([PatElem Type] -> Pat Type)
-> BuilderT GPU (State VNameSource) [PatElem Type]
-> BuilderT GPU (State VNameSource) (Pat Type)
forall a b.
(a -> b)
-> BuilderT GPU (State VNameSource) a
-> BuilderT GPU (State VNameSource) b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat (BuilderT GPU (State VNameSource) [PatElem Type]
 -> BuilderT GPU (State VNameSource) (Pat Type))
-> ((PatElem Type
     -> BuilderT GPU (State VNameSource) (PatElem Type))
    -> BuilderT GPU (State VNameSource) [PatElem Type])
-> (PatElem Type
    -> BuilderT GPU (State VNameSource) (PatElem Type))
-> BuilderT GPU (State VNameSource) (Pat Type)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [PatElem Type]
-> (PatElem Type
    -> BuilderT GPU (State VNameSource) (PatElem Type))
-> BuilderT GPU (State VNameSource) [PatElem Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
pat) ((PatElem Type -> BuilderT GPU (State VNameSource) (PatElem Type))
 -> BuilderT GPU (State VNameSource) (Pat Type))
-> (PatElem Type
    -> BuilderT GPU (State VNameSource) (PatElem Type))
-> BuilderT GPU (State VNameSource) (Pat Type)
forall a b. (a -> b) -> a -> b
$ \PatElem Type
pe -> do
    name <- [Char] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName ([Char] -> BuilderT GPU (State VNameSource) VName)
-> [Char] -> BuilderT GPU (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ VName -> [Char]
baseString (VName -> [Char]) -> VName -> [Char]
forall a b. (a -> b) -> a -> b
$ PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName PatElem Type
pe
    pure pe {patElemName = name}

  alt_stms <- kernelAlternatives alts_pat default_body alts
  let alt_body = Stms GPU -> Result -> Body GPU
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody Stms GPU
alt_stms (Result -> Body GPU) -> Result -> Body GPU
forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes ([VName] -> Result) -> [VName] -> Result
forall a b. (a -> b) -> a -> b
$ Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
alts_pat

  letBind pat . Match [cond] [Case [Just $ BoolValue True] alt] alt_body $
    MatchDec (staticShapes (patTypes pat)) MatchEquiv

transformLambda :: KernelPath -> Lambda SOACS -> DistribM (Lambda GPU)
transformLambda :: KernelPath -> Lambda SOACS -> DistribM (Lambda GPU)
transformLambda KernelPath
path (Lambda [LParam SOACS]
params [Type]
ret Body SOACS
body) =
  [LParam GPU] -> [Type] -> Body GPU -> Lambda GPU
forall rep. [LParam rep] -> [Type] -> Body rep -> Lambda rep
Lambda [LParam SOACS]
[LParam GPU]
params [Type]
ret
    (Body GPU -> Lambda GPU)
-> DistribM (Body GPU) -> DistribM (Lambda GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Scope GPU -> DistribM (Body GPU) -> DistribM (Body GPU)
forall a. Scope GPU -> DistribM a -> DistribM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([LParam GPU] -> Scope GPU
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [LParam SOACS]
[LParam GPU]
params) (KernelPath -> Body SOACS -> DistribM (Body GPU)
transformBody KernelPath
path Body SOACS
body)

versionScanRed ::
  KernelPath ->
  Pat Type ->
  StmAux () ->
  SubExp ->
  Lambda SOACS ->
  DistribM (Stms GPU) ->
  DistribM (Body GPU) ->
  ([(Name, Bool)] -> DistribM (Body GPU)) ->
  DistribM (Stms GPU)
versionScanRed :: KernelPath
-> Pat Type
-> StmAux ()
-> SubExp
-> Lambda SOACS
-> DistribM (Stms GPU)
-> DistribM (Body GPU)
-> (KernelPath -> DistribM (Body GPU))
-> DistribM (Stms GPU)
versionScanRed KernelPath
path Pat Type
pat StmAux ()
aux SubExp
w Lambda SOACS
map_lam DistribM (Stms GPU)
paralleliseOuter DistribM (Body GPU)
outerParallelBody KernelPath -> DistribM (Body GPU)
innerParallelBody =
  if Bool -> Bool
not (Lambda SOACS -> Bool
lambdaContainsParallelism Lambda SOACS
map_lam)
    Bool -> Bool -> Bool
|| (Attr
"sequential_inner" Attr -> Attrs -> Bool
`inAttrs` StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
aux)
    then DistribM (Stms GPU)
paralleliseOuter
    else do
      ((outer_suff, outer_suff_key), suff_stms) <-
        [Char]
-> [SubExp]
-> KernelPath
-> Maybe Int64
-> DistribM ((SubExp, Name), Stms GPU)
sufficientParallelism [Char]
"suff_outer_screma" [SubExp
w] KernelPath
path Maybe Int64
forall a. Maybe a
Nothing

      outer_stms <- outerParallelBody
      inner_stms <- innerParallelBody ((outer_suff_key, False) : path)

      (suff_stms <>) <$> kernelAlternatives pat inner_stms [(outer_suff, outer_stms)]

transformStm :: KernelPath -> Stm SOACS -> DistribM GPUStms
transformStm :: KernelPath -> Stm SOACS -> DistribM (Stms GPU)
transformStm KernelPath
_ Stm SOACS
stm
  | Attr
"sequential" Attr -> Attrs -> Bool
`inAttrs` StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs (Stm SOACS -> StmAux (ExpDec SOACS)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm SOACS
stm) =
      Builder GPU () -> DistribM (Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Builder GPU () -> DistribM (Stms GPU))
-> Builder GPU () -> DistribM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Builder GPU ()
forall (m :: * -> *).
(Transformer m, LetDec (Rep m) ~ LetDec SOACS) =>
Stm SOACS -> m ()
FOT.transformStmRecursively Stm SOACS
stm
transformStm KernelPath
path (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op Op SOACS
soac))
  | Attr
"sequential_outer" Attr -> Attrs -> Bool
`inAttrs` StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
StmAux (ExpDec SOACS)
aux =
      KernelPath -> [Stm SOACS] -> DistribM (Stms GPU)
transformStms KernelPath
path ([Stm SOACS] -> DistribM (Stms GPU))
-> (Stms SOACS -> [Stm SOACS]) -> Stms SOACS -> DistribM (Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms SOACS -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms SOACS -> [Stm SOACS])
-> (Stms SOACS -> Stms SOACS) -> Stms SOACS -> [Stm SOACS]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm SOACS -> Stm SOACS) -> Stms SOACS -> Stms SOACS
forall a b. (a -> b) -> Seq a -> Seq b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certs -> Stm SOACS -> Stm SOACS
forall rep. Certs -> Stm rep -> Stm rep
certify (StmAux () -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux))
        (Stms SOACS -> DistribM (Stms GPU))
-> DistribM (Stms SOACS) -> DistribM (Stms GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Builder SOACS () -> DistribM (Stms SOACS)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Pat (LetDec (Rep (BuilderT SOACS (State VNameSource))))
-> SOAC (Rep (BuilderT SOACS (State VNameSource)))
-> Builder SOACS ()
forall (m :: * -> *).
Transformer m =>
Pat (LetDec (Rep m)) -> SOAC (Rep m) -> m ()
FOT.transformSOAC Pat (LetDec (Rep (BuilderT SOACS (State VNameSource))))
Pat (LetDec SOACS)
pat Op SOACS
SOAC (Rep (BuilderT SOACS (State VNameSource)))
soac)
transformStm KernelPath
path (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Match [SubExp]
c [Case (Body SOACS)]
cases Body SOACS
defbody MatchDec (BranchType SOACS)
rt)) = do
  cases' <- (Case (Body SOACS) -> DistribM (Case (Body GPU)))
-> [Case (Body SOACS)] -> DistribM [Case (Body GPU)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((Body SOACS -> DistribM (Body GPU))
-> Case (Body SOACS) -> DistribM (Case (Body GPU))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Case a -> f (Case b)
traverse ((Body SOACS -> DistribM (Body GPU))
 -> Case (Body SOACS) -> DistribM (Case (Body GPU)))
-> (Body SOACS -> DistribM (Body GPU))
-> Case (Body SOACS)
-> DistribM (Case (Body GPU))
forall a b. (a -> b) -> a -> b
$ KernelPath -> Body SOACS -> DistribM (Body GPU)
transformBody KernelPath
path) [Case (Body SOACS)]
cases
  defbody' <- transformBody path defbody
  pure $ oneStm $ Let pat aux $ Match c cases' defbody' rt
transformStm KernelPath
path (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (WithAcc [WithAccInput SOACS]
inputs Lambda SOACS
lam)) =
  Stm GPU -> Stms GPU
forall rep. Stm rep -> Stms rep
oneStm (Stm GPU -> Stms GPU)
-> (Exp GPU -> Stm GPU) -> Exp GPU -> Stms GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
Pat (LetDec GPU)
pat StmAux (ExpDec SOACS)
StmAux (ExpDec GPU)
aux
    (Exp GPU -> Stms GPU) -> DistribM (Exp GPU) -> DistribM (Stms GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([WithAccInput GPU] -> Lambda GPU -> Exp GPU
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc ((WithAccInput SOACS -> WithAccInput GPU)
-> [WithAccInput SOACS] -> [WithAccInput GPU]
forall a b. (a -> b) -> [a] -> [b]
map WithAccInput SOACS -> WithAccInput GPU
forall {f :: * -> *} {p :: * -> * -> *} {a} {b} {c}.
(Functor f, Bifunctor p) =>
(a, b, f (p (Lambda SOACS) c)) -> (a, b, f (p (Lambda GPU) c))
transformInput [WithAccInput SOACS]
inputs) (Lambda GPU -> Exp GPU)
-> DistribM (Lambda GPU) -> DistribM (Exp GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> Lambda SOACS -> DistribM (Lambda GPU)
transformLambda KernelPath
path Lambda SOACS
lam)
  where
    transformInput :: (a, b, f (p (Lambda SOACS) c)) -> (a, b, f (p (Lambda GPU) c))
transformInput (a
shape, b
arrs, f (p (Lambda SOACS) c)
op) =
      (a
shape, b
arrs, (p (Lambda SOACS) c -> p (Lambda GPU) c)
-> f (p (Lambda SOACS) c) -> f (p (Lambda GPU) c)
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Lambda SOACS -> Lambda GPU)
-> p (Lambda SOACS) c -> p (Lambda GPU) c
forall a b c. (a -> b) -> p a c -> p b c
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first Lambda SOACS -> Lambda GPU
soacsLambdaToGPU) f (p (Lambda SOACS) c)
op)
transformStm KernelPath
path (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Loop [(FParam SOACS, SubExp)]
merge LoopForm
form Body SOACS
body)) =
  Scope GPU -> DistribM (Stms GPU) -> DistribM (Stms GPU)
forall a. Scope GPU -> DistribM a -> DistribM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (LoopForm -> Scope GPU
forall rep. LoopForm -> Scope rep
scopeOfLoopForm LoopForm
form Scope GPU -> Scope GPU -> Scope GPU
forall a. Semigroup a => a -> a -> a
<> [Param DeclType] -> Scope GPU
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param DeclType]
params) (DistribM (Stms GPU) -> DistribM (Stms GPU))
-> DistribM (Stms GPU) -> DistribM (Stms GPU)
forall a b. (a -> b) -> a -> b
$
    Stm GPU -> Stms GPU
forall rep. Stm rep -> Stms rep
oneStm (Stm GPU -> Stms GPU)
-> (Body GPU -> Stm GPU) -> Body GPU -> Stms GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
Pat (LetDec GPU)
pat StmAux (ExpDec SOACS)
StmAux (ExpDec GPU)
aux (Exp GPU -> Stm GPU)
-> (Body GPU -> Exp GPU) -> Body GPU -> Stm GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(FParam GPU, SubExp)] -> LoopForm -> Body GPU -> Exp GPU
forall rep.
[(FParam rep, SubExp)] -> LoopForm -> Body rep -> Exp rep
Loop [(FParam SOACS, SubExp)]
[(FParam GPU, SubExp)]
merge LoopForm
form (Body GPU -> Stms GPU)
-> DistribM (Body GPU) -> DistribM (Stms GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> Body SOACS -> DistribM (Body GPU)
transformBody KernelPath
path Body SOACS
body
  where
    params :: [Param DeclType]
params = ((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
merge
transformStm KernelPath
path (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form)))
  | Just Lambda SOACS
lam <- ScremaForm SOACS -> Maybe (Lambda SOACS)
forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm SOACS
form =
      KernelPath -> MapLoop -> DistribM (Stms GPU)
onMap KernelPath
path (MapLoop -> DistribM (Stms GPU)) -> MapLoop -> DistribM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ Pat Type
-> StmAux () -> SubExp -> Lambda SOACS -> [VName] -> MapLoop
MapLoop Pat Type
Pat (LetDec SOACS)
pat StmAux ()
StmAux (ExpDec SOACS)
aux SubExp
w Lambda SOACS
lam [VName]
arrs
transformStm KernelPath
path (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form)))
  | Just [Scan SOACS]
scans <- ScremaForm SOACS -> Maybe [Scan SOACS]
forall rep. ScremaForm rep -> Maybe [Scan rep]
isScanSOAC ScremaForm SOACS
form,
    Scan Lambda SOACS
scan_lam [SubExp]
nes <- [Scan SOACS] -> Scan SOACS
forall rep. Buildable rep => [Scan rep] -> Scan rep
singleScan [Scan SOACS]
scans,
    Just BuilderT SOACS DistribM ()
do_iswim <- Pat Type
-> SubExp
-> Lambda SOACS
-> [(SubExp, VName)]
-> Maybe (BuilderT SOACS DistribM ())
forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
Pat Type
-> SubExp -> Lambda SOACS -> [(SubExp, VName)] -> Maybe (m ())
iswim Pat Type
Pat (LetDec SOACS)
pat SubExp
w Lambda SOACS
scan_lam ([(SubExp, VName)] -> Maybe (BuilderT SOACS DistribM ()))
-> [(SubExp, VName)] -> Maybe (BuilderT SOACS DistribM ())
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [VName] -> [(SubExp, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
nes [VName]
arrs = do
      types <- (Scope GPU -> Scope SOACS) -> DistribM (Scope SOACS)
forall a. (Scope GPU -> a) -> DistribM a
forall rep (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope Scope GPU -> Scope SOACS
scopeForSOACs
      transformStms path . stmsToList . snd =<< runBuilderT (certifying cs do_iswim) types
  | Just ([Scan SOACS]
scans, Lambda SOACS
map_lam) <- ScremaForm SOACS -> Maybe ([Scan SOACS], Lambda SOACS)
forall rep. ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
isScanomapSOAC ScremaForm SOACS
form = do
      let paralleliseOuter :: DistribM (Stms GPU)
paralleliseOuter = Builder GPU () -> DistribM (Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Builder GPU () -> DistribM (Stms GPU))
-> Builder GPU () -> DistribM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
            scan_ops <- [Scan SOACS]
-> (Scan SOACS -> BuilderT GPU (State VNameSource) (SegBinOp GPU))
-> BuilderT GPU (State VNameSource) [SegBinOp GPU]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Scan SOACS]
scans ((Scan SOACS -> BuilderT GPU (State VNameSource) (SegBinOp GPU))
 -> BuilderT GPU (State VNameSource) [SegBinOp GPU])
-> (Scan SOACS -> BuilderT GPU (State VNameSource) (SegBinOp GPU))
-> BuilderT GPU (State VNameSource) [SegBinOp GPU]
forall a b. (a -> b) -> a -> b
$ \(Scan Lambda SOACS
scan_lam [SubExp]
nes) -> do
              (scan_lam', nes', shape) <- Lambda SOACS
-> [SubExp]
-> BuilderT GPU (State VNameSource) (Lambda SOACS, [SubExp], Shape)
forall (m :: * -> *).
MonadBuilder m =>
Lambda SOACS -> [SubExp] -> m (Lambda SOACS, [SubExp], Shape)
determineReduceOp Lambda SOACS
scan_lam [SubExp]
nes
              let scan_lam'' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
scan_lam'
              pure $ SegBinOp Noncommutative scan_lam'' nes' shape
            let map_lam_sequential = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
map_lam
            lvl <- segThreadCapped [w] "segscan" $ NoRecommendation SegNoVirt
            addStms . fmap (certify cs)
              =<< segScan lvl pat mempty w scan_ops map_lam_sequential arrs [] []

          outerParallelBody :: DistribM (Body GPU)
outerParallelBody =
            Body GPU -> DistribM (Body GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody
              (Body GPU -> DistribM (Body GPU))
-> DistribM (Body GPU) -> DistribM (Body GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Stms GPU -> Result -> Body GPU
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody (Stms GPU -> Result -> Body GPU)
-> DistribM (Stms GPU) -> DistribM (Result -> Body GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DistribM (Stms GPU)
paralleliseOuter DistribM (Result -> Body GPU)
-> DistribM Result -> DistribM (Body GPU)
forall a b. DistribM (a -> b) -> DistribM a -> DistribM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistribM Result
forall a. a -> DistribM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([VName] -> Result
varsRes (Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
Pat (LetDec SOACS)
pat)))

          paralleliseInner :: KernelPath -> DistribM (Stms GPU)
paralleliseInner KernelPath
path' = do
            (mapstm, scanstm) <-
              Pat (LetDec SOACS)
-> (SubExp, [Scan SOACS], Lambda SOACS, [VName])
-> DistribM (Stm SOACS, Stm SOACS)
forall (m :: * -> *) rep.
(MonadFreshNames m, Buildable rep, ExpDec rep ~ (),
 Op rep ~ SOAC rep) =>
Pat (LetDec rep)
-> (SubExp, [Scan rep], Lambda rep, [VName])
-> m (Stm rep, Stm rep)
scanomapToMapAndScan Pat (LetDec SOACS)
pat (SubExp
w, [Scan SOACS]
scans, Lambda SOACS
map_lam, [VName]
arrs)
            types <- asksScope scopeForSOACs
            transformStms path' . stmsToList <=< (`runBuilderT_` types) $
              addStms =<< simplifyStms (stmsFromList [certify cs mapstm, certify cs scanstm])

          innerParallelBody :: KernelPath -> DistribM (Body GPU)
innerParallelBody KernelPath
path' =
            Body GPU -> DistribM (Body GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody
              (Body GPU -> DistribM (Body GPU))
-> DistribM (Body GPU) -> DistribM (Body GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Stms GPU -> Result -> Body GPU
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody (Stms GPU -> Result -> Body GPU)
-> DistribM (Stms GPU) -> DistribM (Result -> Body GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> DistribM (Stms GPU)
paralleliseInner KernelPath
path' DistribM (Result -> Body GPU)
-> DistribM Result -> DistribM (Body GPU)
forall a b. DistribM (a -> b) -> DistribM a -> DistribM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistribM Result
forall a. a -> DistribM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([VName] -> Result
varsRes (Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
Pat (LetDec SOACS)
pat)))

      KernelPath
-> Pat Type
-> StmAux ()
-> SubExp
-> Lambda SOACS
-> DistribM (Stms GPU)
-> DistribM (Body GPU)
-> (KernelPath -> DistribM (Body GPU))
-> DistribM (Stms GPU)
versionScanRed KernelPath
path Pat Type
Pat (LetDec SOACS)
pat StmAux ()
StmAux (ExpDec SOACS)
aux SubExp
w Lambda SOACS
map_lam DistribM (Stms GPU)
paralleliseOuter DistribM (Body GPU)
outerParallelBody KernelPath -> DistribM (Body GPU)
innerParallelBody
  where
    cs :: Certs
cs = StmAux () -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux
transformStm KernelPath
path (Let Pat (LetDec SOACS)
res_pat StmAux (ExpDec SOACS)
aux (Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form)))
  | Just [Reduce Commutativity
comm Lambda SOACS
red_fun [SubExp]
nes] <- ScremaForm SOACS -> Maybe [Reduce SOACS]
forall rep. ScremaForm rep -> Maybe [Reduce rep]
isReduceSOAC ScremaForm SOACS
form,
    let comm' :: Commutativity
comm'
          | Lambda SOACS -> Bool
forall rep. Lambda rep -> Bool
commutativeLambda Lambda SOACS
red_fun = Commutativity
Commutative
          | Bool
otherwise = Commutativity
comm,
    Just BuilderT SOACS DistribM ()
do_irwim <- Pat Type
-> SubExp
-> Commutativity
-> Lambda SOACS
-> [(SubExp, VName)]
-> Maybe (BuilderT SOACS DistribM ())
forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
Pat Type
-> SubExp
-> Commutativity
-> Lambda SOACS
-> [(SubExp, VName)]
-> Maybe (m ())
irwim Pat Type
Pat (LetDec SOACS)
res_pat SubExp
w Commutativity
comm' Lambda SOACS
red_fun ([(SubExp, VName)] -> Maybe (BuilderT SOACS DistribM ()))
-> [(SubExp, VName)] -> Maybe (BuilderT SOACS DistribM ())
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [VName] -> [(SubExp, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
nes [VName]
arrs = do
      types <- (Scope GPU -> Scope SOACS) -> DistribM (Scope SOACS)
forall a. (Scope GPU -> a) -> DistribM a
forall rep (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope Scope GPU -> Scope SOACS
scopeForSOACs
      stms <- fst <$> runBuilderT (simplifyStms =<< collectStms_ (auxing aux do_irwim)) types
      transformStms path $ stmsToList stms
transformStm KernelPath
path (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form)))
  | Just ([Reduce SOACS]
reds, Lambda SOACS
map_lam) <- ScremaForm SOACS -> Maybe ([Reduce SOACS], Lambda SOACS)
forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm SOACS
form = do
      let paralleliseOuter :: DistribM (Stms GPU)
paralleliseOuter = Builder GPU () -> DistribM (Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Builder GPU () -> DistribM (Stms GPU))
-> Builder GPU () -> DistribM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
            red_ops <- [Reduce SOACS]
-> (Reduce SOACS
    -> BuilderT GPU (State VNameSource) (SegBinOp GPU))
-> BuilderT GPU (State VNameSource) [SegBinOp GPU]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Reduce SOACS]
reds ((Reduce SOACS -> BuilderT GPU (State VNameSource) (SegBinOp GPU))
 -> BuilderT GPU (State VNameSource) [SegBinOp GPU])
-> (Reduce SOACS
    -> BuilderT GPU (State VNameSource) (SegBinOp GPU))
-> BuilderT GPU (State VNameSource) [SegBinOp GPU]
forall a b. (a -> b) -> a -> b
$ \(Reduce Commutativity
comm Lambda SOACS
red_lam [SubExp]
nes) -> do
              (red_lam', nes', shape) <- Lambda SOACS
-> [SubExp]
-> BuilderT GPU (State VNameSource) (Lambda SOACS, [SubExp], Shape)
forall (m :: * -> *).
MonadBuilder m =>
Lambda SOACS -> [SubExp] -> m (Lambda SOACS, [SubExp], Shape)
determineReduceOp Lambda SOACS
red_lam [SubExp]
nes
              let comm'
                    | Lambda SOACS -> Bool
forall rep. Lambda rep -> Bool
commutativeLambda Lambda SOACS
red_lam' = Commutativity
Commutative
                    | Bool
otherwise = Commutativity
comm
                  red_lam'' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
red_lam'
              pure $ SegBinOp comm' red_lam'' nes' shape
            let map_lam_sequential = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
map_lam
            lvl <- segThreadCapped [w] "segred" $ NoRecommendation SegNoVirt
            addStms . fmap (certify cs)
              =<< nonSegRed lvl pat w red_ops map_lam_sequential arrs

          outerParallelBody :: DistribM (Body GPU)
outerParallelBody =
            Body GPU -> DistribM (Body GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody
              (Body GPU -> DistribM (Body GPU))
-> DistribM (Body GPU) -> DistribM (Body GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Stms GPU -> Result -> Body GPU
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody (Stms GPU -> Result -> Body GPU)
-> DistribM (Stms GPU) -> DistribM (Result -> Body GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DistribM (Stms GPU)
paralleliseOuter DistribM (Result -> Body GPU)
-> DistribM Result -> DistribM (Body GPU)
forall a b. DistribM (a -> b) -> DistribM a -> DistribM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistribM Result
forall a. a -> DistribM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([VName] -> Result
varsRes (Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
Pat (LetDec SOACS)
pat)))

          paralleliseInner :: KernelPath -> DistribM (Stms GPU)
paralleliseInner KernelPath
path' = do
            (mapstm, redstm) <-
              Pat (LetDec SOACS)
-> (SubExp, [Reduce SOACS], Lambda SOACS, [VName])
-> DistribM (Stm SOACS, Stm SOACS)
forall (m :: * -> *) rep.
(MonadFreshNames m, Buildable rep, ExpDec rep ~ (),
 Op rep ~ SOAC rep) =>
Pat (LetDec rep)
-> (SubExp, [Reduce rep], Lambda rep, [VName])
-> m (Stm rep, Stm rep)
redomapToMapAndReduce Pat (LetDec SOACS)
pat (SubExp
w, [Reduce SOACS]
reds, Lambda SOACS
map_lam, [VName]
arrs)
            types <- asksScope scopeForSOACs
            transformStms path' . stmsToList <=< (`runBuilderT_` types) $
              addStms =<< simplifyStms (stmsFromList [certify cs mapstm, certify cs redstm])

          innerParallelBody :: KernelPath -> DistribM (Body GPU)
innerParallelBody KernelPath
path' =
            Body GPU -> DistribM (Body GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody
              (Body GPU -> DistribM (Body GPU))
-> DistribM (Body GPU) -> DistribM (Body GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Stms GPU -> Result -> Body GPU
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody (Stms GPU -> Result -> Body GPU)
-> DistribM (Stms GPU) -> DistribM (Result -> Body GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> DistribM (Stms GPU)
paralleliseInner KernelPath
path' DistribM (Result -> Body GPU)
-> DistribM Result -> DistribM (Body GPU)
forall a b. DistribM (a -> b) -> DistribM a -> DistribM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistribM Result
forall a. a -> DistribM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([VName] -> Result
varsRes (Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
Pat (LetDec SOACS)
pat)))

      KernelPath
-> Pat Type
-> StmAux ()
-> SubExp
-> Lambda SOACS
-> DistribM (Stms GPU)
-> DistribM (Body GPU)
-> (KernelPath -> DistribM (Body GPU))
-> DistribM (Stms GPU)
versionScanRed KernelPath
path Pat Type
Pat (LetDec SOACS)
pat StmAux ()
StmAux (ExpDec SOACS)
aux SubExp
w Lambda SOACS
map_lam DistribM (Stms GPU)
paralleliseOuter DistribM (Body GPU)
outerParallelBody KernelPath -> DistribM (Body GPU)
innerParallelBody
  where
    cs :: Certs
cs = StmAux () -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux
transformStm KernelPath
path (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form))) = do
  -- This screma is too complicated for us to immediately do
  -- anything, so split it up and try again.
  scope <- (Scope GPU -> Scope SOACS) -> DistribM (Scope SOACS)
forall a. (Scope GPU -> a) -> DistribM a
forall rep (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope Scope GPU -> Scope SOACS
scopeForSOACs
  transformStms path . map (certify (stmAuxCerts aux)) . stmsToList . snd
    =<< runBuilderT (dissectScrema pat w form arrs) scope
transformStm KernelPath
path (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
_ (Op (Stream SubExp
w [VName]
arrs [SubExp]
nes Lambda SOACS
fold_fun))) = do
  -- Remove the stream and leave the body parallel.  It will be
  -- distributed.
  types <- (Scope GPU -> Scope SOACS) -> DistribM (Scope SOACS)
forall a. (Scope GPU -> a) -> DistribM a
forall rep (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope Scope GPU -> Scope SOACS
scopeForSOACs
  transformStms path . stmsToList . snd
    =<< runBuilderT (sequentialStreamWholeArray pat w nes fold_fun arrs) types
--
-- When we are scattering into a multidimensional array, we want to
-- fully parallelise, such that we do not have threads writing
-- potentially large rows. We do this by fissioning the scatter into a
-- map part and a scatter part, where the former is flattened as
-- usual, and the latter has a thread per primitive element to be
-- written.
--
-- TODO: this could be slightly smarter. If we are dealing with a
-- horizontally fused Scatter that targets both single- and
-- multi-dimensional arrays, we could handle the former in the map
-- stage. This would save us from having to store all the intermediate
-- results to memory. Troels suspects such cases are very rare, but
-- they may appear some day.
transformStm KernelPath
path (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op (Scatter SubExp
w [VName]
arrs ScatterSpec VName
as Lambda SOACS
lam)))
  | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([Type] -> Bool) -> [Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
lam = do
      -- Produce map stage.
      map_pat <- ([PatElem Type] -> Pat Type)
-> DistribM [PatElem Type] -> DistribM (Pat Type)
forall a b. (a -> b) -> DistribM a -> DistribM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat (DistribM [PatElem Type] -> DistribM (Pat Type))
-> DistribM [PatElem Type] -> DistribM (Pat Type)
forall a b. (a -> b) -> a -> b
$ [Type]
-> (Type -> DistribM (PatElem Type)) -> DistribM [PatElem Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
lam) ((Type -> DistribM (PatElem Type)) -> DistribM [PatElem Type])
-> (Type -> DistribM (PatElem Type)) -> DistribM [PatElem Type]
forall a b. (a -> b) -> a -> b
$ \Type
t ->
        VName -> Type -> PatElem Type
forall dec. VName -> dec -> PatElem dec
PatElem (VName -> Type -> PatElem Type)
-> DistribM VName -> DistribM (Type -> PatElem Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char] -> DistribM VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"scatter_tmp" DistribM (Type -> PatElem Type)
-> DistribM Type -> DistribM (PatElem Type)
forall a b. DistribM (a -> b) -> DistribM a -> DistribM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> DistribM Type
forall a. a -> DistribM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type
t Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
w)
      map_stms <- onMap path $ MapLoop map_pat aux w lam arrs

      -- Now do the scatters.
      runBuilder_ $ do
        addStms map_stms
        zipWithM_ doScatter (patElems pat) $ groupScatterResults as $ patNames map_pat
  where
    -- Generate code for a scatter where each thread writes only a scalar.
    doScatter :: PatElem Type
-> (Shape, VName, [([VName], VName)]) -> Builder GPU ()
doScatter PatElem Type
res_pe (Shape
scatter_space, VName
arr, [([VName], VName)]
is_vs) = do
      kernel_i <- [Char] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"write_i"
      arr_t <- lookupType arr
      val_t <- stripArray (shapeRank scatter_space) <$> lookupType arr
      val_is <- replicateM (arrayRank val_t) (newVName "val_i")
      (kret, kstms) <- collectStms $ do
        is_vs' <- forM is_vs $ \([VName]
is, VName
v) -> do
          v' <- [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp (VName -> [Char]
baseString VName
v [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_elem") (Exp (Rep (BuilderT GPU (State VNameSource)))
 -> BuilderT GPU (State VNameSource) SubExp)
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource))))
-> BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
v (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ (VName -> DimIndex SubExp) -> [VName] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> (VName -> SubExp) -> VName -> DimIndex SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) ([VName] -> [DimIndex SubExp]) -> [VName] -> [DimIndex SubExp]
forall a b. (a -> b) -> a -> b
$ VName
kernel_i VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
val_is
          is' <- forM is $ \VName
i' ->
            [Char]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp (VName -> [Char]
baseString VName
i' [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_i") (Exp (Rep (BuilderT GPU (State VNameSource)))
 -> BuilderT GPU (State VNameSource) SubExp)
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource))))
-> BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
i' (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
kernel_i]
          pure (Slice $ map DimFix $ is' <> map Var val_is, v')
        pure $ WriteReturns mempty arr is_vs'
      (kernel, stms) <-
        mapKernel
          segThreadCapped
          ((kernel_i, w) : zip val_is (arrayDims val_t))
          mempty
          [arr_t]
          (KernelBody () kstms [kret])
      addStms stms
      letBind (Pat [res_pe]) $ Op $ SegOp kernel
--
transformStm KernelPath
_ (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op (Scatter SubExp
w [VName]
ivs ScatterSpec VName
as Lambda SOACS
lam))) = Builder GPU () -> DistribM (Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Builder GPU () -> DistribM (Stms GPU))
-> Builder GPU () -> DistribM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
  let lam' :: Lambda GPU
lam' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
lam
  write_i <- [Char] -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"write_i"
  let krets = do
        (_a_w, a, is_vs) <- ScatterSpec VName
-> Result -> [(Shape, VName, [(Result, SubExpRes)])]
forall array a.
ScatterSpec array -> [a] -> [(Shape, array, [([a], a)])]
groupScatterResults ScatterSpec VName
as (Result -> [(Shape, VName, [(Result, SubExpRes)])])
-> Result -> [(Shape, VName, [(Result, SubExpRes)])]
forall a b. (a -> b) -> a -> b
$ Body GPU -> Result
forall rep. Body rep -> Result
bodyResult (Body GPU -> Result) -> Body GPU -> Result
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
lam'
        let res_cs =
              ((Result, SubExpRes) -> Certs) -> [(Result, SubExpRes)] -> Certs
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap ((SubExpRes -> Certs) -> Result -> Certs
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap SubExpRes -> Certs
resCerts (Result -> Certs)
-> ((Result, SubExpRes) -> Result) -> (Result, SubExpRes) -> Certs
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Result, SubExpRes) -> Result
forall a b. (a, b) -> a
fst) [(Result, SubExpRes)]
is_vs
                Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> ((Result, SubExpRes) -> Certs) -> [(Result, SubExpRes)] -> Certs
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (SubExpRes -> Certs
resCerts (SubExpRes -> Certs)
-> ((Result, SubExpRes) -> SubExpRes)
-> (Result, SubExpRes)
-> Certs
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Result, SubExpRes) -> SubExpRes
forall a b. (a, b) -> b
snd) [(Result, SubExpRes)]
is_vs
            is_vs' = [([DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> DimIndex SubExp) -> Result -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> (SubExpRes -> SubExp) -> SubExpRes -> DimIndex SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) Result
is, SubExpRes -> SubExp
resSubExp SubExpRes
v) | (Result
is, SubExpRes
v) <- [(Result, SubExpRes)]
is_vs]
        pure $ WriteReturns res_cs a is_vs'
      body = BodyDec GPU -> Stms GPU -> [KernelResult] -> KernelBody GPU
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () (Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms (Body GPU -> Stms GPU) -> Body GPU -> Stms GPU
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
lam') [KernelResult]
krets
      inputs = do
        (p, p_a) <- [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda GPU -> [LParam GPU]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
lam') [VName]
ivs
        pure $ KernelInput (paramName p) (paramType p) p_a [Var write_i]
  (kernel, stms) <-
    mapKernel
      segThreadCapped
      [(write_i, w)]
      inputs
      (patTypes pat)
      body
  certifying (stmAuxCerts aux) $ do
    addStms stms
    letBind pat $ Op $ SegOp kernel
transformStm KernelPath
_ (Let Pat (LetDec SOACS)
orig_pat StmAux (ExpDec SOACS)
aux (Op (Hist SubExp
w [VName]
imgs [HistOp SOACS]
ops Lambda SOACS
bucket_fun))) = do
  let bfun' :: Lambda GPU
bfun' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
bucket_fun

  -- It is important not to launch unnecessarily many threads for
  -- histograms, because it may mean we unnecessarily need to reduce
  -- subhistograms as well.
  Builder GPU () -> DistribM (Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Builder GPU () -> DistribM (Stms GPU))
-> Builder GPU () -> DistribM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
    lvl <- MkSegLevel GPU (State VNameSource)
forall (m :: * -> *). MonadFreshNames m => MkSegLevel GPU m
segThreadCapped [SubExp
w] [Char]
"seghist" (ThreadRecommendation
 -> BuilderT GPU (State VNameSource) (SegOpLevel GPU))
-> ThreadRecommendation
-> BuilderT GPU (State VNameSource) (SegOpLevel GPU)
forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
    addStms =<< histKernel onLambda lvl orig_pat [] [] (stmAuxCerts aux) w ops bfun' imgs
  where
    onLambda :: Lambda SOACS -> BuilderT GPU (State VNameSource) (Lambda GPU)
onLambda = Lambda GPU -> BuilderT GPU (State VNameSource) (Lambda GPU)
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda GPU -> BuilderT GPU (State VNameSource) (Lambda GPU))
-> (Lambda SOACS -> Lambda GPU)
-> Lambda SOACS
-> BuilderT GPU (State VNameSource) (Lambda GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda SOACS -> Lambda GPU
soacsLambdaToGPU
transformStm KernelPath
_ Stm SOACS
stm =
  Builder GPU () -> DistribM (Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Builder GPU () -> DistribM (Stms GPU))
-> Builder GPU () -> DistribM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Builder GPU ()
forall (m :: * -> *).
(Transformer m, LetDec (Rep m) ~ LetDec SOACS) =>
Stm SOACS -> m ()
FOT.transformStmRecursively Stm SOACS
stm

sufficientParallelism ::
  String ->
  [SubExp] ->
  KernelPath ->
  Maybe Int64 ->
  DistribM ((SubExp, Name), Stms GPU)
sufficientParallelism :: [Char]
-> [SubExp]
-> KernelPath
-> Maybe Int64
-> DistribM ((SubExp, Name), Stms GPU)
sufficientParallelism [Char]
desc [SubExp]
ws KernelPath
path Maybe Int64
def =
  [Char]
-> SizeClass -> [SubExp] -> DistribM ((SubExp, Name), Stms GPU)
cmpSizeLe [Char]
desc (KernelPath -> Maybe Int64 -> SizeClass
SizeThreshold KernelPath
path Maybe Int64
def) [SubExp]
ws

-- | Intra-group parallelism is worthwhile if the lambda contains more
-- than one instance of non-map nested parallelism, or any nested
-- parallelism inside a loop.
worthIntrablock :: Lambda SOACS -> Bool
worthIntrablock :: Lambda SOACS -> Bool
worthIntrablock Lambda SOACS
lam = Body SOACS -> Int
forall {rep}. (OpC rep ~ SOAC) => Body rep -> Int
bodyInterest (Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1
  where
    bodyInterest :: Body rep -> Int
bodyInterest Body rep
body =
      Seq Int -> Int
forall a. Num a => Seq a -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (Seq Int -> Int) -> Seq Int -> Int
forall a b. (a -> b) -> a -> b
$ Stm rep -> Int
interest (Stm rep -> Int) -> Seq (Stm rep) -> Seq Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body rep -> Seq (Stm rep)
forall rep. Body rep -> Stms rep
bodyStms Body rep
body
    interest :: Stm rep -> Int
interest Stm rep
stm
      | Attr
"sequential" Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs =
          Int
0 :: Int
      | Op (Screma SubExp
w [VName]
_ ScremaForm rep
form) <- Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm,
        Just Lambda rep
lam' <- ScremaForm rep -> Maybe (Lambda rep)
forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm rep
form =
          SubExp -> Lambda rep -> Int
mapLike SubExp
w Lambda rep
lam'
      | Op (Scatter SubExp
w [VName]
_ ScatterSpec VName
_ Lambda rep
lam') <- Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm =
          SubExp -> Lambda rep -> Int
mapLike SubExp
w Lambda rep
lam'
      | Loop [(FParam rep, SubExp)]
_ LoopForm
_ Body rep
body <- Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm =
          Body rep -> Int
bodyInterest Body rep
body Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
10
      | Match [SubExp]
_ [Case (Body rep)]
cases Body rep
defbody MatchDec (BranchType rep)
_ <- Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm =
          (Int -> Int -> Int) -> Int -> [Int] -> Int
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
            Int -> Int -> Int
forall a. Ord a => a -> a -> a
max
            (Body rep -> Int
bodyInterest Body rep
defbody)
            ((Case (Body rep) -> Int) -> [Case (Body rep)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Body rep -> Int
bodyInterest (Body rep -> Int)
-> (Case (Body rep) -> Body rep) -> Case (Body rep) -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Case (Body rep) -> Body rep
forall body. Case body -> body
caseBody) [Case (Body rep)]
cases)
      | Op (Screma SubExp
w [VName]
_ (ScremaForm Lambda rep
lam' [Scan rep]
_ [Reduce rep]
_)) <- Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm =
          SubExp -> Int
forall {a}. Num a => SubExp -> a
zeroIfTooSmall SubExp
w Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Body rep -> Int
bodyInterest (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam')
      | Op (Stream SubExp
_ [VName]
_ [SubExp]
_ Lambda rep
lam') <- Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm =
          Body rep -> Int
bodyInterest (Body rep -> Int) -> Body rep -> Int
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam'
      | Bool
otherwise =
          Int
0
      where
        attrs :: Attrs
attrs = StmAux (ExpDec rep) -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs (StmAux (ExpDec rep) -> Attrs) -> StmAux (ExpDec rep) -> Attrs
forall a b. (a -> b) -> a -> b
$ Stm rep -> StmAux (ExpDec rep)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm rep
stm
        sequential_inner :: Bool
sequential_inner = Attr
"sequential_inner" Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs

        zeroIfTooSmall :: SubExp -> a
zeroIfTooSmall (Constant (IntValue IntValue
x))
          | IntValue -> Int64
intToInt64 IntValue
x Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
< Int64
32 = a
0
        zeroIfTooSmall SubExp
_ = a
1

        mapLike :: SubExp -> Lambda rep -> Int
mapLike SubExp
w Lambda rep
lam' =
          if Bool
sequential_inner
            then Int
0
            else Int -> Int -> Int
forall a. Ord a => a -> a -> a
max (SubExp -> Int
forall {a}. Num a => SubExp -> a
zeroIfTooSmall SubExp
w) (Body rep -> Int
bodyInterest (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam'))

-- | A lambda is worth sequentialising if it contains enough nested
-- parallelism of an interesting kind.
worthSequentialising :: Lambda SOACS -> Bool
worthSequentialising :: Lambda SOACS -> Bool
worthSequentialising Lambda SOACS
lam = Int -> Body SOACS -> Int
forall {rep} {a}.
(OpC rep ~ SOAC, Num a, Eq a) =>
a -> Body rep -> Int
bodyInterest (Int
0 :: Int) (Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1
  where
    bodyInterest :: a -> Body rep -> Int
bodyInterest a
depth Body rep
body =
      Seq Int -> Int
forall a. Num a => Seq a -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (Seq Int -> Int) -> Seq Int -> Int
forall a b. (a -> b) -> a -> b
$ a -> Stm rep -> Int
interest a
depth (Stm rep -> Int) -> Seq (Stm rep) -> Seq Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body rep -> Seq (Stm rep)
forall rep. Body rep -> Stms rep
bodyStms Body rep
body
    interest :: a -> Stm rep -> Int
interest a
depth Stm rep
stm
      | Attr
"sequential" Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs =
          Int
0 :: Int
      | Op (Screma SubExp
_ [VName]
_ form :: ScremaForm rep
form@(ScremaForm Lambda rep
lam' [Scan rep]
_ [Reduce rep]
_)) <- Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm,
        Maybe (Lambda rep) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Lambda rep) -> Bool) -> Maybe (Lambda rep) -> Bool
forall a b. (a -> b) -> a -> b
$ ScremaForm rep -> Maybe (Lambda rep)
forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm rep
form =
          if Bool
sequential_inner
            then Int
0
            else a -> Body rep -> Int
bodyInterest (a
depth a -> a -> a
forall a. Num a => a -> a -> a
+ a
1) (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam')
      | Op Scatter {} <- Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm =
          Int
0 -- Basically a map.
      | Loop [(FParam rep, SubExp)]
_ ForLoop {} Body rep
body <- Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm =
          a -> Body rep -> Int
bodyInterest (a
depth a -> a -> a
forall a. Num a => a -> a -> a
+ a
1) Body rep
body Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
10
      | WithAcc [WithAccInput rep]
_ Lambda rep
withacc_lam <- Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm =
          a -> Body rep -> Int
bodyInterest (a
depth a -> a -> a
forall a. Num a => a -> a -> a
+ a
1) (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
withacc_lam)
      | Op (Screma SubExp
_ [VName]
_ form :: ScremaForm rep
form@(ScremaForm Lambda rep
lam' [Scan rep]
_ [Reduce rep]
_)) <- Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm =
          Int
1
            Int -> Int -> Int
forall a. Num a => a -> a -> a
+ a -> Body rep -> Int
bodyInterest (a
depth a -> a -> a
forall a. Num a => a -> a -> a
+ a
1) (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam')
            Int -> Int -> Int
forall a. Num a => a -> a -> a
+
            -- Give this a bigger score if it's a redomap just inside
            -- the the outer lambda, as these are often tileable and
            -- thus benefit more from sequentialisation.
            case (ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm rep
form, a
depth) of
              (Just ([Reduce rep], Lambda rep)
_, a
0) -> Int
1
              (Maybe ([Reduce rep], Lambda rep), a)
_ -> Int
0
      | Bool
otherwise =
          Int
0
      where
        attrs :: Attrs
attrs = StmAux (ExpDec rep) -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs (StmAux (ExpDec rep) -> Attrs) -> StmAux (ExpDec rep) -> Attrs
forall a b. (a -> b) -> a -> b
$ Stm rep -> StmAux (ExpDec rep)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm rep
stm
        sequential_inner :: Bool
sequential_inner = Attr
"sequential_inner" Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs

onTopLevelStms ::
  KernelPath ->
  Stms SOACS ->
  DistNestT GPU DistribM GPUStms
onTopLevelStms :: KernelPath -> Stms SOACS -> DistNestT GPU DistribM (Stms GPU)
onTopLevelStms KernelPath
path Stms SOACS
stms =
  DistribM (Stms GPU) -> DistNestT GPU DistribM (Stms GPU)
forall rep (m :: * -> *) a.
(LocalScope rep m, DistRep rep) =>
m a -> DistNestT rep m a
liftInner (DistribM (Stms GPU) -> DistNestT GPU DistribM (Stms GPU))
-> DistribM (Stms GPU) -> DistNestT GPU DistribM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ KernelPath -> [Stm SOACS] -> DistribM (Stms GPU)
transformStms KernelPath
path ([Stm SOACS] -> DistribM (Stms GPU))
-> [Stm SOACS] -> DistribM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms SOACS
stms

onMap :: KernelPath -> MapLoop -> DistribM GPUStms
onMap :: KernelPath -> MapLoop -> DistribM (Stms GPU)
onMap KernelPath
path (MapLoop Pat Type
pat StmAux ()
aux SubExp
w Lambda SOACS
lam [VName]
arrs) = do
  types <- DistribM (Scope GPU)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  let loopnest = Pat Type
-> StmAux () -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting Pat Type
pat StmAux ()
aux SubExp
w ([(Param Type, VName)] -> LoopNesting)
-> [(Param Type, VName)] -> LoopNesting
forall a b. (a -> b) -> a -> b
$ [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam) [VName]
arrs
      env KernelPath
path' =
        DistEnv
          { distNest :: Nestings
distNest = Nesting -> Nestings
singleNesting (Names -> LoopNesting -> Nesting
Nesting Names
forall a. Monoid a => a
mempty LoopNesting
loopnest),
            distScope :: Scope GPU
distScope =
              Pat Type -> Scope GPU
forall rep dec. (LetDec rep ~ dec) => Pat dec -> Scope rep
scopeOfPat Pat Type
pat
                Scope GPU -> Scope GPU -> Scope GPU
forall a. Semigroup a => a -> a -> a
<> Scope SOACS -> Scope GPU
scopeForGPU (Lambda SOACS -> Scope SOACS
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Lambda SOACS
lam)
                Scope GPU -> Scope GPU -> Scope GPU
forall a. Semigroup a => a -> a -> a
<> Scope GPU
types,
            distOnInnerMap :: MapLoop -> DistAcc GPU -> DistNestT GPU DistribM (DistAcc GPU)
distOnInnerMap = KernelPath
-> MapLoop -> DistAcc GPU -> DistNestT GPU DistribM (DistAcc GPU)
onInnerMap KernelPath
path',
            distOnTopLevelStms :: Stms SOACS -> DistNestT GPU DistribM (Stms GPU)
distOnTopLevelStms = KernelPath -> Stms SOACS -> DistNestT GPU DistribM (Stms GPU)
onTopLevelStms KernelPath
path',
            distSegLevel :: MkSegLevel GPU DistribM
distSegLevel = MkSegLevel GPU DistribM
forall (m :: * -> *). MonadFreshNames m => MkSegLevel GPU m
segThreadCapped,
            distOnSOACSStms :: Stm SOACS -> BuilderT GPU (State VNameSource) (Stms GPU)
distOnSOACSStms = Stms GPU -> BuilderT GPU (State VNameSource) (Stms GPU)
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU -> BuilderT GPU (State VNameSource) (Stms GPU))
-> (Stm SOACS -> Stms GPU)
-> Stm SOACS
-> BuilderT GPU (State VNameSource) (Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm GPU -> Stms GPU
forall rep. Stm rep -> Stms rep
oneStm (Stm GPU -> Stms GPU)
-> (Stm SOACS -> Stm GPU) -> Stm SOACS -> Stms GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm SOACS -> Stm GPU
soacsStmToGPU,
            distOnSOACSLambda :: Lambda SOACS -> BuilderT GPU (State VNameSource) (Lambda GPU)
distOnSOACSLambda = Lambda GPU -> BuilderT GPU (State VNameSource) (Lambda GPU)
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda GPU -> BuilderT GPU (State VNameSource) (Lambda GPU))
-> (Lambda SOACS -> Lambda GPU)
-> Lambda SOACS
-> BuilderT GPU (State VNameSource) (Lambda GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda SOACS -> Lambda GPU
soacsLambdaToGPU
          }
      exploitInnerParallelism KernelPath
path' =
        DistEnv GPU DistribM
-> DistNestT GPU DistribM (DistAcc GPU) -> DistribM (Stms GPU)
forall (m :: * -> *) rep.
(MonadLogger m, DistRep rep) =>
DistEnv rep m -> DistNestT rep m (DistAcc rep) -> m (Stms rep)
runDistNestT (KernelPath -> DistEnv GPU DistribM
env KernelPath
path') (DistNestT GPU DistribM (DistAcc GPU) -> DistribM (Stms GPU))
-> DistNestT GPU DistribM (DistAcc GPU) -> DistribM (Stms GPU)
forall a b. (a -> b) -> a -> b
$
          DistAcc GPU -> Stms SOACS -> DistNestT GPU DistribM (DistAcc GPU)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> Stms SOACS -> DistNestT rep m (DistAcc rep)
distributeMapBodyStms DistAcc GPU
acc (Body SOACS -> Stms SOACS
forall rep. Body rep -> Stms rep
bodyStms (Body SOACS -> Stms SOACS) -> Body SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam)

  let exploitOuterParallelism KernelPath
path' = do
        let lam' :: Lambda GPU
lam' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
lam
        DistEnv GPU DistribM
-> DistNestT GPU DistribM (DistAcc GPU) -> DistribM (Stms GPU)
forall (m :: * -> *) rep.
(MonadLogger m, DistRep rep) =>
DistEnv rep m -> DistNestT rep m (DistAcc rep) -> m (Stms rep)
runDistNestT (KernelPath -> DistEnv GPU DistribM
env KernelPath
path') (DistNestT GPU DistribM (DistAcc GPU) -> DistribM (Stms GPU))
-> DistNestT GPU DistribM (DistAcc GPU) -> DistribM (Stms GPU)
forall a b. (a -> b) -> a -> b
$
          DistAcc GPU -> DistNestT GPU DistribM (DistAcc GPU)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> DistNestT rep m (DistAcc rep)
distribute (DistAcc GPU -> DistNestT GPU DistribM (DistAcc GPU))
-> DistAcc GPU -> DistNestT GPU DistribM (DistAcc GPU)
forall a b. (a -> b) -> a -> b
$
            Stms GPU -> DistAcc GPU -> DistAcc GPU
forall rep. Stms rep -> DistAcc rep -> DistAcc rep
addStmsToAcc (Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms (Body GPU -> Stms GPU) -> Body GPU -> Stms GPU
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
lam') DistAcc GPU
acc

  onMap' (newKernel loopnest) path exploitOuterParallelism exploitInnerParallelism pat lam
  where
    acc :: DistAcc GPU
acc =
      DistAcc
        { distTargets :: Targets
distTargets = Target -> Targets
singleTarget (Pat Type
pat, Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult (Body SOACS -> Result) -> Body SOACS -> Result
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam),
          distStms :: Stms GPU
distStms = Stms GPU
forall a. Monoid a => a
mempty
        }

onlyExploitIntra :: Attrs -> Bool
onlyExploitIntra :: Attrs -> Bool
onlyExploitIntra Attrs
attrs =
  Name -> [Attr] -> Attr
AttrComp Name
"incremental_flattening" [Attr
"only_intra"] Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs

mayExploitOuter :: Attrs -> Bool
mayExploitOuter :: Attrs -> Bool
mayExploitOuter Attrs
attrs =
  Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$
    Name -> [Attr] -> Attr
AttrComp Name
"incremental_flattening" [Attr
"no_outer"]
      Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs
      Bool -> Bool -> Bool
|| Name -> [Attr] -> Attr
AttrComp Name
"incremental_flattening" [Attr
"only_inner"]
        Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs

mayExploitIntra :: Attrs -> Bool
mayExploitIntra :: Attrs -> Bool
mayExploitIntra Attrs
attrs =
  Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$
    Name -> [Attr] -> Attr
AttrComp Name
"incremental_flattening" [Attr
"no_intra"]
      Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs
      Bool -> Bool -> Bool
|| Name -> [Attr] -> Attr
AttrComp Name
"incremental_flattening" [Attr
"only_inner"]
        Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs

-- The minimum amount of inner parallelism we require (by default) in
-- intra-group versions.  Less than this is usually pointless on a GPU
-- (but we allow tuning to change it).
intraMinInnerPar :: Int64
intraMinInnerPar :: Int64
intraMinInnerPar = Int64
32 -- One NVIDIA warp

onMap' ::
  KernelNest ->
  KernelPath ->
  (KernelPath -> DistribM (Stms GPU)) ->
  (KernelPath -> DistribM (Stms GPU)) ->
  Pat Type ->
  Lambda SOACS ->
  DistribM (Stms GPU)
onMap' :: KernelNest
-> KernelPath
-> (KernelPath -> DistribM (Stms GPU))
-> (KernelPath -> DistribM (Stms GPU))
-> Pat Type
-> Lambda SOACS
-> DistribM (Stms GPU)
onMap' KernelNest
loopnest KernelPath
path KernelPath -> DistribM (Stms GPU)
mk_seq_stms KernelPath -> DistribM (Stms GPU)
mk_par_stms Pat Type
pat Lambda SOACS
lam = do
  -- Some of the control flow here looks a bit convoluted because we
  -- are trying to avoid generating unneeded threshold parameters,
  -- which means we need to do all the pruning checks up front.

  types <- DistribM (Scope GPU)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope

  let only_intra = Attrs -> Bool
onlyExploitIntra (StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
aux)
      may_intra = Lambda SOACS -> Bool
worthIntrablock Lambda SOACS
lam Bool -> Bool -> Bool
&& Attrs -> Bool
mayExploitIntra Attrs
attrs

  intra <-
    if only_intra || may_intra
      then flip runReaderT types $ intrablockParallelise loopnest lam
      else pure Nothing

  case intra of
    Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
_ | Attr
"sequential_inner" Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs -> do
      seq_body <- Body GPU -> DistribM (Body GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody (Body GPU -> DistribM (Body GPU))
-> DistribM (Body GPU) -> DistribM (Body GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Stms GPU -> Result -> Body GPU
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody (Stms GPU -> Result -> Body GPU)
-> DistribM (Stms GPU) -> DistribM (Result -> Body GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> DistribM (Stms GPU)
mk_seq_stms KernelPath
path DistribM (Result -> Body GPU)
-> DistribM Result -> DistribM (Body GPU)
forall a b. DistribM (a -> b) -> DistribM a -> DistribM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistribM Result
forall a. a -> DistribM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res
      kernelAlternatives pat seq_body []
    --
    Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
Nothing
      | Bool -> Bool
not Bool
only_intra,
        Just DistribM (SubExp, Name, Stms GPU, Body GPU)
m <- Maybe (DistribM (SubExp, Name, Stms GPU, Body GPU))
mkSeqAlts -> do
          (outer_suff, outer_suff_key, outer_suff_stms, seq_body) <- DistribM (SubExp, Name, Stms GPU, Body GPU)
m
          par_body <-
            renameBody
              =<< mkBody
                <$> mk_par_stms ((outer_suff_key, False) : path)
                <*> pure res
          (outer_suff_stms <>) <$> kernelAlternatives pat par_body [(outer_suff, seq_body)]
      --
      | Bool
otherwise -> do
          par_body <- Body GPU -> DistribM (Body GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody (Body GPU -> DistribM (Body GPU))
-> DistribM (Body GPU) -> DistribM (Body GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Stms GPU -> Result -> Body GPU
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody (Stms GPU -> Result -> Body GPU)
-> DistribM (Stms GPU) -> DistribM (Result -> Body GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> DistribM (Stms GPU)
mk_par_stms KernelPath
path DistribM (Result -> Body GPU)
-> DistribM Result -> DistribM (Body GPU)
forall a b. DistribM (a -> b) -> DistribM a -> DistribM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistribM Result
forall a. a -> DistribM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res
          kernelAlternatives pat par_body []
    --
    Just intra' :: ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
intra'@((SubExp, SubExp)
_, SubExp
_, Log
log, Stms GPU
intra_prelude, Stms GPU
intra_stms)
      | Bool
only_intra -> do
          Log -> DistribM ()
forall (m :: * -> *). MonadLogger m => Log -> m ()
addLog Log
log
          group_par_body <- Body GPU -> DistribM (Body GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody (Body GPU -> DistribM (Body GPU))
-> Body GPU -> DistribM (Body GPU)
forall a b. (a -> b) -> a -> b
$ Stms GPU -> Result -> Body GPU
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody Stms GPU
intra_stms Result
res
          (intra_prelude <>) <$> kernelAlternatives pat group_par_body []
      --
      | Bool
otherwise -> do
          Log -> DistribM ()
forall (m :: * -> *). MonadLogger m => Log -> m ()
addLog Log
log

          case Maybe (DistribM (SubExp, Name, Stms GPU, Body GPU))
mkSeqAlts of
            Maybe (DistribM (SubExp, Name, Stms GPU, Body GPU))
Nothing -> do
              (group_par_body, intra_ok, intra_suff_key, intra_suff_stms) <-
                KernelPath
-> ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
-> DistribM (Body GPU, SubExp, Name, Stms GPU)
checkSuffIntraPar KernelPath
path ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
intra'

              par_body <-
                renameBody
                  =<< mkBody
                    <$> mk_par_stms ((intra_suff_key, False) : path)
                    <*> pure res

              (intra_suff_stms <>)
                <$> kernelAlternatives pat par_body [(intra_ok, group_par_body)]
            Just DistribM (SubExp, Name, Stms GPU, Body GPU)
m -> do
              (outer_suff, outer_suff_key, outer_suff_stms, seq_body) <- DistribM (SubExp, Name, Stms GPU, Body GPU)
m

              (group_par_body, intra_ok, intra_suff_key, intra_suff_stms) <-
                checkSuffIntraPar ((outer_suff_key, False) : path) intra'

              par_body <-
                renameBody
                  =<< mkBody
                    <$> mk_par_stms
                      ( [ (outer_suff_key, False),
                          (intra_suff_key, False)
                        ]
                          ++ path
                      )
                    <*> pure res

              ((outer_suff_stms <> intra_suff_stms) <>)
                <$> kernelAlternatives
                  pat
                  par_body
                  [(outer_suff, seq_body), (intra_ok, group_par_body)]
  where
    nest_ws :: [SubExp]
nest_ws = KernelNest -> [SubExp]
kernelNestWidths KernelNest
loopnest
    res :: Result
res = [VName] -> Result
varsRes ([VName] -> Result) -> [VName] -> Result
forall a b. (a -> b) -> a -> b
$ Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
pat
    aux :: StmAux ()
aux = LoopNesting -> StmAux ()
loopNestingAux (LoopNesting -> StmAux ()) -> LoopNesting -> StmAux ()
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
innermostKernelNesting KernelNest
loopnest
    attrs :: Attrs
attrs = StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
aux

    mkSeqAlts :: Maybe (DistribM (SubExp, Name, Stms GPU, Body GPU))
mkSeqAlts
      | Lambda SOACS -> Bool
worthSequentialising Lambda SOACS
lam,
        Attrs -> Bool
mayExploitOuter Attrs
attrs = DistribM (SubExp, Name, Stms GPU, Body GPU)
-> Maybe (DistribM (SubExp, Name, Stms GPU, Body GPU))
forall a. a -> Maybe a
Just (DistribM (SubExp, Name, Stms GPU, Body GPU)
 -> Maybe (DistribM (SubExp, Name, Stms GPU, Body GPU)))
-> DistribM (SubExp, Name, Stms GPU, Body GPU)
-> Maybe (DistribM (SubExp, Name, Stms GPU, Body GPU))
forall a b. (a -> b) -> a -> b
$ do
          ((outer_suff, outer_suff_key), outer_suff_stms) <- DistribM ((SubExp, Name), Stms GPU)
checkSuffOuterPar
          seq_body <-
            renameBody
              =<< mkBody
                <$> mk_seq_stms ((outer_suff_key, True) : path)
                <*> pure res
          pure (outer_suff, outer_suff_key, outer_suff_stms, seq_body)
      | Bool
otherwise =
          Maybe (DistribM (SubExp, Name, Stms GPU, Body GPU))
forall a. Maybe a
Nothing

    checkSuffOuterPar :: DistribM ((SubExp, Name), Stms GPU)
checkSuffOuterPar =
      [Char]
-> [SubExp]
-> KernelPath
-> Maybe Int64
-> DistribM ((SubExp, Name), Stms GPU)
sufficientParallelism [Char]
"suff_outer_par" [SubExp]
nest_ws KernelPath
path Maybe Int64
forall a. Maybe a
Nothing

    checkSuffIntraPar :: KernelPath
-> ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
-> DistribM (Body GPU, SubExp, Name, Stms GPU)
checkSuffIntraPar
      KernelPath
path'
      ((SubExp
_intra_min_par, SubExp
intra_avail_par), SubExp
tblock_size, Log
_, Stms GPU
intra_prelude, Stms GPU
intra_stms) = do
        -- We must check that all intra-group parallelism fits in a group.
        ((intra_ok, intra_suff_key), intra_suff_stms) <- do
          ((intra_suff, suff_key), check_suff_stms) <-
            [Char]
-> [SubExp]
-> KernelPath
-> Maybe Int64
-> DistribM ((SubExp, Name), Stms GPU)
sufficientParallelism
              [Char]
"suff_intra_par"
              [SubExp
intra_avail_par]
              KernelPath
path'
              (Int64 -> Maybe Int64
forall a. a -> Maybe a
Just Int64
intraMinInnerPar)

          runBuilder $ do
            addStms intra_prelude

            max_tblock_size <-
              letSubExp "max_tblock_size" $ Op $ SizeOp $ GetSizeMax SizeThreadBlock
            fits <-
              letSubExp "fits" $
                BasicOp $
                  CmpOp (CmpSle Int64) tblock_size max_tblock_size

            addStms check_suff_stms

            intra_ok <- letSubExp "intra_suff_and_fits" $ BasicOp $ BinOp LogAnd fits intra_suff
            pure (intra_ok, suff_key)

        group_par_body <- renameBody $ mkBody intra_stms res
        pure (group_par_body, intra_ok, intra_suff_key, intra_suff_stms)

removeUnusedMapResults ::
  Pat Type ->
  [SubExpRes] ->
  Lambda rep ->
  Maybe ([Int], Pat Type, Lambda rep)
removeUnusedMapResults :: forall rep.
Pat Type
-> Result -> Lambda rep -> Maybe ([Int], Pat Type, Lambda rep)
removeUnusedMapResults (Pat [PatElem Type]
pes) Result
res Lambda rep
lam = do
  let ([PatElem Type]
pes', Result
body_res) =
        [(PatElem Type, SubExpRes)] -> ([PatElem Type], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(PatElem Type, SubExpRes)] -> ([PatElem Type], Result))
-> [(PatElem Type, SubExpRes)] -> ([PatElem Type], Result)
forall a b. (a -> b) -> a -> b
$ ((PatElem Type, SubExpRes) -> Bool)
-> [(PatElem Type, SubExpRes)] -> [(PatElem Type, SubExpRes)]
forall a. (a -> Bool) -> [a] -> [a]
filter (PatElem Type -> Bool
used (PatElem Type -> Bool)
-> ((PatElem Type, SubExpRes) -> PatElem Type)
-> (PatElem Type, SubExpRes)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElem Type, SubExpRes) -> PatElem Type
forall a b. (a, b) -> a
fst) ([(PatElem Type, SubExpRes)] -> [(PatElem Type, SubExpRes)])
-> [(PatElem Type, SubExpRes)] -> [(PatElem Type, SubExpRes)]
forall a b. (a -> b) -> a -> b
$ [PatElem Type] -> Result -> [(PatElem Type, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem Type]
pes (Result -> [(PatElem Type, SubExpRes)])
-> Result -> [(PatElem Type, SubExpRes)]
forall a b. (a -> b) -> a -> b
$ Body rep -> Result
forall rep. Body rep -> Result
bodyResult (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam)
  perm <- (PatElem Type -> SubExp) -> [PatElem Type] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (PatElem Type -> VName) -> PatElem Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName) [PatElem Type]
pes' [SubExp] -> [SubExp] -> Maybe [Int]
forall a. Eq a => [a] -> [a] -> Maybe [Int]
`isPermutationOf` (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
res
  pure (perm, Pat pes', lam {lambdaBody = (lambdaBody lam) {bodyResult = body_res}})
  where
    used :: PatElem Type -> Bool
used PatElem Type
pe = PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName PatElem Type
pe VName -> Names -> Bool
`nameIn` Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
res

onInnerMap ::
  KernelPath ->
  MapLoop ->
  DistAcc GPU ->
  DistNestT GPU DistribM (DistAcc GPU)
onInnerMap :: KernelPath
-> MapLoop -> DistAcc GPU -> DistNestT GPU DistribM (DistAcc GPU)
onInnerMap KernelPath
path maploop :: MapLoop
maploop@(MapLoop Pat Type
pat StmAux ()
aux SubExp
w Lambda SOACS
lam [VName]
arrs) DistAcc GPU
acc
  | Lambda SOACS -> Bool
unbalancedLambda Lambda SOACS
lam,
    Lambda SOACS -> Bool
lambdaContainsParallelism Lambda SOACS
lam =
      Stm SOACS -> DistAcc GPU -> DistNestT GPU DistribM (DistAcc GPU)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc (MapLoop -> Stm SOACS
mapLoopStm MapLoop
maploop) DistAcc GPU
acc
  | Bool
otherwise =
      DistAcc GPU
-> Stm SOACS
-> DistNestT
     GPU
     DistribM
     (Maybe (PostStms GPU, Result, KernelNest, DistAcc GPU))
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc GPU
acc (MapLoop -> Stm SOACS
mapLoopStm MapLoop
maploop) DistNestT
  GPU
  DistribM
  (Maybe (PostStms GPU, Result, KernelNest, DistAcc GPU))
-> (Maybe (PostStms GPU, Result, KernelNest, DistAcc GPU)
    -> DistNestT GPU DistribM (DistAcc GPU))
-> DistNestT GPU DistribM (DistAcc GPU)
forall a b.
DistNestT GPU DistribM a
-> (a -> DistNestT GPU DistribM b) -> DistNestT GPU DistribM b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Just (PostStms GPU
post_kernels, Result
res, KernelNest
nest, DistAcc GPU
acc')
          | Just ([Int]
perm, Pat Type
pat', Lambda SOACS
lam') <- Pat Type
-> Result -> Lambda SOACS -> Maybe ([Int], Pat Type, Lambda SOACS)
forall rep.
Pat Type
-> Result -> Lambda rep -> Maybe ([Int], Pat Type, Lambda rep)
removeUnusedMapResults Pat Type
pat Result
res Lambda SOACS
lam -> do
              PostStms GPU -> DistNestT GPU DistribM ()
forall (m :: * -> *) rep.
Monad m =>
PostStms rep -> DistNestT rep m ()
addPostStms PostStms GPU
post_kernels
              [Int]
-> KernelNest
-> DistAcc GPU
-> Pat Type
-> Lambda SOACS
-> DistNestT GPU DistribM (DistAcc GPU)
forall {rep}.
[Int]
-> KernelNest
-> DistAcc rep
-> Pat Type
-> Lambda SOACS
-> DistNestT GPU DistribM (DistAcc rep)
multiVersion [Int]
perm KernelNest
nest DistAcc GPU
acc' Pat Type
pat' Lambda SOACS
lam'
        Maybe (PostStms GPU, Result, KernelNest, DistAcc GPU)
_ -> MapLoop -> DistAcc GPU -> DistNestT GPU DistribM (DistAcc GPU)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
distributeMap MapLoop
maploop DistAcc GPU
acc
  where
    discardTargets :: DistAcc rep -> DistAcc rep
discardTargets DistAcc rep
acc' =
      -- FIXME: work around bogus targets.
      DistAcc rep
acc' {distTargets = singleTarget (mempty, mempty)}

    -- GHC 9.2 loops without the type annotation.
    generate ::
      [Int] ->
      KernelNest ->
      Pat Type ->
      Lambda SOACS ->
      DistEnv GPU DistribM ->
      Scope GPU ->
      DistribM (Stms GPU)
    generate :: [Int]
-> KernelNest
-> Pat Type
-> Lambda SOACS
-> DistEnv GPU DistribM
-> Scope GPU
-> DistribM (Stms GPU)
generate [Int]
perm KernelNest
nest Pat Type
pat' Lambda SOACS
lam' DistEnv GPU DistribM
dist_env Scope GPU
extra_scope = Scope GPU -> DistribM (Stms GPU) -> DistribM (Stms GPU)
forall a. Scope GPU -> DistribM a -> DistribM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPU
extra_scope (DistribM (Stms GPU) -> DistribM (Stms GPU))
-> DistribM (Stms GPU) -> DistribM (Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
      let maploop' :: MapLoop
maploop' = Pat Type
-> StmAux () -> SubExp -> Lambda SOACS -> [VName] -> MapLoop
MapLoop Pat Type
pat' StmAux ()
aux SubExp
w Lambda SOACS
lam' [VName]
arrs

          exploitInnerParallelism :: KernelPath -> DistribM (Stms GPU)
exploitInnerParallelism KernelPath
path' = do
            let dist_env' :: DistEnv GPU DistribM
dist_env' =
                  DistEnv GPU DistribM
dist_env
                    { distOnTopLevelStms = onTopLevelStms path',
                      distOnInnerMap = onInnerMap path'
                    }
            DistEnv GPU DistribM
-> DistNestT GPU DistribM (DistAcc GPU) -> DistribM (Stms GPU)
forall (m :: * -> *) rep.
(MonadLogger m, DistRep rep) =>
DistEnv rep m -> DistNestT rep m (DistAcc rep) -> m (Stms rep)
runDistNestT DistEnv GPU DistribM
dist_env' (DistNestT GPU DistribM (DistAcc GPU) -> DistribM (Stms GPU))
-> (DistNestT GPU DistribM (DistAcc GPU)
    -> DistNestT GPU DistribM (DistAcc GPU))
-> DistNestT GPU DistribM (DistAcc GPU)
-> DistribM (Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelNest
-> DistNestT GPU DistribM (DistAcc GPU)
-> DistNestT GPU DistribM (DistAcc GPU)
forall (m :: * -> *) rep a.
(Monad m, DistRep rep) =>
KernelNest -> DistNestT rep m a -> DistNestT rep m a
inNesting KernelNest
nest (DistNestT GPU DistribM (DistAcc GPU)
 -> DistNestT GPU DistribM (DistAcc GPU))
-> (DistNestT GPU DistribM (DistAcc GPU)
    -> DistNestT GPU DistribM (DistAcc GPU))
-> DistNestT GPU DistribM (DistAcc GPU)
-> DistNestT GPU DistribM (DistAcc GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope GPU
-> DistNestT GPU DistribM (DistAcc GPU)
-> DistNestT GPU DistribM (DistAcc GPU)
forall a.
Scope GPU -> DistNestT GPU DistribM a -> DistNestT GPU DistribM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPU
extra_scope (DistNestT GPU DistribM (DistAcc GPU) -> DistribM (Stms GPU))
-> DistNestT GPU DistribM (DistAcc GPU) -> DistribM (Stms GPU)
forall a b. (a -> b) -> a -> b
$
              DistAcc GPU -> DistAcc GPU
forall {rep}. DistAcc rep -> DistAcc rep
discardTargets
                (DistAcc GPU -> DistAcc GPU)
-> DistNestT GPU DistribM (DistAcc GPU)
-> DistNestT GPU DistribM (DistAcc GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MapLoop -> DistAcc GPU -> DistNestT GPU DistribM (DistAcc GPU)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
distributeMap MapLoop
maploop' DistAcc GPU
acc {distStms = mempty}
      -- Normally the permutation is for the output pattern, but
      -- we can't really change that, so we change the result
      -- order instead.
      let lam_res' :: Result
lam_res' =
            [Int] -> Result -> Result
forall a. [Int] -> [a] -> [a]
rearrangeShape ([Int] -> [Int]
rearrangeInverse [Int]
perm) (Result -> Result) -> Result -> Result
forall a b. (a -> b) -> a -> b
$
              Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult (Body SOACS -> Result) -> Body SOACS -> Result
forall a b. (a -> b) -> a -> b
$
                Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam'
          lam'' :: Lambda SOACS
lam'' = Lambda SOACS
lam' {lambdaBody = (lambdaBody lam') {bodyResult = lam_res'}}
          map_nesting :: LoopNesting
map_nesting = Pat Type
-> StmAux () -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting Pat Type
pat' StmAux ()
aux SubExp
w ([(Param Type, VName)] -> LoopNesting)
-> [(Param Type, VName)] -> LoopNesting
forall a b. (a -> b) -> a -> b
$ [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam') [VName]
arrs
          nest' :: KernelNest
nest' = Target -> LoopNesting -> KernelNest -> KernelNest
pushInnerKernelNesting (Pat Type
pat', Result
lam_res') LoopNesting
map_nesting KernelNest
nest

      -- XXX: we do not construct a new KernelPath when
      -- sequentialising.  This is only OK as long as further
      -- versioning does not take place down that branch (it currently
      -- does not).
      (sequentialised_kernel, nestw_stms) <- Scope GPU
-> DistribM (Stm GPU, Stms GPU) -> DistribM (Stm GPU, Stms GPU)
forall a. Scope GPU -> DistribM a -> DistribM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPU
extra_scope (DistribM (Stm GPU, Stms GPU) -> DistribM (Stm GPU, Stms GPU))
-> DistribM (Stm GPU, Stms GPU) -> DistribM (Stm GPU, Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
        let sequentialised_lam :: Lambda GPU
sequentialised_lam = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
lam''
        MkSegLevel GPU DistribM
-> KernelNest -> Body GPU -> DistribM (Stm GPU, Stms GPU)
forall rep (m :: * -> *).
(DistRep rep, MonadFreshNames m, LocalScope rep m) =>
MkSegLevel rep m -> KernelNest -> Body rep -> m (Stm rep, Stms rep)
constructKernel MkSegLevel GPU DistribM
forall (m :: * -> *). MonadFreshNames m => MkSegLevel GPU m
segThreadCapped KernelNest
nest' (Body GPU -> DistribM (Stm GPU, Stms GPU))
-> Body GPU -> DistribM (Stm GPU, Stms GPU)
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
sequentialised_lam

      let outer_pat = LoopNesting -> Pat Type
loopNestingPat (LoopNesting -> Pat Type) -> LoopNesting -> Pat Type
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
nest
      (nestw_stms <>)
        <$> onMap'
          nest'
          path
          (const $ pure $ oneStm sequentialised_kernel)
          exploitInnerParallelism
          outer_pat
          lam''

    multiVersion :: [Int]
-> KernelNest
-> DistAcc rep
-> Pat Type
-> Lambda SOACS
-> DistNestT GPU DistribM (DistAcc rep)
multiVersion [Int]
perm KernelNest
nest DistAcc rep
acc' Pat Type
pat' Lambda SOACS
lam' = do
      -- The kernel can be distributed by itself, so now we can
      -- decide whether to just sequentialise, or exploit inner
      -- parallelism.
      dist_env <- DistNestT GPU DistribM (DistEnv GPU DistribM)
forall r (m :: * -> *). MonadReader r m => m r
ask
      let extra_scope = Targets -> Scope GPU
forall rep. DistRep rep => Targets -> Scope rep
targetsScope (Targets -> Scope GPU) -> Targets -> Scope GPU
forall a b. (a -> b) -> a -> b
$ DistAcc rep -> Targets
forall rep. DistAcc rep -> Targets
distTargets DistAcc rep
acc'

      stms <- liftInner $ generate perm nest pat' lam' dist_env extra_scope
      postStm stms
      pure acc'