{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -Wno-overlapping-patterns -Wno-incomplete-patterns -Wno-incomplete-uni-patterns -Wno-incomplete-record-updates #-}

module Futhark.Pass.ExtractKernels.DistributeNests
  ( MapLoop (..),
    mapLoopStm,
    bodyContainsParallelism,
    lambdaContainsParallelism,
    determineReduceOp,
    histKernel,
    DistEnv (..),
    DistAcc (..),
    runDistNestT,
    DistNestT,
    liftInner,
    distributeMap,
    distribute,
    distributeSingleStm,
    distributeMapBodyStms,
    addStmsToAcc,
    addStmToAcc,
    permutationAndMissing,
    addPostStms,
    postStm,
    inNesting,
  )
where

import Control.Arrow (first)
import Control.Monad
import Control.Monad.RWS.Strict
import Control.Monad.Reader
import Control.Monad.Trans.Maybe
import Control.Monad.Writer.Strict
import Data.List (find, partition, tails)
import Data.List.NonEmpty (NonEmpty (..))
import Data.Map qualified as M
import Data.Maybe
import Futhark.IR
import Futhark.IR.GPU.Op (SegVirt (..))
import Futhark.IR.SOACS (SOACS)
import Futhark.IR.SOACS qualified as SOACS
import Futhark.IR.SOACS.SOAC hiding (HistOp, histDest)
import Futhark.IR.SOACS.Simplify (simpleSOACS, simplifyStms)
import Futhark.IR.SegOp
import Futhark.MonadFreshNames
import Futhark.Pass.ExtractKernels.BlockedKernel
import Futhark.Pass.ExtractKernels.Distribution
import Futhark.Pass.ExtractKernels.ISRWIM
import Futhark.Pass.ExtractKernels.Interchange
import Futhark.Tools
import Futhark.Transform.CopyPropagate
import Futhark.Transform.FirstOrderTransform qualified as FOT
import Futhark.Transform.Rename
import Futhark.Util.Log

scopeForSOACs :: (SameScope rep SOACS) => Scope rep -> Scope SOACS
scopeForSOACs :: forall rep. SameScope rep SOACS => Scope rep -> Scope SOACS
scopeForSOACs = Scope rep -> Scope SOACS
forall fromrep torep.
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope

data MapLoop = MapLoop (Pat Type) (StmAux ()) SubExp (Lambda SOACS) [VName]

mapLoopStm :: MapLoop -> Stm SOACS
mapLoopStm :: MapLoop -> Stm SOACS
mapLoopStm (MapLoop Pat Type
pat StmAux ()
aux SubExp
w Lambda SOACS
lam [VName]
arrs) =
  Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec SOACS)
pat StmAux ()
StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ Op SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam

data DistEnv rep m = DistEnv
  { forall rep (m :: * -> *). DistEnv rep m -> Nestings
distNest :: Nestings,
    forall rep (m :: * -> *). DistEnv rep m -> Scope rep
distScope :: Scope rep,
    forall rep (m :: * -> *).
DistEnv rep m -> Stms SOACS -> DistNestT rep m (Stms rep)
distOnTopLevelStms :: Stms SOACS -> DistNestT rep m (Stms rep),
    forall rep (m :: * -> *).
DistEnv rep m
-> MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
distOnInnerMap ::
      MapLoop ->
      DistAcc rep ->
      DistNestT rep m (DistAcc rep),
    forall rep (m :: * -> *).
DistEnv rep m -> Stm SOACS -> Builder rep (Stms rep)
distOnSOACSStms :: Stm SOACS -> Builder rep (Stms rep),
    forall rep (m :: * -> *).
DistEnv rep m -> Lambda SOACS -> Builder rep (Lambda rep)
distOnSOACSLambda :: Lambda SOACS -> Builder rep (Lambda rep),
    forall rep (m :: * -> *). DistEnv rep m -> MkSegLevel rep m
distSegLevel :: MkSegLevel rep m
  }

data DistAcc rep = DistAcc
  { forall rep. DistAcc rep -> Targets
distTargets :: Targets,
    forall rep. DistAcc rep -> Stms rep
distStms :: Stms rep
  }

data DistRes rep = DistRes
  { forall rep. DistRes rep -> PostStms rep
accPostStms :: PostStms rep,
    forall rep. DistRes rep -> Log
accLog :: Log
  }

instance Semigroup (DistRes rep) where
  DistRes PostStms rep
ks1 Log
log1 <> :: DistRes rep -> DistRes rep -> DistRes rep
<> DistRes PostStms rep
ks2 Log
log2 =
    PostStms rep -> Log -> DistRes rep
forall rep. PostStms rep -> Log -> DistRes rep
DistRes (PostStms rep
ks1 PostStms rep -> PostStms rep -> PostStms rep
forall a. Semigroup a => a -> a -> a
<> PostStms rep
ks2) (Log
log1 Log -> Log -> Log
forall a. Semigroup a => a -> a -> a
<> Log
log2)

instance Monoid (DistRes rep) where
  mempty :: DistRes rep
mempty = PostStms rep -> Log -> DistRes rep
forall rep. PostStms rep -> Log -> DistRes rep
DistRes PostStms rep
forall a. Monoid a => a
mempty Log
forall a. Monoid a => a
mempty

newtype PostStms rep = PostStms {forall rep. PostStms rep -> Stms rep
unPostStms :: Stms rep}

instance Semigroup (PostStms rep) where
  PostStms Stms rep
xs <> :: PostStms rep -> PostStms rep -> PostStms rep
<> PostStms Stms rep
ys = Stms rep -> PostStms rep
forall rep. Stms rep -> PostStms rep
PostStms (Stms rep -> PostStms rep) -> Stms rep -> PostStms rep
forall a b. (a -> b) -> a -> b
$ Stms rep
ys Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> Stms rep
xs

instance Monoid (PostStms rep) where
  mempty :: PostStms rep
mempty = Stms rep -> PostStms rep
forall rep. Stms rep -> PostStms rep
PostStms Stms rep
forall a. Monoid a => a
mempty

typeEnvFromDistAcc :: (DistRep rep) => DistAcc rep -> Scope rep
typeEnvFromDistAcc :: forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc = Pat Type -> Scope rep
forall rep dec. (LetDec rep ~ dec) => Pat dec -> Scope rep
scopeOfPat (Pat Type -> Scope rep)
-> (DistAcc rep -> Pat Type) -> DistAcc rep -> Scope rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Pat Type, Result) -> Pat Type
forall a b. (a, b) -> a
fst ((Pat Type, Result) -> Pat Type)
-> (DistAcc rep -> (Pat Type, Result)) -> DistAcc rep -> Pat Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Targets -> (Pat Type, Result)
outerTarget (Targets -> (Pat Type, Result))
-> (DistAcc rep -> Targets) -> DistAcc rep -> (Pat Type, Result)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DistAcc rep -> Targets
forall rep. DistAcc rep -> Targets
distTargets

addStmsToAcc :: Stms rep -> DistAcc rep -> DistAcc rep
addStmsToAcc :: forall rep. Stms rep -> DistAcc rep -> DistAcc rep
addStmsToAcc Stms rep
stms DistAcc rep
acc =
  DistAcc rep
acc {distStms = stms <> distStms acc}

addStmToAcc ::
  (MonadFreshNames m, DistRep rep) =>
  Stm SOACS ->
  DistAcc rep ->
  DistNestT rep m (DistAcc rep)
addStmToAcc :: forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc = do
  onSoacs <- (DistEnv rep m -> Stm SOACS -> Builder rep (Stms rep))
-> DistNestT rep m (Stm SOACS -> Builder rep (Stms rep))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv rep m -> Stm SOACS -> Builder rep (Stms rep)
forall rep (m :: * -> *).
DistEnv rep m -> Stm SOACS -> Builder rep (Stms rep)
distOnSOACSStms
  (stm', _) <- runBuilder $ onSoacs stm
  pure acc {distStms = stm' <> distStms acc}

soacsLambda ::
  (MonadFreshNames m, DistRep rep) =>
  Lambda SOACS ->
  DistNestT rep m (Lambda rep)
soacsLambda :: forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Lambda SOACS -> DistNestT rep m (Lambda rep)
soacsLambda Lambda SOACS
lam = do
  onLambda <- (DistEnv rep m -> Lambda SOACS -> Builder rep (Lambda rep))
-> DistNestT rep m (Lambda SOACS -> Builder rep (Lambda rep))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv rep m -> Lambda SOACS -> Builder rep (Lambda rep)
forall rep (m :: * -> *).
DistEnv rep m -> Lambda SOACS -> Builder rep (Lambda rep)
distOnSOACSLambda
  fst <$> runBuilder (onLambda lam)

newtype DistNestT rep m a
  = DistNestT (ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) a)
  deriving
    ( (forall a b. (a -> b) -> DistNestT rep m a -> DistNestT rep m b)
-> (forall a b. a -> DistNestT rep m b -> DistNestT rep m a)
-> Functor (DistNestT rep m)
forall a b. a -> DistNestT rep m b -> DistNestT rep m a
forall a b. (a -> b) -> DistNestT rep m a -> DistNestT rep m b
forall rep (m :: * -> *) a b.
Functor m =>
a -> DistNestT rep m b -> DistNestT rep m a
forall rep (m :: * -> *) a b.
Functor m =>
(a -> b) -> DistNestT rep m a -> DistNestT rep m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall rep (m :: * -> *) a b.
Functor m =>
(a -> b) -> DistNestT rep m a -> DistNestT rep m b
fmap :: forall a b. (a -> b) -> DistNestT rep m a -> DistNestT rep m b
$c<$ :: forall rep (m :: * -> *) a b.
Functor m =>
a -> DistNestT rep m b -> DistNestT rep m a
<$ :: forall a b. a -> DistNestT rep m b -> DistNestT rep m a
Functor,
      Functor (DistNestT rep m)
Functor (DistNestT rep m) =>
(forall a. a -> DistNestT rep m a)
-> (forall a b.
    DistNestT rep m (a -> b) -> DistNestT rep m a -> DistNestT rep m b)
-> (forall a b c.
    (a -> b -> c)
    -> DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m c)
-> (forall a b.
    DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m b)
-> (forall a b.
    DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m a)
-> Applicative (DistNestT rep m)
forall a. a -> DistNestT rep m a
forall a b.
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m a
forall a b.
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m b
forall a b.
DistNestT rep m (a -> b) -> DistNestT rep m a -> DistNestT rep m b
forall a b c.
(a -> b -> c)
-> DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m c
forall rep (m :: * -> *).
Applicative m =>
Functor (DistNestT rep m)
forall rep (m :: * -> *) a. Applicative m => a -> DistNestT rep m a
forall rep (m :: * -> *) a b.
Applicative m =>
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m a
forall rep (m :: * -> *) a b.
Applicative m =>
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m b
forall rep (m :: * -> *) a b.
Applicative m =>
DistNestT rep m (a -> b) -> DistNestT rep m a -> DistNestT rep m b
forall rep (m :: * -> *) a b c.
Applicative m =>
(a -> b -> c)
-> DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall rep (m :: * -> *) a. Applicative m => a -> DistNestT rep m a
pure :: forall a. a -> DistNestT rep m a
$c<*> :: forall rep (m :: * -> *) a b.
Applicative m =>
DistNestT rep m (a -> b) -> DistNestT rep m a -> DistNestT rep m b
<*> :: forall a b.
DistNestT rep m (a -> b) -> DistNestT rep m a -> DistNestT rep m b
$cliftA2 :: forall rep (m :: * -> *) a b c.
Applicative m =>
(a -> b -> c)
-> DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m c
liftA2 :: forall a b c.
(a -> b -> c)
-> DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m c
$c*> :: forall rep (m :: * -> *) a b.
Applicative m =>
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m b
*> :: forall a b.
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m b
$c<* :: forall rep (m :: * -> *) a b.
Applicative m =>
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m a
<* :: forall a b.
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m a
Applicative,
      Applicative (DistNestT rep m)
Applicative (DistNestT rep m) =>
(forall a b.
 DistNestT rep m a -> (a -> DistNestT rep m b) -> DistNestT rep m b)
-> (forall a b.
    DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m b)
-> (forall a. a -> DistNestT rep m a)
-> Monad (DistNestT rep m)
forall a. a -> DistNestT rep m a
forall a b.
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m b
forall a b.
DistNestT rep m a -> (a -> DistNestT rep m b) -> DistNestT rep m b
forall rep (m :: * -> *). Monad m => Applicative (DistNestT rep m)
forall rep (m :: * -> *) a. Monad m => a -> DistNestT rep m a
forall rep (m :: * -> *) a b.
Monad m =>
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m b
forall rep (m :: * -> *) a b.
Monad m =>
DistNestT rep m a -> (a -> DistNestT rep m b) -> DistNestT rep m b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall rep (m :: * -> *) a b.
Monad m =>
DistNestT rep m a -> (a -> DistNestT rep m b) -> DistNestT rep m b
>>= :: forall a b.
DistNestT rep m a -> (a -> DistNestT rep m b) -> DistNestT rep m b
$c>> :: forall rep (m :: * -> *) a b.
Monad m =>
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m b
>> :: forall a b.
DistNestT rep m a -> DistNestT rep m b -> DistNestT rep m b
$creturn :: forall rep (m :: * -> *) a. Monad m => a -> DistNestT rep m a
return :: forall a. a -> DistNestT rep m a
Monad,
      MonadReader (DistEnv rep m),
      MonadWriter (DistRes rep)
    )

liftInner :: (LocalScope rep m, DistRep rep) => m a -> DistNestT rep m a
liftInner :: forall rep (m :: * -> *) a.
(LocalScope rep m, DistRep rep) =>
m a -> DistNestT rep m a
liftInner m a
m = do
  outer_scope <- DistNestT rep m (Scope rep)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  DistNestT $
    lift $
      lift $ do
        inner_scope <- askScope
        localScope (outer_scope `M.difference` inner_scope) m

instance (MonadFreshNames m) => MonadFreshNames (DistNestT rep m) where
  getNameSource :: DistNestT rep m VNameSource
getNameSource = ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) VNameSource
-> DistNestT rep m VNameSource
forall rep (m :: * -> *) a.
ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) a
-> DistNestT rep m a
DistNestT (ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) VNameSource
 -> DistNestT rep m VNameSource)
-> ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) VNameSource
-> DistNestT rep m VNameSource
forall a b. (a -> b) -> a -> b
$ WriterT (DistRes rep) m VNameSource
-> ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) VNameSource
forall (m :: * -> *) a.
Monad m =>
m a -> ReaderT (DistEnv rep m) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift WriterT (DistRes rep) m VNameSource
forall (m :: * -> *). MonadFreshNames m => m VNameSource
getNameSource
  putNameSource :: VNameSource -> DistNestT rep m ()
putNameSource = ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) ()
-> DistNestT rep m ()
forall rep (m :: * -> *) a.
ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) a
-> DistNestT rep m a
DistNestT (ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) ()
 -> DistNestT rep m ())
-> (VNameSource
    -> ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) ())
-> VNameSource
-> DistNestT rep m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WriterT (DistRes rep) m ()
-> ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) ()
forall (m :: * -> *) a.
Monad m =>
m a -> ReaderT (DistEnv rep m) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (WriterT (DistRes rep) m ()
 -> ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) ())
-> (VNameSource -> WriterT (DistRes rep) m ())
-> VNameSource
-> ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VNameSource -> WriterT (DistRes rep) m ()
forall (m :: * -> *). MonadFreshNames m => VNameSource -> m ()
putNameSource

instance (Monad m, ASTRep rep) => HasScope rep (DistNestT rep m) where
  askScope :: DistNestT rep m (Scope rep)
askScope = (DistEnv rep m -> Scope rep) -> DistNestT rep m (Scope rep)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv rep m -> Scope rep
forall rep (m :: * -> *). DistEnv rep m -> Scope rep
distScope

instance (Monad m, ASTRep rep) => LocalScope rep (DistNestT rep m) where
  localScope :: forall a. Scope rep -> DistNestT rep m a -> DistNestT rep m a
localScope Scope rep
types = (DistEnv rep m -> DistEnv rep m)
-> DistNestT rep m a -> DistNestT rep m a
forall a.
(DistEnv rep m -> DistEnv rep m)
-> DistNestT rep m a -> DistNestT rep m a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((DistEnv rep m -> DistEnv rep m)
 -> DistNestT rep m a -> DistNestT rep m a)
-> (DistEnv rep m -> DistEnv rep m)
-> DistNestT rep m a
-> DistNestT rep m a
forall a b. (a -> b) -> a -> b
$ \DistEnv rep m
env ->
    DistEnv rep m
env {distScope = types <> distScope env}

instance (Monad m) => MonadLogger (DistNestT rep m) where
  addLog :: Log -> DistNestT rep m ()
addLog Log
msgs = DistRes rep -> DistNestT rep m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell DistRes rep
forall a. Monoid a => a
mempty {accLog = msgs}

runDistNestT ::
  (MonadLogger m, DistRep rep) =>
  DistEnv rep m ->
  DistNestT rep m (DistAcc rep) ->
  m (Stms rep)
runDistNestT :: forall (m :: * -> *) rep.
(MonadLogger m, DistRep rep) =>
DistEnv rep m -> DistNestT rep m (DistAcc rep) -> m (Stms rep)
runDistNestT DistEnv rep m
env (DistNestT ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) (DistAcc rep)
m) = do
  (acc, res) <- WriterT (DistRes rep) m (DistAcc rep)
-> m (DistAcc rep, DistRes rep)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT (DistRes rep) m (DistAcc rep)
 -> m (DistAcc rep, DistRes rep))
-> WriterT (DistRes rep) m (DistAcc rep)
-> m (DistAcc rep, DistRes rep)
forall a b. (a -> b) -> a -> b
$ ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) (DistAcc rep)
-> DistEnv rep m -> WriterT (DistRes rep) m (DistAcc rep)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) (DistAcc rep)
m DistEnv rep m
env
  addLog $ accLog res
  -- There may be a few final targets remaining - these correspond to
  -- arrays that are identity mapped, and must have statements
  -- inserted here.
  pure $
    unPostStms (accPostStms res) <> identityStms (outerTarget $ distTargets acc)
  where
    outermost :: LoopNesting
outermost = Nesting -> LoopNesting
nestingLoop (Nesting -> LoopNesting) -> Nesting -> LoopNesting
forall a b. (a -> b) -> a -> b
$
      case DistEnv rep m -> Nestings
forall rep (m :: * -> *). DistEnv rep m -> Nestings
distNest DistEnv rep m
env of
        (Nesting
nest, []) -> Nesting
nest
        (Nesting
_, Nesting
nest : [Nesting]
_) -> Nesting
nest
    params_to_arrs :: [(VName, VName)]
params_to_arrs =
      ((Param Type, VName) -> (VName, VName))
-> [(Param Type, VName)] -> [(VName, VName)]
forall a b. (a -> b) -> [a] -> [b]
map ((Param Type -> VName) -> (Param Type, VName) -> (VName, VName)
forall b c d. (b -> c) -> (b, d) -> (c, d)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first Param Type -> VName
forall dec. Param dec -> VName
paramName) ([(Param Type, VName)] -> [(VName, VName)])
-> [(Param Type, VName)] -> [(VName, VName)]
forall a b. (a -> b) -> a -> b
$
        LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs LoopNesting
outermost

    identityStms :: (Pat Type, Result) -> Stms rep
identityStms (Pat Type
rem_pat, Result
res) =
      [Stm rep] -> Stms rep
forall rep. [Stm rep] -> Stms rep
stmsFromList ([Stm rep] -> Stms rep) -> [Stm rep] -> Stms rep
forall a b. (a -> b) -> a -> b
$ (PatElem Type -> SubExpRes -> Stm rep)
-> [PatElem Type] -> Result -> [Stm rep]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PatElem Type -> SubExpRes -> Stm rep
identityStm (Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
rem_pat) Result
res
    identityStm :: PatElem Type -> SubExpRes -> Stm rep
identityStm PatElem Type
pe (SubExpRes Certs
cs (Var VName
v))
      | Just VName
arr <- VName -> [(VName, VName)] -> Maybe VName
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v [(VName, VName)]
params_to_arrs =
          Certs -> Stm rep -> Stm rep
forall rep. Certs -> Stm rep -> Stm rep
certify Certs
cs (Stm rep -> Stm rep) -> (BasicOp -> Stm rep) -> BasicOp -> Stm rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type
pe]) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp rep -> Stm rep) -> (BasicOp -> Exp rep) -> BasicOp -> Stm rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Stm rep) -> BasicOp -> Stm rep
forall a b. (a -> b) -> a -> b
$
            Shape -> SubExp -> BasicOp
Replicate Shape
forall a. Monoid a => a
mempty (VName -> SubExp
Var VName
arr)
    identityStm PatElem Type
pe (SubExpRes Certs
cs SubExp
se) =
      Certs -> Stm rep -> Stm rep
forall rep. Certs -> Stm rep -> Stm rep
certify Certs
cs (Stm rep -> Stm rep) -> (BasicOp -> Stm rep) -> BasicOp -> Stm rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type
pe]) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp rep -> Stm rep) -> (BasicOp -> Exp rep) -> BasicOp -> Stm rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Stm rep) -> BasicOp -> Stm rep
forall a b. (a -> b) -> a -> b
$
        Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [LoopNesting -> SubExp
loopNestingWidth LoopNesting
outermost]) SubExp
se

addPostStms :: (Monad m) => PostStms rep -> DistNestT rep m ()
addPostStms :: forall (m :: * -> *) rep.
Monad m =>
PostStms rep -> DistNestT rep m ()
addPostStms PostStms rep
ks = DistRes rep -> DistNestT rep m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (DistRes rep -> DistNestT rep m ())
-> DistRes rep -> DistNestT rep m ()
forall a b. (a -> b) -> a -> b
$ DistRes (ZonkAny 0)
forall a. Monoid a => a
mempty {accPostStms = ks}

postStm :: (Monad m) => Stms rep -> DistNestT rep m ()
postStm :: forall (m :: * -> *) rep. Monad m => Stms rep -> DistNestT rep m ()
postStm Stms rep
stms = PostStms rep -> DistNestT rep m ()
forall (m :: * -> *) rep.
Monad m =>
PostStms rep -> DistNestT rep m ()
addPostStms (PostStms rep -> DistNestT rep m ())
-> PostStms rep -> DistNestT rep m ()
forall a b. (a -> b) -> a -> b
$ Stms rep -> PostStms rep
forall rep. Stms rep -> PostStms rep
PostStms Stms rep
stms

withStm ::
  (Monad m, DistRep rep) =>
  Stm SOACS ->
  DistNestT rep m a ->
  DistNestT rep m a
withStm :: forall (m :: * -> *) rep a.
(Monad m, DistRep rep) =>
Stm SOACS -> DistNestT rep m a -> DistNestT rep m a
withStm Stm SOACS
stm = (DistEnv rep m -> DistEnv rep m)
-> DistNestT rep m a -> DistNestT rep m a
forall a.
(DistEnv rep m -> DistEnv rep m)
-> DistNestT rep m a -> DistNestT rep m a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((DistEnv rep m -> DistEnv rep m)
 -> DistNestT rep m a -> DistNestT rep m a)
-> (DistEnv rep m -> DistEnv rep m)
-> DistNestT rep m a
-> DistNestT rep m a
forall a b. (a -> b) -> a -> b
$ \DistEnv rep m
env ->
  DistEnv rep m
env
    { distScope =
        castScope (scopeOf stm) <> distScope env,
      distNest =
        letBindInInnerNesting provided $
          distNest env
    }
  where
    provided :: Names
provided = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat Type -> [VName]) -> Pat Type -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Pat (LetDec SOACS)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm SOACS
stm

leavingNesting ::
  (MonadFreshNames m, DistRep rep) =>
  DistAcc rep ->
  DistNestT rep m (DistAcc rep)
leavingNesting :: forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
DistAcc rep -> DistNestT rep m (DistAcc rep)
leavingNesting DistAcc rep
acc =
  case Targets -> Maybe ((Pat Type, Result), Targets)
popInnerTarget (Targets -> Maybe ((Pat Type, Result), Targets))
-> Targets -> Maybe ((Pat Type, Result), Targets)
forall a b. (a -> b) -> a -> b
$ DistAcc rep -> Targets
forall rep. DistAcc rep -> Targets
distTargets DistAcc rep
acc of
    Maybe ((Pat Type, Result), Targets)
Nothing ->
      [Char] -> DistNestT rep m (DistAcc rep)
forall a. HasCallStack => [Char] -> a
error [Char]
"The kernel targets list is unexpectedly small"
    Just ((Pat Type
pat, Result
res), Targets
newtargets)
      | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Seq (Stm rep) -> Bool
forall a. Seq a -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (Seq (Stm rep) -> Bool) -> Seq (Stm rep) -> Bool
forall a b. (a -> b) -> a -> b
$ DistAcc rep -> Seq (Stm rep)
forall rep. DistAcc rep -> Stms rep
distStms DistAcc rep
acc -> do
          -- Any statements left over correspond to something that
          -- could not be distributed because it would cause irregular
          -- arrays.  These must be reconstructed into a a Map SOAC
          -- that will be sequentialised. XXX: life would be better if
          -- we were able to distribute irregular parallelism.
          (Nesting _ inner, _) <- (DistEnv rep m -> Nestings) -> DistNestT rep m Nestings
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv rep m -> Nestings
forall rep (m :: * -> *). DistEnv rep m -> Nestings
distNest
          let MapNesting _ aux w params_and_arrs = inner
              body = BodyDec rep -> Seq (Stm rep) -> Result -> Body rep
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () (DistAcc rep -> Seq (Stm rep)
forall rep. DistAcc rep -> Stms rep
distStms DistAcc rep
acc) Result
res
              used_in_body = Body rep -> Names
forall a. FreeIn a => a -> Names
freeIn Body rep
body
              (used_params, used_arrs) =
                unzip $
                  filter ((`nameIn` used_in_body) . paramName . fst) params_and_arrs
              lam' =
                Lambda
                  { lambdaParams :: [LParam rep]
lambdaParams = [Param Type]
[LParam rep]
used_params,
                    lambdaBody :: Body rep
lambdaBody = Body rep
body,
                    lambdaReturnType :: [Type]
lambdaReturnType = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Pat Type -> [Type]
forall dec. Typed dec => Pat dec -> [Type]
patTypes Pat Type
pat
                  }
          stms <-
            runBuilder_ . auxing aux . FOT.transformSOAC pat $
              Screma w used_arrs $
                mapSOAC lam'

          pure $ acc {distTargets = newtargets, distStms = stms}
      | Bool
otherwise -> do
          -- Any results left over correspond to a Replicate or a Copy in
          -- the parent nesting, depending on whether the argument is a
          -- parameter of the innermost nesting.
          (Nesting _ inner_nesting, _) <- (DistEnv rep m -> Nestings) -> DistNestT rep m Nestings
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv rep m -> Nestings
forall rep (m :: * -> *). DistEnv rep m -> Nestings
distNest
          let w = LoopNesting -> SubExp
loopNestingWidth LoopNesting
inner_nesting
              aux = LoopNesting -> StmAux ()
loopNestingAux LoopNesting
inner_nesting
              inps = LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs LoopNesting
inner_nesting

              remnantStm PatElem Type
pe (SubExpRes Certs
cs (Var VName
v))
                | Just (Param Type
_, VName
arr) <- ((Param Type, VName) -> Bool)
-> [(Param Type, VName)] -> Maybe (Param Type, VName)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v) (VName -> Bool)
-> ((Param Type, VName) -> VName) -> (Param Type, VName) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> VName
forall dec. Param dec -> VName
paramName (Param Type -> VName)
-> ((Param Type, VName) -> Param Type)
-> (Param Type, VName)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst) [(Param Type, VName)]
inps =
                    Certs -> Stm rep -> Stm rep
forall rep. Certs -> Stm rep -> Stm rep
certify Certs
cs (Stm rep -> Stm rep) -> (BasicOp -> Stm rep) -> BasicOp -> Stm rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type
pe]) StmAux ()
StmAux (ExpDec rep)
aux (Exp rep -> Stm rep) -> (BasicOp -> Exp rep) -> BasicOp -> Stm rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Stm rep) -> BasicOp -> Stm rep
forall a b. (a -> b) -> a -> b
$
                      Shape -> SubExp -> BasicOp
Replicate Shape
forall a. Monoid a => a
mempty (VName -> SubExp
Var VName
arr)
              remnantStm PatElem Type
pe (SubExpRes Certs
cs SubExp
se) =
                Certs -> Stm rep -> Stm rep
forall rep. Certs -> Stm rep -> Stm rep
certify Certs
cs (Stm rep -> Stm rep) -> (BasicOp -> Stm rep) -> BasicOp -> Stm rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type
pe]) StmAux ()
StmAux (ExpDec rep)
aux (Exp rep -> Stm rep) -> (BasicOp -> Exp rep) -> BasicOp -> Stm rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Stm rep) -> BasicOp -> Stm rep
forall a b. (a -> b) -> a -> b
$
                  Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
se

              stms =
                [Stm rep] -> Seq (Stm rep)
forall rep. [Stm rep] -> Stms rep
stmsFromList ([Stm rep] -> Seq (Stm rep)) -> [Stm rep] -> Seq (Stm rep)
forall a b. (a -> b) -> a -> b
$ (PatElem Type -> SubExpRes -> Stm rep)
-> [PatElem Type] -> Result -> [Stm rep]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PatElem Type -> SubExpRes -> Stm rep
remnantStm (Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
pat) Result
res

          pure $ acc {distTargets = newtargets, distStms = stms}

mapNesting ::
  (MonadFreshNames m, DistRep rep) =>
  Pat Type ->
  StmAux () ->
  SubExp ->
  Lambda SOACS ->
  [VName] ->
  DistNestT rep m (DistAcc rep) ->
  DistNestT rep m (DistAcc rep)
mapNesting :: forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Pat Type
-> StmAux ()
-> SubExp
-> Lambda SOACS
-> [VName]
-> DistNestT rep m (DistAcc rep)
-> DistNestT rep m (DistAcc rep)
mapNesting Pat Type
pat StmAux ()
aux SubExp
w Lambda SOACS
lam [VName]
arrs DistNestT rep m (DistAcc rep)
m =
  (DistEnv rep m -> DistEnv rep m)
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall a.
(DistEnv rep m -> DistEnv rep m)
-> DistNestT rep m a -> DistNestT rep m a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local DistEnv rep m -> DistEnv rep m
extend (DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall a b. (a -> b) -> a -> b
$ DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
DistAcc rep -> DistNestT rep m (DistAcc rep)
leavingNesting (DistAcc rep -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< DistNestT rep m (DistAcc rep)
m
  where
    nest :: Nesting
nest =
      Names -> LoopNesting -> Nesting
Nesting Names
forall a. Monoid a => a
mempty (LoopNesting -> Nesting) -> LoopNesting -> Nesting
forall a b. (a -> b) -> a -> b
$
        Pat Type
-> StmAux () -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting Pat Type
pat StmAux ()
aux SubExp
w ([(Param Type, VName)] -> LoopNesting)
-> [(Param Type, VName)] -> LoopNesting
forall a b. (a -> b) -> a -> b
$
          [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam) [VName]
arrs
    extend :: DistEnv rep m -> DistEnv rep m
extend DistEnv rep m
env =
      DistEnv rep m
env
        { distNest = pushInnerNesting nest $ distNest env,
          distScope = castScope (scopeOf lam) <> distScope env
        }

inNesting ::
  (Monad m, DistRep rep) =>
  KernelNest ->
  DistNestT rep m a ->
  DistNestT rep m a
inNesting :: forall (m :: * -> *) rep a.
(Monad m, DistRep rep) =>
KernelNest -> DistNestT rep m a -> DistNestT rep m a
inNesting (LoopNesting
outer, [LoopNesting]
nests) = (DistEnv rep m -> DistEnv rep m)
-> DistNestT rep m a -> DistNestT rep m a
forall a.
(DistEnv rep m -> DistEnv rep m)
-> DistNestT rep m a -> DistNestT rep m a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((DistEnv rep m -> DistEnv rep m)
 -> DistNestT rep m a -> DistNestT rep m a)
-> (DistEnv rep m -> DistEnv rep m)
-> DistNestT rep m a
-> DistNestT rep m a
forall a b. (a -> b) -> a -> b
$ \DistEnv rep m
env ->
  DistEnv rep m
env
    { distNest = (inner, nests'),
      distScope = foldMap scopeOfLoopNesting (outer : nests) <> distScope env
    }
  where
    (Nesting
inner, [Nesting]
nests') =
      case [LoopNesting] -> [LoopNesting]
forall a. [a] -> [a]
reverse [LoopNesting]
nests of
        [] -> (LoopNesting -> Nesting
asNesting LoopNesting
outer, [])
        (LoopNesting
inner' : [LoopNesting]
ns) -> (LoopNesting -> Nesting
asNesting LoopNesting
inner', (LoopNesting -> Nesting) -> [LoopNesting] -> [Nesting]
forall a b. (a -> b) -> [a] -> [b]
map LoopNesting -> Nesting
asNesting ([LoopNesting] -> [Nesting]) -> [LoopNesting] -> [Nesting]
forall a b. (a -> b) -> a -> b
$ LoopNesting
outer LoopNesting -> [LoopNesting] -> [LoopNesting]
forall a. a -> [a] -> [a]
: [LoopNesting] -> [LoopNesting]
forall a. [a] -> [a]
reverse [LoopNesting]
ns)
    asNesting :: LoopNesting -> Nesting
asNesting = Names -> LoopNesting -> Nesting
Nesting Names
forall a. Monoid a => a
mempty

bodyContainsParallelism :: Body SOACS -> Bool
bodyContainsParallelism :: Body SOACS -> Bool
bodyContainsParallelism = (Stm SOACS -> Bool) -> Stms SOACS -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Stm SOACS -> Bool
isParallelStm (Stms SOACS -> Bool)
-> (Body SOACS -> Stms SOACS) -> Body SOACS -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body SOACS -> Stms SOACS
forall rep. Body rep -> Stms rep
bodyStms
  where
    isParallelStm :: Stm SOACS -> Bool
isParallelStm Stm SOACS
stm =
      Exp SOACS -> Bool
isMap (Stm SOACS -> Exp SOACS
forall rep. Stm rep -> Exp rep
stmExp Stm SOACS
stm)
        Bool -> Bool -> Bool
&& Bool -> Bool
not (Attr
"sequential" Attr -> Attrs -> Bool
`inAttrs` StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs (Stm SOACS -> StmAux (ExpDec SOACS)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm SOACS
stm))
    isMap :: Exp SOACS -> Bool
isMap BasicOp {} = Bool
False
    isMap Apply {} = Bool
False
    isMap Match {} = Bool
False
    isMap (Loop [(FParam SOACS, SubExp)]
_ ForLoop {} Body SOACS
body) = Body SOACS -> Bool
bodyContainsParallelism Body SOACS
body
    isMap (Loop [(FParam SOACS, SubExp)]
_ WhileLoop {} Body SOACS
_) = Bool
False
    isMap (WithAcc [WithAccInput SOACS]
_ Lambda SOACS
lam) = Body SOACS -> Bool
bodyContainsParallelism (Body SOACS -> Bool) -> Body SOACS -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam
    isMap Op {} = Bool
True

lambdaContainsParallelism :: Lambda SOACS -> Bool
lambdaContainsParallelism :: Lambda SOACS -> Bool
lambdaContainsParallelism = Body SOACS -> Bool
bodyContainsParallelism (Body SOACS -> Bool)
-> (Lambda SOACS -> Body SOACS) -> Lambda SOACS -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody

distributeMapBodyStms ::
  (MonadFreshNames m, LocalScope rep m, DistRep rep) =>
  DistAcc rep ->
  Stms SOACS ->
  DistNestT rep m (DistAcc rep)
distributeMapBodyStms :: forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> Stms SOACS -> DistNestT rep m (DistAcc rep)
distributeMapBodyStms DistAcc rep
orig_acc = DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> DistNestT rep m (DistAcc rep)
distribute (DistAcc rep -> DistNestT rep m (DistAcc rep))
-> (Stms SOACS -> DistNestT rep m (DistAcc rep))
-> Stms SOACS
-> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< DistAcc rep -> [Stm SOACS] -> DistNestT rep m (DistAcc rep)
forall {rep} {m :: * -> *}.
(LetDec rep ~ Type, ExpDec rep ~ (), BodyDec rep ~ (),
 MonadFreshNames m, Buildable rep, HasSegOp rep, BuilderOps rep,
 CanBeAliased (OpC rep), AliasedOp (OpC rep), LocalScope rep m,
 Ord (OpC rep (Aliases rep)), Show (OpC rep (Aliases rep)),
 Rename (OpC rep (Aliases rep)), Substitute (OpC rep (Aliases rep)),
 FreeIn (OpC rep (Aliases rep)), Pretty (OpC rep (Aliases rep))) =>
DistAcc rep -> [Stm SOACS] -> DistNestT rep m (DistAcc rep)
onStms DistAcc rep
orig_acc ([Stm SOACS] -> DistNestT rep m (DistAcc rep))
-> (Stms SOACS -> [Stm SOACS])
-> Stms SOACS
-> DistNestT rep m (DistAcc rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms SOACS -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList
  where
    onStms :: DistAcc rep -> [Stm SOACS] -> DistNestT rep m (DistAcc rep)
onStms DistAcc rep
acc [] = DistAcc rep -> DistNestT rep m (DistAcc rep)
forall a. a -> DistNestT rep m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure DistAcc rep
acc
    onStms DistAcc rep
acc (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op (Stream SubExp
w [VName]
arrs [SubExp]
accs Lambda SOACS
lam)) : [Stm SOACS]
stms) = do
      types <- (Scope rep -> Scope SOACS) -> DistNestT rep m (Scope SOACS)
forall a. (Scope rep -> a) -> DistNestT rep m a
forall rep (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope Scope rep -> Scope SOACS
forall rep. SameScope rep SOACS => Scope rep -> Scope SOACS
scopeForSOACs
      stream_stms <-
        snd <$> runBuilderT (sequentialStreamWholeArray pat w accs lam arrs) types
      stream_stms' <-
        runReaderT (copyPropagateInStms simpleSOACS types stream_stms) types
      onStms acc $ stmsToList (fmap (certify (stmAuxCerts aux)) stream_stms') ++ stms
    onStms DistAcc rep
acc (Stm SOACS
stm : [Stm SOACS]
stms) =
      -- It is important that stm is in scope if 'maybeDistributeStm'
      -- wants to distribute, even if this causes the slightly silly
      -- situation that stm is in scope of itself.
      Stm SOACS
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep a.
(Monad m, DistRep rep) =>
Stm SOACS -> DistNestT rep m a -> DistNestT rep m a
withStm Stm SOACS
stm (DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
maybeDistributeStm Stm SOACS
stm (DistAcc rep -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< DistAcc rep -> [Stm SOACS] -> DistNestT rep m (DistAcc rep)
onStms DistAcc rep
acc [Stm SOACS]
stms

onInnerMap :: (Monad m) => MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
onInnerMap :: forall (m :: * -> *) rep.
Monad m =>
MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
onInnerMap MapLoop
loop DistAcc rep
acc = do
  f <- (DistEnv rep m
 -> MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep))
-> DistNestT
     rep m (MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv rep m
-> MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
forall rep (m :: * -> *).
DistEnv rep m
-> MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
distOnInnerMap
  f loop acc

onTopLevelStms :: (Monad m) => Stms SOACS -> DistNestT rep m ()
onTopLevelStms :: forall (m :: * -> *) rep.
Monad m =>
Stms SOACS -> DistNestT rep m ()
onTopLevelStms Stms SOACS
stms = do
  f <- (DistEnv rep m -> Stms SOACS -> DistNestT rep m (Stms rep))
-> DistNestT rep m (Stms SOACS -> DistNestT rep m (Stms rep))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv rep m -> Stms SOACS -> DistNestT rep m (Stms rep)
forall rep (m :: * -> *).
DistEnv rep m -> Stms SOACS -> DistNestT rep m (Stms rep)
distOnTopLevelStms
  postStm =<< f stms

maybeDistributeStm ::
  (MonadFreshNames m, LocalScope rep m, DistRep rep) =>
  Stm SOACS ->
  DistAcc rep ->
  DistNestT rep m (DistAcc rep)
maybeDistributeStm :: forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
maybeDistributeStm Stm SOACS
stm DistAcc rep
acc
  | Attr
"sequential" Attr -> Attrs -> Bool
`inAttrs` StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs (Stm SOACS -> StmAux (ExpDec SOACS)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm SOACS
stm) =
      Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
maybeDistributeStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op Op SOACS
soac)) DistAcc rep
acc
  | Attr
"sequential_outer" Attr -> Attrs -> Bool
`inAttrs` StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
StmAux (ExpDec SOACS)
aux =
      DistAcc rep -> Stms SOACS -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> Stms SOACS -> DistNestT rep m (DistAcc rep)
distributeMapBodyStms DistAcc rep
acc (Stms SOACS -> DistNestT rep m (DistAcc rep))
-> (Stms SOACS -> Stms SOACS)
-> Stms SOACS
-> DistNestT rep m (DistAcc rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm SOACS -> Stm SOACS) -> Stms SOACS -> Stms SOACS
forall a b. (a -> b) -> Seq a -> Seq b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certs -> Stm SOACS -> Stm SOACS
forall rep. Certs -> Stm rep -> Stm rep
certify (StmAux () -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux))
        (Stms SOACS -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (Stms SOACS) -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Builder SOACS () -> DistNestT rep m (Stms SOACS)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Pat (LetDec (Rep (BuilderT SOACS (State VNameSource))))
-> SOAC (Rep (BuilderT SOACS (State VNameSource)))
-> Builder SOACS ()
forall (m :: * -> *).
Transformer m =>
Pat (LetDec (Rep m)) -> SOAC (Rep m) -> m ()
FOT.transformSOAC Pat (LetDec (Rep (BuilderT SOACS (State VNameSource))))
Pat (LetDec SOACS)
pat Op SOACS
SOAC (Rep (BuilderT SOACS (State VNameSource)))
soac)
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
_ (Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form))) DistAcc rep
acc
  | Just Lambda SOACS
lam <- ScremaForm SOACS -> Maybe (Lambda SOACS)
forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm SOACS
form =
      -- Only distribute inside the map if we can distribute everything
      -- following the map.
      DistAcc rep -> DistNestT rep m (Maybe (DistAcc rep))
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> DistNestT rep m (Maybe (DistAcc rep))
distributeIfPossible DistAcc rep
acc DistNestT rep m (Maybe (DistAcc rep))
-> (Maybe (DistAcc rep) -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep)
forall a b.
DistNestT rep m a -> (a -> DistNestT rep m b) -> DistNestT rep m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Maybe (DistAcc rep)
Nothing -> Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
        Just DistAcc rep
acc' -> DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> DistNestT rep m (DistAcc rep)
distribute (DistAcc rep -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
Monad m =>
MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
onInnerMap (Pat Type
-> StmAux () -> SubExp -> Lambda SOACS -> [VName] -> MapLoop
MapLoop Pat Type
Pat (LetDec SOACS)
pat (Stm SOACS -> StmAux (ExpDec SOACS)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm SOACS
stm) SubExp
w Lambda SOACS
lam [VName]
arrs) DistAcc rep
acc'
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Loop [(FParam SOACS, SubExp)]
merge form :: LoopForm
form@ForLoop {} Body SOACS
body)) DistAcc rep
acc
  | (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> Names -> Bool
`notNameIn` [Type] -> Names
forall a. FreeIn a => a -> Names
freeIn (Pat Type -> [Type]
forall dec. Typed dec => Pat dec -> [Type]
patTypes Pat Type
Pat (LetDec SOACS)
pat)) (Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
Pat (LetDec SOACS)
pat),
    Body SOACS -> Bool
bodyContainsParallelism Body SOACS
body =
      DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm DistNestT
  rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
-> (Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
    -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep)
forall a b.
DistNestT rep m a -> (a -> DistNestT rep m b) -> DistNestT rep m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Just (PostStms rep
kernels, Result
res, KernelNest
nest, DistAcc rep
acc')
          | -- XXX: We cannot distribute if this loop depends on
            -- certificates bound within the loop nest (well, we could,
            -- but interchange would not be valid).  This is not a
            -- fundamental restriction, but an artifact of our
            -- certificate representation, which we should probably
            -- rethink.
            Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$
              (LoopForm -> Names
forall a. FreeIn a => a -> Names
freeIn LoopForm
form Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> StmAux () -> Names
forall a. FreeIn a => a -> Names
freeIn StmAux ()
StmAux (ExpDec SOACS)
aux)
                Names -> Names -> Bool
`namesIntersect` KernelNest -> Names
boundInKernelNest KernelNest
nest,
            Just ([Int]
perm, [PatElem Type]
pat_unused) <- Pat Type -> Result -> Maybe ([Int], [PatElem Type])
permutationAndMissing Pat Type
Pat (LetDec SOACS)
pat Result
res ->
              -- We need to pretend pat_unused was used anyway, by adding
              -- it to the kernel nest.
              Scope rep
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall a. Scope rep -> DistNestT rep m a -> DistNestT rep m a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (DistAcc rep -> Scope rep
forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') (DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall a b. (a -> b) -> a -> b
$ do
                PostStms rep -> DistNestT rep m ()
forall (m :: * -> *) rep.
Monad m =>
PostStms rep -> DistNestT rep m ()
addPostStms PostStms rep
kernels
                nest' <- [PatElem Type] -> KernelNest -> DistNestT rep m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElem Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElem Type]
pat_unused KernelNest
nest
                types <- asksScope scopeForSOACs

                -- Simplification is key to hoisting out statements that
                -- were variant to the loop, but invariant to the outer maps
                -- (which are now innermost).
                stms <-
                  (`runReaderT` types) $
                    simplifyStms =<< interchangeLoops nest' (SeqLoop perm pat merge form body)
                onTopLevelStms stms
                pure acc'
        Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
_ ->
          Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
_ (Match [SubExp]
cond [Case (Body SOACS)]
cases Body SOACS
defbody MatchDec (BranchType SOACS)
ret)) DistAcc rep
acc
  | (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> Names -> Bool
`notNameIn` Pat Type -> Names
forall a. FreeIn a => a -> Names
freeIn Pat Type
Pat (LetDec SOACS)
pat) (Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
Pat (LetDec SOACS)
pat),
    (Body SOACS -> Bool) -> [Body SOACS] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Body SOACS -> Bool
bodyContainsParallelism (Body SOACS
defbody Body SOACS -> [Body SOACS] -> [Body SOACS]
forall a. a -> [a] -> [a]
: (Case (Body SOACS) -> Body SOACS)
-> [Case (Body SOACS)] -> [Body SOACS]
forall a b. (a -> b) -> [a] -> [b]
map Case (Body SOACS) -> Body SOACS
forall body. Case body -> body
caseBody [Case (Body SOACS)]
cases)
      Bool -> Bool -> Bool
|| Bool -> Bool
not ((ExtType -> Bool) -> [ExtType] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ExtType -> Bool
forall shape u. TypeBase shape u -> Bool
primType (MatchDec ExtType -> [ExtType]
forall rt. MatchDec rt -> [rt]
matchReturns MatchDec ExtType
MatchDec (BranchType SOACS)
ret)) =
      DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm DistNestT
  rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
-> (Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
    -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep)
forall a b.
DistNestT rep m a -> (a -> DistNestT rep m b) -> DistNestT rep m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Just (PostStms rep
kernels, Result
res, KernelNest
nest, DistAcc rep
acc')
          | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$
              ([SubExp] -> Names
forall a. FreeIn a => a -> Names
freeIn [SubExp]
cond Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> MatchDec ExtType -> Names
forall a. FreeIn a => a -> Names
freeIn MatchDec ExtType
MatchDec (BranchType SOACS)
ret) Names -> Names -> Bool
`namesIntersect` KernelNest -> Names
boundInKernelNest KernelNest
nest,
            Just ([Int]
perm, [PatElem Type]
pat_unused) <- Pat Type -> Result -> Maybe ([Int], [PatElem Type])
permutationAndMissing Pat Type
Pat (LetDec SOACS)
pat Result
res ->
              -- We need to pretend pat_unused was used anyway, by adding
              -- it to the kernel nest.
              Scope rep
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall a. Scope rep -> DistNestT rep m a -> DistNestT rep m a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (DistAcc rep -> Scope rep
forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') (DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall a b. (a -> b) -> a -> b
$ do
                nest' <- [PatElem Type] -> KernelNest -> DistNestT rep m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElem Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElem Type]
pat_unused KernelNest
nest
                addPostStms kernels
                types <- asksScope scopeForSOACs
                let branch = [Int]
-> Pat Type
-> [SubExp]
-> [Case (Body SOACS)]
-> Body SOACS
-> MatchDec (BranchType SOACS)
-> Branch
Branch [Int]
perm Pat Type
Pat (LetDec SOACS)
pat [SubExp]
cond [Case (Body SOACS)]
cases Body SOACS
defbody MatchDec (BranchType SOACS)
ret
                stms <-
                  (`runReaderT` types) $
                    simplifyStms . oneStm =<< interchangeBranch nest' branch
                onTopLevelStms stms
                pure acc'
        Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
_ ->
          Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
_ (WithAcc [WithAccInput SOACS]
inputs Lambda SOACS
lam)) DistAcc rep
acc
  | Lambda SOACS -> Bool
lambdaContainsParallelism Lambda SOACS
lam =
      DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm DistNestT
  rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
-> (Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
    -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep)
forall a b.
DistNestT rep m a -> (a -> DistNestT rep m b) -> DistNestT rep m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Just (PostStms rep
kernels, Result
res, KernelNest
nest, DistAcc rep
acc')
          | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$
              [Type] -> Names
forall a. FreeIn a => a -> Names
freeIn (Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop Int
num_accs (Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
lam))
                Names -> Names -> Bool
`namesIntersect` KernelNest -> Names
boundInKernelNest KernelNest
nest,
            Just ([Int]
perm, [PatElem Type]
pat_unused) <- Pat Type -> Result -> Maybe ([Int], [PatElem Type])
permutationAndMissing Pat Type
Pat (LetDec SOACS)
pat Result
res ->
              -- We need to pretend pat_unused was used anyway, by adding
              -- it to the kernel nest.
              Scope rep
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall a. Scope rep -> DistNestT rep m a -> DistNestT rep m a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (DistAcc rep -> Scope rep
forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') (DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall a b. (a -> b) -> a -> b
$ do
                nest' <- [PatElem Type] -> KernelNest -> DistNestT rep m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElem Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElem Type]
pat_unused KernelNest
nest
                types <- asksScope scopeForSOACs
                addPostStms kernels
                let withacc = [Int]
-> Pat Type -> [WithAccInput SOACS] -> Lambda SOACS -> WithAccStm
WithAccStm [Int]
perm Pat Type
Pat (LetDec SOACS)
pat [WithAccInput SOACS]
inputs Lambda SOACS
lam
                stms <-
                  (`runReaderT` types) $
                    simplifyStms . oneStm =<< interchangeWithAcc nest' withacc
                onTopLevelStms stms
                pure acc'
        Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
_ ->
          Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
  where
    num_accs :: Int
num_accs = [WithAccInput SOACS] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput SOACS]
inputs
maybeDistributeStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form))) DistAcc rep
acc
  | Just [Reduce Commutativity
comm Lambda SOACS
lam [SubExp]
nes] <- ScremaForm SOACS -> Maybe [Reduce SOACS]
forall rep. ScremaForm rep -> Maybe [Reduce rep]
isReduceSOAC ScremaForm SOACS
form,
    Just BuilderT SOACS (DistNestT rep m) ()
m <- Pat Type
-> SubExp
-> Commutativity
-> Lambda SOACS
-> [(SubExp, VName)]
-> Maybe (BuilderT SOACS (DistNestT rep m) ())
forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
Pat Type
-> SubExp
-> Commutativity
-> Lambda SOACS
-> [(SubExp, VName)]
-> Maybe (m ())
irwim Pat Type
Pat (LetDec SOACS)
pat SubExp
w Commutativity
comm Lambda SOACS
lam ([(SubExp, VName)] -> Maybe (BuilderT SOACS (DistNestT rep m) ()))
-> [(SubExp, VName)] -> Maybe (BuilderT SOACS (DistNestT rep m) ())
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [VName] -> [(SubExp, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
nes [VName]
arrs = do
      types <- (Scope rep -> Scope SOACS) -> DistNestT rep m (Scope SOACS)
forall a. (Scope rep -> a) -> DistNestT rep m a
forall rep (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope Scope rep -> Scope SOACS
forall rep. SameScope rep SOACS => Scope rep -> Scope SOACS
scopeForSOACs
      (_, stms) <- runBuilderT (auxing aux m) types
      distributeMapBodyStms acc stms

-- Parallelise segmented scatters.
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op (Scatter SubExp
w [VName]
ivs ScatterSpec VName
as Lambda SOACS
lam))) DistAcc rep
acc =
  DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm DistNestT
  rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
-> (Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
    -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep)
forall a b.
DistNestT rep m a -> (a -> DistNestT rep m b) -> DistNestT rep m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just (PostStms rep
kernels, Result
res, KernelNest
nest, DistAcc rep
acc')
      | Just ([Int]
perm, [PatElem Type]
pat_unused) <- Pat Type -> Result -> Maybe ([Int], [PatElem Type])
permutationAndMissing Pat Type
Pat (LetDec SOACS)
pat Result
res ->
          Scope rep
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall a. Scope rep -> DistNestT rep m a -> DistNestT rep m a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (DistAcc rep -> Scope rep
forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') (DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall a b. (a -> b) -> a -> b
$ do
            nest' <- [PatElem Type] -> KernelNest -> DistNestT rep m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElem Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElem Type]
pat_unused KernelNest
nest
            lam' <- soacsLambda lam
            addPostStms kernels
            postStm =<< segmentedScatterKernel nest' perm pat (stmAuxCerts aux) w lam' ivs as
            pure acc'
    Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
_ ->
      Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
-- Parallelise segmented Hist.
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op (Hist SubExp
w [VName]
as [HistOp SOACS]
ops Lambda SOACS
lam))) DistAcc rep
acc =
  DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm DistNestT
  rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
-> (Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
    -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep)
forall a b.
DistNestT rep m a -> (a -> DistNestT rep m b) -> DistNestT rep m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just (PostStms rep
kernels, Result
res, KernelNest
nest, DistAcc rep
acc')
      | Just ([Int]
perm, [PatElem Type]
pat_unused) <- Pat Type -> Result -> Maybe ([Int], [PatElem Type])
permutationAndMissing Pat Type
Pat (LetDec SOACS)
pat Result
res ->
          Scope rep
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall a. Scope rep -> DistNestT rep m a -> DistNestT rep m a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (DistAcc rep -> Scope rep
forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') (DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall a b. (a -> b) -> a -> b
$ do
            lam' <- Lambda SOACS -> DistNestT rep m (Lambda rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Lambda SOACS -> DistNestT rep m (Lambda rep)
soacsLambda Lambda SOACS
lam
            nest' <- expandKernelNest pat_unused nest
            addPostStms kernels
            postStm =<< segmentedHistKernel nest' perm (stmAuxCerts aux) w ops lam' as
            pure acc'
    Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
_ ->
      Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
-- Parallelise Index slices if the result is going to be returned
-- directly from the kernel.  This is because we would otherwise have
-- to sequentialise writing the result, which may be costly.
maybeDistributeStm stm :: Stm SOACS
stm@(Let (Pat [PatElem (LetDec SOACS)
pe]) StmAux (ExpDec SOACS)
aux (BasicOp (Index VName
arr Slice SubExp
slice))) DistAcc rep
acc
  | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([SubExp] -> Bool) -> [SubExp] -> Bool
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice,
    VName -> SubExp
Var (PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName PatElem Type
PatElem (LetDec SOACS)
pe) SubExp -> [SubExp] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ((Pat Type, Result) -> Result
forall a b. (a, b) -> b
snd (Targets -> (Pat Type, Result)
innerTarget (DistAcc rep -> Targets
forall rep. DistAcc rep -> Targets
distTargets DistAcc rep
acc))) =
      DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm DistNestT
  rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
-> (Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
    -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep)
forall a b.
DistNestT rep m a -> (a -> DistNestT rep m b) -> DistNestT rep m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Just (PostStms rep
kernels, Result
_res, KernelNest
nest, DistAcc rep
acc') ->
          Scope rep
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall a. Scope rep -> DistNestT rep m a -> DistNestT rep m a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (DistAcc rep -> Scope rep
forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') (DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall a b. (a -> b) -> a -> b
$ do
            PostStms rep -> DistNestT rep m ()
forall (m :: * -> *) rep.
Monad m =>
PostStms rep -> DistNestT rep m ()
addPostStms PostStms rep
kernels
            Stms rep -> DistNestT rep m ()
forall (m :: * -> *) rep. Monad m => Stms rep -> DistNestT rep m ()
postStm (Stms rep -> DistNestT rep m ())
-> DistNestT rep m (Stms rep) -> DistNestT rep m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< KernelNest
-> Certs -> VName -> Slice SubExp -> DistNestT rep m (Stms rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> Certs -> VName -> Slice SubExp -> DistNestT rep m (Stms rep)
segmentedGatherKernel KernelNest
nest (StmAux () -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) VName
arr Slice SubExp
slice
            DistAcc rep -> DistNestT rep m (DistAcc rep)
forall a. a -> DistNestT rep m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure DistAcc rep
acc'
        Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
_ ->
          Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
-- If the scan can be distributed by itself, we will turn it into a
-- segmented scan.
--
-- If the scan cannot be distributed by itself, it will be
-- sequentialised in the default case for this function.
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form))) DistAcc rep
acc
  | Just ([Scan SOACS]
scans, Lambda SOACS
map_lam) <- ScremaForm SOACS -> Maybe ([Scan SOACS], Lambda SOACS)
forall rep. ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
isScanomapSOAC ScremaForm SOACS
form,
    Scan Lambda SOACS
lam [SubExp]
nes <- [Scan SOACS] -> Scan SOACS
forall rep. Buildable rep => [Scan rep] -> Scan rep
singleScan [Scan SOACS]
scans =
      DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm DistNestT
  rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
-> (Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
    -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep)
forall a b.
DistNestT rep m a -> (a -> DistNestT rep m b) -> DistNestT rep m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Just (PostStms rep
kernels, Result
res, KernelNest
nest, DistAcc rep
acc')
          | Just ([Int]
perm, [PatElem Type]
pat_unused) <- Pat Type -> Result -> Maybe ([Int], [PatElem Type])
permutationAndMissing Pat Type
Pat (LetDec SOACS)
pat Result
res ->
              -- We need to pretend pat_unused was used anyway, by adding
              -- it to the kernel nest.
              Scope rep
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall a. Scope rep -> DistNestT rep m a -> DistNestT rep m a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (DistAcc rep -> Scope rep
forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') (DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall a b. (a -> b) -> a -> b
$ do
                nest' <- [PatElem Type] -> KernelNest -> DistNestT rep m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElem Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElem Type]
pat_unused KernelNest
nest
                map_lam' <- soacsLambda map_lam
                localScope (typeEnvFromDistAcc acc') $
                  segmentedScanomapKernel nest' perm (stmAuxCerts aux) w lam map_lam' nes arrs
                    >>= kernelOrNot mempty stm acc kernels acc'
        Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
_ ->
          Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
-- If the map function of the reduction contains parallelism we split
-- it, so that the parallelism can be exploited.
maybeDistributeStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form))) DistAcc rep
acc
  | Just ([Reduce SOACS]
reds, Lambda SOACS
map_lam) <- ScremaForm SOACS -> Maybe ([Reduce SOACS], Lambda SOACS)
forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm SOACS
form,
    Lambda SOACS -> Bool
lambdaContainsParallelism Lambda SOACS
map_lam = do
      (mapstm, redstm) <-
        Pat (LetDec SOACS)
-> (SubExp, [Reduce SOACS], Lambda SOACS, [VName])
-> DistNestT rep m (Stm SOACS, Stm SOACS)
forall (m :: * -> *) rep.
(MonadFreshNames m, Buildable rep, ExpDec rep ~ (),
 Op rep ~ SOAC rep) =>
Pat (LetDec rep)
-> (SubExp, [Reduce rep], Lambda rep, [VName])
-> m (Stm rep, Stm rep)
redomapToMapAndReduce Pat (LetDec SOACS)
pat (SubExp
w, [Reduce SOACS]
reds, Lambda SOACS
map_lam, [VName]
arrs)
      distributeMapBodyStms acc $ oneStm mapstm {stmAux = aux} <> oneStm redstm
-- if the reduction can be distributed by itself, we will turn it into a
-- segmented reduce.
--
-- If the reduction cannot be distributed by itself, it will be
-- sequentialised in the default case for this function.
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form))) DistAcc rep
acc
  | Just ([Reduce SOACS]
reds, Lambda SOACS
map_lam) <- ScremaForm SOACS -> Maybe ([Reduce SOACS], Lambda SOACS)
forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm SOACS
form,
    Reduce Commutativity
comm Lambda SOACS
lam [SubExp]
nes <- [Reduce SOACS] -> Reduce SOACS
forall rep. Buildable rep => [Reduce rep] -> Reduce rep
singleReduce [Reduce SOACS]
reds =
      DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm DistNestT
  rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
-> (Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
    -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep)
forall a b.
DistNestT rep m a -> (a -> DistNestT rep m b) -> DistNestT rep m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Just (PostStms rep
kernels, Result
res, KernelNest
nest, DistAcc rep
acc')
          | Just ([Int]
perm, [PatElem Type]
pat_unused) <- Pat Type -> Result -> Maybe ([Int], [PatElem Type])
permutationAndMissing Pat Type
Pat (LetDec SOACS)
pat Result
res ->
              -- We need to pretend pat_unused was used anyway, by adding
              -- it to the kernel nest.
              Scope rep
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall a. Scope rep -> DistNestT rep m a -> DistNestT rep m a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (DistAcc rep -> Scope rep
forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') (DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall a b. (a -> b) -> a -> b
$ do
                nest' <- [PatElem Type] -> KernelNest -> DistNestT rep m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElem Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElem Type]
pat_unused KernelNest
nest

                lam' <- soacsLambda lam
                map_lam' <- soacsLambda map_lam

                let comm'
                      | Lambda SOACS -> Bool
forall rep. Lambda rep -> Bool
commutativeLambda Lambda SOACS
lam = Commutativity
Commutative
                      | Bool
otherwise = Commutativity
comm

                regularSegmentedRedomapKernel nest' perm (stmAuxCerts aux) w comm' lam' map_lam' nes arrs
                  >>= kernelOrNot mempty stm acc kernels acc'
        Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
_ ->
          Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
maybeDistributeStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form))) DistAcc rep
acc = do
  -- This Screma is too complicated for us to immediately do
  -- anything, so split it up and try again.
  scope <- (Scope rep -> Scope SOACS) -> DistNestT rep m (Scope SOACS)
forall a. (Scope rep -> a) -> DistNestT rep m a
forall rep (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope Scope rep -> Scope SOACS
forall rep. SameScope rep SOACS => Scope rep -> Scope SOACS
scopeForSOACs
  distributeMapBodyStms acc . fmap (certify (stmAuxCerts aux)) . snd
    =<< runBuilderT (dissectScrema pat w form arrs) scope
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
_ StmAux (ExpDec SOACS)
aux (BasicOp (Replicate Shape
shape (Var VName
stm_arr)))) DistAcc rep
acc = do
  DistAcc rep
-> Stm SOACS
-> VName
-> (KernelNest -> Pat Type -> VName -> DistNestT rep m (Stms rep))
-> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> VName
-> (KernelNest -> Pat Type -> VName -> DistNestT rep m (Stms rep))
-> DistNestT rep m (DistAcc rep)
distributeSingleUnaryStm DistAcc rep
acc Stm SOACS
stm VName
stm_arr ((KernelNest -> Pat Type -> VName -> DistNestT rep m (Stms rep))
 -> DistNestT rep m (DistAcc rep))
-> (KernelNest -> Pat Type -> VName -> DistNestT rep m (Stms rep))
-> DistNestT rep m (DistAcc rep)
forall a b. (a -> b) -> a -> b
$ \KernelNest
nest Pat Type
outerpat VName
arr ->
    if Shape
shape Shape -> Shape -> Bool
forall a. Eq a => a -> a -> Bool
== Shape
forall a. Monoid a => a
mempty
      then Stms rep -> DistNestT rep m (Stms rep)
forall a. a -> DistNestT rep m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms rep -> DistNestT rep m (Stms rep))
-> Stms rep -> DistNestT rep m (Stms rep)
forall a b. (a -> b) -> a -> b
$ Stm rep -> Stms rep
forall rep. Stm rep -> Stms rep
oneStm (Stm rep -> Stms rep) -> Stm rep -> Stms rep
forall a b. (a -> b) -> a -> b
$ Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec rep)
outerpat StmAux (ExpDec rep)
StmAux (ExpDec SOACS)
aux (Exp rep -> Stm rep) -> Exp rep -> Stm rep
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
forall a. Monoid a => a
mempty (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr
      else Builder rep () -> DistNestT rep m (Stms rep)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Builder rep () -> DistNestT rep m (Stms rep))
-> Builder rep () -> DistNestT rep m (Stms rep)
forall a b. (a -> b) -> a -> b
$ StmAux () -> Builder rep () -> Builder rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
StmAux (ExpDec SOACS)
aux (Builder rep () -> Builder rep ())
-> Builder rep () -> Builder rep ()
forall a b. (a -> b) -> a -> b
$ do
        arr_t <- VName -> BuilderT rep (State VNameSource) Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
        let arr_r = Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
arr_t
            nest_r = [LoopNesting] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (KernelNest -> [LoopNesting]
forall a b. (a, b) -> b
snd KernelNest
nest) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
            res_r = Int
arr_r Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
shape
        -- Move the to-be-replicated dimensions outermost.
        arr_tr <-
          letExp (baseString arr <> "_tr") . BasicOp $
            Rearrange arr ([nest_r .. arr_r - 1] ++ [0 .. nest_r - 1])
        -- Replicate the now-outermost dimensions appropriately.
        arr_tr_rep <-
          letExp (baseString arr <> "_tr_rep") . BasicOp $
            Replicate shape (Var arr_tr)
        -- Move the replicated dimensions back where they belong.
        letBind outerpat . BasicOp $
          Rearrange arr_tr_rep ([res_r - nest_r .. res_r - 1] ++ [0 .. res_r - nest_r - 1])
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
_ StmAux (ExpDec SOACS)
aux (BasicOp (Replicate Shape
shape SubExp
v))) DistAcc rep
acc = do
  DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm DistNestT
  rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
-> (Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
    -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep)
forall a b.
DistNestT rep m a -> (a -> DistNestT rep m b) -> DistNestT rep m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just (PostStms rep
kernels, Result
_, KernelNest
nest, DistAcc rep
acc')
      | KernelNest -> Names
boundInKernelNest KernelNest
nest Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
== Names
forall a. Monoid a => a
mempty -> do
          PostStms rep -> DistNestT rep m ()
forall (m :: * -> *) rep.
Monad m =>
PostStms rep -> DistNestT rep m ()
addPostStms PostStms rep
kernels
          let outerpat :: Pat Type
outerpat = LoopNesting -> Pat Type
loopNestingPat (LoopNesting -> Pat Type) -> LoopNesting -> Pat Type
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
nest
              nest_shape :: Shape
nest_shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$ KernelNest -> [SubExp]
kernelNestWidths KernelNest
nest
          Scope rep
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall a. Scope rep -> DistNestT rep m a -> DistNestT rep m a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (DistAcc rep -> Scope rep
forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') (DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall a b. (a -> b) -> a -> b
$ do
            Stms rep -> DistNestT rep m ()
forall (m :: * -> *) rep. Monad m => Stms rep -> DistNestT rep m ()
postStm (Stms rep -> DistNestT rep m ())
-> (Exp rep -> DistNestT rep m (Stms rep))
-> Exp rep
-> DistNestT rep m ()
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Builder rep () -> DistNestT rep m (Stms rep)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Builder rep () -> DistNestT rep m (Stms rep))
-> (Exp rep -> Builder rep ())
-> Exp rep
-> DistNestT rep m (Stms rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StmAux () -> Builder rep () -> Builder rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
StmAux (ExpDec SOACS)
aux (Builder rep () -> Builder rep ())
-> (Exp rep -> Builder rep ()) -> Exp rep -> Builder rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec (Rep (BuilderT rep (State VNameSource))))
-> Exp (Rep (BuilderT rep (State VNameSource))) -> Builder rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat Type
Pat (LetDec (Rep (BuilderT rep (State VNameSource))))
outerpat (Exp rep -> DistNestT rep m ()) -> Exp rep -> DistNestT rep m ()
forall a b. (a -> b) -> a -> b
$
              BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (Shape -> SubExp -> BasicOp
Replicate (Shape
nest_shape Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape
shape) SubExp
v)
            DistAcc rep -> DistNestT rep m (DistAcc rep)
forall a. a -> DistNestT rep m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure DistAcc rep
acc'
    Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
_ -> Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
-- Opaques are applied to the full array, because otherwise they can
-- drastically inhibit parallelisation in some cases.
maybeDistributeStm stm :: Stm SOACS
stm@(Let (Pat [PatElem (LetDec SOACS)
pe]) StmAux (ExpDec SOACS)
aux (BasicOp (Opaque OpaqueOp
_ (Var VName
stm_arr)))) DistAcc rep
acc
  | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ PatElem Type -> Type
forall t. Typed t => t -> Type
typeOf PatElem Type
PatElem (LetDec SOACS)
pe =
      DistAcc rep
-> Stm SOACS
-> VName
-> (KernelNest -> Pat Type -> VName -> DistNestT rep m (Stms rep))
-> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> VName
-> (KernelNest -> Pat Type -> VName -> DistNestT rep m (Stms rep))
-> DistNestT rep m (DistAcc rep)
distributeSingleUnaryStm DistAcc rep
acc Stm SOACS
stm VName
stm_arr ((KernelNest -> Pat Type -> VName -> DistNestT rep m (Stms rep))
 -> DistNestT rep m (DistAcc rep))
-> (KernelNest -> Pat Type -> VName -> DistNestT rep m (Stms rep))
-> DistNestT rep m (DistAcc rep)
forall a b. (a -> b) -> a -> b
$ \KernelNest
_ Pat Type
outerpat VName
arr ->
        Stms rep -> DistNestT rep m (Stms rep)
forall a. a -> DistNestT rep m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms rep -> DistNestT rep m (Stms rep))
-> Stms rep -> DistNestT rep m (Stms rep)
forall a b. (a -> b) -> a -> b
$ Stm rep -> Stms rep
forall rep. Stm rep -> Stms rep
oneStm (Stm rep -> Stms rep) -> Stm rep -> Stms rep
forall a b. (a -> b) -> a -> b
$ Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec rep)
outerpat StmAux (ExpDec rep)
StmAux (ExpDec SOACS)
aux (Exp rep -> Stm rep) -> Exp rep -> Stm rep
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
forall a. Monoid a => a
mempty (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
_ StmAux (ExpDec SOACS)
aux (BasicOp (Rearrange VName
stm_arr [Int]
perm))) DistAcc rep
acc =
  DistAcc rep
-> Stm SOACS
-> VName
-> (KernelNest -> Pat Type -> VName -> DistNestT rep m (Stms rep))
-> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> VName
-> (KernelNest -> Pat Type -> VName -> DistNestT rep m (Stms rep))
-> DistNestT rep m (DistAcc rep)
distributeSingleUnaryStm DistAcc rep
acc Stm SOACS
stm VName
stm_arr ((KernelNest -> Pat Type -> VName -> DistNestT rep m (Stms rep))
 -> DistNestT rep m (DistAcc rep))
-> (KernelNest -> Pat Type -> VName -> DistNestT rep m (Stms rep))
-> DistNestT rep m (DistAcc rep)
forall a b. (a -> b) -> a -> b
$ \KernelNest
nest Pat Type
outerpat VName
arr -> do
    let r :: Int
r = [LoopNesting] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (KernelNest -> [LoopNesting]
forall a b. (a, b) -> b
snd KernelNest
nest) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
        perm' :: [Int]
perm' = [Int
0 .. Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ (Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
r) [Int]
perm
    -- We need to add a copy, because the original map nest
    -- will have produced an array without aliases, and so must we.
    arr' <- [Char] -> DistNestT rep m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName ([Char] -> DistNestT rep m VName)
-> [Char] -> DistNestT rep m VName
forall a b. (a -> b) -> a -> b
$ VName -> [Char]
baseString VName
arr
    arr_t <- lookupType arr
    pure $
      stmsFromList
        [ Let (Pat [PatElem arr' arr_t]) aux $ BasicOp $ Replicate mempty $ Var arr,
          Let outerpat aux $ BasicOp $ Rearrange arr' perm'
        ]
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
_ StmAux (ExpDec SOACS)
aux (BasicOp (Reshape VName
stm_arr NewShape SubExp
reshape))) DistAcc rep
acc =
  DistAcc rep
-> Stm SOACS
-> VName
-> (KernelNest -> Pat Type -> VName -> DistNestT rep m (Stms rep))
-> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> VName
-> (KernelNest -> Pat Type -> VName -> DistNestT rep m (Stms rep))
-> DistNestT rep m (DistAcc rep)
distributeSingleUnaryStm DistAcc rep
acc Stm SOACS
stm VName
stm_arr ((KernelNest -> Pat Type -> VName -> DistNestT rep m (Stms rep))
 -> DistNestT rep m (DistAcc rep))
-> (KernelNest -> Pat Type -> VName -> DistNestT rep m (Stms rep))
-> DistNestT rep m (DistAcc rep)
forall a b. (a -> b) -> a -> b
$ \KernelNest
nest Pat Type
outerpat VName
arr -> do
    let outer :: Shape
outer = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape (KernelNest -> [SubExp]
kernelNestWidths KernelNest
nest)
        reshape' :: NewShape SubExp
reshape' = Shape -> NewShape SubExp
forall new. ShapeBase new -> NewShape new
reshapeCoerce Shape
outer NewShape SubExp -> NewShape SubExp -> NewShape SubExp
forall a. Semigroup a => a -> a -> a
<> Shape -> NewShape SubExp -> NewShape SubExp
newshapeInner Shape
outer NewShape SubExp
reshape
    Stms rep -> DistNestT rep m (Stms rep)
forall a. a -> DistNestT rep m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms rep -> DistNestT rep m (Stms rep))
-> Stms rep -> DistNestT rep m (Stms rep)
forall a b. (a -> b) -> a -> b
$ Stm rep -> Stms rep
forall rep. Stm rep -> Stms rep
oneStm (Stm rep -> Stms rep) -> Stm rep -> Stms rep
forall a b. (a -> b) -> a -> b
$ Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec rep)
outerpat StmAux (ExpDec rep)
StmAux (ExpDec SOACS)
aux (Exp rep -> Stm rep) -> Exp rep -> Stm rep
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ VName -> NewShape SubExp -> BasicOp
Reshape VName
arr NewShape SubExp
reshape'
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (BasicOp (Update Safety
_ VName
arr Slice SubExp
slice (Var VName
v)))) DistAcc rep
acc
  | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([SubExp] -> Bool) -> [SubExp] -> Bool
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice =
      DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm DistNestT
  rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
-> (Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
    -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep)
forall a b.
DistNestT rep m a -> (a -> DistNestT rep m b) -> DistNestT rep m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Just (PostStms rep
kernels, Result
res, KernelNest
nest, DistAcc rep
acc')
          | (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
res [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat Type -> [VName]) -> Pat Type -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Pat (LetDec SOACS)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm SOACS
stm),
            Just ([Int]
perm, [PatElem Type]
pat_unused) <- Pat Type -> Result -> Maybe ([Int], [PatElem Type])
permutationAndMissing Pat Type
Pat (LetDec SOACS)
pat Result
res -> do
              PostStms rep -> DistNestT rep m ()
forall (m :: * -> *) rep.
Monad m =>
PostStms rep -> DistNestT rep m ()
addPostStms PostStms rep
kernels
              Scope rep
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall a. Scope rep -> DistNestT rep m a -> DistNestT rep m a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (DistAcc rep -> Scope rep
forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') (DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall a b. (a -> b) -> a -> b
$ do
                nest' <- [PatElem Type] -> KernelNest -> DistNestT rep m KernelNest
forall (m :: * -> *).
MonadFreshNames m =>
[PatElem Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElem Type]
pat_unused KernelNest
nest
                postStm
                  =<< segmentedUpdateKernel nest' perm (stmAuxCerts aux) arr slice v
                pure acc'
        Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
_ -> Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
maybeDistributeStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
_ StmAux (ExpDec SOACS)
aux (BasicOp (Concat Int
d (VName
x :| [VName]
xs) SubExp
w))) DistAcc rep
acc =
  DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm DistNestT
  rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
-> (Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
    -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep)
forall a b.
DistNestT rep m a -> (a -> DistNestT rep m b) -> DistNestT rep m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just (PostStms rep
kernels, Result
_, KernelNest
nest, DistAcc rep
acc') ->
      Scope rep
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall a. Scope rep -> DistNestT rep m a -> DistNestT rep m a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (DistAcc rep -> Scope rep
forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') (DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall a b. (a -> b) -> a -> b
$
        KernelNest -> DistNestT rep m (Maybe (Stms rep))
segmentedConcat KernelNest
nest
          DistNestT rep m (Maybe (Stms rep))
-> (Maybe (Stms rep) -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep)
forall a b.
DistNestT rep m a -> (a -> DistNestT rep m b) -> DistNestT rep m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Certs
-> Stm SOACS
-> DistAcc rep
-> PostStms rep
-> DistAcc rep
-> Maybe (Stms rep)
-> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Certs
-> Stm SOACS
-> DistAcc rep
-> PostStms rep
-> DistAcc rep
-> Maybe (Stms rep)
-> DistNestT rep m (DistAcc rep)
kernelOrNot (StmAux () -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) Stm SOACS
stm DistAcc rep
acc PostStms rep
kernels DistAcc rep
acc'
    Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
_ ->
      Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
  where
    segmentedConcat :: KernelNest -> DistNestT rep m (Maybe (Stms rep))
segmentedConcat KernelNest
nest =
      KernelNest
-> [Int]
-> Names
-> Names
-> [SubExp]
-> [VName]
-> (Pat Type
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> [SubExp]
    -> [VName]
    -> BuilderT rep m ())
-> DistNestT rep m (Maybe (Stms rep))
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> [Int]
-> Names
-> Names
-> [SubExp]
-> [VName]
-> (Pat Type
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> [SubExp]
    -> [VName]
    -> BuilderT rep m ())
-> DistNestT rep m (Maybe (Stms rep))
isSegmentedOp KernelNest
nest [Int
0] Names
forall a. Monoid a => a
mempty Names
forall a. Monoid a => a
mempty [] (VName
x VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
xs) ((Pat Type
  -> [(VName, SubExp)]
  -> [KernelInput]
  -> [SubExp]
  -> [VName]
  -> BuilderT rep m ())
 -> DistNestT rep m (Maybe (Stms rep)))
-> (Pat Type
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> [SubExp]
    -> [VName]
    -> BuilderT rep m ())
-> DistNestT rep m (Maybe (Stms rep))
forall a b. (a -> b) -> a -> b
$
        \Pat Type
pat [(VName, SubExp)]
_ [KernelInput]
_ [SubExp]
_ (VName
x' : [VName]
xs') ->
          let d' :: Int
d' = Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [LoopNesting] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (KernelNest -> [LoopNesting]
forall a b. (a, b) -> b
snd KernelNest
nest) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
           in Stm (Rep (BuilderT rep m)) -> BuilderT rep m ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep (BuilderT rep m)) -> BuilderT rep m ())
-> Stm (Rep (BuilderT rep m)) -> BuilderT rep m ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (BuilderT rep m)))
-> StmAux (ExpDec (Rep (BuilderT rep m)))
-> Exp (Rep (BuilderT rep m))
-> Stm (Rep (BuilderT rep m))
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec (Rep (BuilderT rep m)))
pat StmAux (ExpDec (Rep (BuilderT rep m)))
StmAux (ExpDec SOACS)
aux (Exp (Rep (BuilderT rep m)) -> Stm (Rep (BuilderT rep m)))
-> Exp (Rep (BuilderT rep m)) -> Stm (Rep (BuilderT rep m))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (BuilderT rep m))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT rep m)))
-> BasicOp -> Exp (Rep (BuilderT rep m))
forall a b. (a -> b) -> a -> b
$ Int -> NonEmpty VName -> SubExp -> BasicOp
Concat Int
d' (VName
x' VName -> [VName] -> NonEmpty VName
forall a. a -> [a] -> NonEmpty a
:| [VName]
xs') SubExp
w
maybeDistributeStm Stm SOACS
stm DistAcc rep
acc =
  Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc

distributeSingleUnaryStm ::
  (MonadFreshNames m, LocalScope rep m, DistRep rep) =>
  DistAcc rep ->
  Stm SOACS ->
  VName ->
  (KernelNest -> Pat Type -> VName -> DistNestT rep m (Stms rep)) ->
  DistNestT rep m (DistAcc rep)
distributeSingleUnaryStm :: forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> VName
-> (KernelNest -> Pat Type -> VName -> DistNestT rep m (Stms rep))
-> DistNestT rep m (DistAcc rep)
distributeSingleUnaryStm DistAcc rep
acc Stm SOACS
stm VName
stm_arr KernelNest -> Pat Type -> VName -> DistNestT rep m (Stms rep)
f =
  DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm DistNestT
  rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
-> (Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
    -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep)
forall a b.
DistNestT rep m a -> (a -> DistNestT rep m b) -> DistNestT rep m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just (PostStms rep
kernels, Result
res, KernelNest
nest, DistAcc rep
acc')
      | (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
res [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat Type -> [VName]) -> Pat Type -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Pat (LetDec SOACS)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm SOACS
stm),
        (LoopNesting
outer, [LoopNesting]
_) <- KernelNest
nest,
        [(Param Type
_, VName
arr)] <- LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs LoopNesting
outer,
        KernelNest -> Names
boundInKernelNest KernelNest
nest Names -> Names -> Names
`namesIntersection` Stm SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn Stm SOACS
stm
          Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
== VName -> Names
oneName VName
stm_arr,
        VName -> KernelNest -> Bool
perfectlyMapped VName
arr KernelNest
nest -> do
          PostStms rep -> DistNestT rep m ()
forall (m :: * -> *) rep.
Monad m =>
PostStms rep -> DistNestT rep m ()
addPostStms PostStms rep
kernels
          let outerpat :: Pat Type
outerpat = LoopNesting -> Pat Type
loopNestingPat (LoopNesting -> Pat Type) -> LoopNesting -> Pat Type
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
nest
          Scope rep
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall a. Scope rep -> DistNestT rep m a -> DistNestT rep m a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (DistAcc rep -> Scope rep
forall rep. DistRep rep => DistAcc rep -> Scope rep
typeEnvFromDistAcc DistAcc rep
acc') (DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall a b. (a -> b) -> a -> b
$ do
            Stms rep -> DistNestT rep m ()
forall (m :: * -> *) rep. Monad m => Stms rep -> DistNestT rep m ()
postStm (Stms rep -> DistNestT rep m ())
-> DistNestT rep m (Stms rep) -> DistNestT rep m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< KernelNest -> Pat Type -> VName -> DistNestT rep m (Stms rep)
f KernelNest
nest Pat Type
outerpat VName
arr
            DistAcc rep -> DistNestT rep m (DistAcc rep)
forall a. a -> DistNestT rep m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure DistAcc rep
acc'
    Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
_ -> Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc Stm SOACS
stm DistAcc rep
acc
  where
    perfectlyMapped :: VName -> KernelNest -> Bool
perfectlyMapped VName
arr (LoopNesting
outer, [LoopNesting]
nest)
      | [(Param Type
p, VName
arr')] <- LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs LoopNesting
outer,
        VName
arr VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
arr' =
          case [LoopNesting]
nest of
            [] -> Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
stm_arr
            LoopNesting
x : [LoopNesting]
xs -> VName -> KernelNest -> Bool
perfectlyMapped (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p) (LoopNesting
x, [LoopNesting]
xs)
      | Bool
otherwise =
          Bool
False

distribute ::
  (MonadFreshNames m, LocalScope rep m, DistRep rep) =>
  DistAcc rep ->
  DistNestT rep m (DistAcc rep)
distribute :: forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> DistNestT rep m (DistAcc rep)
distribute DistAcc rep
acc =
  DistAcc rep -> Maybe (DistAcc rep) -> DistAcc rep
forall a. a -> Maybe a -> a
fromMaybe DistAcc rep
acc (Maybe (DistAcc rep) -> DistAcc rep)
-> DistNestT rep m (Maybe (DistAcc rep))
-> DistNestT rep m (DistAcc rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DistAcc rep -> DistNestT rep m (Maybe (DistAcc rep))
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> DistNestT rep m (Maybe (DistAcc rep))
distributeIfPossible DistAcc rep
acc

mkSegLevel ::
  (MonadFreshNames m, LocalScope rep m, DistRep rep) =>
  DistNestT rep m (MkSegLevel rep (DistNestT rep m))
mkSegLevel :: forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistNestT rep m (MkSegLevel rep (DistNestT rep m))
mkSegLevel = do
  mk_lvl <- (DistEnv rep m
 -> [SubExp]
 -> [Char]
 -> ThreadRecommendation
 -> BuilderT rep m (SegOpLevel rep))
-> DistNestT
     rep
     m
     ([SubExp]
      -> [Char]
      -> ThreadRecommendation
      -> BuilderT rep m (SegOpLevel rep))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv rep m
-> [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep m (SegOpLevel rep)
forall rep (m :: * -> *). DistEnv rep m -> MkSegLevel rep m
distSegLevel
  pure $ \[SubExp]
w [Char]
desc ThreadRecommendation
r -> do
    (lvl, stms) <- DistNestT rep m (SegOpLevel rep, Stms rep)
-> BuilderT rep (DistNestT rep m) (SegOpLevel rep, Stms rep)
forall (m :: * -> *) a. Monad m => m a -> BuilderT rep m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (DistNestT rep m (SegOpLevel rep, Stms rep)
 -> BuilderT rep (DistNestT rep m) (SegOpLevel rep, Stms rep))
-> DistNestT rep m (SegOpLevel rep, Stms rep)
-> BuilderT rep (DistNestT rep m) (SegOpLevel rep, Stms rep)
forall a b. (a -> b) -> a -> b
$ m (SegOpLevel rep, Stms rep)
-> DistNestT rep m (SegOpLevel rep, Stms rep)
forall rep (m :: * -> *) a.
(LocalScope rep m, DistRep rep) =>
m a -> DistNestT rep m a
liftInner (m (SegOpLevel rep, Stms rep)
 -> DistNestT rep m (SegOpLevel rep, Stms rep))
-> m (SegOpLevel rep, Stms rep)
-> DistNestT rep m (SegOpLevel rep, Stms rep)
forall a b. (a -> b) -> a -> b
$ BuilderT rep m (SegOpLevel rep) -> m (SegOpLevel rep, Stms rep)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
BuilderT rep m a -> m (a, Stms rep)
runBuilderT' (BuilderT rep m (SegOpLevel rep) -> m (SegOpLevel rep, Stms rep))
-> BuilderT rep m (SegOpLevel rep) -> m (SegOpLevel rep, Stms rep)
forall a b. (a -> b) -> a -> b
$ [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep m (SegOpLevel rep)
mk_lvl [SubExp]
w [Char]
desc ThreadRecommendation
r
    addStms stms
    pure lvl

distributeIfPossible ::
  (MonadFreshNames m, LocalScope rep m, DistRep rep) =>
  DistAcc rep ->
  DistNestT rep m (Maybe (DistAcc rep))
distributeIfPossible :: forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> DistNestT rep m (Maybe (DistAcc rep))
distributeIfPossible DistAcc rep
acc = do
  nest <- (DistEnv rep m -> Nestings) -> DistNestT rep m Nestings
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv rep m -> Nestings
forall rep (m :: * -> *). DistEnv rep m -> Nestings
distNest
  mk_lvl <- mkSegLevel
  tryDistribute mk_lvl nest (distTargets acc) (distStms acc) >>= \case
    Maybe (Targets, Stms rep)
Nothing -> Maybe (DistAcc rep) -> DistNestT rep m (Maybe (DistAcc rep))
forall a. a -> DistNestT rep m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (DistAcc rep)
forall a. Maybe a
Nothing
    Just (Targets
targets, Stms rep
kernel) -> do
      Stms rep -> DistNestT rep m ()
forall (m :: * -> *) rep. Monad m => Stms rep -> DistNestT rep m ()
postStm Stms rep
kernel
      Maybe (DistAcc rep) -> DistNestT rep m (Maybe (DistAcc rep))
forall a. a -> DistNestT rep m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (DistAcc rep) -> DistNestT rep m (Maybe (DistAcc rep)))
-> Maybe (DistAcc rep) -> DistNestT rep m (Maybe (DistAcc rep))
forall a b. (a -> b) -> a -> b
$
        DistAcc rep -> Maybe (DistAcc rep)
forall a. a -> Maybe a
Just
          DistAcc
            { distTargets :: Targets
distTargets = Targets
targets,
              distStms :: Stms rep
distStms = Stms rep
forall a. Monoid a => a
mempty
            }

distributeSingleStm ::
  (MonadFreshNames m, LocalScope rep m, DistRep rep) =>
  DistAcc rep ->
  Stm SOACS ->
  DistNestT
    rep
    m
    ( Maybe
        ( PostStms rep,
          Result,
          KernelNest,
          DistAcc rep
        )
    )
distributeSingleStm :: forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc rep
acc Stm SOACS
stm = do
  nest <- (DistEnv rep m -> Nestings) -> DistNestT rep m Nestings
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv rep m -> Nestings
forall rep (m :: * -> *). DistEnv rep m -> Nestings
distNest
  mk_lvl <- mkSegLevel
  tryDistribute mk_lvl nest (distTargets acc) (distStms acc) >>= \case
    Maybe (Targets, Stms rep)
Nothing -> Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
forall a. a -> DistNestT rep m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
forall a. Maybe a
Nothing
    Just (Targets
targets, Stms rep
distributed_stms) ->
      Nestings
-> Targets
-> Stm SOACS
-> DistNestT rep m (Maybe (Result, Targets, KernelNest))
forall (m :: * -> *) t rep.
(MonadFreshNames m, HasScope t m, ASTRep rep) =>
Nestings
-> Targets -> Stm rep -> m (Maybe (Result, Targets, KernelNest))
tryDistributeStm Nestings
nest Targets
targets Stm SOACS
stm DistNestT rep m (Maybe (Result, Targets, KernelNest))
-> (Maybe (Result, Targets, KernelNest)
    -> DistNestT
         rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep)))
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
forall a b.
DistNestT rep m a -> (a -> DistNestT rep m b) -> DistNestT rep m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Maybe (Result, Targets, KernelNest)
Nothing -> Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
forall a. a -> DistNestT rep m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
forall a. Maybe a
Nothing
        Just (Result
res, Targets
targets', KernelNest
new_kernel_nest) ->
          Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
forall a. a -> DistNestT rep m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
 -> DistNestT
      rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep)))
-> Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
-> DistNestT
     rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
forall a b. (a -> b) -> a -> b
$
            (PostStms rep, Result, KernelNest, DistAcc rep)
-> Maybe (PostStms rep, Result, KernelNest, DistAcc rep)
forall a. a -> Maybe a
Just
              ( Stms rep -> PostStms rep
forall rep. Stms rep -> PostStms rep
PostStms Stms rep
distributed_stms,
                Result
res,
                KernelNest
new_kernel_nest,
                DistAcc
                  { distTargets :: Targets
distTargets = Targets
targets',
                    distStms :: Stms rep
distStms = Stms rep
forall a. Monoid a => a
mempty
                  }
              )

segmentedScatterKernel ::
  (MonadFreshNames m, LocalScope rep m, DistRep rep) =>
  KernelNest ->
  [Int] ->
  Pat Type ->
  Certs ->
  SubExp ->
  Lambda rep ->
  [VName] ->
  [(Shape, Int, VName)] ->
  DistNestT rep m (Stms rep)
segmentedScatterKernel :: forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> [Int]
-> Pat Type
-> Certs
-> SubExp
-> Lambda rep
-> [VName]
-> ScatterSpec VName
-> DistNestT rep m (Stms rep)
segmentedScatterKernel KernelNest
nest [Int]
perm Pat Type
scatter_pat Certs
cs SubExp
scatter_w Lambda rep
lam [VName]
ivs ScatterSpec VName
dests = do
  -- We replicate some of the checking done by 'isSegmentedOp', but
  -- things are different because a scatter is not a reduction or
  -- scan.
  --
  -- First, pretend that the scatter is also part of the nesting.  The
  -- KernelNest we produce here is technically not sensible, but it's
  -- good enough for flatKernel to work.
  let nesting :: LoopNesting
nesting =
        Pat Type
-> StmAux () -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting Pat Type
scatter_pat (Certs -> Attrs -> Provenance -> () -> StmAux ()
forall dec. Certs -> Attrs -> Provenance -> dec -> StmAux dec
StmAux Certs
cs Attrs
forall a. Monoid a => a
mempty Provenance
forall a. Monoid a => a
mempty ()) SubExp
scatter_w ([(Param Type, VName)] -> LoopNesting)
-> [(Param Type, VName)] -> LoopNesting
forall a b. (a -> b) -> a -> b
$ [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda rep -> [LParam rep]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam) [VName]
ivs
      nest' :: KernelNest
nest' =
        (Pat Type, Result) -> LoopNesting -> KernelNest -> KernelNest
pushInnerKernelNesting (Pat Type
scatter_pat, Body rep -> Result
forall rep. Body rep -> Result
bodyResult (Body rep -> Result) -> Body rep -> Result
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam) LoopNesting
nesting KernelNest
nest
  (ispace, kernel_inps) <- KernelNest -> DistNestT rep m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
nest'

  let (as_ws, as_ns, as) = unzip3 dests
      indexes = (Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) [Int]
as_ns ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ (Shape -> Int) -> [Shape] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Shape -> Int
forall a. ShapeBase a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Shape]
as_ws

  -- The input/output arrays ('as') _must_ correspond to some kernel
  -- input, or else the original nested scatter would have been
  -- ill-typed.  Find them.
  as_inps <- mapM (findInput kernel_inps) as

  mk_lvl <- mkSegLevel

  let (is, vs) = splitAt (sum indexes) $ bodyResult $ lambdaBody lam
  (is', k_body_stms) <- runBuilder $ do
    addStms $ bodyStms $ lambdaBody lam
    pure is

  let grouped = ScatterSpec KernelInput
-> Result -> [(Shape, KernelInput, [(Result, SubExpRes)])]
forall array a.
ScatterSpec array -> [a] -> [(Shape, array, [([a], a)])]
groupScatterResults ([Shape] -> [Int] -> [KernelInput] -> ScatterSpec KernelInput
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Shape]
as_ws [Int]
as_ns [KernelInput]
as_inps) (Result
is' Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
vs)
      (_, dest_arrs, _) = unzip3 grouped

  dest_arrs_ts <- mapM (lookupType . kernelInputArray) dest_arrs

  let k_body =
        BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms rep
k_body_stms ([KernelResult] -> KernelBody rep)
-> [KernelResult] -> KernelBody rep
forall a b. (a -> b) -> a -> b
$
          (Type
 -> (Shape, KernelInput, [(Result, SubExpRes)]) -> KernelResult)
-> [Type]
-> [(Shape, KernelInput, [(Result, SubExpRes)])]
-> [KernelResult]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith ([(VName, SubExp)]
-> Type
-> (Shape, KernelInput, [(Result, SubExpRes)])
-> KernelResult
forall {b} {a}.
[(VName, b)]
-> Type -> (a, KernelInput, [(Result, SubExpRes)]) -> KernelResult
inPlaceReturn [(VName, SubExp)]
ispace) [Type]
dest_arrs_ts [(Shape, KernelInput, [(Result, SubExpRes)])]
grouped
      -- Remove unused kernel inputs, since some of these might
      -- reference the array we are scattering into.
      kernel_inps' =
        (KernelInput -> Bool) -> [KernelInput] -> [KernelInput]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> Names -> Bool
`nameIn` KernelBody rep -> Names
forall a. FreeIn a => a -> Names
freeIn KernelBody rep
k_body) (VName -> Bool) -> (KernelInput -> VName) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputName) [KernelInput]
kernel_inps

  (k, k_stms) <- mapKernel mk_lvl ispace kernel_inps' dest_arrs_ts k_body

  traverse renameStm <=< runBuilder_ $ do
    addStms k_stms

    let pat = [PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem Type] -> Pat Type)
-> ([PatElem Type] -> [PatElem Type]) -> [PatElem Type] -> Pat Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [PatElem Type] -> [PatElem Type]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([PatElem Type] -> Pat Type) -> [PatElem Type] -> Pat Type
forall a b. (a -> b) -> a -> b
$ Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems (Pat Type -> [PatElem Type]) -> Pat Type -> [PatElem Type]
forall a b. (a -> b) -> a -> b
$ LoopNesting -> Pat Type
loopNestingPat (LoopNesting -> Pat Type) -> LoopNesting -> Pat Type
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
nest
    letBind pat $ Op $ segOp k
  where
    findInput :: t KernelInput -> VName -> f KernelInput
findInput t KernelInput
kernel_inps VName
a =
      f KernelInput
-> (KernelInput -> f KernelInput)
-> Maybe KernelInput
-> f KernelInput
forall b a. b -> (a -> b) -> Maybe a -> b
maybe f KernelInput
forall {a}. a
bad KernelInput -> f KernelInput
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe KernelInput -> f KernelInput)
-> Maybe KernelInput -> f KernelInput
forall a b. (a -> b) -> a -> b
$ (KernelInput -> Bool) -> t KernelInput -> Maybe KernelInput
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
a) (VName -> Bool) -> (KernelInput -> VName) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputName) t KernelInput
kernel_inps
    bad :: a
bad = [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"Ill-typed nested scatter encountered."

    inPlaceReturn :: [(VName, b)]
-> Type -> (a, KernelInput, [(Result, SubExpRes)]) -> KernelResult
inPlaceReturn [(VName, b)]
ispace Type
arr_t (a
_, KernelInput
inp, [(Result, SubExpRes)]
is_vs) =
      Certs -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns Certs
write_cs (KernelInput -> VName
kernelInputArray KernelInput
inp) ([(Slice SubExp, SubExp)] -> KernelResult)
-> [(Slice SubExp, SubExp)] -> KernelResult
forall a b. (a -> b) -> a -> b
$ do
        (is, v) <- [(Result, SubExpRes)]
is_vs
        pure
          ( fullSlice arr_t . map DimFix $
              map Var (init gtids) ++ map resSubExp is,
            resSubExp v
          )
      where
        write_cs :: Certs
write_cs =
          ((Result, SubExpRes) -> Certs) -> [(Result, SubExpRes)] -> Certs
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap ((SubExpRes -> Certs) -> Result -> Certs
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap SubExpRes -> Certs
resCerts (Result -> Certs)
-> ((Result, SubExpRes) -> Result) -> (Result, SubExpRes) -> Certs
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Result, SubExpRes) -> Result
forall a b. (a, b) -> a
fst) [(Result, SubExpRes)]
is_vs
            Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> ((Result, SubExpRes) -> Certs) -> [(Result, SubExpRes)] -> Certs
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (SubExpRes -> Certs
resCerts (SubExpRes -> Certs)
-> ((Result, SubExpRes) -> SubExpRes)
-> (Result, SubExpRes)
-> Certs
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Result, SubExpRes) -> SubExpRes
forall a b. (a, b) -> b
snd) [(Result, SubExpRes)]
is_vs
        ([VName]
gtids, [b]
_ws) = [(VName, b)] -> ([VName], [b])
forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, b)]
ispace

segmentedUpdateKernel ::
  (MonadFreshNames m, LocalScope rep m, DistRep rep) =>
  KernelNest ->
  [Int] ->
  Certs ->
  VName ->
  Slice SubExp ->
  VName ->
  DistNestT rep m (Stms rep)
segmentedUpdateKernel :: forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> [Int]
-> Certs
-> VName
-> Slice SubExp
-> VName
-> DistNestT rep m (Stms rep)
segmentedUpdateKernel KernelNest
nest [Int]
perm Certs
cs VName
arr Slice SubExp
slice VName
v = do
  (base_ispace, kernel_inps) <- KernelNest -> DistNestT rep m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
nest
  let slice_dims = Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice
  slice_gtids <- replicateM (length slice_dims) (newVName "gtid_slice")

  let ispace = [(VName, SubExp)]
base_ispace [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
slice_gtids [SubExp]
slice_dims

  ((dest_t, res), kstms) <- runBuilder $ do
    -- Compute indexes into full array.
    v' <-
      certifying cs . letSubExp "v" . BasicOp . Index v $
        Slice (map (DimFix . Var) slice_gtids)
    slice_is <-
      traverse (toSubExp "index") $
        fixSlice (fmap pe64 slice) $
          map (pe64 . Var) slice_gtids

    let write_is = ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> ((VName, SubExp) -> VName) -> (VName, SubExp) -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst) [(VName, SubExp)]
base_ispace [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
slice_is
        arr' =
          VName -> (KernelInput -> VName) -> Maybe KernelInput -> VName
forall b a. b -> (a -> b) -> Maybe a -> b
maybe ([Char] -> VName
forall a. HasCallStack => [Char] -> a
error [Char]
"incorrectly typed Update") KernelInput -> VName
kernelInputArray (Maybe KernelInput -> VName) -> Maybe KernelInput -> VName
forall a b. (a -> b) -> a -> b
$
            (KernelInput -> Bool) -> [KernelInput] -> Maybe KernelInput
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
arr) (VName -> Bool) -> (KernelInput -> VName) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputName) [KernelInput]
kernel_inps
    arr_t <- lookupType arr'
    pure
      ( arr_t,
        WriteReturns mempty arr' [(Slice $ map DimFix write_is, v')]
      )

  -- Remove unused kernel inputs, since some of these might
  -- reference the array we are scattering into.
  let kernel_inps' =
        (KernelInput -> Bool) -> [KernelInput] -> [KernelInput]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> Names -> Bool
`nameIn` (Stms rep -> Names
forall a. FreeIn a => a -> Names
freeIn Stms rep
kstms Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> KernelResult -> Names
forall a. FreeIn a => a -> Names
freeIn KernelResult
res)) (VName -> Bool) -> (KernelInput -> VName) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputName) [KernelInput]
kernel_inps

  mk_lvl <- mkSegLevel
  (k, prestms) <-
    mapKernel mk_lvl ispace kernel_inps' [dest_t] $
      KernelBody () kstms [res]

  traverse renameStm <=< runBuilder_ $ do
    addStms prestms
    let pat = [PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem Type] -> Pat Type)
-> ([PatElem Type] -> [PatElem Type]) -> [PatElem Type] -> Pat Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [PatElem Type] -> [PatElem Type]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([PatElem Type] -> Pat Type) -> [PatElem Type] -> Pat Type
forall a b. (a -> b) -> a -> b
$ Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems (Pat Type -> [PatElem Type]) -> Pat Type -> [PatElem Type]
forall a b. (a -> b) -> a -> b
$ LoopNesting -> Pat Type
loopNestingPat (LoopNesting -> Pat Type) -> LoopNesting -> Pat Type
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
nest
    letBind pat $ Op $ segOp k

segmentedGatherKernel ::
  (MonadFreshNames m, LocalScope rep m, DistRep rep) =>
  KernelNest ->
  Certs ->
  VName ->
  Slice SubExp ->
  DistNestT rep m (Stms rep)
segmentedGatherKernel :: forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> Certs -> VName -> Slice SubExp -> DistNestT rep m (Stms rep)
segmentedGatherKernel KernelNest
nest Certs
cs VName
arr Slice SubExp
slice = do
  let slice_dims :: [SubExp]
slice_dims = Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice
  slice_gtids <- Int -> DistNestT rep m VName -> DistNestT rep m [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
slice_dims) ([Char] -> DistNestT rep m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"gtid_slice")

  (base_ispace, kernel_inps) <- flatKernel nest
  let ispace = [(VName, SubExp)]
base_ispace [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
slice_gtids [SubExp]
slice_dims

  ((res_t, res), kstms) <- runBuilder $ do
    -- Compute indexes into full array.
    slice'' <-
      subExpSlice . sliceSlice (primExpSlice slice) $
        primExpSlice $
          Slice $
            map (DimFix . Var) slice_gtids
    v' <- certifying cs $ letSubExp "v" $ BasicOp $ Index arr slice''
    v_t <- subExpType v'
    pure (v_t, Returns ResultMaySimplify mempty v')

  mk_lvl <- mkSegLevel
  (k, prestms) <-
    mapKernel mk_lvl ispace kernel_inps [res_t] $
      KernelBody () kstms [res]

  traverse renameStm <=< runBuilder_ $ do
    addStms prestms

    let pat = [PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem Type] -> Pat Type) -> [PatElem Type] -> Pat Type
forall a b. (a -> b) -> a -> b
$ Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems (Pat Type -> [PatElem Type]) -> Pat Type -> [PatElem Type]
forall a b. (a -> b) -> a -> b
$ LoopNesting -> Pat Type
loopNestingPat (LoopNesting -> Pat Type) -> LoopNesting -> Pat Type
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
nest

    letBind pat $ Op $ segOp k

segmentedHistKernel ::
  (MonadFreshNames m, LocalScope rep m, DistRep rep) =>
  KernelNest ->
  [Int] ->
  Certs ->
  SubExp ->
  [SOACS.HistOp SOACS] ->
  Lambda rep ->
  [VName] ->
  DistNestT rep m (Stms rep)
segmentedHistKernel :: forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> [Int]
-> Certs
-> SubExp
-> [HistOp SOACS]
-> Lambda rep
-> [VName]
-> DistNestT rep m (Stms rep)
segmentedHistKernel KernelNest
nest [Int]
perm Certs
cs SubExp
hist_w [HistOp SOACS]
ops Lambda rep
lam [VName]
arrs = do
  -- We replicate some of the checking done by 'isSegmentedOp', but
  -- things are different because a Hist is not a reduction or
  -- scan.
  (ispace, inputs) <- KernelNest -> DistNestT rep m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
nest
  let orig_pat =
        [PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem Type] -> Pat Type)
-> ([PatElem Type] -> [PatElem Type]) -> [PatElem Type] -> Pat Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [PatElem Type] -> [PatElem Type]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([PatElem Type] -> Pat Type) -> [PatElem Type] -> Pat Type
forall a b. (a -> b) -> a -> b
$
          Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems (Pat Type -> [PatElem Type]) -> Pat Type -> [PatElem Type]
forall a b. (a -> b) -> a -> b
$
            LoopNesting -> Pat Type
loopNestingPat (LoopNesting -> Pat Type) -> LoopNesting -> Pat Type
forall a b. (a -> b) -> a -> b
$
              KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
nest

  -- The input/output arrays _must_ correspond to some kernel input,
  -- or else the original nested Hist would have been ill-typed.
  -- Find them.
  ops' <- forM ops $ \(SOACS.HistOp Shape
num_bins SubExp
rf [VName]
dests [SubExp]
nes Lambda SOACS
op) ->
    Shape
-> SubExp -> [VName] -> [SubExp] -> Lambda SOACS -> HistOp SOACS
forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
SOACS.HistOp Shape
num_bins SubExp
rf
      ([VName] -> [SubExp] -> Lambda SOACS -> HistOp SOACS)
-> DistNestT rep m [VName]
-> DistNestT rep m ([SubExp] -> Lambda SOACS -> HistOp SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> DistNestT rep m VName)
-> [VName] -> DistNestT rep m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((KernelInput -> VName)
-> DistNestT rep m KernelInput -> DistNestT rep m VName
forall a b. (a -> b) -> DistNestT rep m a -> DistNestT rep m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap KernelInput -> VName
kernelInputArray (DistNestT rep m KernelInput -> DistNestT rep m VName)
-> (VName -> DistNestT rep m KernelInput)
-> VName
-> DistNestT rep m VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [KernelInput] -> VName -> DistNestT rep m KernelInput
forall {f :: * -> *} {t :: * -> *}.
(Applicative f, Foldable t) =>
t KernelInput -> VName -> f KernelInput
findInput [KernelInput]
inputs) [VName]
dests
      DistNestT rep m ([SubExp] -> Lambda SOACS -> HistOp SOACS)
-> DistNestT rep m [SubExp]
-> DistNestT rep m (Lambda SOACS -> HistOp SOACS)
forall a b.
DistNestT rep m (a -> b) -> DistNestT rep m a -> DistNestT rep m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [SubExp] -> DistNestT rep m [SubExp]
forall a. a -> DistNestT rep m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp]
nes
      DistNestT rep m (Lambda SOACS -> HistOp SOACS)
-> DistNestT rep m (Lambda SOACS) -> DistNestT rep m (HistOp SOACS)
forall a b.
DistNestT rep m (a -> b) -> DistNestT rep m a -> DistNestT rep m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Lambda SOACS -> DistNestT rep m (Lambda SOACS)
forall a. a -> DistNestT rep m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda SOACS
op

  mk_lvl <- asks distSegLevel
  onLambda <- asks distOnSOACSLambda
  let onLambda' = ((Lambda rep, Stms rep) -> Lambda rep)
-> BuilderT rep m (Lambda rep, Stms rep)
-> BuilderT rep m (Lambda rep)
forall a b. (a -> b) -> BuilderT rep m a -> BuilderT rep m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Lambda rep, Stms rep) -> Lambda rep
forall a b. (a, b) -> a
fst (BuilderT rep m (Lambda rep, Stms rep)
 -> BuilderT rep m (Lambda rep))
-> (Lambda SOACS -> BuilderT rep m (Lambda rep, Stms rep))
-> Lambda SOACS
-> BuilderT rep m (Lambda rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder rep (Lambda rep) -> BuilderT rep m (Lambda rep, Stms rep)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder rep (Lambda rep) -> BuilderT rep m (Lambda rep, Stms rep))
-> (Lambda SOACS -> Builder rep (Lambda rep))
-> Lambda SOACS
-> BuilderT rep m (Lambda rep, Stms rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda SOACS -> Builder rep (Lambda rep)
onLambda
  liftInner $
    runBuilderT'_ $ do
      -- It is important not to launch unnecessarily many threads for
      -- histograms, because it may mean we unnecessarily need to reduce
      -- subhistograms as well.
      lvl <- mk_lvl (hist_w : map snd ispace) "seghist" $ NoRecommendation SegNoVirt
      addStms
        =<< histKernel onLambda' lvl orig_pat ispace inputs cs hist_w ops' lam arrs
  where
    findInput :: t KernelInput -> VName -> f KernelInput
findInput t KernelInput
kernel_inps VName
a =
      f KernelInput
-> (KernelInput -> f KernelInput)
-> Maybe KernelInput
-> f KernelInput
forall b a. b -> (a -> b) -> Maybe a -> b
maybe f KernelInput
forall {a}. a
bad KernelInput -> f KernelInput
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe KernelInput -> f KernelInput)
-> Maybe KernelInput -> f KernelInput
forall a b. (a -> b) -> a -> b
$ (KernelInput -> Bool) -> t KernelInput -> Maybe KernelInput
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
a) (VName -> Bool) -> (KernelInput -> VName) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputName) t KernelInput
kernel_inps
    bad :: a
bad = [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"Ill-typed nested Hist encountered."

histKernel ::
  (MonadBuilder m, DistRep (Rep m)) =>
  (Lambda SOACS -> m (Lambda (Rep m))) ->
  SegOpLevel (Rep m) ->
  Pat Type ->
  [(VName, SubExp)] ->
  [KernelInput] ->
  Certs ->
  SubExp ->
  [SOACS.HistOp SOACS] ->
  Lambda (Rep m) ->
  [VName] ->
  m (Stms (Rep m))
histKernel :: forall (m :: * -> *).
(MonadBuilder m, DistRep (Rep m)) =>
(Lambda SOACS -> m (Lambda (Rep m)))
-> SegOpLevel (Rep m)
-> Pat Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Certs
-> SubExp
-> [HistOp SOACS]
-> Lambda (Rep m)
-> [VName]
-> m (Stms (Rep m))
histKernel Lambda SOACS -> m (Lambda (Rep m))
onLambda SegOpLevel (Rep m)
lvl Pat Type
orig_pat [(VName, SubExp)]
ispace [KernelInput]
inputs Certs
cs SubExp
hist_w [HistOp SOACS]
ops Lambda (Rep m)
lam [VName]
arrs = BuilderT (Rep m) m () -> m (Stms (Rep m))
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
BuilderT rep m a -> m (Stms rep)
runBuilderT'_ (BuilderT (Rep m) m () -> m (Stms (Rep m)))
-> BuilderT (Rep m) m () -> m (Stms (Rep m))
forall a b. (a -> b) -> a -> b
$ do
  ops' <- [HistOp SOACS]
-> (HistOp SOACS -> BuilderT (Rep m) m (HistOp (Rep m)))
-> BuilderT (Rep m) m [HistOp (Rep m)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp SOACS]
ops ((HistOp SOACS -> BuilderT (Rep m) m (HistOp (Rep m)))
 -> BuilderT (Rep m) m [HistOp (Rep m)])
-> (HistOp SOACS -> BuilderT (Rep m) m (HistOp (Rep m)))
-> BuilderT (Rep m) m [HistOp (Rep m)]
forall a b. (a -> b) -> a -> b
$ \(SOACS.HistOp Shape
dest_shape SubExp
rf [VName]
dests [SubExp]
nes Lambda SOACS
op) -> do
    (op', nes', shape) <- Lambda SOACS
-> [SubExp] -> BuilderT (Rep m) m (Lambda SOACS, [SubExp], Shape)
forall (m :: * -> *).
MonadBuilder m =>
Lambda SOACS -> [SubExp] -> m (Lambda SOACS, [SubExp], Shape)
determineReduceOp Lambda SOACS
op [SubExp]
nes
    op'' <- lift $ onLambda op'
    pure $ HistOp dest_shape rf dests nes' shape op''

  let isDest = (VName -> [VName] -> Bool) -> [VName] -> VName -> Bool
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> [VName] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem ([VName] -> VName -> Bool) -> [VName] -> VName -> Bool
forall a b. (a -> b) -> a -> b
$ (HistOp (Rep m) -> [VName]) -> [HistOp (Rep m)] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap HistOp (Rep m) -> [VName]
forall rep. HistOp rep -> [VName]
histDest [HistOp (Rep m)]
ops'
      inputs' = (KernelInput -> Bool) -> [KernelInput] -> [KernelInput]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (KernelInput -> Bool) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Bool
isDest (VName -> Bool) -> (KernelInput -> VName) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputArray) [KernelInput]
inputs

  certifying cs $
    addStms
      =<< traverse renameStm
      =<< segHist lvl orig_pat hist_w ispace inputs' ops' lam arrs

determineReduceOp ::
  (MonadBuilder m) =>
  Lambda SOACS ->
  [SubExp] ->
  m (Lambda SOACS, [SubExp], Shape)
determineReduceOp :: forall (m :: * -> *).
MonadBuilder m =>
Lambda SOACS -> [SubExp] -> m (Lambda SOACS, [SubExp], Shape)
determineReduceOp Lambda SOACS
lam [SubExp]
nes =
  -- FIXME? We are assuming that the accumulator is a replicate, and
  -- we fish out its value in a gross way.
  case (SubExp -> Maybe VName) -> [SubExp] -> Maybe [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> Maybe VName
subExpVar [SubExp]
nes of
    Just [VName]
ne_vs' -> do
      let (Shape
shape, Lambda SOACS
lam') = Lambda SOACS -> (Shape, Lambda SOACS)
isVectorMap Lambda SOACS
lam
      nes' <- [VName] -> (VName -> m SubExp) -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
ne_vs' ((VName -> m SubExp) -> m [SubExp])
-> (VName -> m SubExp) -> m [SubExp]
forall a b. (a -> b) -> a -> b
$ \VName
ne_v -> do
        ne_v_t <- VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
ne_v
        letSubExp "hist_ne" $
          BasicOp $
            Index ne_v $
              fullSlice ne_v_t $
                replicate (shapeRank shape) $
                  DimFix $
                    intConst Int64 0
      pure (lam', nes', shape)
    Maybe [VName]
Nothing ->
      (Lambda SOACS, [SubExp], Shape)
-> m (Lambda SOACS, [SubExp], Shape)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda SOACS
lam, [SubExp]
nes, Shape
forall a. Monoid a => a
mempty)

isVectorMap :: Lambda SOACS -> (Shape, Lambda SOACS)
isVectorMap :: Lambda SOACS -> (Shape, Lambda SOACS)
isVectorMap Lambda SOACS
lam
  | [Let (Pat [PatElem (LetDec SOACS)]
pes) StmAux (ExpDec SOACS)
_ (Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form))] <-
      Stms SOACS -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms SOACS -> [Stm SOACS]) -> Stms SOACS -> [Stm SOACS]
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Stms SOACS
forall rep. Body rep -> Stms rep
bodyStms (Body SOACS -> Stms SOACS) -> Body SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam,
    (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult (Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam)) [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== (PatElem Type -> SubExp) -> [PatElem Type] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (PatElem Type -> VName) -> PatElem Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName) [PatElem Type]
[PatElem (LetDec SOACS)]
pes,
    Just Lambda SOACS
map_lam <- ScremaForm SOACS -> Maybe (Lambda SOACS)
forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm SOACS
form,
    [VName]
arrs [VName] -> [VName] -> Bool
forall a. Eq a => a -> a -> Bool
== (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam) =
      let (Shape
shape, Lambda SOACS
lam') = Lambda SOACS -> (Shape, Lambda SOACS)
isVectorMap Lambda SOACS
map_lam
       in ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w] Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape
shape, Lambda SOACS
lam')
  | Bool
otherwise = (Shape
forall a. Monoid a => a
mempty, Lambda SOACS
lam)

segmentedScanomapKernel ::
  (MonadFreshNames m, LocalScope rep m, DistRep rep) =>
  KernelNest ->
  [Int] ->
  Certs ->
  SubExp ->
  Lambda SOACS ->
  Lambda rep ->
  [SubExp] ->
  [VName] ->
  DistNestT rep m (Maybe (Stms rep))
segmentedScanomapKernel :: forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> [Int]
-> Certs
-> SubExp
-> Lambda SOACS
-> Lambda rep
-> [SubExp]
-> [VName]
-> DistNestT rep m (Maybe (Stms rep))
segmentedScanomapKernel KernelNest
nest [Int]
perm Certs
cs SubExp
segment_size Lambda SOACS
lam Lambda rep
map_lam [SubExp]
nes [VName]
arrs = do
  mk_lvl <- (DistEnv rep m
 -> [SubExp]
 -> [Char]
 -> ThreadRecommendation
 -> BuilderT rep m (SegOpLevel rep))
-> DistNestT
     rep
     m
     ([SubExp]
      -> [Char]
      -> ThreadRecommendation
      -> BuilderT rep m (SegOpLevel rep))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv rep m
-> [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep m (SegOpLevel rep)
forall rep (m :: * -> *). DistEnv rep m -> MkSegLevel rep m
distSegLevel
  onLambda <- asks distOnSOACSLambda
  let onLambda' = ((Lambda rep, Stms rep) -> Lambda rep)
-> BuilderT rep m (Lambda rep, Stms rep)
-> BuilderT rep m (Lambda rep)
forall a b. (a -> b) -> BuilderT rep m a -> BuilderT rep m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Lambda rep, Stms rep) -> Lambda rep
forall a b. (a, b) -> a
fst (BuilderT rep m (Lambda rep, Stms rep)
 -> BuilderT rep m (Lambda rep))
-> (Lambda SOACS -> BuilderT rep m (Lambda rep, Stms rep))
-> Lambda SOACS
-> BuilderT rep m (Lambda rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder rep (Lambda rep) -> BuilderT rep m (Lambda rep, Stms rep)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder rep (Lambda rep) -> BuilderT rep m (Lambda rep, Stms rep))
-> (Lambda SOACS -> Builder rep (Lambda rep))
-> Lambda SOACS
-> BuilderT rep m (Lambda rep, Stms rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda SOACS -> Builder rep (Lambda rep)
onLambda
  isSegmentedOp nest perm (freeIn lam) (freeIn map_lam) nes [] $
    \Pat Type
pat [(VName, SubExp)]
ispace [KernelInput]
inps [SubExp]
nes' [VName]
_ -> do
      (lam', nes'', shape) <- Lambda SOACS
-> [SubExp] -> BuilderT rep m (Lambda SOACS, [SubExp], Shape)
forall (m :: * -> *).
MonadBuilder m =>
Lambda SOACS -> [SubExp] -> m (Lambda SOACS, [SubExp], Shape)
determineReduceOp Lambda SOACS
lam [SubExp]
nes'
      lam'' <- onLambda' lam'
      let scan_op = Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
forall rep.
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp Commutativity
Noncommutative Lambda rep
lam'' [SubExp]
nes'' Shape
shape
      lvl <- mk_lvl (segment_size : map snd ispace) "segscan" $ NoRecommendation SegNoVirt
      addStms
        =<< traverse renameStm
        =<< segScan lvl pat cs segment_size [scan_op] map_lam arrs ispace inps

regularSegmentedRedomapKernel ::
  (MonadFreshNames m, LocalScope rep m, DistRep rep) =>
  KernelNest ->
  [Int] ->
  Certs ->
  SubExp ->
  Commutativity ->
  Lambda rep ->
  Lambda rep ->
  [SubExp] ->
  [VName] ->
  DistNestT rep m (Maybe (Stms rep))
regularSegmentedRedomapKernel :: forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> [Int]
-> Certs
-> SubExp
-> Commutativity
-> Lambda rep
-> Lambda rep
-> [SubExp]
-> [VName]
-> DistNestT rep m (Maybe (Stms rep))
regularSegmentedRedomapKernel KernelNest
nest [Int]
perm Certs
cs SubExp
segment_size Commutativity
comm Lambda rep
lam Lambda rep
map_lam [SubExp]
nes [VName]
arrs = do
  mk_lvl <- (DistEnv rep m
 -> [SubExp]
 -> [Char]
 -> ThreadRecommendation
 -> BuilderT rep m (SegOpLevel rep))
-> DistNestT
     rep
     m
     ([SubExp]
      -> [Char]
      -> ThreadRecommendation
      -> BuilderT rep m (SegOpLevel rep))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks DistEnv rep m
-> [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep m (SegOpLevel rep)
forall rep (m :: * -> *). DistEnv rep m -> MkSegLevel rep m
distSegLevel
  isSegmentedOp nest perm (freeIn lam) (freeIn map_lam) nes [] $
    \Pat Type
pat [(VName, SubExp)]
ispace [KernelInput]
inps [SubExp]
nes' [VName]
_ -> do
      let red_op :: SegBinOp rep
red_op = Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
forall rep.
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp Commutativity
comm Lambda rep
lam [SubExp]
nes' Shape
forall a. Monoid a => a
mempty
      lvl <- [SubExp]
-> [Char]
-> ThreadRecommendation
-> BuilderT rep m (SegOpLevel rep)
mk_lvl (SubExp
segment_size SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace) [Char]
"segred" (ThreadRecommendation -> BuilderT rep m (SegOpLevel rep))
-> ThreadRecommendation -> BuilderT rep m (SegOpLevel rep)
forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
      addStms
        =<< traverse renameStm
        =<< segRed lvl pat cs segment_size [red_op] map_lam arrs ispace inps

isSegmentedOp ::
  (MonadFreshNames m, LocalScope rep m, DistRep rep) =>
  KernelNest ->
  [Int] ->
  Names ->
  Names ->
  [SubExp] ->
  [VName] ->
  ( Pat Type ->
    [(VName, SubExp)] ->
    [KernelInput] ->
    [SubExp] ->
    [VName] ->
    BuilderT rep m ()
  ) ->
  DistNestT rep m (Maybe (Stms rep))
isSegmentedOp :: forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
KernelNest
-> [Int]
-> Names
-> Names
-> [SubExp]
-> [VName]
-> (Pat Type
    -> [(VName, SubExp)]
    -> [KernelInput]
    -> [SubExp]
    -> [VName]
    -> BuilderT rep m ())
-> DistNestT rep m (Maybe (Stms rep))
isSegmentedOp KernelNest
nest [Int]
perm Names
free_in_op Names
_free_in_fold_op [SubExp]
nes [VName]
arrs Pat Type
-> [(VName, SubExp)]
-> [KernelInput]
-> [SubExp]
-> [VName]
-> BuilderT rep m ()
m = MaybeT (DistNestT rep m) (Stms rep)
-> DistNestT rep m (Maybe (Stms rep))
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT (MaybeT (DistNestT rep m) (Stms rep)
 -> DistNestT rep m (Maybe (Stms rep)))
-> MaybeT (DistNestT rep m) (Stms rep)
-> DistNestT rep m (Maybe (Stms rep))
forall a b. (a -> b) -> a -> b
$ do
  -- We must verify that array inputs to the operation are inputs to
  -- the outermost loop nesting or free in the loop nest.  Nothing
  -- free in the op may be bound by the nest.  Furthermore, the
  -- neutral elements must be free in the loop nest.
  --
  -- We must summarise any names from free_in_op that are bound in the
  -- nest, and describe how to obtain them given segment indices.

  let bound_by_nest :: Names
bound_by_nest = KernelNest -> Names
boundInKernelNest KernelNest
nest

  (ispace, kernel_inps) <- KernelNest
-> MaybeT (DistNestT rep m) ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
nest

  when (free_in_op `namesIntersect` bound_by_nest) $
    fail "Non-fold lambda uses nest-bound parameters."

  let indices = ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst [(VName, SubExp)]
ispace

      prepareNe (Var VName
v)
        | VName
v VName -> Names -> Bool
`nameIn` Names
bound_by_nest =
            [Char] -> MaybeT (DistNestT rep m) SubExp
forall a. [Char] -> MaybeT (DistNestT rep m) a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Neutral element bound in nest"
      prepareNe SubExp
ne = SubExp -> MaybeT (DistNestT rep m) SubExp
forall a. a -> MaybeT (DistNestT rep m) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
ne

      prepareArr VName
arr =
        case (KernelInput -> Bool) -> [KernelInput] -> Maybe KernelInput
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
arr) (VName -> Bool) -> (KernelInput -> VName) -> KernelInput -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelInput -> VName
kernelInputName) [KernelInput]
kernel_inps of
          Just KernelInput
inp
            | KernelInput -> [SubExp]
kernelInputIndices KernelInput
inp [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
indices ->
                BuilderT rep m VName
-> MaybeT (DistNestT rep m) (BuilderT rep m VName)
forall a. a -> MaybeT (DistNestT rep m) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (BuilderT rep m VName
 -> MaybeT (DistNestT rep m) (BuilderT rep m VName))
-> BuilderT rep m VName
-> MaybeT (DistNestT rep m) (BuilderT rep m VName)
forall a b. (a -> b) -> a -> b
$ VName -> BuilderT rep m VName
forall a. a -> BuilderT rep m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> BuilderT rep m VName) -> VName -> BuilderT rep m VName
forall a b. (a -> b) -> a -> b
$ KernelInput -> VName
kernelInputArray KernelInput
inp
          Maybe KernelInput
Nothing
            | VName
arr VName -> Names -> Bool
`notNameIn` Names
bound_by_nest ->
                -- This input is something that is free inside
                -- the loop nesting. We will have to replicate
                -- it.
                BuilderT rep m VName
-> MaybeT (DistNestT rep m) (BuilderT rep m VName)
forall a. a -> MaybeT (DistNestT rep m) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (BuilderT rep m VName
 -> MaybeT (DistNestT rep m) (BuilderT rep m VName))
-> BuilderT rep m VName
-> MaybeT (DistNestT rep m) (BuilderT rep m VName)
forall a b. (a -> b) -> a -> b
$
                  [Char] -> Exp (Rep (BuilderT rep m)) -> BuilderT rep m VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp
                    (VName -> [Char]
baseString VName
arr [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_repd")
                    (BasicOp -> Exp (Rep (BuilderT rep m))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT rep m)))
-> BasicOp -> Exp (Rep (BuilderT rep m))
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$ ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace) (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr)
          Maybe KernelInput
_ ->
            [Char] -> MaybeT (DistNestT rep m) (BuilderT rep m VName)
forall a. [Char] -> MaybeT (DistNestT rep m) a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Input not free, perfectly mapped, or outermost."

  nes' <- mapM prepareNe nes

  mk_arrs <- mapM prepareArr arrs

  lift $
    liftInner $
      runBuilderT'_ $ do
        nested_arrs <- sequence mk_arrs

        let pat =
              [PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem Type] -> Pat Type)
-> ([PatElem Type] -> [PatElem Type]) -> [PatElem Type] -> Pat Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [PatElem Type] -> [PatElem Type]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([PatElem Type] -> Pat Type) -> [PatElem Type] -> Pat Type
forall a b. (a -> b) -> a -> b
$
                Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems (Pat Type -> [PatElem Type]) -> Pat Type -> [PatElem Type]
forall a b. (a -> b) -> a -> b
$
                  LoopNesting -> Pat Type
loopNestingPat (LoopNesting -> Pat Type) -> LoopNesting -> Pat Type
forall a b. (a -> b) -> a -> b
$
                    KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
nest

        m pat ispace kernel_inps nes' nested_arrs

permutationAndMissing :: Pat Type -> Result -> Maybe ([Int], [PatElem Type])
permutationAndMissing :: Pat Type -> Result -> Maybe ([Int], [PatElem Type])
permutationAndMissing (Pat [PatElem Type]
pes) Result
res = do
  let ([PatElem Type]
_used, [PatElem Type]
unused) =
        (PatElem Type -> Bool)
-> [PatElem Type] -> ([PatElem Type], [PatElem Type])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((VName -> Names -> Bool
`nameIn` Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
res) (VName -> Bool) -> (PatElem Type -> VName) -> PatElem Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName) [PatElem Type]
pes
      res' :: [SubExp]
res' = (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
res
      res_expanded :: [SubExp]
res_expanded = [SubExp]
res' [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ (PatElem Type -> SubExp) -> [PatElem Type] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (PatElem Type -> VName) -> PatElem Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName) [PatElem Type]
unused
  perm <- (PatElem Type -> SubExp) -> [PatElem Type] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (PatElem Type -> VName) -> PatElem Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName) [PatElem Type]
pes [SubExp] -> [SubExp] -> Maybe [Int]
forall a. Eq a => [a] -> [a] -> Maybe [Int]
`isPermutationOf` [SubExp]
res_expanded
  pure (perm, unused)

-- Add extra pattern elements to every kernel nesting level.
expandKernelNest ::
  (MonadFreshNames m) => [PatElem Type] -> KernelNest -> m KernelNest
expandKernelNest :: forall (m :: * -> *).
MonadFreshNames m =>
[PatElem Type] -> KernelNest -> m KernelNest
expandKernelNest [PatElem Type]
pes (LoopNesting
outer_nest, [LoopNesting]
inner_nests) = do
  let outer_size :: [SubExp]
outer_size =
        LoopNesting -> SubExp
loopNestingWidth LoopNesting
outer_nest
          SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: (LoopNesting -> SubExp) -> [LoopNesting] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map LoopNesting -> SubExp
loopNestingWidth [LoopNesting]
inner_nests
      inner_sizes :: [[SubExp]]
inner_sizes = [SubExp] -> [[SubExp]]
forall a. [a] -> [[a]]
tails ([SubExp] -> [[SubExp]]) -> [SubExp] -> [[SubExp]]
forall a b. (a -> b) -> a -> b
$ (LoopNesting -> SubExp) -> [LoopNesting] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map LoopNesting -> SubExp
loopNestingWidth [LoopNesting]
inner_nests
  outer_nest' <- LoopNesting -> [SubExp] -> m LoopNesting
expandWith LoopNesting
outer_nest [SubExp]
outer_size
  inner_nests' <- zipWithM expandWith inner_nests inner_sizes
  pure (outer_nest', inner_nests')
  where
    expandWith :: LoopNesting -> [SubExp] -> m LoopNesting
expandWith LoopNesting
nest [SubExp]
dims = do
      pes' <- (PatElem Type -> m (PatElem Type))
-> [PatElem Type] -> m [PatElem Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ([SubExp] -> PatElem Type -> m (PatElem Type)
forall {m :: * -> *} {dec}.
(MonadFreshNames m, Typed dec) =>
[SubExp] -> PatElem dec -> m (PatElem Type)
expandPatElemWith [SubExp]
dims) [PatElem Type]
pes
      pure
        nest
          { loopNestingPat =
              Pat $ patElems (loopNestingPat nest) <> pes'
          }

    expandPatElemWith :: [SubExp] -> PatElem dec -> m (PatElem Type)
expandPatElemWith [SubExp]
dims PatElem dec
pe = do
      name <- [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName ([Char] -> m VName) -> [Char] -> m VName
forall a b. (a -> b) -> a -> b
$ VName -> [Char]
baseString (VName -> [Char]) -> VName -> [Char]
forall a b. (a -> b) -> a -> b
$ PatElem dec -> VName
forall dec. PatElem dec -> VName
patElemName PatElem dec
pe
      pure
        pe
          { patElemName = name,
            patElemDec = patElemType pe `arrayOfShape` Shape dims
          }

kernelOrNot ::
  (MonadFreshNames m, DistRep rep) =>
  Certs ->
  Stm SOACS ->
  DistAcc rep ->
  PostStms rep ->
  DistAcc rep ->
  Maybe (Stms rep) ->
  DistNestT rep m (DistAcc rep)
kernelOrNot :: forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Certs
-> Stm SOACS
-> DistAcc rep
-> PostStms rep
-> DistAcc rep
-> Maybe (Stms rep)
-> DistNestT rep m (DistAcc rep)
kernelOrNot Certs
cs Stm SOACS
stm DistAcc rep
acc PostStms rep
_ DistAcc rep
_ Maybe (Stms rep)
Nothing =
  Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc (Certs -> Stm SOACS -> Stm SOACS
forall rep. Certs -> Stm rep -> Stm rep
certify Certs
cs Stm SOACS
stm) DistAcc rep
acc
kernelOrNot Certs
cs Stm SOACS
_ DistAcc rep
_ PostStms rep
kernels DistAcc rep
acc' (Just Stms rep
stms) = do
  PostStms rep -> DistNestT rep m ()
forall (m :: * -> *) rep.
Monad m =>
PostStms rep -> DistNestT rep m ()
addPostStms PostStms rep
kernels
  Stms rep -> DistNestT rep m ()
forall (m :: * -> *) rep. Monad m => Stms rep -> DistNestT rep m ()
postStm (Stms rep -> DistNestT rep m ()) -> Stms rep -> DistNestT rep m ()
forall a b. (a -> b) -> a -> b
$ (Stm rep -> Stm rep) -> Stms rep -> Stms rep
forall a b. (a -> b) -> Seq a -> Seq b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certs -> Stm rep -> Stm rep
forall rep. Certs -> Stm rep -> Stm rep
certify Certs
cs) Stms rep
stms
  DistAcc rep -> DistNestT rep m (DistAcc rep)
forall a. a -> DistNestT rep m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure DistAcc rep
acc'

distributeMap ::
  (MonadFreshNames m, LocalScope rep m, DistRep rep) =>
  MapLoop ->
  DistAcc rep ->
  DistNestT rep m (DistAcc rep)
distributeMap :: forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
distributeMap (MapLoop Pat Type
pat StmAux ()
aux SubExp
w Lambda SOACS
lam [VName]
arrs) DistAcc rep
acc =
  DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> DistNestT rep m (DistAcc rep)
distribute
    (DistAcc rep -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Pat Type
-> StmAux ()
-> SubExp
-> Lambda SOACS
-> [VName]
-> DistNestT rep m (DistAcc rep)
-> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Pat Type
-> StmAux ()
-> SubExp
-> Lambda SOACS
-> [VName]
-> DistNestT rep m (DistAcc rep)
-> DistNestT rep m (DistAcc rep)
mapNesting
      Pat Type
pat
      StmAux ()
aux
      SubExp
w
      Lambda SOACS
lam
      [VName]
arrs
      (DistAcc rep -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> DistNestT rep m (DistAcc rep)
distribute (DistAcc rep -> DistNestT rep m (DistAcc rep))
-> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< DistAcc rep -> Stms SOACS -> DistNestT rep m (DistAcc rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> Stms SOACS -> DistNestT rep m (DistAcc rep)
distributeMapBodyStms DistAcc rep
acc' Stms SOACS
lam_stms)
  where
    acc' :: DistAcc rep
acc' =
      DistAcc
        { distTargets :: Targets
distTargets =
            (Pat Type, Result) -> Targets -> Targets
pushInnerTarget
              (Pat Type
pat, Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult (Body SOACS -> Result) -> Body SOACS -> Result
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam)
              (Targets -> Targets) -> Targets -> Targets
forall a b. (a -> b) -> a -> b
$ DistAcc rep -> Targets
forall rep. DistAcc rep -> Targets
distTargets DistAcc rep
acc,
          distStms :: Stms rep
distStms = Stms rep
forall a. Monoid a => a
mempty
        }

    lam_stms :: Stms SOACS
lam_stms = Body SOACS -> Stms SOACS
forall rep. Body rep -> Stms rep
bodyStms (Body SOACS -> Stms SOACS) -> Body SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam