{-# LANGUAGE TypeFamilies #-}

-- | Turn certain uses of accumulators into SegHists.
module Futhark.Optimise.HistAccs (histAccsGPU) where

import Control.Monad
import Control.Monad.Reader
import Control.Monad.State
import Data.Map.Strict qualified as M
import Futhark.IR.GPU
import Futhark.MonadFreshNames
import Futhark.Pass
import Futhark.Tools
import Futhark.Transform.Rename
import Prelude hiding (quot)

-- | A mapping from accumulator variables to their source.
type Accs rep = M.Map VName (WithAccInput rep)

type OptM = ReaderT (Scope GPU) (State VNameSource)

optimiseBody :: Accs GPU -> Body GPU -> OptM (Body GPU)
optimiseBody :: Map VName (WithAccInput GPU) -> Body GPU -> OptM (Body GPU)
optimiseBody Map VName (WithAccInput GPU)
accs Body GPU
body = Stms GPU -> Result -> Body GPU
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody (Stms GPU -> Result -> Body GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Result -> Body GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Map VName (WithAccInput GPU)
-> Stms GPU -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
optimiseStms Map VName (WithAccInput GPU)
accs (Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms Body GPU
body) ReaderT (Scope GPU) (State VNameSource) (Result -> Body GPU)
-> ReaderT (Scope GPU) (State VNameSource) Result
-> OptM (Body GPU)
forall a b.
ReaderT (Scope GPU) (State VNameSource) (a -> b)
-> ReaderT (Scope GPU) (State VNameSource) a
-> ReaderT (Scope GPU) (State VNameSource) b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> ReaderT (Scope GPU) (State VNameSource) Result
forall a. a -> ReaderT (Scope GPU) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body GPU -> Result
forall rep. Body rep -> Result
bodyResult Body GPU
body)

optimiseExp :: Accs GPU -> Exp GPU -> OptM (Exp GPU)
optimiseExp :: Map VName (WithAccInput GPU) -> Exp GPU -> OptM (Exp GPU)
optimiseExp Map VName (WithAccInput GPU)
accs = Mapper GPU GPU (ReaderT (Scope GPU) (State VNameSource))
-> Exp GPU -> OptM (Exp GPU)
forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper GPU GPU (ReaderT (Scope GPU) (State VNameSource))
mapper
  where
    mapper :: Mapper GPU GPU (ReaderT (Scope GPU) (State VNameSource))
mapper =
      Mapper GPU GPU (ReaderT (Scope GPU) (State VNameSource))
forall rep (m :: * -> *). Monad m => Mapper rep rep m
identityMapper
        { mapOnBody = \Scope GPU
scope Body GPU
body -> Scope GPU -> OptM (Body GPU) -> OptM (Body GPU)
forall a.
Scope GPU
-> ReaderT (Scope GPU) (State VNameSource) a
-> ReaderT (Scope GPU) (State VNameSource) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPU
scope (OptM (Body GPU) -> OptM (Body GPU))
-> OptM (Body GPU) -> OptM (Body GPU)
forall a b. (a -> b) -> a -> b
$ Map VName (WithAccInput GPU) -> Body GPU -> OptM (Body GPU)
optimiseBody Map VName (WithAccInput GPU)
accs Body GPU
body
        }

extractUpdate ::
  Accs rep ->
  VName ->
  Stms rep ->
  Maybe ((WithAccInput rep, VName, [SubExp], [SubExp]), Stms rep)
extractUpdate :: forall rep.
Accs rep
-> VName
-> Stms rep
-> Maybe ((WithAccInput rep, VName, [SubExp], [SubExp]), Stms rep)
extractUpdate Accs rep
accs VName
v Stms rep
stms = do
  (stm, stms') <- Stms rep -> Maybe (Stm rep, Stms rep)
forall rep. Stms rep -> Maybe (Stm rep, Stms rep)
stmsHead Stms rep
stms
  case stm of
    Let (Pat [PatElem VName
pe_v LetDec rep
_]) StmAux (ExpDec rep)
_ (BasicOp (UpdateAcc Safety
_ VName
acc [SubExp]
is [SubExp]
vs))
      | VName
pe_v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v -> do
          acc_input <- VName -> Accs rep -> Maybe (WithAccInput rep)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
acc Accs rep
accs
          Just ((acc_input, acc, is, vs), stms')
    Stm rep
_ -> do
      (x, stms'') <- Accs rep
-> VName
-> Stms rep
-> Maybe ((WithAccInput rep, VName, [SubExp], [SubExp]), Stms rep)
forall rep.
Accs rep
-> VName
-> Stms rep
-> Maybe ((WithAccInput rep, VName, [SubExp], [SubExp]), Stms rep)
extractUpdate Accs rep
accs VName
v Stms rep
stms'
      pure (x, oneStm stm <> stms'')

mkHistBody :: Accs GPU -> KernelBody GPU -> Maybe (KernelBody GPU, WithAccInput GPU, VName)
mkHistBody :: Map VName (WithAccInput GPU)
-> KernelBody GPU
-> Maybe (KernelBody GPU, WithAccInput GPU, VName)
mkHistBody Map VName (WithAccInput GPU)
accs (KernelBody () Stms GPU
stms [Returns ResultManifest
rm Certs
cs (Var VName
v)]) = do
  ((acc_input, acc, is, vs), stms') <- Map VName (WithAccInput GPU)
-> VName
-> Stms GPU
-> Maybe ((WithAccInput GPU, VName, [SubExp], [SubExp]), Stms GPU)
forall rep.
Accs rep
-> VName
-> Stms rep
-> Maybe ((WithAccInput rep, VName, [SubExp], [SubExp]), Stms rep)
extractUpdate Map VName (WithAccInput GPU)
accs VName
v Stms GPU
stms
  pure
    ( KernelBody () stms' $ map (Returns rm cs) is ++ map (Returns rm cs) vs,
      acc_input,
      acc
    )
mkHistBody Map VName (WithAccInput GPU)
_ KernelBody GPU
_ = Maybe (KernelBody GPU, WithAccInput GPU, VName)
forall a. Maybe a
Nothing

withAccLamToHistLam :: (MonadFreshNames m) => Shape -> Lambda GPU -> m (Lambda GPU)
withAccLamToHistLam :: forall (m :: * -> *).
MonadFreshNames m =>
Shape -> Lambda GPU -> m (Lambda GPU)
withAccLamToHistLam Shape
shape Lambda GPU
lam =
  Lambda GPU -> m (Lambda GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda (Lambda GPU -> m (Lambda GPU)) -> Lambda GPU -> m (Lambda GPU)
forall a b. (a -> b) -> a -> b
$ Lambda GPU
lam {lambdaParams = drop (shapeRank shape) (lambdaParams lam)}

addArrsToAcc ::
  (MonadBuilder m, Rep m ~ GPU) =>
  SegLevel ->
  Shape ->
  [VName] ->
  VName ->
  m (Exp GPU)
addArrsToAcc :: forall (m :: * -> *).
(MonadBuilder m, Rep m ~ GPU) =>
SegLevel -> Shape -> [VName] -> VName -> m (Exp GPU)
addArrsToAcc SegLevel
lvl Shape
shape [VName]
arrs VName
acc = do
  flat <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"phys_tid"
  gtids <- replicateM (shapeRank shape) (newVName "gtid")
  let space = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
flat ([(VName, SubExp)] -> SegSpace) -> [(VName, SubExp)] -> SegSpace
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids ([SubExp] -> [(VName, SubExp)]) -> [SubExp] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape

  (acc', stms) <- localScope (scopeOfSegSpace space) . collectStms $ do
    vs <- forM arrs $ \VName
arr -> do
      arr_t <- VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
      letSubExp (baseString arr <> "_elem") $
        BasicOp $
          Index arr $
            fullSlice arr_t $
              map (DimFix . Var) gtids
    letExp (baseString acc <> "_upd") $
      BasicOp $
        UpdateAcc Safe acc (map Var gtids) vs

  acc_t <- lookupType acc
  pure . Op . SegOp . SegMap lvl space [acc_t] $
    KernelBody () stms [Returns ResultMaySimplify mempty (Var acc')]

flatKernelBody ::
  (MonadBuilder m) =>
  SegSpace ->
  KernelBody (Rep m) ->
  m (SegSpace, KernelBody (Rep m))
flatKernelBody :: forall (m :: * -> *).
MonadBuilder m =>
SegSpace -> KernelBody (Rep m) -> m (SegSpace, KernelBody (Rep m))
flatKernelBody SegSpace
space KernelBody (Rep m)
kbody = do
  gtid <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gtid"
  dims_prod <-
    letSubExp "dims_prod"
      =<< foldBinOp (Mul Int64 OverflowUndef) (intConst Int64 1) (segSpaceDims space)

  let space' = VName -> [(VName, SubExp)] -> SegSpace
SegSpace (SegSpace -> VName
segFlat SegSpace
space) [(VName
gtid, SubExp
dims_prod)]

  kbody_stms <- localScope (scopeOfSegSpace space') . collectStms_ $ do
    let new_inds =
          [TPrimExp Int64 VName]
-> TPrimExp Int64 VName -> [TPrimExp Int64 VName]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex ((SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (SegSpace -> [SubExp]
segSpaceDims SegSpace
space)) (SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName) -> SubExp -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
gtid)
    zipWithM_ letBindNames (map (pure . fst) (unSegSpace space))
      =<< mapM toExp new_inds
    addStms $ kernelBodyStms kbody

  pure (space', kbody {kernelBodyStms = kbody_stms})

optimiseStm :: Accs GPU -> Stm GPU -> OptM (Stms GPU)
-- TODO: this is very restricted currently, but shows the idea.
optimiseStm :: Map VName (WithAccInput GPU)
-> Stm GPU -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
optimiseStm Map VName (WithAccInput GPU)
accs (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (WithAcc [WithAccInput GPU]
inputs Lambda GPU
lam)) = do
  Scope GPU
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall a.
Scope GPU
-> ReaderT (Scope GPU) (State VNameSource) a
-> ReaderT (Scope GPU) (State VNameSource) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param (LParamInfo GPU)] -> Scope GPU
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams (Lambda GPU -> [Param (LParamInfo GPU)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
lam)) (ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
 -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU))
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
    body' <- Map VName (WithAccInput GPU) -> Body GPU -> OptM (Body GPU)
optimiseBody Map VName (WithAccInput GPU)
accs' (Body GPU -> OptM (Body GPU)) -> Body GPU -> OptM (Body GPU)
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
lam
    let lam' = Lambda GPU
lam {lambdaBody = body'}
    pure $ oneStm $ Let pat aux $ WithAcc inputs lam'
  where
    acc_names :: [VName]
acc_names = (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName ([Param Type] -> [VName]) -> [Param Type] -> [VName]
forall a b. (a -> b) -> a -> b
$ Int -> [Param Type] -> [Param Type]
forall a. Int -> [a] -> [a]
drop ([WithAccInput GPU] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput GPU]
inputs) ([Param Type] -> [Param Type]) -> [Param Type] -> [Param Type]
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> [Param (LParamInfo GPU)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
lam
    accs' :: Map VName (WithAccInput GPU)
accs' = [(VName, WithAccInput GPU)] -> Map VName (WithAccInput GPU)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([VName] -> [WithAccInput GPU] -> [(VName, WithAccInput GPU)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
acc_names [WithAccInput GPU]
inputs) Map VName (WithAccInput GPU)
-> Map VName (WithAccInput GPU) -> Map VName (WithAccInput GPU)
forall a. Semigroup a => a -> a -> a
<> Map VName (WithAccInput GPU)
accs
optimiseStm Map VName (WithAccInput GPU)
accs (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (Op (SegOp (SegMap SegLevel
lvl SegSpace
space [Type]
_ KernelBody GPU
kbody))))
  | Map VName (WithAccInput GPU)
accs Map VName (WithAccInput GPU)
-> Map VName (WithAccInput GPU) -> Bool
forall a. Eq a => a -> a -> Bool
/= Map VName (WithAccInput GPU)
forall a. Monoid a => a
mempty,
    Just (KernelBody GPU
kbody', (Shape
acc_shape, [VName]
_, Just (Lambda GPU
acc_lam, [SubExp]
acc_nes)), VName
acc) <-
      Map VName (WithAccInput GPU)
-> KernelBody GPU
-> Maybe (KernelBody GPU, WithAccInput GPU, VName)
mkHistBody Map VName (WithAccInput GPU)
accs KernelBody GPU
kbody,
    (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([Type] -> Bool) -> [Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda GPU
acc_lam = Builder GPU ()
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Builder GPU ()
 -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU))
-> Builder GPU ()
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
      hist_dests <- [SubExp]
-> (SubExp -> BuilderT GPU (State VNameSource) VName)
-> BuilderT GPU (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SubExp]
acc_nes ((SubExp -> BuilderT GPU (State VNameSource) VName)
 -> BuilderT GPU (State VNameSource) [VName])
-> (SubExp -> BuilderT GPU (State VNameSource) VName)
-> BuilderT GPU (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \SubExp
ne ->
        String
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"hist_dest" (Exp (Rep (BuilderT GPU (State VNameSource)))
 -> BuilderT GPU (State VNameSource) VName)
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource))))
-> BasicOp -> Exp (Rep (BuilderT GPU (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
acc_shape SubExp
ne

      acc_lam' <- withAccLamToHistLam acc_shape acc_lam

      let ts' =
            Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate (Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
acc_shape) (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
              [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ Lambda GPU -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda GPU
acc_lam
          histop =
            HistOp
              { histShape :: Shape
histShape = Shape
acc_shape,
                histRaceFactor :: SubExp
histRaceFactor = IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1,
                histDest :: [VName]
histDest = [VName]
hist_dests,
                histNeutral :: [SubExp]
histNeutral = [SubExp]
acc_nes,
                histOpShape :: Shape
histOpShape = Shape
forall a. Monoid a => a
mempty,
                histOp :: Lambda GPU
histOp = Lambda GPU
acc_lam'
              }

      (space', kbody'') <- flatKernelBody space kbody'

      hist_dest_upd <-
        letTupExp "hist_dest_upd" $ Op $ SegOp $ SegHist lvl space' ts' kbody'' [histop]

      addStm . Let pat aux =<< addArrsToAcc lvl acc_shape hist_dest_upd acc
optimiseStm Map VName (WithAccInput GPU)
accs (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux Exp GPU
e) =
  Stm GPU -> Stms GPU
forall rep. Stm rep -> Stms rep
oneStm (Stm GPU -> Stms GPU)
-> (Exp GPU -> Stm GPU) -> Exp GPU -> Stms GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (Exp GPU -> Stms GPU)
-> OptM (Exp GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Map VName (WithAccInput GPU) -> Exp GPU -> OptM (Exp GPU)
optimiseExp Map VName (WithAccInput GPU)
accs Exp GPU
e

optimiseStms :: Accs GPU -> Stms GPU -> OptM (Stms GPU)
optimiseStms :: Map VName (WithAccInput GPU)
-> Stms GPU -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
optimiseStms Map VName (WithAccInput GPU)
accs Stms GPU
stms =
  Scope GPU
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall a.
Scope GPU
-> ReaderT (Scope GPU) (State VNameSource) a
-> ReaderT (Scope GPU) (State VNameSource) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Stms GPU -> Scope GPU
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms GPU
stms) (ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
 -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU))
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall a b. (a -> b) -> a -> b
$
    [Stms GPU] -> Stms GPU
forall a. Monoid a => [a] -> a
mconcat ([Stms GPU] -> Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) [Stms GPU]
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm GPU -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU))
-> [Stm GPU] -> ReaderT (Scope GPU) (State VNameSource) [Stms GPU]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Map VName (WithAccInput GPU)
-> Stm GPU -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
optimiseStm Map VName (WithAccInput GPU)
accs) (Stms GPU -> [Stm GPU]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms GPU
stms)

-- | The pass for GPU kernels.
histAccsGPU :: Pass GPU GPU
histAccsGPU :: Pass GPU GPU
histAccsGPU =
  String -> String -> (Prog GPU -> PassM (Prog GPU)) -> Pass GPU GPU
forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass String
"hist accs" String
"Turn certain accumulations into histograms" ((Prog GPU -> PassM (Prog GPU)) -> Pass GPU GPU)
-> (Prog GPU -> PassM (Prog GPU)) -> Pass GPU GPU
forall a b. (a -> b) -> a -> b
$
    (Scope GPU -> Stms GPU -> PassM (Stms GPU))
-> Prog GPU -> PassM (Prog GPU)
forall rep.
(Scope rep -> Stms rep -> PassM (Stms rep))
-> Prog rep -> PassM (Prog rep)
intraproceduralTransformation Scope GPU -> Stms GPU -> PassM (Stms GPU)
forall {m :: * -> *}.
MonadFreshNames m =>
Scope GPU -> Stms GPU -> m (Stms GPU)
onStms
  where
    onStms :: Scope GPU -> Stms GPU -> m (Stms GPU)
onStms Scope GPU
scope Stms GPU
stms =
      (VNameSource -> (Stms GPU, VNameSource)) -> m (Stms GPU)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Stms GPU, VNameSource)) -> m (Stms GPU))
-> (State VNameSource (Stms GPU)
    -> VNameSource -> (Stms GPU, VNameSource))
-> State VNameSource (Stms GPU)
-> m (Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. State VNameSource (Stms GPU)
-> VNameSource -> (Stms GPU, VNameSource)
forall s a. State s a -> s -> (a, s)
runState (State VNameSource (Stms GPU) -> m (Stms GPU))
-> State VNameSource (Stms GPU) -> m (Stms GPU)
forall a b. (a -> b) -> a -> b
$
        ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
-> Scope GPU -> State VNameSource (Stms GPU)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Map VName (WithAccInput GPU)
-> Stms GPU -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
optimiseStms Map VName (WithAccInput GPU)
forall a. Monoid a => a
mempty Stms GPU
stms) Scope GPU
scope