{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE TypeFamilies #-}

-- | Perform array short circuiting
module Futhark.Optimise.ArrayShortCircuiting
  ( optimiseSeqMem,
    optimiseGPUMem,
    optimiseMCMem,
  )
where

import Control.Monad
import Control.Monad.Reader
import Data.Function ((&))
import Data.List qualified as L
import Data.Map qualified as M
import Data.Maybe (fromMaybe)
import Futhark.Analysis.Alias qualified as AnlAls
import Futhark.IR.Aliases
import Futhark.IR.GPUMem
import Futhark.IR.MCMem
import Futhark.IR.Mem.LMAD qualified as LMAD
import Futhark.IR.SeqMem
import Futhark.Optimise.ArrayShortCircuiting.ArrayCoalescing
import Futhark.Optimise.ArrayShortCircuiting.DataStructs
import Futhark.Pass (Pass (..))
import Futhark.Pass qualified as Pass
import Futhark.Util

data Env inner = Env
  { forall inner. Env inner -> CoalsTab
envCoalesceTab :: CoalsTab,
    forall inner. Env inner -> inner -> UpdateM inner inner
onInner :: inner -> UpdateM inner inner,
    forall inner. Env inner -> Names
memAllocsToRemove :: Names
  }

type UpdateM inner a = Reader (Env inner) a

optimiseSeqMem :: Pass SeqMem SeqMem
optimiseSeqMem :: Pass SeqMem SeqMem
optimiseSeqMem = String
-> String
-> (Prog (Aliases SeqMem) -> PassM (Map Name CoalsTab))
-> (NoOp SeqMem -> UpdateM (NoOp SeqMem) (NoOp SeqMem))
-> (CoalsTab
    -> [FParam (Aliases SeqMem)] -> (Names, [FParam (Aliases SeqMem)]))
-> Pass SeqMem SeqMem
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem, AliasableRep rep) =>
String
-> String
-> (Prog (Aliases rep) -> PassM (Map Name CoalsTab))
-> (inner rep -> UpdateM (inner rep) (inner rep))
-> (CoalsTab
    -> [FParam (Aliases rep)] -> (Names, [FParam (Aliases rep)]))
-> Pass rep rep
pass String
"short-circuit" String
"Array Short-Circuiting" Prog (Aliases SeqMem) -> PassM (Map Name CoalsTab)
forall (m :: * -> *).
MonadFreshNames m =>
Prog (Aliases SeqMem) -> m (Map Name CoalsTab)
mkCoalsTab NoOp SeqMem -> UpdateM (NoOp SeqMem) (NoOp SeqMem)
forall a. a -> ReaderT (Env (NoOp SeqMem)) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure CoalsTab
-> [FParam (Aliases SeqMem)] -> (Names, [FParam (Aliases SeqMem)])
CoalsTab
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> (Names, [Param (MemInfo SubExp Uniqueness MemBind)])
replaceInParams

optimiseGPUMem :: Pass GPUMem GPUMem
optimiseGPUMem :: Pass GPUMem GPUMem
optimiseGPUMem = String
-> String
-> (Prog (Aliases GPUMem) -> PassM (Map Name CoalsTab))
-> (HostOp NoOp GPUMem
    -> UpdateM (HostOp NoOp GPUMem) (HostOp NoOp GPUMem))
-> (CoalsTab
    -> [FParam (Aliases GPUMem)] -> (Names, [FParam (Aliases GPUMem)]))
-> Pass GPUMem GPUMem
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem, AliasableRep rep) =>
String
-> String
-> (Prog (Aliases rep) -> PassM (Map Name CoalsTab))
-> (inner rep -> UpdateM (inner rep) (inner rep))
-> (CoalsTab
    -> [FParam (Aliases rep)] -> (Names, [FParam (Aliases rep)]))
-> Pass rep rep
pass String
"short-circuit-gpu" String
"Array Short-Circuiting (GPU)" Prog (Aliases GPUMem) -> PassM (Map Name CoalsTab)
forall (m :: * -> *).
MonadFreshNames m =>
Prog (Aliases GPUMem) -> m (Map Name CoalsTab)
mkCoalsTabGPU HostOp NoOp GPUMem
-> UpdateM (HostOp NoOp GPUMem) (HostOp NoOp GPUMem)
replaceInHostOp CoalsTab
-> [FParam (Aliases GPUMem)] -> (Names, [FParam (Aliases GPUMem)])
CoalsTab
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> (Names, [Param (MemInfo SubExp Uniqueness MemBind)])
replaceInParams

optimiseMCMem :: Pass MCMem MCMem
optimiseMCMem :: Pass MCMem MCMem
optimiseMCMem = String
-> String
-> (Prog (Aliases MCMem) -> PassM (Map Name CoalsTab))
-> (MCOp NoOp MCMem -> UpdateM (MCOp NoOp MCMem) (MCOp NoOp MCMem))
-> (CoalsTab
    -> [FParam (Aliases MCMem)] -> (Names, [FParam (Aliases MCMem)]))
-> Pass MCMem MCMem
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem, AliasableRep rep) =>
String
-> String
-> (Prog (Aliases rep) -> PassM (Map Name CoalsTab))
-> (inner rep -> UpdateM (inner rep) (inner rep))
-> (CoalsTab
    -> [FParam (Aliases rep)] -> (Names, [FParam (Aliases rep)]))
-> Pass rep rep
pass String
"short-circuit-mc" String
"Array Short-Circuiting (MC)" Prog (Aliases MCMem) -> PassM (Map Name CoalsTab)
forall (m :: * -> *).
MonadFreshNames m =>
Prog (Aliases MCMem) -> m (Map Name CoalsTab)
mkCoalsTabMC MCOp NoOp MCMem -> UpdateM (MCOp NoOp MCMem) (MCOp NoOp MCMem)
replaceInMCOp CoalsTab
-> [FParam (Aliases MCMem)] -> (Names, [FParam (Aliases MCMem)])
CoalsTab
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> (Names, [Param (MemInfo SubExp Uniqueness MemBind)])
replaceInParams

replaceInParams :: CoalsTab -> [Param FParamMem] -> (Names, [Param FParamMem])
replaceInParams :: CoalsTab
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> (Names, [Param (MemInfo SubExp Uniqueness MemBind)])
replaceInParams CoalsTab
coalstab [Param (MemInfo SubExp Uniqueness MemBind)]
fparams =
  let (Names
mem_allocs_to_remove, [Param (MemInfo SubExp Uniqueness MemBind)]
fparams') =
        ((Names, [Param (MemInfo SubExp Uniqueness MemBind)])
 -> Param (MemInfo SubExp Uniqueness MemBind)
 -> (Names, [Param (MemInfo SubExp Uniqueness MemBind)]))
-> (Names, [Param (MemInfo SubExp Uniqueness MemBind)])
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> (Names, [Param (MemInfo SubExp Uniqueness MemBind)])
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (Names, [Param (MemInfo SubExp Uniqueness MemBind)])
-> Param (MemInfo SubExp Uniqueness MemBind)
-> (Names, [Param (MemInfo SubExp Uniqueness MemBind)])
replaceInParam (Names
forall a. Monoid a => a
mempty, [Param (MemInfo SubExp Uniqueness MemBind)]
forall a. Monoid a => a
mempty) [Param (MemInfo SubExp Uniqueness MemBind)]
fparams
   in (Names
mem_allocs_to_remove, [Param (MemInfo SubExp Uniqueness MemBind)]
-> [Param (MemInfo SubExp Uniqueness MemBind)]
forall a. [a] -> [a]
reverse [Param (MemInfo SubExp Uniqueness MemBind)]
fparams')
  where
    replaceInParam :: (Names, [Param (MemInfo SubExp Uniqueness MemBind)])
-> Param (MemInfo SubExp Uniqueness MemBind)
-> (Names, [Param (MemInfo SubExp Uniqueness MemBind)])
replaceInParam (Names
to_remove, [Param (MemInfo SubExp Uniqueness MemBind)]
acc) (Param Attrs
attrs VName
name MemInfo SubExp Uniqueness MemBind
dec) =
      case MemInfo SubExp Uniqueness MemBind
dec of
        MemMem Space
_
          | Just CoalsEntry
entry <- VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name CoalsTab
coalstab ->
              (VName -> Names
oneName (CoalsEntry -> VName
dstmem CoalsEntry
entry) Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
to_remove, Attrs
-> VName
-> MemInfo SubExp Uniqueness MemBind
-> Param (MemInfo SubExp Uniqueness MemBind)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs (CoalsEntry -> VName
dstmem CoalsEntry
entry) MemInfo SubExp Uniqueness MemBind
dec Param (MemInfo SubExp Uniqueness MemBind)
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> [Param (MemInfo SubExp Uniqueness MemBind)]
forall a. a -> [a] -> [a]
: [Param (MemInfo SubExp Uniqueness MemBind)]
acc)
        MemArray PrimType
pt ShapeBase SubExp
shp Uniqueness
u (ArrayIn VName
m LMAD
ixf)
          | Just CoalsEntry
entry <- VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m CoalsTab
coalstab ->
              (Names
to_remove, Attrs
-> VName
-> MemInfo SubExp Uniqueness MemBind
-> Param (MemInfo SubExp Uniqueness MemBind)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
name (PrimType
-> ShapeBase SubExp
-> Uniqueness
-> MemBind
-> MemInfo SubExp Uniqueness MemBind
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase SubExp
shp Uniqueness
u (MemBind -> MemInfo SubExp Uniqueness MemBind)
-> MemBind -> MemInfo SubExp Uniqueness MemBind
forall a b. (a -> b) -> a -> b
$ VName -> LMAD -> MemBind
ArrayIn (CoalsEntry -> VName
dstmem CoalsEntry
entry) LMAD
ixf) Param (MemInfo SubExp Uniqueness MemBind)
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> [Param (MemInfo SubExp Uniqueness MemBind)]
forall a. a -> [a] -> [a]
: [Param (MemInfo SubExp Uniqueness MemBind)]
acc)
        MemInfo SubExp Uniqueness MemBind
_ -> (Names
to_remove, Attrs
-> VName
-> MemInfo SubExp Uniqueness MemBind
-> Param (MemInfo SubExp Uniqueness MemBind)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
name MemInfo SubExp Uniqueness MemBind
dec Param (MemInfo SubExp Uniqueness MemBind)
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> [Param (MemInfo SubExp Uniqueness MemBind)]
forall a. a -> [a] -> [a]
: [Param (MemInfo SubExp Uniqueness MemBind)]
acc)

removeAllocsInStms :: Stms rep -> UpdateM inner (Stms rep)
removeAllocsInStms :: forall rep inner. Stms rep -> UpdateM inner (Stms rep)
removeAllocsInStms Stms rep
stms = do
  to_remove <- (Env inner -> Names) -> ReaderT (Env inner) Identity Names
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env inner -> Names
forall inner. Env inner -> Names
memAllocsToRemove
  stmsToList stms
    & filter (not . flip nameIn to_remove . head . patNames . stmPat)
    & stmsFromList
    & pure

pass ::
  (Mem rep inner, LetDec rep ~ LetDecMem, AliasableRep rep) =>
  String ->
  String ->
  (Prog (Aliases rep) -> Pass.PassM (M.Map Name CoalsTab)) ->
  (inner rep -> UpdateM (inner rep) (inner rep)) ->
  (CoalsTab -> [FParam (Aliases rep)] -> (Names, [FParam (Aliases rep)])) ->
  Pass rep rep
pass :: forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem, AliasableRep rep) =>
String
-> String
-> (Prog (Aliases rep) -> PassM (Map Name CoalsTab))
-> (inner rep -> UpdateM (inner rep) (inner rep))
-> (CoalsTab
    -> [FParam (Aliases rep)] -> (Names, [FParam (Aliases rep)]))
-> Pass rep rep
pass String
flag String
desc Prog (Aliases rep) -> PassM (Map Name CoalsTab)
mk inner rep -> UpdateM (inner rep) (inner rep)
on_inner CoalsTab
-> [FParam (Aliases rep)] -> (Names, [FParam (Aliases rep)])
on_fparams =
  String -> String -> (Prog rep -> PassM (Prog rep)) -> Pass rep rep
forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass String
flag String
desc ((Prog rep -> PassM (Prog rep)) -> Pass rep rep)
-> (Prog rep -> PassM (Prog rep)) -> Pass rep rep
forall a b. (a -> b) -> a -> b
$ \Prog rep
prog -> do
    coaltabs <- Prog (Aliases rep) -> PassM (Map Name CoalsTab)
mk (Prog (Aliases rep) -> PassM (Map Name CoalsTab))
-> Prog (Aliases rep) -> PassM (Map Name CoalsTab)
forall a b. (a -> b) -> a -> b
$ Prog rep -> Prog (Aliases rep)
forall rep. AliasableRep rep => Prog rep -> Prog (Aliases rep)
AnlAls.aliasAnalysis Prog rep
prog
    Pass.intraproceduralTransformationWithConsts pure (onFun coaltabs) prog
  where
    onFun :: Map Name CoalsTab -> Stms rep -> FunDef rep -> PassM (FunDef rep)
onFun Map Name CoalsTab
coaltabs Stms rep
_ FunDef rep
f = do
      let coaltab :: CoalsTab
coaltab = Map Name CoalsTab
coaltabs Map Name CoalsTab -> Name -> CoalsTab
forall k a. Ord k => Map k a -> k -> a
M.! FunDef rep -> Name
forall rep. FunDef rep -> Name
funDefName FunDef rep
f
      let (Names
mem_allocs_to_remove, [FParam (Aliases rep)]
new_fparams) = CoalsTab
-> [FParam (Aliases rep)] -> (Names, [FParam (Aliases rep)])
on_fparams CoalsTab
coaltab ([FParam (Aliases rep)] -> (Names, [FParam (Aliases rep)]))
-> [FParam (Aliases rep)] -> (Names, [FParam (Aliases rep)])
forall a b. (a -> b) -> a -> b
$ FunDef rep -> [FParam rep]
forall rep. FunDef rep -> [FParam rep]
funDefParams FunDef rep
f
      FunDef rep -> PassM (FunDef rep)
forall a. a -> PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (FunDef rep -> PassM (FunDef rep))
-> FunDef rep -> PassM (FunDef rep)
forall a b. (a -> b) -> a -> b
$
        FunDef rep
f
          { funDefBody = onBody coaltab mem_allocs_to_remove $ funDefBody f,
            funDefParams = new_fparams
          }

    onBody :: CoalsTab -> Names -> Body rep -> Body rep
onBody CoalsTab
coaltab Names
mem_allocs_to_remove Body rep
body =
      Body rep
body
        { bodyStms =
            runReader
              (updateStms $ bodyStms body)
              (Env coaltab on_inner mem_allocs_to_remove),
          bodyResult = map (replaceResMem coaltab) $ bodyResult body
        }

replaceResMem :: CoalsTab -> SubExpRes -> SubExpRes
replaceResMem :: CoalsTab -> SubExpRes -> SubExpRes
replaceResMem CoalsTab
coaltab SubExpRes
res =
  case (VName -> CoalsTab -> Maybe CoalsEntry)
-> CoalsTab -> VName -> Maybe CoalsEntry
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup CoalsTab
coaltab (VName -> Maybe CoalsEntry) -> Maybe VName -> Maybe CoalsEntry
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SubExpRes -> Maybe VName
subExpResVName SubExpRes
res of
    Just CoalsEntry
entry -> SubExpRes
res {resSubExp = Var $ dstmem entry}
    Maybe CoalsEntry
Nothing -> SubExpRes
res

updateStms ::
  (Mem rep inner, LetDec rep ~ LetDecMem) =>
  Stms rep ->
  UpdateM (inner rep) (Stms rep)
updateStms :: forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> UpdateM (inner rep) (Stms rep)
updateStms Stms rep
stms = do
  stms' <- (Stm rep -> ReaderT (Env (inner rep)) Identity (Stm rep))
-> Stms rep -> ReaderT (Env (inner rep)) Identity (Stms rep)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Seq a -> m (Seq b)
mapM Stm rep -> ReaderT (Env (inner rep)) Identity (Stm rep)
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stm rep -> UpdateM (inner rep) (Stm rep)
replaceInStm Stms rep
stms
  removeAllocsInStms stms'

replaceInStm ::
  (Mem rep inner, LetDec rep ~ LetDecMem) =>
  Stm rep ->
  UpdateM (inner rep) (Stm rep)
replaceInStm :: forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stm rep -> UpdateM (inner rep) (Stm rep)
replaceInStm (Let (Pat [PatElem (LetDec rep)]
elems) (StmAux Certs
c Attrs
a Provenance
loc ExpDec rep
d) Exp rep
e) = do
  elems' <- (PatElem LetDecMem
 -> ReaderT (Env (inner rep)) Identity (PatElem LetDecMem))
-> [PatElem LetDecMem]
-> ReaderT (Env (inner rep)) Identity [PatElem LetDecMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM PatElem LetDecMem
-> ReaderT (Env (inner rep)) Identity (PatElem LetDecMem)
forall inner.
PatElem LetDecMem -> UpdateM inner (PatElem LetDecMem)
replaceInPatElem [PatElem (LetDec rep)]
[PatElem LetDecMem]
elems
  e' <- replaceInExp elems' e
  entries <- asks (M.elems . envCoalesceTab)
  let c' = case (CoalsEntry -> Bool) -> [CoalsEntry] -> [CoalsEntry]
forall a. (a -> Bool) -> [a] -> [a]
filter (\CoalsEntry
entry -> ((PatElem LetDecMem -> VName) -> [PatElem LetDecMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName [PatElem (LetDec rep)]
[PatElem LetDecMem]
elems [VName] -> [VName] -> [VName]
forall a. Eq a => [a] -> [a] -> [a]
`L.intersect` Map VName Coalesced -> [VName]
forall k a. Map k a -> [k]
M.keys (CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
entry)) [VName] -> [VName] -> Bool
forall a. Eq a => a -> a -> Bool
/= []) [CoalsEntry]
entries of
        [] -> Certs
c
        [CoalsEntry]
entries' -> Certs
c Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> (CoalsEntry -> Certs) -> [CoalsEntry] -> Certs
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap CoalsEntry -> Certs
certs [CoalsEntry]
entries'
  pure $ Let (Pat elems') (StmAux c' a loc d) e'
  where
    replaceInPatElem :: PatElem LetDecMem -> UpdateM inner (PatElem LetDecMem)
    replaceInPatElem :: forall inner.
PatElem LetDecMem -> UpdateM inner (PatElem LetDecMem)
replaceInPatElem p :: PatElem LetDecMem
p@(PatElem VName
vname (MemArray PrimType
_ ShapeBase SubExp
_ NoUniqueness
u MemBind
_)) =
      PatElem LetDecMem -> Maybe (PatElem LetDecMem) -> PatElem LetDecMem
forall a. a -> Maybe a -> a
fromMaybe PatElem LetDecMem
p (Maybe (PatElem LetDecMem) -> PatElem LetDecMem)
-> ReaderT (Env inner) Identity (Maybe (PatElem LetDecMem))
-> ReaderT (Env inner) Identity (PatElem LetDecMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName
-> (VName -> LetDecMem -> PatElem LetDecMem)
-> NoUniqueness
-> ReaderT (Env inner) Identity (Maybe (PatElem LetDecMem))
forall u a inner.
VName -> (VName -> MemBound u -> a) -> u -> UpdateM inner (Maybe a)
lookupAndReplace VName
vname VName -> LetDecMem -> PatElem LetDecMem
forall dec. VName -> dec -> PatElem dec
PatElem NoUniqueness
u
    replaceInPatElem PatElem LetDecMem
p = PatElem LetDecMem
-> ReaderT (Env inner) Identity (PatElem LetDecMem)
forall a. a -> ReaderT (Env inner) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure PatElem LetDecMem
p

replaceInExp ::
  (Mem rep inner, LetDec rep ~ LetDecMem) =>
  [PatElem LetDecMem] ->
  Exp rep ->
  UpdateM (inner rep) (Exp rep)
replaceInExp :: forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
[PatElem LetDecMem] -> Exp rep -> UpdateM (inner rep) (Exp rep)
replaceInExp [PatElem LetDecMem]
_ e :: Exp rep
e@(BasicOp BasicOp
_) = Exp rep -> ReaderT (Env (inner rep)) Identity (Exp rep)
forall a. a -> ReaderT (Env (inner rep)) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp rep
e
replaceInExp [PatElem LetDecMem]
pat_elems (Match [SubExp]
cond_ses [Case (Body rep)]
cases Body rep
defbody MatchDec (BranchType rep)
dec) = do
  defbody' <- Body rep -> UpdateM (inner rep) (Body rep)
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Body rep -> UpdateM (inner rep) (Body rep)
replaceInIfBody Body rep
defbody
  cases' <- mapM (\(Case [Maybe PrimValue]
p Body rep
b) -> [Maybe PrimValue] -> Body rep -> Case (Body rep)
forall body. [Maybe PrimValue] -> body -> Case body
Case [Maybe PrimValue]
p (Body rep -> Case (Body rep))
-> UpdateM (inner rep) (Body rep)
-> ReaderT (Env (inner rep)) Identity (Case (Body rep))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body rep -> UpdateM (inner rep) (Body rep)
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Body rep -> UpdateM (inner rep) (Body rep)
replaceInIfBody Body rep
b) cases
  case_rets <- zipWithM (generalizeIxfun pat_elems) pat_elems $ matchReturns dec
  let dec' = MatchDec (BranchType rep)
dec {matchReturns = case_rets}
  pure $ Match cond_ses cases' defbody' dec'
replaceInExp [PatElem LetDecMem]
_ (Loop [(FParam rep, SubExp)]
loop_inits LoopForm
loop_form (Body BodyDec rep
dec Stms rep
stms [SubExpRes]
res)) = do
  loop_inits' <- ((Param (MemInfo SubExp Uniqueness MemBind), SubExp)
 -> ReaderT
      (Env (inner rep))
      Identity
      (Param (MemInfo SubExp Uniqueness MemBind)))
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> ReaderT
     (Env (inner rep))
     Identity
     [Param (MemInfo SubExp Uniqueness MemBind)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Param (MemInfo SubExp Uniqueness MemBind)
-> ReaderT
     (Env (inner rep))
     Identity
     (Param (MemInfo SubExp Uniqueness MemBind))
forall inner.
Param (MemInfo SubExp Uniqueness MemBind)
-> UpdateM inner (Param (MemInfo SubExp Uniqueness MemBind))
replaceInFParam (Param (MemInfo SubExp Uniqueness MemBind)
 -> ReaderT
      (Env (inner rep))
      Identity
      (Param (MemInfo SubExp Uniqueness MemBind)))
-> ((Param (MemInfo SubExp Uniqueness MemBind), SubExp)
    -> Param (MemInfo SubExp Uniqueness MemBind))
-> (Param (MemInfo SubExp Uniqueness MemBind), SubExp)
-> ReaderT
     (Env (inner rep))
     Identity
     (Param (MemInfo SubExp Uniqueness MemBind))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param (MemInfo SubExp Uniqueness MemBind), SubExp)
-> Param (MemInfo SubExp Uniqueness MemBind)
forall a b. (a, b) -> a
fst) [(FParam rep, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
loop_inits
  stms' <- updateStms stms
  coalstab <- asks envCoalesceTab
  let res' = (SubExpRes -> SubExpRes) -> [SubExpRes] -> [SubExpRes]
forall a b. (a -> b) -> [a] -> [b]
map (CoalsTab -> SubExpRes -> SubExpRes
replaceResMem CoalsTab
coalstab) [SubExpRes]
res
  pure $ Loop (zip loop_inits' $ map snd loop_inits) loop_form $ Body dec stms' res'
replaceInExp [PatElem LetDecMem]
_ (Op Op rep
op) =
  case Op rep
op of
    Inner inner rep
i -> do
      on_op <- (Env (inner rep) -> inner rep -> UpdateM (inner rep) (inner rep))
-> ReaderT
     (Env (inner rep))
     Identity
     (inner rep -> UpdateM (inner rep) (inner rep))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env (inner rep) -> inner rep -> UpdateM (inner rep) (inner rep)
forall inner. Env inner -> inner -> UpdateM inner inner
onInner
      Op . Inner <$> on_op i
    Op rep
_ -> Exp rep -> ReaderT (Env (inner rep)) Identity (Exp rep)
forall a. a -> ReaderT (Env (inner rep)) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp rep -> ReaderT (Env (inner rep)) Identity (Exp rep))
-> Exp rep -> ReaderT (Env (inner rep)) Identity (Exp rep)
forall a b. (a -> b) -> a -> b
$ Op rep -> Exp rep
forall rep. Op rep -> Exp rep
Op Op rep
op
replaceInExp [PatElem LetDecMem]
_ e :: Exp rep
e@WithAcc {} = Exp rep -> ReaderT (Env (inner rep)) Identity (Exp rep)
forall a. a -> ReaderT (Env (inner rep)) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp rep
e
replaceInExp [PatElem LetDecMem]
_ e :: Exp rep
e@Apply {} = Exp rep -> ReaderT (Env (inner rep)) Identity (Exp rep)
forall a. a -> ReaderT (Env (inner rep)) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp rep
e

replaceInSegOp ::
  (Mem rep inner, LetDec rep ~ LetDecMem) =>
  SegOp lvl rep ->
  UpdateM (inner rep) (SegOp lvl rep)
replaceInSegOp :: forall rep (inner :: * -> *) lvl.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
SegOp lvl rep -> UpdateM (inner rep) (SegOp lvl rep)
replaceInSegOp (SegMap lvl
lvl SegSpace
sp [Type]
tps KernelBody rep
body) = do
  stms <- Stms rep -> UpdateM (inner rep) (Stms rep)
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> UpdateM (inner rep) (Stms rep)
updateStms (Stms rep -> UpdateM (inner rep) (Stms rep))
-> Stms rep -> UpdateM (inner rep) (Stms rep)
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> Stms rep
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody rep
body
  pure $ SegMap lvl sp tps $ body {kernelBodyStms = stms}
replaceInSegOp (SegRed lvl
lvl SegSpace
sp [Type]
tps KernelBody rep
body [SegBinOp rep]
binops) = do
  stms <- Stms rep -> UpdateM (inner rep) (Stms rep)
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> UpdateM (inner rep) (Stms rep)
updateStms (Stms rep -> UpdateM (inner rep) (Stms rep))
-> Stms rep -> UpdateM (inner rep) (Stms rep)
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> Stms rep
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody rep
body
  pure $ SegRed lvl sp tps (body {kernelBodyStms = stms}) binops
replaceInSegOp (SegScan lvl
lvl SegSpace
sp [Type]
tps KernelBody rep
body [SegBinOp rep]
binops) = do
  stms <- Stms rep -> UpdateM (inner rep) (Stms rep)
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> UpdateM (inner rep) (Stms rep)
updateStms (Stms rep -> UpdateM (inner rep) (Stms rep))
-> Stms rep -> UpdateM (inner rep) (Stms rep)
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> Stms rep
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody rep
body
  pure $ SegScan lvl sp tps (body {kernelBodyStms = stms}) binops
replaceInSegOp (SegHist lvl
lvl SegSpace
sp [Type]
tps KernelBody rep
body [HistOp rep]
hist_ops) = do
  stms <- Stms rep -> UpdateM (inner rep) (Stms rep)
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> UpdateM (inner rep) (Stms rep)
updateStms (Stms rep -> UpdateM (inner rep) (Stms rep))
-> Stms rep -> UpdateM (inner rep) (Stms rep)
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> Stms rep
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody rep
body
  pure $ SegHist lvl sp tps (body {kernelBodyStms = stms}) hist_ops

replaceInHostOp :: HostOp NoOp GPUMem -> UpdateM (HostOp NoOp GPUMem) (HostOp NoOp GPUMem)
replaceInHostOp :: HostOp NoOp GPUMem
-> UpdateM (HostOp NoOp GPUMem) (HostOp NoOp GPUMem)
replaceInHostOp (SegOp SegOp SegLevel GPUMem
op) = SegOp SegLevel GPUMem -> HostOp NoOp GPUMem
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp (SegOp SegLevel GPUMem -> HostOp NoOp GPUMem)
-> ReaderT
     (Env (HostOp NoOp GPUMem)) Identity (SegOp SegLevel GPUMem)
-> UpdateM (HostOp NoOp GPUMem) (HostOp NoOp GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOp SegLevel GPUMem
-> ReaderT
     (Env (HostOp NoOp GPUMem)) Identity (SegOp SegLevel GPUMem)
forall rep (inner :: * -> *) lvl.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
SegOp lvl rep -> UpdateM (inner rep) (SegOp lvl rep)
replaceInSegOp SegOp SegLevel GPUMem
op
replaceInHostOp HostOp NoOp GPUMem
op = HostOp NoOp GPUMem
-> UpdateM (HostOp NoOp GPUMem) (HostOp NoOp GPUMem)
forall a. a -> ReaderT (Env (HostOp NoOp GPUMem)) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure HostOp NoOp GPUMem
op

replaceInMCOp :: MCOp NoOp MCMem -> UpdateM (MCOp NoOp MCMem) (MCOp NoOp MCMem)
replaceInMCOp :: MCOp NoOp MCMem -> UpdateM (MCOp NoOp MCMem) (MCOp NoOp MCMem)
replaceInMCOp (ParOp Maybe (SegOp () MCMem)
par_op SegOp () MCMem
op) =
  Maybe (SegOp () MCMem) -> SegOp () MCMem -> MCOp NoOp MCMem
forall (op :: * -> *) rep.
Maybe (SegOp () rep) -> SegOp () rep -> MCOp op rep
ParOp (Maybe (SegOp () MCMem) -> SegOp () MCMem -> MCOp NoOp MCMem)
-> ReaderT
     (Env (MCOp NoOp MCMem)) Identity (Maybe (SegOp () MCMem))
-> ReaderT
     (Env (MCOp NoOp MCMem))
     Identity
     (SegOp () MCMem -> MCOp NoOp MCMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SegOp () MCMem
 -> ReaderT (Env (MCOp NoOp MCMem)) Identity (SegOp () MCMem))
-> Maybe (SegOp () MCMem)
-> ReaderT
     (Env (MCOp NoOp MCMem)) Identity (Maybe (SegOp () MCMem))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Maybe a -> f (Maybe b)
traverse SegOp () MCMem
-> ReaderT (Env (MCOp NoOp MCMem)) Identity (SegOp () MCMem)
forall rep (inner :: * -> *) lvl.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
SegOp lvl rep -> UpdateM (inner rep) (SegOp lvl rep)
replaceInSegOp Maybe (SegOp () MCMem)
par_op ReaderT
  (Env (MCOp NoOp MCMem))
  Identity
  (SegOp () MCMem -> MCOp NoOp MCMem)
-> ReaderT (Env (MCOp NoOp MCMem)) Identity (SegOp () MCMem)
-> UpdateM (MCOp NoOp MCMem) (MCOp NoOp MCMem)
forall a b.
ReaderT (Env (MCOp NoOp MCMem)) Identity (a -> b)
-> ReaderT (Env (MCOp NoOp MCMem)) Identity a
-> ReaderT (Env (MCOp NoOp MCMem)) Identity b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOp () MCMem
-> ReaderT (Env (MCOp NoOp MCMem)) Identity (SegOp () MCMem)
forall rep (inner :: * -> *) lvl.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
SegOp lvl rep -> UpdateM (inner rep) (SegOp lvl rep)
replaceInSegOp SegOp () MCMem
op
replaceInMCOp MCOp NoOp MCMem
op = MCOp NoOp MCMem -> UpdateM (MCOp NoOp MCMem) (MCOp NoOp MCMem)
forall a. a -> ReaderT (Env (MCOp NoOp MCMem)) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure MCOp NoOp MCMem
op

generalizeIxfun :: [PatElem dec] -> PatElem LetDecMem -> BodyReturns -> UpdateM inner BodyReturns
generalizeIxfun :: forall dec inner.
[PatElem dec]
-> PatElem LetDecMem
-> BranchTypeMem
-> UpdateM inner BranchTypeMem
generalizeIxfun
  [PatElem dec]
pat_elems
  (PatElem VName
vname (MemArray PrimType
_ ShapeBase SubExp
_ NoUniqueness
_ (ArrayIn VName
mem LMAD
ixf)))
  m :: BranchTypeMem
m@(MemArray PrimType
pt ShapeBase ExtSize
shp NoUniqueness
u MemReturn
_) = do
    coaltab <- (Env inner -> CoalsTab) -> ReaderT (Env inner) Identity CoalsTab
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env inner -> CoalsTab
forall inner. Env inner -> CoalsTab
envCoalesceTab
    if any (M.member vname . vartab) coaltab
      then
        existentialiseLMAD (map patElemName pat_elems) ixf
          & ReturnsInBlock mem
          & MemArray pt shp u
          & pure
      else pure m
generalizeIxfun [PatElem dec]
_ PatElem LetDecMem
_ BranchTypeMem
m = BranchTypeMem -> ReaderT (Env inner) Identity BranchTypeMem
forall a. a -> ReaderT (Env inner) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure BranchTypeMem
m

replaceInIfBody :: (Mem rep inner, LetDec rep ~ LetDecMem) => Body rep -> UpdateM (inner rep) (Body rep)
replaceInIfBody :: forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Body rep -> UpdateM (inner rep) (Body rep)
replaceInIfBody b :: Body rep
b@(Body BodyDec rep
_ Stms rep
stms [SubExpRes]
res) = do
  coaltab <- (Env (inner rep) -> CoalsTab)
-> ReaderT (Env (inner rep)) Identity CoalsTab
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env (inner rep) -> CoalsTab
forall inner. Env inner -> CoalsTab
envCoalesceTab
  stms' <- updateStms stms
  pure $ b {bodyStms = stms', bodyResult = map (replaceResMem coaltab) res}

replaceInFParam :: Param FParamMem -> UpdateM inner (Param FParamMem)
replaceInFParam :: forall inner.
Param (MemInfo SubExp Uniqueness MemBind)
-> UpdateM inner (Param (MemInfo SubExp Uniqueness MemBind))
replaceInFParam p :: Param (MemInfo SubExp Uniqueness MemBind)
p@(Param Attrs
_ VName
vname (MemArray PrimType
_ ShapeBase SubExp
_ Uniqueness
u MemBind
_)) = do
  Param (MemInfo SubExp Uniqueness MemBind)
-> Maybe (Param (MemInfo SubExp Uniqueness MemBind))
-> Param (MemInfo SubExp Uniqueness MemBind)
forall a. a -> Maybe a -> a
fromMaybe Param (MemInfo SubExp Uniqueness MemBind)
p (Maybe (Param (MemInfo SubExp Uniqueness MemBind))
 -> Param (MemInfo SubExp Uniqueness MemBind))
-> ReaderT
     (Env inner)
     Identity
     (Maybe (Param (MemInfo SubExp Uniqueness MemBind)))
-> ReaderT
     (Env inner) Identity (Param (MemInfo SubExp Uniqueness MemBind))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName
-> (VName
    -> MemInfo SubExp Uniqueness MemBind
    -> Param (MemInfo SubExp Uniqueness MemBind))
-> Uniqueness
-> ReaderT
     (Env inner)
     Identity
     (Maybe (Param (MemInfo SubExp Uniqueness MemBind)))
forall u a inner.
VName -> (VName -> MemBound u -> a) -> u -> UpdateM inner (Maybe a)
lookupAndReplace VName
vname (Attrs
-> VName
-> MemInfo SubExp Uniqueness MemBind
-> Param (MemInfo SubExp Uniqueness MemBind)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty) Uniqueness
u
replaceInFParam Param (MemInfo SubExp Uniqueness MemBind)
p = Param (MemInfo SubExp Uniqueness MemBind)
-> ReaderT
     (Env inner) Identity (Param (MemInfo SubExp Uniqueness MemBind))
forall a. a -> ReaderT (Env inner) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Param (MemInfo SubExp Uniqueness MemBind)
p

lookupAndReplace ::
  VName ->
  (VName -> MemBound u -> a) ->
  u ->
  UpdateM inner (Maybe a)
lookupAndReplace :: forall u a inner.
VName -> (VName -> MemBound u -> a) -> u -> UpdateM inner (Maybe a)
lookupAndReplace VName
vname VName -> MemBound u -> a
f u
u = do
  coaltab <- (Env inner -> CoalsTab) -> ReaderT (Env inner) Identity CoalsTab
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env inner -> CoalsTab
forall inner. Env inner -> CoalsTab
envCoalesceTab
  case M.lookup vname $ foldMap vartab coaltab of
    Just (Coalesced CoalescedKind
_ (MemBlock PrimType
pt ShapeBase SubExp
shp VName
mem LMAD
ixf) FreeVarSubsts
subs) ->
      LMAD
ixf
        LMAD -> (LMAD -> LMAD) -> LMAD
forall a b. a -> (a -> b) -> b
& (LMAD -> LMAD) -> LMAD -> LMAD
forall a. Eq a => (a -> a) -> a -> a
fixPoint (FreeVarSubsts -> LMAD -> LMAD
forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
LMAD.substitute FreeVarSubsts
subs)
        LMAD -> (LMAD -> MemBind) -> MemBind
forall a b. a -> (a -> b) -> b
& VName -> LMAD -> MemBind
ArrayIn VName
mem
        MemBind -> (MemBind -> MemBound u) -> MemBound u
forall a b. a -> (a -> b) -> b
& PrimType -> ShapeBase SubExp -> u -> MemBind -> MemBound u
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase SubExp
shp u
u
        MemBound u -> (MemBound u -> a) -> a
forall a b. a -> (a -> b) -> b
& VName -> MemBound u -> a
f VName
vname
        a -> (a -> Maybe a) -> Maybe a
forall a b. a -> (a -> b) -> b
& a -> Maybe a
forall a. a -> Maybe a
Just
        Maybe a
-> (Maybe a -> ReaderT (Env inner) Identity (Maybe a))
-> ReaderT (Env inner) Identity (Maybe a)
forall a b. a -> (a -> b) -> b
& Maybe a -> ReaderT (Env inner) Identity (Maybe a)
forall a. a -> ReaderT (Env inner) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    Maybe Coalesced
Nothing -> Maybe a -> ReaderT (Env inner) Identity (Maybe a)
forall a. a -> ReaderT (Env inner) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe a
forall a. Maybe a
Nothing