{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE Strict #-}
{-# LANGUAGE TypeFamilies #-}

module Futhark.CodeGen.ImpGen
  ( -- * Entry Points
    compileProg,

    -- * Pluggable Compiler
    OpCompiler,
    ExpCompiler,
    CopyCompiler,
    StmsCompiler,
    AllocCompiler,
    Operations (..),
    defaultOperations,
    MemLoc (..),
    sliceMemLoc,
    MemEntry (..),
    ScalarEntry (..),

    -- * Monadic Compiler Interface
    ImpM,
    localDefaultSpace,
    askFunction,
    newVNameForFun,
    nameForFun,
    askEnv,
    localEnv,
    localOps,
    VTable,
    getVTable,
    localVTable,
    subImpM,
    subImpM_,
    emit,
    emitFunction,
    hasFunction,
    collect,
    collect',
    VarEntry (..),
    ArrayEntry (..),

    -- * Lookups
    lookupVar,
    lookupArray,
    lookupArraySpace,
    lookupMemory,
    lookupAcc,
    askAttrs,
    askProvenance,

    -- * Building Blocks
    TV,
    MkTV (..),
    tvSize,
    tvExp,
    tvVar,
    ToExp (..),
    compileAlloc,
    everythingVolatile,
    compileBody,
    compileBody',
    compileLoopBody,
    defCompileStms,
    compileStms,
    compileExp,
    defCompileExp,
    fullyIndexArray,
    fullyIndexArray',
    copy,
    copyDWIM,
    copyDWIMFix,
    lmadCopy,
    typeSize,
    inBounds,
    caseMatch,

    -- * Constructing code.
    newVName,
    dLParams,
    dFParams,
    addLoopVar,
    dScope,
    dArray,
    dPrim,
    dPrimS,
    dPrimSV,
    dPrimVol,
    dPrim_,
    dPrimV_,
    dPrimV,
    dPrimVE,
    dIndexSpace,
    dIndexSpace',
    sFor,
    sWhile,
    sComment,
    sIf,
    sWhen,
    sUnless,
    sOp,
    sDeclareMem,
    sAlloc,
    sAlloc_,
    sArray,
    sArrayInMem,
    sAllocArray,
    sAllocArrayPerm,
    sStaticArray,
    sWrite,
    sUpdate,
    sLoopNest,
    sLoopSpace,
    (<--),
    (<~~),
    function,
    genConstants,
    warn,
    module Language.Futhark.Warnings,
  )
where

import Control.Monad
import Control.Monad.Reader
import Control.Monad.State
import Control.Parallel.Strategies
import Data.Bifunctor (first)
import Data.DList qualified as DL
import Data.Either
import Data.List (find)
import Data.List.NonEmpty (NonEmpty (..))
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Set qualified as S
import Data.String
import Data.Text qualified as T
import Futhark.CodeGen.ImpCode
  ( Bytes,
    Count,
    Elements,
    elements,
  )
import Futhark.CodeGen.ImpCode qualified as Imp
import Futhark.Construct hiding (ToExp (..))
import Futhark.IR.Mem
import Futhark.IR.Mem.LMAD qualified as LMAD
import Futhark.IR.SOACS (SOACS)
import Futhark.Transform.Rename (renameLambda)
import Futhark.Util
import Futhark.Util.IntegralExp
import Futhark.Util.Pretty hiding (nest, space)
import Language.Futhark.Warnings
import Prelude hiding (mod, quot)

-- | How to compile an t'Op'.
type OpCompiler rep r op = Pat (LetDec rep) -> Op rep -> ImpM rep r op ()

-- | How to compile some 'Stms'.
type StmsCompiler rep r op = Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()

-- | How to compile an 'Exp'.
type ExpCompiler rep r op = Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()

type CopyCompiler rep r op =
  PrimType ->
  MemLoc ->
  MemLoc ->
  ImpM rep r op ()

-- | An alternate way of compiling an allocation.
type AllocCompiler rep r op = VName -> Count Bytes (Imp.TExp Int64) -> ImpM rep r op ()

data Operations rep r op = Operations
  { forall rep r op. Operations rep r op -> ExpCompiler rep r op
opsExpCompiler :: ExpCompiler rep r op,
    forall rep r op. Operations rep r op -> OpCompiler rep r op
opsOpCompiler :: OpCompiler rep r op,
    forall rep r op. Operations rep r op -> StmsCompiler rep r op
opsStmsCompiler :: StmsCompiler rep r op,
    forall rep r op. Operations rep r op -> CopyCompiler rep r op
opsCopyCompiler :: CopyCompiler rep r op,
    forall rep r op.
Operations rep r op -> Map Space (AllocCompiler rep r op)
opsAllocCompilers :: M.Map Space (AllocCompiler rep r op)
  }

-- | An operations set for which the expression compiler always
-- returns 'defCompileExp'.
defaultOperations ::
  (Mem rep inner, FreeIn op) =>
  OpCompiler rep r op ->
  Operations rep r op
defaultOperations :: forall rep (inner :: * -> *) op r.
(Mem rep inner, FreeIn op) =>
OpCompiler rep r op -> Operations rep r op
defaultOperations OpCompiler rep r op
opc =
  Operations
    { opsExpCompiler :: ExpCompiler rep r op
opsExpCompiler = ExpCompiler rep r op
forall rep (inner :: * -> *) r op.
Mem rep inner =>
Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()
defCompileExp,
      opsOpCompiler :: OpCompiler rep r op
opsOpCompiler = OpCompiler rep r op
opc,
      opsStmsCompiler :: StmsCompiler rep r op
opsStmsCompiler = StmsCompiler rep r op
forall rep (inner :: * -> *) op r.
(Mem rep inner, FreeIn op) =>
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
defCompileStms,
      opsCopyCompiler :: CopyCompiler rep r op
opsCopyCompiler = CopyCompiler rep r op
forall rep r op. CopyCompiler rep r op
lmadCopy,
      opsAllocCompilers :: Map Space (AllocCompiler rep r op)
opsAllocCompilers = Map Space (AllocCompiler rep r op)
forall a. Monoid a => a
mempty
    }

-- | When an array is declared, this is where it is stored.
data MemLoc = MemLoc
  { MemLoc -> VName
memLocName :: VName,
    MemLoc -> [SubExp]
memLocShape :: [Imp.DimSize],
    MemLoc -> LMAD (TExp Int64)
memLocLMAD :: LMAD.LMAD (Imp.TExp Int64)
  }
  deriving (MemLoc -> MemLoc -> Bool
(MemLoc -> MemLoc -> Bool)
-> (MemLoc -> MemLoc -> Bool) -> Eq MemLoc
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: MemLoc -> MemLoc -> Bool
== :: MemLoc -> MemLoc -> Bool
$c/= :: MemLoc -> MemLoc -> Bool
/= :: MemLoc -> MemLoc -> Bool
Eq, Int -> MemLoc -> ShowS
[MemLoc] -> ShowS
MemLoc -> [Char]
(Int -> MemLoc -> ShowS)
-> (MemLoc -> [Char]) -> ([MemLoc] -> ShowS) -> Show MemLoc
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> MemLoc -> ShowS
showsPrec :: Int -> MemLoc -> ShowS
$cshow :: MemLoc -> [Char]
show :: MemLoc -> [Char]
$cshowList :: [MemLoc] -> ShowS
showList :: [MemLoc] -> ShowS
Show)

sliceMemLoc :: MemLoc -> Slice (Imp.TExp Int64) -> MemLoc
sliceMemLoc :: MemLoc -> Slice (TExp Int64) -> MemLoc
sliceMemLoc (MemLoc VName
mem [SubExp]
shape LMAD (TExp Int64)
lmad) Slice (TExp Int64)
slice =
  VName -> [SubExp] -> LMAD (TExp Int64) -> MemLoc
MemLoc VName
mem [SubExp]
shape (LMAD (TExp Int64) -> MemLoc) -> LMAD (TExp Int64) -> MemLoc
forall a b. (a -> b) -> a -> b
$ LMAD (TExp Int64) -> Slice (TExp Int64) -> LMAD (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
LMAD num -> Slice num -> LMAD num
LMAD.slice LMAD (TExp Int64)
lmad Slice (TExp Int64)
slice

flatSliceMemLoc :: MemLoc -> FlatSlice (Imp.TExp Int64) -> MemLoc
flatSliceMemLoc :: MemLoc -> FlatSlice (TExp Int64) -> MemLoc
flatSliceMemLoc (MemLoc VName
mem [SubExp]
shape LMAD (TExp Int64)
lmad) FlatSlice (TExp Int64)
slice =
  VName -> [SubExp] -> LMAD (TExp Int64) -> MemLoc
MemLoc VName
mem [SubExp]
shape (LMAD (TExp Int64) -> MemLoc) -> LMAD (TExp Int64) -> MemLoc
forall a b. (a -> b) -> a -> b
$ LMAD (TExp Int64) -> FlatSlice (TExp Int64) -> LMAD (TExp Int64)
forall num.
IntegralExp num =>
LMAD num -> FlatSlice num -> LMAD num
LMAD.flatSlice LMAD (TExp Int64)
lmad FlatSlice (TExp Int64)
slice

data ArrayEntry = ArrayEntry
  { ArrayEntry -> MemLoc
entryArrayLoc :: MemLoc,
    ArrayEntry -> PrimType
entryArrayElemType :: PrimType
  }
  deriving (Int -> ArrayEntry -> ShowS
[ArrayEntry] -> ShowS
ArrayEntry -> [Char]
(Int -> ArrayEntry -> ShowS)
-> (ArrayEntry -> [Char])
-> ([ArrayEntry] -> ShowS)
-> Show ArrayEntry
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ArrayEntry -> ShowS
showsPrec :: Int -> ArrayEntry -> ShowS
$cshow :: ArrayEntry -> [Char]
show :: ArrayEntry -> [Char]
$cshowList :: [ArrayEntry] -> ShowS
showList :: [ArrayEntry] -> ShowS
Show)

entryArrayShape :: ArrayEntry -> [Imp.DimSize]
entryArrayShape :: ArrayEntry -> [SubExp]
entryArrayShape = MemLoc -> [SubExp]
memLocShape (MemLoc -> [SubExp])
-> (ArrayEntry -> MemLoc) -> ArrayEntry -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayEntry -> MemLoc
entryArrayLoc

newtype MemEntry = MemEntry {MemEntry -> Space
entryMemSpace :: Imp.Space}
  deriving (Int -> MemEntry -> ShowS
[MemEntry] -> ShowS
MemEntry -> [Char]
(Int -> MemEntry -> ShowS)
-> (MemEntry -> [Char]) -> ([MemEntry] -> ShowS) -> Show MemEntry
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> MemEntry -> ShowS
showsPrec :: Int -> MemEntry -> ShowS
$cshow :: MemEntry -> [Char]
show :: MemEntry -> [Char]
$cshowList :: [MemEntry] -> ShowS
showList :: [MemEntry] -> ShowS
Show)

newtype ScalarEntry = ScalarEntry
  { ScalarEntry -> PrimType
entryScalarType :: PrimType
  }
  deriving (Int -> ScalarEntry -> ShowS
[ScalarEntry] -> ShowS
ScalarEntry -> [Char]
(Int -> ScalarEntry -> ShowS)
-> (ScalarEntry -> [Char])
-> ([ScalarEntry] -> ShowS)
-> Show ScalarEntry
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ScalarEntry -> ShowS
showsPrec :: Int -> ScalarEntry -> ShowS
$cshow :: ScalarEntry -> [Char]
show :: ScalarEntry -> [Char]
$cshowList :: [ScalarEntry] -> ShowS
showList :: [ScalarEntry] -> ShowS
Show)

-- | Every non-scalar variable must be associated with an entry.
data VarEntry rep
  = ArrayVar (Maybe (Exp rep)) ArrayEntry
  | ScalarVar (Maybe (Exp rep)) ScalarEntry
  | MemVar (Maybe (Exp rep)) MemEntry
  | AccVar (Maybe (Exp rep)) (VName, Shape, [Type])
  deriving (Int -> VarEntry rep -> ShowS
[VarEntry rep] -> ShowS
VarEntry rep -> [Char]
(Int -> VarEntry rep -> ShowS)
-> (VarEntry rep -> [Char])
-> ([VarEntry rep] -> ShowS)
-> Show (VarEntry rep)
forall rep. RepTypes rep => Int -> VarEntry rep -> ShowS
forall rep. RepTypes rep => [VarEntry rep] -> ShowS
forall rep. RepTypes rep => VarEntry rep -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall rep. RepTypes rep => Int -> VarEntry rep -> ShowS
showsPrec :: Int -> VarEntry rep -> ShowS
$cshow :: forall rep. RepTypes rep => VarEntry rep -> [Char]
show :: VarEntry rep -> [Char]
$cshowList :: forall rep. RepTypes rep => [VarEntry rep] -> ShowS
showList :: [VarEntry rep] -> ShowS
Show)

data ValueDestination
  = ScalarDestination VName
  | MemoryDestination VName
  | -- | The 'MemLoc' is 'Just' if a copy if
    -- required.  If it is 'Nothing', then a
    -- copy/assignment of a memory block somewhere
    -- takes care of this array.
    ArrayDestination (Maybe MemLoc)
  deriving (Int -> ValueDestination -> ShowS
[ValueDestination] -> ShowS
ValueDestination -> [Char]
(Int -> ValueDestination -> ShowS)
-> (ValueDestination -> [Char])
-> ([ValueDestination] -> ShowS)
-> Show ValueDestination
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ValueDestination -> ShowS
showsPrec :: Int -> ValueDestination -> ShowS
$cshow :: ValueDestination -> [Char]
show :: ValueDestination -> [Char]
$cshowList :: [ValueDestination] -> ShowS
showList :: [ValueDestination] -> ShowS
Show)

data Env rep r op = Env
  { forall rep r op. Env rep r op -> ExpCompiler rep r op
envExpCompiler :: ExpCompiler rep r op,
    forall rep r op. Env rep r op -> StmsCompiler rep r op
envStmsCompiler :: StmsCompiler rep r op,
    forall rep r op. Env rep r op -> OpCompiler rep r op
envOpCompiler :: OpCompiler rep r op,
    forall rep r op. Env rep r op -> CopyCompiler rep r op
envCopyCompiler :: CopyCompiler rep r op,
    forall rep r op. Env rep r op -> Map Space (AllocCompiler rep r op)
envAllocCompilers :: M.Map Space (AllocCompiler rep r op),
    forall rep r op. Env rep r op -> Space
envDefaultSpace :: Imp.Space,
    forall rep r op. Env rep r op -> Volatility
envVolatility :: Imp.Volatility,
    -- | User-extensible environment.
    forall rep r op. Env rep r op -> r
envEnv :: r,
    -- | Name of the function we are compiling, if any.
    forall rep r op. Env rep r op -> Maybe Name
envFunction :: Maybe Name,
    -- | The set of attributes that are active on the enclosing
    -- statements (including the one we are currently compiling).
    forall rep r op. Env rep r op -> Attrs
envAttrs :: Attrs,
    -- | The provenance of whatever we are currently generating code for. This
    -- can be used to insert information in the generated code.
    forall rep r op. Env rep r op -> Provenance
envProvenance :: Provenance
  }

newEnv :: r -> Operations rep r op -> Imp.Space -> Env rep r op
newEnv :: forall r rep op. r -> Operations rep r op -> Space -> Env rep r op
newEnv r
r Operations rep r op
ops Space
ds =
  Env
    { envExpCompiler :: ExpCompiler rep r op
envExpCompiler = Operations rep r op -> ExpCompiler rep r op
forall rep r op. Operations rep r op -> ExpCompiler rep r op
opsExpCompiler Operations rep r op
ops,
      envStmsCompiler :: StmsCompiler rep r op
envStmsCompiler = Operations rep r op -> StmsCompiler rep r op
forall rep r op. Operations rep r op -> StmsCompiler rep r op
opsStmsCompiler Operations rep r op
ops,
      envOpCompiler :: OpCompiler rep r op
envOpCompiler = Operations rep r op -> OpCompiler rep r op
forall rep r op. Operations rep r op -> OpCompiler rep r op
opsOpCompiler Operations rep r op
ops,
      envCopyCompiler :: CopyCompiler rep r op
envCopyCompiler = Operations rep r op -> CopyCompiler rep r op
forall rep r op. Operations rep r op -> CopyCompiler rep r op
opsCopyCompiler Operations rep r op
ops,
      envAllocCompilers :: Map Space (AllocCompiler rep r op)
envAllocCompilers = Map Space (AllocCompiler rep r op)
forall a. Monoid a => a
mempty,
      envDefaultSpace :: Space
envDefaultSpace = Space
ds,
      envVolatility :: Volatility
envVolatility = Volatility
Imp.Nonvolatile,
      envEnv :: r
envEnv = r
r,
      envFunction :: Maybe Name
envFunction = Maybe Name
forall a. Maybe a
Nothing,
      envAttrs :: Attrs
envAttrs = Attrs
forall a. Monoid a => a
mempty,
      envProvenance :: Provenance
envProvenance = Provenance
forall a. Monoid a => a
mempty
    }

-- | The symbol table used during compilation.
type VTable rep = M.Map VName (VarEntry rep)

data ImpState rep r op = ImpState
  { forall {k} rep (r :: k) op. ImpState rep r op -> VTable rep
stateVTable :: VTable rep,
    forall {k} rep (r :: k) op. ImpState rep r op -> Functions op
stateFunctions :: Imp.Functions op,
    forall {k} rep (r :: k) op. ImpState rep r op -> Code op
stateCode :: Imp.Code op,
    forall {k} rep (r :: k) op. ImpState rep r op -> Constants op
stateConstants :: Imp.Constants op,
    forall {k} rep (r :: k) op. ImpState rep r op -> Warnings
stateWarnings :: Warnings,
    -- | Maps the arrays backing each accumulator to their
    -- update function and neutral elements.  This works
    -- because an array name can only become part of a single
    -- accumulator throughout its lifetime.  If the arrays
    -- backing an accumulator is not in this mapping, the
    -- accumulator is scatter-like.
    forall {k} rep (r :: k) op.
ImpState rep r op
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
stateAccs :: M.Map VName ([VName], Maybe (Lambda rep, [SubExp])),
    forall {k} rep (r :: k) op. ImpState rep r op -> VNameSource
stateNameSource :: VNameSource
  }

newState :: VNameSource -> ImpState rep r op
newState :: forall {k} rep (r :: k) op. VNameSource -> ImpState rep r op
newState = VTable rep
-> Functions op
-> Code op
-> Constants op
-> Warnings
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
-> VNameSource
-> ImpState rep r op
forall {k} rep (r :: k) op.
VTable rep
-> Functions op
-> Code op
-> Constants op
-> Warnings
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
-> VNameSource
-> ImpState rep r op
ImpState VTable rep
forall a. Monoid a => a
mempty Functions op
forall a. Monoid a => a
mempty Code op
forall a. Monoid a => a
mempty Constants op
forall a. Monoid a => a
mempty Warnings
forall a. Monoid a => a
mempty Map VName ([VName], Maybe (Lambda rep, [SubExp]))
forall a. Monoid a => a
mempty

newtype ImpM rep r op a
  = ImpM (ReaderT (Env rep r op) (State (ImpState rep r op)) a)
  deriving
    ( (forall a b. (a -> b) -> ImpM rep r op a -> ImpM rep r op b)
-> (forall a b. a -> ImpM rep r op b -> ImpM rep r op a)
-> Functor (ImpM rep r op)
forall a b. a -> ImpM rep r op b -> ImpM rep r op a
forall a b. (a -> b) -> ImpM rep r op a -> ImpM rep r op b
forall rep r op a b. a -> ImpM rep r op b -> ImpM rep r op a
forall rep r op a b. (a -> b) -> ImpM rep r op a -> ImpM rep r op b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall rep r op a b. (a -> b) -> ImpM rep r op a -> ImpM rep r op b
fmap :: forall a b. (a -> b) -> ImpM rep r op a -> ImpM rep r op b
$c<$ :: forall rep r op a b. a -> ImpM rep r op b -> ImpM rep r op a
<$ :: forall a b. a -> ImpM rep r op b -> ImpM rep r op a
Functor,
      Functor (ImpM rep r op)
Functor (ImpM rep r op) =>
(forall a. a -> ImpM rep r op a)
-> (forall a b.
    ImpM rep r op (a -> b) -> ImpM rep r op a -> ImpM rep r op b)
-> (forall a b c.
    (a -> b -> c)
    -> ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op c)
-> (forall a b.
    ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b)
-> (forall a b.
    ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op a)
-> Applicative (ImpM rep r op)
forall a. a -> ImpM rep r op a
forall a b. ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op a
forall a b. ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
forall a b.
ImpM rep r op (a -> b) -> ImpM rep r op a -> ImpM rep r op b
forall rep r op. Functor (ImpM rep r op)
forall a b c.
(a -> b -> c)
-> ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op c
forall rep r op a. a -> ImpM rep r op a
forall rep r op a b.
ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op a
forall rep r op a b.
ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
forall rep r op a b.
ImpM rep r op (a -> b) -> ImpM rep r op a -> ImpM rep r op b
forall rep r op a b c.
(a -> b -> c)
-> ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall rep r op a. a -> ImpM rep r op a
pure :: forall a. a -> ImpM rep r op a
$c<*> :: forall rep r op a b.
ImpM rep r op (a -> b) -> ImpM rep r op a -> ImpM rep r op b
<*> :: forall a b.
ImpM rep r op (a -> b) -> ImpM rep r op a -> ImpM rep r op b
$cliftA2 :: forall rep r op a b c.
(a -> b -> c)
-> ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op c
liftA2 :: forall a b c.
(a -> b -> c)
-> ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op c
$c*> :: forall rep r op a b.
ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
*> :: forall a b. ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
$c<* :: forall rep r op a b.
ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op a
<* :: forall a b. ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op a
Applicative,
      Applicative (ImpM rep r op)
Applicative (ImpM rep r op) =>
(forall a b.
 ImpM rep r op a -> (a -> ImpM rep r op b) -> ImpM rep r op b)
-> (forall a b.
    ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b)
-> (forall a. a -> ImpM rep r op a)
-> Monad (ImpM rep r op)
forall a. a -> ImpM rep r op a
forall a b. ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
forall a b.
ImpM rep r op a -> (a -> ImpM rep r op b) -> ImpM rep r op b
forall rep r op. Applicative (ImpM rep r op)
forall rep r op a. a -> ImpM rep r op a
forall rep r op a b.
ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
forall rep r op a b.
ImpM rep r op a -> (a -> ImpM rep r op b) -> ImpM rep r op b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall rep r op a b.
ImpM rep r op a -> (a -> ImpM rep r op b) -> ImpM rep r op b
>>= :: forall a b.
ImpM rep r op a -> (a -> ImpM rep r op b) -> ImpM rep r op b
$c>> :: forall rep r op a b.
ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
>> :: forall a b. ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
$creturn :: forall rep r op a. a -> ImpM rep r op a
return :: forall a. a -> ImpM rep r op a
Monad,
      MonadState (ImpState rep r op),
      MonadReader (Env rep r op)
    )

instance MonadFreshNames (ImpM rep r op) where
  getNameSource :: ImpM rep r op VNameSource
getNameSource = (ImpState rep r op -> VNameSource) -> ImpM rep r op VNameSource
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState rep r op -> VNameSource
forall {k} rep (r :: k) op. ImpState rep r op -> VNameSource
stateNameSource
  putNameSource :: VNameSource -> ImpM rep r op ()
putNameSource VNameSource
src = (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ())
-> (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateNameSource = src}

-- Cannot be an KernelsMem scope because the index functions have
-- the wrong leaves (VName instead of Imp.Exp).
instance HasScope SOACS (ImpM rep r op) where
  askScope :: ImpM rep r op (Scope SOACS)
askScope = (ImpState rep r op -> Scope SOACS) -> ImpM rep r op (Scope SOACS)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ImpState rep r op -> Scope SOACS) -> ImpM rep r op (Scope SOACS))
-> (ImpState rep r op -> Scope SOACS)
-> ImpM rep r op (Scope SOACS)
forall a b. (a -> b) -> a -> b
$ (VarEntry rep -> NameInfo SOACS)
-> Map VName (VarEntry rep) -> Scope SOACS
forall a b k. (a -> b) -> Map k a -> Map k b
M.map (Type -> NameInfo SOACS
LetDec SOACS -> NameInfo SOACS
forall rep. LetDec rep -> NameInfo rep
LetName (Type -> NameInfo SOACS)
-> (VarEntry rep -> Type) -> VarEntry rep -> NameInfo SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VarEntry rep -> Type
forall {rep}. VarEntry rep -> Type
entryType) (Map VName (VarEntry rep) -> Scope SOACS)
-> (ImpState rep r op -> Map VName (VarEntry rep))
-> ImpState rep r op
-> Scope SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpState rep r op -> Map VName (VarEntry rep)
forall {k} rep (r :: k) op. ImpState rep r op -> VTable rep
stateVTable
    where
      entryType :: VarEntry rep -> Type
entryType (MemVar Maybe (Exp rep)
_ MemEntry
memEntry) =
        Space -> Type
forall shape u. Space -> TypeBase shape u
Mem (MemEntry -> Space
entryMemSpace MemEntry
memEntry)
      entryType (ArrayVar Maybe (Exp rep)
_ ArrayEntry
arrayEntry) =
        PrimType -> Shape -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array
          (ArrayEntry -> PrimType
entryArrayElemType ArrayEntry
arrayEntry)
          ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$ ArrayEntry -> [SubExp]
entryArrayShape ArrayEntry
arrayEntry)
          NoUniqueness
NoUniqueness
      entryType (ScalarVar Maybe (Exp rep)
_ ScalarEntry
scalarEntry) =
        PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ ScalarEntry -> PrimType
entryScalarType ScalarEntry
scalarEntry
      entryType (AccVar Maybe (Exp rep)
_ (VName
acc, Shape
ispace, [Type]
ts)) =
        VName -> Shape -> [Type] -> NoUniqueness -> Type
forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
NoUniqueness

runImpM ::
  ImpM rep r op a ->
  r ->
  Operations rep r op ->
  Imp.Space ->
  ImpState rep r op ->
  (a, ImpState rep r op)
runImpM :: forall rep r op a.
ImpM rep r op a
-> r
-> Operations rep r op
-> Space
-> ImpState rep r op
-> (a, ImpState rep r op)
runImpM (ImpM ReaderT (Env rep r op) (State (ImpState rep r op)) a
m) r
r Operations rep r op
ops Space
space = State (ImpState rep r op) a
-> ImpState rep r op -> (a, ImpState rep r op)
forall s a. State s a -> s -> (a, s)
runState (ReaderT (Env rep r op) (State (ImpState rep r op)) a
-> Env rep r op -> State (ImpState rep r op) a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Env rep r op) (State (ImpState rep r op)) a
m (Env rep r op -> State (ImpState rep r op) a)
-> Env rep r op -> State (ImpState rep r op) a
forall a b. (a -> b) -> a -> b
$ r -> Operations rep r op -> Space -> Env rep r op
forall r rep op. r -> Operations rep r op -> Space -> Env rep r op
newEnv r
r Operations rep r op
ops Space
space)

subImpM_ ::
  r' ->
  Operations rep r' op' ->
  ImpM rep r' op' a ->
  ImpM rep r op (Imp.Code op')
subImpM_ :: forall r' rep op' a r op.
r'
-> Operations rep r' op'
-> ImpM rep r' op' a
-> ImpM rep r op (Code op')
subImpM_ r'
r Operations rep r' op'
ops ImpM rep r' op' a
m = (a, Code op') -> Code op'
forall a b. (a, b) -> b
snd ((a, Code op') -> Code op')
-> ImpM rep r op (a, Code op') -> ImpM rep r op (Code op')
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> r'
-> Operations rep r' op'
-> ImpM rep r' op' a
-> ImpM rep r op (a, Code op')
forall r' rep op' a r op.
r'
-> Operations rep r' op'
-> ImpM rep r' op' a
-> ImpM rep r op (a, Code op')
subImpM r'
r Operations rep r' op'
ops ImpM rep r' op' a
m

subImpM ::
  r' ->
  Operations rep r' op' ->
  ImpM rep r' op' a ->
  ImpM rep r op (a, Imp.Code op')
subImpM :: forall r' rep op' a r op.
r'
-> Operations rep r' op'
-> ImpM rep r' op' a
-> ImpM rep r op (a, Code op')
subImpM r'
r Operations rep r' op'
ops (ImpM ReaderT (Env rep r' op') (State (ImpState rep r' op')) a
m) = do
  env <- ImpM rep r op (Env rep r op)
forall r (m :: * -> *). MonadReader r m => m r
ask
  s <- get

  let env' =
        Env rep r op
env
          { envExpCompiler = opsExpCompiler ops,
            envStmsCompiler = opsStmsCompiler ops,
            envCopyCompiler = opsCopyCompiler ops,
            envOpCompiler = opsOpCompiler ops,
            envAllocCompilers = opsAllocCompilers ops,
            envEnv = r
          }
      s' =
        ImpState
          { stateVTable :: VTable rep
stateVTable = ImpState rep r op -> VTable rep
forall {k} rep (r :: k) op. ImpState rep r op -> VTable rep
stateVTable ImpState rep r op
s,
            stateFunctions :: Functions op'
stateFunctions = Functions op'
forall a. Monoid a => a
mempty,
            stateCode :: Code op'
stateCode = Code op'
forall a. Monoid a => a
mempty,
            stateNameSource :: VNameSource
stateNameSource = ImpState rep r op -> VNameSource
forall {k} rep (r :: k) op. ImpState rep r op -> VNameSource
stateNameSource ImpState rep r op
s,
            stateConstants :: Constants op'
stateConstants = Constants op'
forall a. Monoid a => a
mempty,
            stateWarnings :: Warnings
stateWarnings = Warnings
forall a. Monoid a => a
mempty,
            stateAccs :: Map VName ([VName], Maybe (Lambda rep, [SubExp]))
stateAccs = ImpState rep r op
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
forall {k} rep (r :: k) op.
ImpState rep r op
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
stateAccs ImpState rep r op
s
          }
      (x, s'') = runState (runReaderT m env') s'

  putNameSource $ stateNameSource s''
  warnings $ stateWarnings s''
  pure (x, stateCode s'')

-- | Execute a code generation action, returning the code that was
-- emitted.
collect :: ImpM rep r op () -> ImpM rep r op (Imp.Code op)
collect :: forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect = (((), Code op) -> Code op)
-> ImpM rep r op ((), Code op) -> ImpM rep r op (Code op)
forall a b. (a -> b) -> ImpM rep r op a -> ImpM rep r op b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((), Code op) -> Code op
forall a b. (a, b) -> b
snd (ImpM rep r op ((), Code op) -> ImpM rep r op (Code op))
-> (ImpM rep r op () -> ImpM rep r op ((), Code op))
-> ImpM rep r op ()
-> ImpM rep r op (Code op)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpM rep r op () -> ImpM rep r op ((), Code op)
forall rep r op a. ImpM rep r op a -> ImpM rep r op (a, Code op)
collect'

collect' :: ImpM rep r op a -> ImpM rep r op (a, Imp.Code op)
collect' :: forall rep r op a. ImpM rep r op a -> ImpM rep r op (a, Code op)
collect' ImpM rep r op a
m = do
  prev_code <- (ImpState rep r op -> Code op) -> ImpM rep r op (Code op)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState rep r op -> Code op
forall {k} rep (r :: k) op. ImpState rep r op -> Code op
stateCode
  modify $ \ImpState rep r op
s -> ImpState rep r op
s {stateCode = mempty}
  x <- m
  new_code <- gets stateCode
  modify $ \ImpState rep r op
s -> ImpState rep r op
s {stateCode = prev_code}
  pure (x, new_code)

-- | Emit some generated imperative code.
emit :: Imp.Code op -> ImpM rep r op ()
emit :: forall op rep r. Code op -> ImpM rep r op ()
emit Code op
code = (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ())
-> (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateCode = stateCode s <> code}

warnings :: Warnings -> ImpM rep r op ()
warnings :: forall rep r op. Warnings -> ImpM rep r op ()
warnings Warnings
ws = (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ())
-> (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateWarnings = ws <> stateWarnings s}

-- | Emit a warning about something the user should be aware of.
warn :: (Located loc) => loc -> [loc] -> T.Text -> ImpM rep r op ()
warn :: forall loc rep r op.
Located loc =>
loc -> [loc] -> Text -> ImpM rep r op ()
warn loc
loc [loc]
locs Text
problem =
  Warnings -> ImpM rep r op ()
forall rep r op. Warnings -> ImpM rep r op ()
warnings (Warnings -> ImpM rep r op ()) -> Warnings -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Loc -> [Loc] -> Doc () -> Warnings
singleWarning' (loc -> Loc
forall a. Located a => a -> Loc
locOf loc
loc) ((loc -> Loc) -> [loc] -> [Loc]
forall a b. (a -> b) -> [a] -> [b]
map loc -> Loc
forall a. Located a => a -> Loc
locOf [loc]
locs) (Text -> Doc ()
forall ann. Text -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty Text
problem)

-- | Emit a function in the generated code.
emitFunction :: Name -> Imp.Function op -> ImpM rep r op ()
emitFunction :: forall op rep r. Name -> Function op -> ImpM rep r op ()
emitFunction Name
fname Function op
fun = do
  Imp.Functions fs <- (ImpState rep r op -> Functions op) -> ImpM rep r op (Functions op)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState rep r op -> Functions op
forall {k} rep (r :: k) op. ImpState rep r op -> Functions op
stateFunctions
  modify $ \ImpState rep r op
s -> ImpState rep r op
s {stateFunctions = Imp.Functions $ (fname, fun) : fs}

-- | Check if a function of a given name exists.
hasFunction :: Name -> ImpM rep r op Bool
hasFunction :: forall rep r op. Name -> ImpM rep r op Bool
hasFunction Name
fname = (ImpState rep r op -> Bool) -> ImpM rep r op Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ImpState rep r op -> Bool) -> ImpM rep r op Bool)
-> (ImpState rep r op -> Bool) -> ImpM rep r op Bool
forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s ->
  let Imp.Functions [(Name, Function op)]
fs = ImpState rep r op -> Functions op
forall {k} rep (r :: k) op. ImpState rep r op -> Functions op
stateFunctions ImpState rep r op
s
   in Maybe (Function op) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Function op) -> Bool) -> Maybe (Function op) -> Bool
forall a b. (a -> b) -> a -> b
$ Name -> [(Name, Function op)] -> Maybe (Function op)
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Name
fname [(Name, Function op)]
fs

constsVTable :: (Mem rep inner) => Stms rep -> VTable rep
constsVTable :: forall rep (inner :: * -> *).
Mem rep inner =>
Stms rep -> VTable rep
constsVTable = (Stm rep -> VTable rep) -> Seq (Stm rep) -> VTable rep
forall m a. Monoid m => (a -> m) -> Seq a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm rep -> VTable rep
forall {rep}.
HasLetDecMem (LetDec rep) =>
Stm rep -> Map VName (VarEntry rep)
stmVtable
  where
    stmVtable :: Stm rep -> Map VName (VarEntry rep)
stmVtable (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ Exp rep
e) =
      (PatElem (LetDec rep) -> Map VName (VarEntry rep))
-> [PatElem (LetDec rep)] -> Map VName (VarEntry rep)
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Exp rep -> PatElem (LetDec rep) -> Map VName (VarEntry rep)
forall {t} {rep}.
HasLetDecMem t =>
Exp rep -> PatElem t -> Map VName (VarEntry rep)
peVtable Exp rep
e) ([PatElem (LetDec rep)] -> Map VName (VarEntry rep))
-> [PatElem (LetDec rep)] -> Map VName (VarEntry rep)
forall a b. (a -> b) -> a -> b
$ Pat (LetDec rep) -> [PatElem (LetDec rep)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat
    peVtable :: Exp rep -> PatElem t -> Map VName (VarEntry rep)
peVtable Exp rep
e (PatElem VName
name t
dec) =
      VName -> VarEntry rep -> Map VName (VarEntry rep)
forall k a. k -> a -> Map k a
M.singleton VName
name (VarEntry rep -> Map VName (VarEntry rep))
-> VarEntry rep -> Map VName (VarEntry rep)
forall a b. (a -> b) -> a -> b
$ Maybe (Exp rep) -> LParamMem -> VarEntry rep
forall rep. Maybe (Exp rep) -> LParamMem -> VarEntry rep
memBoundToVarEntry (Exp rep -> Maybe (Exp rep)
forall a. a -> Maybe a
Just Exp rep
e) (LParamMem -> VarEntry rep) -> LParamMem -> VarEntry rep
forall a b. (a -> b) -> a -> b
$ t -> LParamMem
forall t. HasLetDecMem t => t -> LParamMem
letDecMem t
dec

compileProg ::
  (Mem rep inner, FreeIn op, MonadFreshNames m) =>
  r ->
  Operations rep r op ->
  Imp.Space ->
  Prog rep ->
  m (Warnings, Imp.Definitions op)
compileProg :: forall rep (inner :: * -> *) op (m :: * -> *) r.
(Mem rep inner, FreeIn op, MonadFreshNames m) =>
r
-> Operations rep r op
-> Space
-> Prog rep
-> m (Warnings, Definitions op)
compileProg r
r Operations rep r op
ops Space
space (Prog OpaqueTypes
types Stms rep
consts [FunDef rep]
funs) =
  (VNameSource -> ((Warnings, Definitions op), VNameSource))
-> m (Warnings, Definitions op)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((Warnings, Definitions op), VNameSource))
 -> m (Warnings, Definitions op))
-> (VNameSource -> ((Warnings, Definitions op), VNameSource))
-> m (Warnings, Definitions op)
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
    let ([()]
_, [ImpState rep r op]
ss) =
          [((), ImpState rep r op)] -> ([()], [ImpState rep r op])
forall a b. [(a, b)] -> ([a], [b])
unzip ([((), ImpState rep r op)] -> ([()], [ImpState rep r op]))
-> [((), ImpState rep r op)] -> ([()], [ImpState rep r op])
forall a b. (a -> b) -> a -> b
$ Strategy ((), ImpState rep r op)
-> (FunDef rep -> ((), ImpState rep r op))
-> [FunDef rep]
-> [((), ImpState rep r op)]
forall b a. Strategy b -> (a -> b) -> [a] -> [b]
parMap Strategy ((), ImpState rep r op)
forall a. Strategy a
rpar (VNameSource -> FunDef rep -> ((), ImpState rep r op)
compileFunDef' VNameSource
src) [FunDef rep]
funs
        free_in_funs :: Names
free_in_funs =
          Functions op -> Names
forall a. FreeIn a => a -> Names
freeIn (Functions op -> Names) -> Functions op -> Names
forall a b. (a -> b) -> a -> b
$ [Functions op] -> Functions op
forall a. Monoid a => [a] -> a
mconcat ([Functions op] -> Functions op) -> [Functions op] -> Functions op
forall a b. (a -> b) -> a -> b
$ (ImpState rep r op -> Functions op)
-> [ImpState rep r op] -> [Functions op]
forall a b. (a -> b) -> [a] -> [b]
map ImpState rep r op -> Functions op
forall {k} rep (r :: k) op. ImpState rep r op -> Functions op
stateFunctions [ImpState rep r op]
ss
        ((), ImpState rep r op
s') =
          ImpM rep r op ()
-> r
-> Operations rep r op
-> Space
-> ImpState rep r op
-> ((), ImpState rep r op)
forall rep r op a.
ImpM rep r op a
-> r
-> Operations rep r op
-> Space
-> ImpState rep r op
-> (a, ImpState rep r op)
runImpM (Names -> Stms rep -> ImpM rep r op ()
forall rep r op. Names -> Stms rep -> ImpM rep r op ()
compileConsts Names
free_in_funs Stms rep
consts) r
r Operations rep r op
ops Space
space (ImpState rep r op -> ((), ImpState rep r op))
-> ImpState rep r op -> ((), ImpState rep r op)
forall a b. (a -> b) -> a -> b
$
            [ImpState rep r op] -> ImpState rep r op
forall {k} {k} {rep} {r :: k} {op} {rep} {r :: k}.
[ImpState rep r op] -> ImpState rep r op
combineStates [ImpState rep r op]
ss
     in ( ( ImpState rep r op -> Warnings
forall {k} rep (r :: k) op. ImpState rep r op -> Warnings
stateWarnings ImpState rep r op
s',
            OpaqueTypes -> Constants op -> Functions op -> Definitions op
forall a.
OpaqueTypes -> Constants a -> Functions a -> Definitions a
Imp.Definitions
              OpaqueTypes
types
              ((ImpState rep r op -> Constants op)
-> [ImpState rep r op] -> Constants op
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap ImpState rep r op -> Constants op
forall {k} rep (r :: k) op. ImpState rep r op -> Constants op
stateConstants [ImpState rep r op]
ss Constants op -> Constants op -> Constants op
forall a. Semigroup a => a -> a -> a
<> ImpState rep r op -> Constants op
forall {k} rep (r :: k) op. ImpState rep r op -> Constants op
stateConstants ImpState rep r op
s')
              (ImpState rep r op -> Functions op
forall {k} rep (r :: k) op. ImpState rep r op -> Functions op
stateFunctions ImpState rep r op
s')
          ),
          ImpState rep r op -> VNameSource
forall {k} rep (r :: k) op. ImpState rep r op -> VNameSource
stateNameSource ImpState rep r op
s'
        )
  where
    compileFunDef' :: VNameSource -> FunDef rep -> ((), ImpState rep r op)
compileFunDef' VNameSource
src FunDef rep
fdef =
      ImpM rep r op ()
-> r
-> Operations rep r op
-> Space
-> ImpState rep r op
-> ((), ImpState rep r op)
forall rep r op a.
ImpM rep r op a
-> r
-> Operations rep r op
-> Space
-> ImpState rep r op
-> (a, ImpState rep r op)
runImpM
        (OpaqueTypes -> FunDef rep -> ImpM rep r op ()
forall rep (inner :: * -> *) r op.
Mem rep inner =>
OpaqueTypes -> FunDef rep -> ImpM rep r op ()
compileFunDef OpaqueTypes
types FunDef rep
fdef)
        r
r
        Operations rep r op
ops
        Space
space
        (VNameSource -> ImpState rep (ZonkAny 3) op
forall {k} rep (r :: k) op. VNameSource -> ImpState rep r op
newState VNameSource
src) {stateVTable = constsVTable consts}

    combineStates :: [ImpState rep r op] -> ImpState rep r op
combineStates [ImpState rep r op]
ss =
      let Imp.Functions [(Name, Function op)]
funs' = [Functions op] -> Functions op
forall a. Monoid a => [a] -> a
mconcat ([Functions op] -> Functions op) -> [Functions op] -> Functions op
forall a b. (a -> b) -> a -> b
$ (ImpState rep r op -> Functions op)
-> [ImpState rep r op] -> [Functions op]
forall a b. (a -> b) -> [a] -> [b]
map ImpState rep r op -> Functions op
forall {k} rep (r :: k) op. ImpState rep r op -> Functions op
stateFunctions [ImpState rep r op]
ss
          src :: VNameSource
src = [VNameSource] -> VNameSource
forall a. Monoid a => [a] -> a
mconcat ((ImpState rep r op -> VNameSource)
-> [ImpState rep r op] -> [VNameSource]
forall a b. (a -> b) -> [a] -> [b]
map ImpState rep r op -> VNameSource
forall {k} rep (r :: k) op. ImpState rep r op -> VNameSource
stateNameSource [ImpState rep r op]
ss)
       in (VNameSource -> ImpState rep (ZonkAny 1) op
forall {k} rep (r :: k) op. VNameSource -> ImpState rep r op
newState VNameSource
src)
            { stateFunctions =
                Imp.Functions $ M.toList $ M.fromList funs',
              stateWarnings =
                mconcat $ map stateWarnings ss
            }

compileConsts :: Names -> Stms rep -> ImpM rep r op ()
compileConsts :: forall rep r op. Names -> Stms rep -> ImpM rep r op ()
compileConsts Names
used_consts Stms rep
stms = ImpM rep r op (Names, ()) -> ImpM rep r op ()
forall rep r op a. ImpM rep r op (Names, a) -> ImpM rep r op a
genConstants (ImpM rep r op (Names, ()) -> ImpM rep r op ())
-> ImpM rep r op (Names, ()) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ do
  Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
used_consts Stms rep
stms (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ () -> ImpM rep r op ()
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  (Names, ()) -> ImpM rep r op (Names, ())
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Names
used_consts, ())

lookupOpaqueType :: Name -> OpaqueTypes -> OpaqueType
lookupOpaqueType :: Name -> OpaqueTypes -> OpaqueType
lookupOpaqueType Name
v (OpaqueTypes [(Name, OpaqueType)]
types) =
  case Name -> [(Name, OpaqueType)] -> Maybe OpaqueType
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Name
v [(Name, OpaqueType)]
types of
    Just OpaqueType
t -> OpaqueType
t
    Maybe OpaqueType
Nothing -> [Char] -> OpaqueType
forall a. HasCallStack => [Char] -> a
error ([Char] -> OpaqueType) -> [Char] -> OpaqueType
forall a b. (a -> b) -> a -> b
$ [Char]
"Unknown opaque type: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Name -> [Char]
forall a. Show a => a -> [Char]
show Name
v

valueTypeSign :: ValueType -> Signedness
valueTypeSign :: ValueType -> Signedness
valueTypeSign (ValueType Signedness
sign Rank
_ PrimType
_) = Signedness
sign

entryPointSignedness :: OpaqueTypes -> EntryPointType -> [Signedness]
entryPointSignedness :: OpaqueTypes -> EntryPointType -> [Signedness]
entryPointSignedness OpaqueTypes
_ (TypeTransparent ValueType
vt) = [ValueType -> Signedness
valueTypeSign ValueType
vt]
entryPointSignedness OpaqueTypes
types (TypeOpaque Name
desc) =
  case Name -> OpaqueTypes -> OpaqueType
lookupOpaqueType Name
desc OpaqueTypes
types of
    OpaqueType [ValueType]
vts -> (ValueType -> Signedness) -> [ValueType] -> [Signedness]
forall a b. (a -> b) -> [a] -> [b]
map ValueType -> Signedness
valueTypeSign [ValueType]
vts
    OpaqueArray Int
_ Name
_ [ValueType]
vts -> (ValueType -> Signedness) -> [ValueType] -> [Signedness]
forall a b. (a -> b) -> [a] -> [b]
map ValueType -> Signedness
valueTypeSign [ValueType]
vts
    OpaqueRecordArray Int
_ Name
_ [(Name, EntryPointType)]
fs -> ((Name, EntryPointType) -> [Signedness])
-> [(Name, EntryPointType)] -> [Signedness]
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (OpaqueTypes -> EntryPointType -> [Signedness]
entryPointSignedness OpaqueTypes
types (EntryPointType -> [Signedness])
-> ((Name, EntryPointType) -> EntryPointType)
-> (Name, EntryPointType)
-> [Signedness]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name, EntryPointType) -> EntryPointType
forall a b. (a, b) -> b
snd) [(Name, EntryPointType)]
fs
    OpaqueRecord [(Name, EntryPointType)]
fs -> ((Name, EntryPointType) -> [Signedness])
-> [(Name, EntryPointType)] -> [Signedness]
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (OpaqueTypes -> EntryPointType -> [Signedness]
entryPointSignedness OpaqueTypes
types (EntryPointType -> [Signedness])
-> ((Name, EntryPointType) -> EntryPointType)
-> (Name, EntryPointType)
-> [Signedness]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name, EntryPointType) -> EntryPointType
forall a b. (a, b) -> b
snd) [(Name, EntryPointType)]
fs
    OpaqueSum [ValueType]
vts [(Name, [(EntryPointType, [Int])])]
_ -> (ValueType -> Signedness) -> [ValueType] -> [Signedness]
forall a b. (a -> b) -> [a] -> [b]
map ValueType -> Signedness
valueTypeSign [ValueType]
vts

-- | How many value parameters are accepted by this entry point?  This
-- is used to determine which of the function parameters correspond to
-- the parameters of the original function (they must all come at the
-- end).
entryPointSize :: OpaqueTypes -> EntryPointType -> Int
entryPointSize :: OpaqueTypes -> EntryPointType -> Int
entryPointSize OpaqueTypes
_ (TypeTransparent ValueType
_) = Int
1
entryPointSize OpaqueTypes
types (TypeOpaque Name
desc) =
  case Name -> OpaqueTypes -> OpaqueType
lookupOpaqueType Name
desc OpaqueTypes
types of
    OpaqueType [ValueType]
vts -> [ValueType] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ValueType]
vts
    OpaqueArray Int
_ Name
_ [ValueType]
vts -> [ValueType] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ValueType]
vts
    OpaqueRecordArray Int
_ Name
_ [(Name, EntryPointType)]
fs -> [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ ((Name, EntryPointType) -> Int)
-> [(Name, EntryPointType)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (OpaqueTypes -> EntryPointType -> Int
entryPointSize OpaqueTypes
types (EntryPointType -> Int)
-> ((Name, EntryPointType) -> EntryPointType)
-> (Name, EntryPointType)
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name, EntryPointType) -> EntryPointType
forall a b. (a, b) -> b
snd) [(Name, EntryPointType)]
fs
    OpaqueRecord [(Name, EntryPointType)]
fs -> [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ ((Name, EntryPointType) -> Int)
-> [(Name, EntryPointType)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (OpaqueTypes -> EntryPointType -> Int
entryPointSize OpaqueTypes
types (EntryPointType -> Int)
-> ((Name, EntryPointType) -> EntryPointType)
-> (Name, EntryPointType)
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name, EntryPointType) -> EntryPointType
forall a b. (a, b) -> b
snd) [(Name, EntryPointType)]
fs
    OpaqueSum [ValueType]
vts [(Name, [(EntryPointType, [Int])])]
_ -> [ValueType] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ValueType]
vts

compileInParam ::
  (Mem rep inner) =>
  FParam rep ->
  ImpM rep r op (Either Imp.Param ArrayDecl)
compileInParam :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
FParam rep -> ImpM rep r op (Either Param ArrayDecl)
compileInParam FParam rep
fparam = case Param FParamMem -> FParamMem
forall dec. Param dec -> dec
paramDec FParam rep
Param FParamMem
fparam of
  MemPrim PrimType
bt ->
    Either Param ArrayDecl -> ImpM rep r op (Either Param ArrayDecl)
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either Param ArrayDecl -> ImpM rep r op (Either Param ArrayDecl))
-> Either Param ArrayDecl -> ImpM rep r op (Either Param ArrayDecl)
forall a b. (a -> b) -> a -> b
$ Param -> Either Param ArrayDecl
forall a b. a -> Either a b
Left (Param -> Either Param ArrayDecl)
-> Param -> Either Param ArrayDecl
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Param
Imp.ScalarParam VName
name PrimType
bt
  MemMem Space
space ->
    Either Param ArrayDecl -> ImpM rep r op (Either Param ArrayDecl)
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either Param ArrayDecl -> ImpM rep r op (Either Param ArrayDecl))
-> Either Param ArrayDecl -> ImpM rep r op (Either Param ArrayDecl)
forall a b. (a -> b) -> a -> b
$ Param -> Either Param ArrayDecl
forall a b. a -> Either a b
Left (Param -> Either Param ArrayDecl)
-> Param -> Either Param ArrayDecl
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Param
Imp.MemParam VName
name Space
space
  MemArray PrimType
bt Shape
shape Uniqueness
_ (ArrayIn VName
mem LMAD (TExp Int64)
lmad) ->
    Either Param ArrayDecl -> ImpM rep r op (Either Param ArrayDecl)
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either Param ArrayDecl -> ImpM rep r op (Either Param ArrayDecl))
-> Either Param ArrayDecl -> ImpM rep r op (Either Param ArrayDecl)
forall a b. (a -> b) -> a -> b
$ ArrayDecl -> Either Param ArrayDecl
forall a b. b -> Either a b
Right (ArrayDecl -> Either Param ArrayDecl)
-> ArrayDecl -> Either Param ArrayDecl
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> MemLoc -> ArrayDecl
ArrayDecl VName
name PrimType
bt (MemLoc -> ArrayDecl) -> MemLoc -> ArrayDecl
forall a b. (a -> b) -> a -> b
$ VName -> [SubExp] -> LMAD (TExp Int64) -> MemLoc
MemLoc VName
mem (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape) LMAD (TExp Int64)
lmad
  MemAcc {} ->
    [Char] -> ImpM rep r op (Either Param ArrayDecl)
forall a. HasCallStack => [Char] -> a
error [Char]
"Functions may not have accumulator parameters."
  where
    name :: VName
name = Param FParamMem -> VName
forall dec. Param dec -> VName
paramName FParam rep
Param FParamMem
fparam

data ArrayDecl = ArrayDecl VName PrimType MemLoc

compileInParams ::
  (Mem rep inner) =>
  OpaqueTypes ->
  [FParam rep] ->
  Maybe [EntryParam] ->
  ImpM rep r op ([Imp.Param], [ArrayDecl], Maybe [((Name, Uniqueness), Imp.ExternalValue)])
compileInParams :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
OpaqueTypes
-> [FParam rep]
-> Maybe [EntryParam]
-> ImpM
     rep
     r
     op
     ([Param], [ArrayDecl], Maybe [((Name, Uniqueness), ExternalValue)])
compileInParams OpaqueTypes
types [FParam rep]
params Maybe [EntryParam]
eparams = do
  (inparams, arrayds) <- [Either Param ArrayDecl] -> ([Param], [ArrayDecl])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either Param ArrayDecl] -> ([Param], [ArrayDecl]))
-> ImpM rep r op [Either Param ArrayDecl]
-> ImpM rep r op ([Param], [ArrayDecl])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Param FParamMem -> ImpM rep r op (Either Param ArrayDecl))
-> [Param FParamMem] -> ImpM rep r op [Either Param ArrayDecl]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM FParam rep -> ImpM rep r op (Either Param ArrayDecl)
Param FParamMem -> ImpM rep r op (Either Param ArrayDecl)
forall rep (inner :: * -> *) r op.
Mem rep inner =>
FParam rep -> ImpM rep r op (Either Param ArrayDecl)
compileInParam [FParam rep]
[Param FParamMem]
params
  let findArray VName
x = (ArrayDecl -> Bool) -> [ArrayDecl] -> Maybe ArrayDecl
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (VName -> ArrayDecl -> Bool
isArrayDecl VName
x) [ArrayDecl]
arrayds

      summaries = [(VName, Space)] -> Map VName Space
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Space)] -> Map VName Space)
-> [(VName, Space)] -> Map VName Space
forall a b. (a -> b) -> a -> b
$ (Param FParamMem -> Maybe (VName, Space))
-> [Param FParamMem] -> [(VName, Space)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Param FParamMem -> Maybe (VName, Space)
forall {d} {u} {ret}.
Param (MemInfo d u ret) -> Maybe (VName, Space)
memSummary [FParam rep]
[Param FParamMem]
params
        where
          memSummary :: Param (MemInfo d u ret) -> Maybe (VName, Space)
memSummary Param (MemInfo d u ret)
param
            | MemMem Space
space <- Param (MemInfo d u ret) -> MemInfo d u ret
forall dec. Param dec -> dec
paramDec Param (MemInfo d u ret)
param =
                (VName, Space) -> Maybe (VName, Space)
forall a. a -> Maybe a
Just (Param (MemInfo d u ret) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo d u ret)
param, Space
space)
            | Bool
otherwise =
                Maybe (VName, Space)
forall a. Maybe a
Nothing

      findMemInfo :: VName -> Maybe Space
      findMemInfo = (VName -> Map VName Space -> Maybe Space)
-> Map VName Space -> VName -> Maybe Space
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> Map VName Space -> Maybe Space
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Map VName Space
summaries

      mkValueDesc Param FParamMem
fparam Signedness
signedness =
        case (VName -> Maybe ArrayDecl
findArray (VName -> Maybe ArrayDecl) -> VName -> Maybe ArrayDecl
forall a b. (a -> b) -> a -> b
$ Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
fparam, Param FParamMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param FParamMem
fparam) of
          (Just (ArrayDecl VName
_ PrimType
bt (MemLoc VName
mem [SubExp]
shape LMAD (TExp Int64)
_)), Type
_) -> do
            memspace <- VName -> Maybe Space
findMemInfo VName
mem
            Just $ Imp.ArrayValue mem memspace bt signedness shape
          (Maybe ArrayDecl
_, Prim PrimType
bt) ->
            ValueDesc -> Maybe ValueDesc
forall a. a -> Maybe a
Just (ValueDesc -> Maybe ValueDesc) -> ValueDesc -> Maybe ValueDesc
forall a b. (a -> b) -> a -> b
$ PrimType -> Signedness -> VName -> ValueDesc
Imp.ScalarValue PrimType
bt Signedness
signedness (VName -> ValueDesc) -> VName -> ValueDesc
forall a b. (a -> b) -> a -> b
$ Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
fparam
          (Maybe ArrayDecl, Type)
_ ->
            Maybe ValueDesc
forall a. Maybe a
Nothing

      mkExts (EntryParam Name
v Uniqueness
u et :: EntryPointType
et@(TypeOpaque Name
desc) : [EntryParam]
epts) [Param FParamMem]
fparams =
        let signs :: [Signedness]
signs = OpaqueTypes -> EntryPointType -> [Signedness]
entryPointSignedness OpaqueTypes
types EntryPointType
et
            n :: Int
n = OpaqueTypes -> EntryPointType -> Int
entryPointSize OpaqueTypes
types EntryPointType
et
            ([Param FParamMem]
fparams', [Param FParamMem]
rest) = Int -> [Param FParamMem] -> ([Param FParamMem], [Param FParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
n [Param FParamMem]
fparams
         in ( (Name
v, Uniqueness
u),
              Name -> [ValueDesc] -> ExternalValue
Imp.OpaqueValue
                Name
desc
                ([Maybe ValueDesc] -> [ValueDesc]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe ValueDesc] -> [ValueDesc])
-> [Maybe ValueDesc] -> [ValueDesc]
forall a b. (a -> b) -> a -> b
$ (Param FParamMem -> Signedness -> Maybe ValueDesc)
-> [Param FParamMem] -> [Signedness] -> [Maybe ValueDesc]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Param FParamMem -> Signedness -> Maybe ValueDesc
mkValueDesc [Param FParamMem]
fparams' [Signedness]
signs)
            )
              ((Name, Uniqueness), ExternalValue)
-> [((Name, Uniqueness), ExternalValue)]
-> [((Name, Uniqueness), ExternalValue)]
forall a. a -> [a] -> [a]
: [EntryParam]
-> [Param FParamMem] -> [((Name, Uniqueness), ExternalValue)]
mkExts [EntryParam]
epts [Param FParamMem]
rest
      mkExts (EntryParam Name
v Uniqueness
u (TypeTransparent (ValueType Signedness
s Rank
_ PrimType
_)) : [EntryParam]
epts) (Param FParamMem
fparam : [Param FParamMem]
fparams) =
        Maybe ((Name, Uniqueness), ExternalValue)
-> [((Name, Uniqueness), ExternalValue)]
forall a. Maybe a -> [a]
maybeToList (((Name
v, Uniqueness
u),) (ExternalValue -> ((Name, Uniqueness), ExternalValue))
-> (ValueDesc -> ExternalValue)
-> ValueDesc
-> ((Name, Uniqueness), ExternalValue)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ValueDesc -> ExternalValue
Imp.TransparentValue (ValueDesc -> ((Name, Uniqueness), ExternalValue))
-> Maybe ValueDesc -> Maybe ((Name, Uniqueness), ExternalValue)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Param FParamMem -> Signedness -> Maybe ValueDesc
mkValueDesc Param FParamMem
fparam Signedness
s)
          [((Name, Uniqueness), ExternalValue)]
-> [((Name, Uniqueness), ExternalValue)]
-> [((Name, Uniqueness), ExternalValue)]
forall a. [a] -> [a] -> [a]
++ [EntryParam]
-> [Param FParamMem] -> [((Name, Uniqueness), ExternalValue)]
mkExts [EntryParam]
epts [Param FParamMem]
fparams
      mkExts [EntryParam]
_ [Param FParamMem]
_ = []

  pure
    ( inparams,
      arrayds,
      case eparams of
        Just [EntryParam]
eparams' ->
          let num_val_params :: Int
num_val_params = [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((EntryParam -> Int) -> [EntryParam] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (OpaqueTypes -> EntryPointType -> Int
entryPointSize OpaqueTypes
types (EntryPointType -> Int)
-> (EntryParam -> EntryPointType) -> EntryParam -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EntryParam -> EntryPointType
entryParamType) [EntryParam]
eparams')
              ([Param FParamMem]
_ctx_params, [Param FParamMem]
val_params) = Int -> [Param FParamMem] -> ([Param FParamMem], [Param FParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Param FParamMem] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [FParam rep]
[Param FParamMem]
params Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
num_val_params) [FParam rep]
[Param FParamMem]
params
           in [((Name, Uniqueness), ExternalValue)]
-> Maybe [((Name, Uniqueness), ExternalValue)]
forall a. a -> Maybe a
Just ([((Name, Uniqueness), ExternalValue)]
 -> Maybe [((Name, Uniqueness), ExternalValue)])
-> [((Name, Uniqueness), ExternalValue)]
-> Maybe [((Name, Uniqueness), ExternalValue)]
forall a b. (a -> b) -> a -> b
$ [EntryParam]
-> [Param FParamMem] -> [((Name, Uniqueness), ExternalValue)]
mkExts [EntryParam]
eparams' [Param FParamMem]
val_params
        Maybe [EntryParam]
Nothing -> Maybe [((Name, Uniqueness), ExternalValue)]
forall a. Maybe a
Nothing
    )
  where
    isArrayDecl :: VName -> ArrayDecl -> Bool
isArrayDecl VName
x (ArrayDecl VName
y PrimType
_ MemLoc
_) = VName
x VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
y

compileOutParam ::
  FunReturns -> ImpM rep r op (Maybe Imp.Param, ValueDestination)
compileOutParam :: forall rep r op.
RetTypeMem -> ImpM rep r op (Maybe Param, ValueDestination)
compileOutParam (MemPrim PrimType
t) = do
  name <- [Char] -> ImpM rep r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"prim_out"
  pure (Just $ Imp.ScalarParam name t, ScalarDestination name)
compileOutParam (MemMem Space
space) = do
  name <- [Char] -> ImpM rep r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"mem_out"
  pure (Just $ Imp.MemParam name space, MemoryDestination name)
compileOutParam MemArray {} =
  (Maybe Param, ValueDestination)
-> ImpM rep r op (Maybe Param, ValueDestination)
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Param
forall a. Maybe a
Nothing, Maybe MemLoc -> ValueDestination
ArrayDestination Maybe MemLoc
forall a. Maybe a
Nothing)
compileOutParam MemAcc {} =
  [Char] -> ImpM rep r op (Maybe Param, ValueDestination)
forall a. HasCallStack => [Char] -> a
error [Char]
"Functions may not return accumulators."

compileExternalValues ::
  (Mem rep inner) =>
  OpaqueTypes ->
  [RetType rep] ->
  [EntryResult] ->
  [Maybe Imp.Param] ->
  ImpM rep r op [(Uniqueness, Imp.ExternalValue)]
compileExternalValues :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
OpaqueTypes
-> [RetType rep]
-> [EntryResult]
-> [Maybe Param]
-> ImpM rep r op [(Uniqueness, ExternalValue)]
compileExternalValues OpaqueTypes
types [RetType rep]
orig_rts [EntryResult]
orig_epts [Maybe Param]
maybe_params = do
  let ([RetTypeMem]
ctx_rts, [RetTypeMem]
val_rts) =
        Int -> [RetTypeMem] -> ([RetTypeMem], [RetTypeMem])
forall a. Int -> [a] -> ([a], [a])
splitAt
          ([RetTypeMem] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [RetType rep]
[RetTypeMem]
orig_rts Int -> Int -> Int
forall a. Num a => a -> a -> a
- [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((EntryResult -> Int) -> [EntryResult] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (OpaqueTypes -> EntryPointType -> Int
entryPointSize OpaqueTypes
types (EntryPointType -> Int)
-> (EntryResult -> EntryPointType) -> EntryResult -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EntryResult -> EntryPointType
entryResultType) [EntryResult]
orig_epts))
          [RetType rep]
[RetTypeMem]
orig_rts

  let nthOut :: Int -> VName
nthOut Int
i = case Int -> [Maybe Param] -> Maybe (Maybe Param)
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
i [Maybe Param]
maybe_params of
        Just (Just Param
p) -> Param -> VName
Imp.paramName Param
p
        Just Maybe Param
Nothing -> [Char] -> VName
forall a. HasCallStack => [Char] -> a
error ([Char] -> VName) -> [Char] -> VName
forall a b. (a -> b) -> a -> b
$ [Char]
"Output " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
i [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
" not a param."
        Maybe (Maybe Param)
Nothing -> [Char] -> VName
forall a. HasCallStack => [Char] -> a
error ([Char] -> VName) -> [Char] -> VName
forall a b. (a -> b) -> a -> b
$ [Char]
"Param " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
i [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
" does not exist."

      mkValueDesc :: Int -> Signedness -> RetTypeMem -> ImpM rep r op ValueDesc
mkValueDesc Int
_ Signedness
signedness (MemArray PrimType
t ShapeBase (Ext SubExp)
shape Uniqueness
_ MemReturn
ret) = do
        (mem, space) <-
          case MemReturn
ret of
            ReturnsNewBlock Space
space Int
j ExtLMAD
_lmad ->
              (VName, Space) -> ImpM rep r op (VName, Space)
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> VName
nthOut Int
j, Space
space)
            ReturnsInBlock VName
mem ExtLMAD
_lmad -> do
              space <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM rep r op MemEntry -> ImpM rep r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM rep r op MemEntry
forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory VName
mem
              pure (mem, space)
        pure $ Imp.ArrayValue mem space t signedness $ map f $ shapeDims shape
        where
          f :: Ext SubExp -> SubExp
f (Free SubExp
v) = SubExp
v
          f (Ext Int
i) = VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Int -> VName
nthOut Int
i
      mkValueDesc Int
i Signedness
signedness (MemPrim PrimType
bt) =
        ValueDesc -> ImpM rep r op ValueDesc
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ValueDesc -> ImpM rep r op ValueDesc)
-> ValueDesc -> ImpM rep r op ValueDesc
forall a b. (a -> b) -> a -> b
$ PrimType -> Signedness -> VName -> ValueDesc
Imp.ScalarValue PrimType
bt Signedness
signedness (VName -> ValueDesc) -> VName -> ValueDesc
forall a b. (a -> b) -> a -> b
$ Int -> VName
nthOut Int
i
      mkValueDesc Int
_ Signedness
_ MemAcc {} =
        [Char] -> ImpM rep r op ValueDesc
forall a. HasCallStack => [Char] -> a
error [Char]
"mkValueDesc: unexpected MemAcc output."
      mkValueDesc Int
_ Signedness
_ MemMem {} =
        [Char] -> ImpM rep r op ValueDesc
forall a. HasCallStack => [Char] -> a
error [Char]
"mkValueDesc: unexpected MemMem output."

      mkExts :: Int
-> [EntryResult]
-> [RetTypeMem]
-> ImpM rep r op [(Uniqueness, ExternalValue)]
mkExts Int
i (EntryResult Uniqueness
u et :: EntryPointType
et@(TypeOpaque Name
desc) : [EntryResult]
epts) [RetTypeMem]
rets = do
        let signs :: [Signedness]
signs = OpaqueTypes -> EntryPointType -> [Signedness]
entryPointSignedness OpaqueTypes
types EntryPointType
et
            n :: Int
n = OpaqueTypes -> EntryPointType -> Int
entryPointSize OpaqueTypes
types EntryPointType
et
            ([RetTypeMem]
rets', [RetTypeMem]
rest) = Int -> [RetTypeMem] -> ([RetTypeMem], [RetTypeMem])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
n [RetTypeMem]
rets
        vds <- [(Int, Signedness, RetTypeMem)]
-> ((Int, Signedness, RetTypeMem) -> ImpM rep r op ValueDesc)
-> ImpM rep r op [ValueDesc]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Int]
-> [Signedness] -> [RetTypeMem] -> [(Int, Signedness, RetTypeMem)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Int
i ..] [Signedness]
signs [RetTypeMem]
rets') (((Int, Signedness, RetTypeMem) -> ImpM rep r op ValueDesc)
 -> ImpM rep r op [ValueDesc])
-> ((Int, Signedness, RetTypeMem) -> ImpM rep r op ValueDesc)
-> ImpM rep r op [ValueDesc]
forall a b. (a -> b) -> a -> b
$ \(Int
j, Signedness
s, RetTypeMem
r) -> Int -> Signedness -> RetTypeMem -> ImpM rep r op ValueDesc
mkValueDesc Int
j Signedness
s RetTypeMem
r
        ((u, Imp.OpaqueValue desc vds) :) <$> mkExts (i + n) epts rest
      mkExts Int
i (EntryResult Uniqueness
u (TypeTransparent (ValueType Signedness
s Rank
_ PrimType
_)) : [EntryResult]
epts) (RetTypeMem
ret : [RetTypeMem]
rets) = do
        vd <- Int -> Signedness -> RetTypeMem -> ImpM rep r op ValueDesc
mkValueDesc Int
i Signedness
s RetTypeMem
ret
        ((u, Imp.TransparentValue vd) :) <$> mkExts (i + 1) epts rets
      mkExts Int
_ [EntryResult]
_ [RetTypeMem]
_ = [(Uniqueness, ExternalValue)]
-> ImpM rep r op [(Uniqueness, ExternalValue)]
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []

  Int
-> [EntryResult]
-> [RetTypeMem]
-> ImpM rep r op [(Uniqueness, ExternalValue)]
mkExts ([RetTypeMem] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [RetTypeMem]
ctx_rts) [EntryResult]
orig_epts [RetTypeMem]
val_rts

compileOutParams ::
  (Mem rep inner) =>
  OpaqueTypes ->
  [RetType rep] ->
  Maybe [EntryResult] ->
  ImpM rep r op (Maybe [(Uniqueness, Imp.ExternalValue)], [Imp.Param], [ValueDestination])
compileOutParams :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
OpaqueTypes
-> [RetType rep]
-> Maybe [EntryResult]
-> ImpM
     rep
     r
     op
     (Maybe [(Uniqueness, ExternalValue)], [Param], [ValueDestination])
compileOutParams OpaqueTypes
types [RetType rep]
orig_rts Maybe [EntryResult]
maybe_orig_epts = do
  (maybe_params, dests) <- (RetTypeMem -> ImpM rep r op (Maybe Param, ValueDestination))
-> [RetTypeMem]
-> ImpM rep r op ([Maybe Param], [ValueDestination])
forall (m :: * -> *) a b c.
Applicative m =>
(a -> m (b, c)) -> [a] -> m ([b], [c])
mapAndUnzipM RetTypeMem -> ImpM rep r op (Maybe Param, ValueDestination)
forall rep r op.
RetTypeMem -> ImpM rep r op (Maybe Param, ValueDestination)
compileOutParam [RetType rep]
[RetTypeMem]
orig_rts
  evs <- case maybe_orig_epts of
    Just [EntryResult]
orig_epts ->
      [(Uniqueness, ExternalValue)]
-> Maybe [(Uniqueness, ExternalValue)]
forall a. a -> Maybe a
Just ([(Uniqueness, ExternalValue)]
 -> Maybe [(Uniqueness, ExternalValue)])
-> ImpM rep r op [(Uniqueness, ExternalValue)]
-> ImpM rep r op (Maybe [(Uniqueness, ExternalValue)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> OpaqueTypes
-> [RetType rep]
-> [EntryResult]
-> [Maybe Param]
-> ImpM rep r op [(Uniqueness, ExternalValue)]
forall rep (inner :: * -> *) r op.
Mem rep inner =>
OpaqueTypes
-> [RetType rep]
-> [EntryResult]
-> [Maybe Param]
-> ImpM rep r op [(Uniqueness, ExternalValue)]
compileExternalValues OpaqueTypes
types [RetType rep]
orig_rts [EntryResult]
orig_epts [Maybe Param]
maybe_params
    Maybe [EntryResult]
Nothing -> Maybe [(Uniqueness, ExternalValue)]
-> ImpM rep r op (Maybe [(Uniqueness, ExternalValue)])
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe [(Uniqueness, ExternalValue)]
forall a. Maybe a
Nothing
  pure (evs, catMaybes maybe_params, dests)

compileFunDef ::
  (Mem rep inner) =>
  OpaqueTypes ->
  FunDef rep ->
  ImpM rep r op ()
compileFunDef :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
OpaqueTypes -> FunDef rep -> ImpM rep r op ()
compileFunDef OpaqueTypes
types (FunDef Maybe EntryPoint
entry Attrs
_ Name
fname [(RetType rep, RetAls)]
rettype [FParam rep]
params Body rep
body) =
  (Env rep r op -> Env rep r op)
-> ImpM rep r op () -> ImpM rep r op ()
forall a.
(Env rep r op -> Env rep r op)
-> ImpM rep r op a -> ImpM rep r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (\Env rep r op
env -> Env rep r op
env {envFunction = name_entry `mplus` Just fname}) (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ do
    ((outparams, inparams, results, args), body') <- ImpM
  rep
  r
  op
  ([Param], [Param], Maybe [(Uniqueness, ExternalValue)],
   Maybe [((Name, Uniqueness), ExternalValue)])
-> ImpM
     rep
     r
     op
     (([Param], [Param], Maybe [(Uniqueness, ExternalValue)],
       Maybe [((Name, Uniqueness), ExternalValue)]),
      Code op)
forall rep r op a. ImpM rep r op a -> ImpM rep r op (a, Code op)
collect' ImpM
  rep
  r
  op
  ([Param], [Param], Maybe [(Uniqueness, ExternalValue)],
   Maybe [((Name, Uniqueness), ExternalValue)])
compile
    let entry' = case (Maybe Name
name_entry, Maybe [(Uniqueness, ExternalValue)]
results, Maybe [((Name, Uniqueness), ExternalValue)]
args) of
          (Just Name
name_entry', Just [(Uniqueness, ExternalValue)]
results', Just [((Name, Uniqueness), ExternalValue)]
args') ->
            EntryPoint -> Maybe EntryPoint
forall a. a -> Maybe a
Just (EntryPoint -> Maybe EntryPoint) -> EntryPoint -> Maybe EntryPoint
forall a b. (a -> b) -> a -> b
$ Name
-> [(Uniqueness, ExternalValue)]
-> [((Name, Uniqueness), ExternalValue)]
-> EntryPoint
Imp.EntryPoint Name
name_entry' [(Uniqueness, ExternalValue)]
results' [((Name, Uniqueness), ExternalValue)]
args'
          (Maybe Name, Maybe [(Uniqueness, ExternalValue)],
 Maybe [((Name, Uniqueness), ExternalValue)])
_ ->
            Maybe EntryPoint
forall a. Maybe a
Nothing
    emitFunction fname $ Imp.Function entry' outparams inparams body'
  where
    (Maybe Name
name_entry, Maybe [EntryParam]
params_entry, Maybe [EntryResult]
ret_entry) = case Maybe EntryPoint
entry of
      Maybe EntryPoint
Nothing -> (Maybe Name
forall a. Maybe a
Nothing, Maybe [EntryParam]
forall a. Maybe a
Nothing, Maybe [EntryResult]
forall a. Maybe a
Nothing)
      Just (Name
x, [EntryParam]
y, [EntryResult]
z) -> (Name -> Maybe Name
forall a. a -> Maybe a
Just Name
x, [EntryParam] -> Maybe [EntryParam]
forall a. a -> Maybe a
Just [EntryParam]
y, [EntryResult] -> Maybe [EntryResult]
forall a. a -> Maybe a
Just [EntryResult]
z)
    compile :: ImpM
  rep
  r
  op
  ([Param], [Param], Maybe [(Uniqueness, ExternalValue)],
   Maybe [((Name, Uniqueness), ExternalValue)])
compile = do
      (inparams, arrayds, args) <- OpaqueTypes
-> [FParam rep]
-> Maybe [EntryParam]
-> ImpM
     rep
     r
     op
     ([Param], [ArrayDecl], Maybe [((Name, Uniqueness), ExternalValue)])
forall rep (inner :: * -> *) r op.
Mem rep inner =>
OpaqueTypes
-> [FParam rep]
-> Maybe [EntryParam]
-> ImpM
     rep
     r
     op
     ([Param], [ArrayDecl], Maybe [((Name, Uniqueness), ExternalValue)])
compileInParams OpaqueTypes
types [FParam rep]
params Maybe [EntryParam]
params_entry
      (results, outparams, dests) <- compileOutParams types (map fst rettype) ret_entry
      addFParams params
      addArrays arrayds

      let Body _ stms ses = body
      compileStms (freeIn ses) stms $
        forM_ (zip dests ses) $
          \(ValueDestination
d, SubExpRes Certs
_ SubExp
se) -> ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIMDest ValueDestination
d [] SubExp
se []

      pure (outparams, inparams, results, args)

compileBody :: Pat (LetDec rep) -> Body rep -> ImpM rep r op ()
compileBody :: forall rep r op. Pat (LetDec rep) -> Body rep -> ImpM rep r op ()
compileBody Pat (LetDec rep)
pat (Body BodyDec rep
_ Stms rep
stms Result
ses) = do
  dests <- Pat (LetDec rep) -> ImpM rep r op [ValueDestination]
forall rep r op.
Pat (LetDec rep) -> ImpM rep r op [ValueDestination]
destinationFromPat Pat (LetDec rep)
pat
  compileStms (freeIn ses) stms $
    forM_ (zip dests ses) $
      \(ValueDestination
d, SubExpRes Certs
_ SubExp
se) -> ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIMDest ValueDestination
d [] SubExp
se []

compileBody' :: [Param dec] -> Body rep -> ImpM rep r op ()
compileBody' :: forall dec rep r op. [Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [Param dec]
params (Body BodyDec rep
_ Stms rep
stms Result
ses) =
  Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms (Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
ses) Stms rep
stms (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
    [(Param dec, SubExpRes)]
-> ((Param dec, SubExpRes) -> ImpM rep r op ()) -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param dec] -> Result -> [(Param dec, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param dec]
params Result
ses) (((Param dec, SubExpRes) -> ImpM rep r op ()) -> ImpM rep r op ())
-> ((Param dec, SubExpRes) -> ImpM rep r op ()) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
      \(Param dec
param, SubExpRes Certs
_ SubExp
se) -> VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
param) [] SubExp
se []

compileLoopBody :: (Typed dec) => [Param dec] -> Body rep -> ImpM rep r op ()
compileLoopBody :: forall dec rep r op.
Typed dec =>
[Param dec] -> Body rep -> ImpM rep r op ()
compileLoopBody [Param dec]
mergeparams (Body BodyDec rep
_ Stms rep
stms Result
ses) = do
  -- We cannot write the results to the merge parameters immediately,
  -- as some of the results may actually *be* merge parameters, and
  -- would thus be clobbered.  Therefore, we first copy to new
  -- variables mirroring the merge parameters, and then copy this
  -- buffer to the merge parameters.  This is efficient, because the
  -- operations are all scalar operations.
  tmpnames <- (Param dec -> ImpM rep r op VName)
-> [Param dec] -> ImpM rep r op [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ([Char] -> ImpM rep r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName ([Char] -> ImpM rep r op VName)
-> (Param dec -> [Char]) -> Param dec -> ImpM rep r op VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"_tmp") ShowS -> (Param dec -> [Char]) -> Param dec -> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> [Char]
baseString (VName -> [Char]) -> (Param dec -> VName) -> Param dec -> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> VName
forall dec. Param dec -> VName
paramName) [Param dec]
mergeparams
  compileStms (freeIn ses) stms $ do
    copy_to_merge_params <- forM (zip3 mergeparams tmpnames ses) $ \(Param dec
p, VName
tmp, SubExpRes Certs
_ SubExp
se) ->
      case Param dec -> Type
forall t. Typed t => t -> Type
typeOf Param dec
p of
        Prim PrimType
pt -> do
          Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
tmp Volatility
Imp.Nonvolatile PrimType
pt
          Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
tmp (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
pt SubExp
se
          ImpM rep r op () -> ImpM rep r op (ImpM rep r op ())
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ImpM rep r op () -> ImpM rep r op (ImpM rep r op ()))
-> ImpM rep r op () -> ImpM rep r op (ImpM rep r op ())
forall a b. (a -> b) -> a -> b
$ Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p) (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
tmp PrimType
pt
        Mem Space
space | Var VName
v <- SubExp
se -> do
          Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Code op
forall a. VName -> Space -> Code a
Imp.DeclareMem VName
tmp Space
space
          Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> Space -> Code op
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
tmp VName
v Space
space
          ImpM rep r op () -> ImpM rep r op (ImpM rep r op ())
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ImpM rep r op () -> ImpM rep r op (ImpM rep r op ()))
-> ImpM rep r op () -> ImpM rep r op (ImpM rep r op ())
forall a b. (a -> b) -> a -> b
$ Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> Space -> Code op
forall a. VName -> VName -> Space -> Code a
Imp.SetMem (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p) VName
tmp Space
space
        Type
_ -> ImpM rep r op () -> ImpM rep r op (ImpM rep r op ())
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ImpM rep r op () -> ImpM rep r op (ImpM rep r op ()))
-> ImpM rep r op () -> ImpM rep r op (ImpM rep r op ())
forall a b. (a -> b) -> a -> b
$ () -> ImpM rep r op ()
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    sequence_ copy_to_merge_params

compileStms :: Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms :: forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
alive_after_stms Stms rep
all_stms ImpM rep r op ()
m = do
  cb <- (Env rep r op -> StmsCompiler rep r op)
-> ImpM rep r op (StmsCompiler rep r op)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> StmsCompiler rep r op
forall rep r op. Env rep r op -> StmsCompiler rep r op
envStmsCompiler
  cb alive_after_stms all_stms m

attachProvenance :: Provenance -> Imp.Code op -> Imp.Code op
attachProvenance :: forall op. Provenance -> Code op -> Code op
attachProvenance Provenance
_ Code op
Imp.Skip = Code op
forall a. Code a
Imp.Skip
attachProvenance Provenance
p Code op
c
  | Provenance
p Provenance -> Provenance -> Bool
forall a. Eq a => a -> a -> Bool
== Provenance
forall a. Monoid a => a
mempty = Code op
c
  | Bool
otherwise = Metadata -> Code op
forall a. Metadata -> Code a
Imp.Meta (Provenance -> Metadata
Imp.MetaProvenance Provenance
p) Code op -> Code op -> Code op
forall a. Semigroup a => a -> a -> a
<> Code op
c

defCompileStms ::
  (Mem rep inner, FreeIn op) =>
  Names ->
  Stms rep ->
  ImpM rep r op () ->
  ImpM rep r op ()
defCompileStms :: forall rep (inner :: * -> *) op r.
(Mem rep inner, FreeIn op) =>
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
defCompileStms Names
alive_after_stms Stms rep
all_stms ImpM rep r op ()
m =
  -- We keep track of any memory blocks produced by the statements,
  -- and after the last time that memory block is used, we insert a
  -- Free.  This is very conservative, but can cut down on lifetimes
  -- in some cases.
  ImpM rep r op Names -> ImpM rep r op ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ImpM rep r op Names -> ImpM rep r op ())
-> ImpM rep r op Names -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Set (VName, Space) -> [Stm rep] -> ImpM rep r op Names
compileStms' Set (VName, Space)
forall a. Monoid a => a
mempty ([Stm rep] -> ImpM rep r op Names)
-> [Stm rep] -> ImpM rep r op Names
forall a b. (a -> b) -> a -> b
$ Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms rep
all_stms
  where
    compileStms' :: Set (VName, Space) -> [Stm rep] -> ImpM rep r op Names
compileStms' Set (VName, Space)
allocs (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux Exp rep
e : [Stm rep]
bs) = do
      e_code <- (Code op -> Code op)
-> ImpM rep r op (Code op) -> ImpM rep r op (Code op)
forall a b. (a -> b) -> ImpM rep r op a -> ImpM rep r op b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Provenance -> Code op -> Code op
forall op. Provenance -> Code op -> Code op
attachProvenance (StmAux (ExpDec rep) -> Provenance
forall dec. StmAux dec -> Provenance
stmAuxLoc StmAux (ExpDec rep)
aux))
        (ImpM rep r op (Code op) -> ImpM rep r op (Code op))
-> (ImpM rep r op () -> ImpM rep r op (Code op))
-> ImpM rep r op ()
-> ImpM rep r op (Code op)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Attrs -> ImpM rep r op (Code op) -> ImpM rep r op (Code op)
forall rep r op a. Attrs -> ImpM rep r op a -> ImpM rep r op a
localAttrs (StmAux (ExpDec rep) -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux (ExpDec rep)
aux)
        (ImpM rep r op (Code op) -> ImpM rep r op (Code op))
-> (ImpM rep r op () -> ImpM rep r op (Code op))
-> ImpM rep r op ()
-> ImpM rep r op (Code op)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpM rep r op () -> ImpM rep r op (Code op)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect
        (ImpM rep r op () -> ImpM rep r op (Code op))
-> (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op ()
-> ImpM rep r op (Code op)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Provenance -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op a. Provenance -> ImpM rep r op a -> ImpM rep r op a
localProvenance (StmAux (ExpDec rep) -> Provenance
forall dec. StmAux dec -> Provenance
stmAuxLoc StmAux (ExpDec rep)
aux)
        (ImpM rep r op () -> ImpM rep r op (Code op))
-> ImpM rep r op () -> ImpM rep r op (Code op)
forall a b. (a -> b) -> a -> b
$ do
          Maybe (Exp rep) -> [PatElem (LetDec rep)] -> ImpM rep r op ()
forall rep (inner :: * -> *) r op.
Mem rep inner =>
Maybe (Exp rep) -> [PatElem (LetDec rep)] -> ImpM rep r op ()
dVars (Exp rep -> Maybe (Exp rep)
forall a. a -> Maybe a
Just Exp rep
e) (Pat (LetDec rep) -> [PatElem (LetDec rep)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat)
          Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()
forall rep r op. Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()
compileExp Pat (LetDec rep)
pat Exp rep
e
      (live_after, bs_code) <- collect' $ compileStms' (patternAllocs pat <> allocs) bs
      let dies_here VName
v =
            (VName
v VName -> Names -> Bool
`notNameIn` Names
live_after) Bool -> Bool -> Bool
&& (VName
v VName -> Names -> Bool
`nameIn` Code op -> Names
forall a. FreeIn a => a -> Names
freeIn Code op
e_code)
          to_free = ((VName, Space) -> Bool)
-> Set (VName, Space) -> Set (VName, Space)
forall a. (a -> Bool) -> Set a -> Set a
S.filter (VName -> Bool
dies_here (VName -> Bool)
-> ((VName, Space) -> VName) -> (VName, Space) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, Space) -> VName
forall a b. (a, b) -> a
fst) Set (VName, Space)
allocs

      emit e_code
      mapM_ (emit . uncurry Imp.Free) to_free
      emit bs_code

      pure $ freeIn e_code <> live_after
    compileStms' Set (VName, Space)
_ [] = do
      code <- ImpM rep r op () -> ImpM rep r op (Code op)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect ImpM rep r op ()
m
      emit code
      pure $ freeIn code <> alive_after_stms

    patternAllocs :: Pat (LetDec rep) -> Set (VName, Space)
patternAllocs = [(VName, Space)] -> Set (VName, Space)
forall a. Ord a => [a] -> Set a
S.fromList ([(VName, Space)] -> Set (VName, Space))
-> (Pat (LetDec rep) -> [(VName, Space)])
-> Pat (LetDec rep)
-> Set (VName, Space)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElem (LetDec rep) -> Maybe (VName, Space))
-> [PatElem (LetDec rep)] -> [(VName, Space)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe PatElem (LetDec rep) -> Maybe (VName, Space)
forall {dec}. Typed dec => PatElem dec -> Maybe (VName, Space)
isMemPatElem ([PatElem (LetDec rep)] -> [(VName, Space)])
-> (Pat (LetDec rep) -> [PatElem (LetDec rep)])
-> Pat (LetDec rep)
-> [(VName, Space)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec rep) -> [PatElem (LetDec rep)]
forall dec. Pat dec -> [PatElem dec]
patElems
    isMemPatElem :: PatElem dec -> Maybe (VName, Space)
isMemPatElem PatElem dec
pe = case PatElem dec -> Type
forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem dec
pe of
      Mem Space
space -> (VName, Space) -> Maybe (VName, Space)
forall a. a -> Maybe a
Just (PatElem dec -> VName
forall dec. PatElem dec -> VName
patElemName PatElem dec
pe, Space
space)
      Type
_ -> Maybe (VName, Space)
forall a. Maybe a
Nothing

compileExp :: Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()
compileExp :: forall rep r op. Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()
compileExp Pat (LetDec rep)
pat Exp rep
e = do
  ec <- (Env rep r op -> Pat (LetDec rep) -> Exp rep -> ImpM rep r op ())
-> ImpM rep r op (Pat (LetDec rep) -> Exp rep -> ImpM rep r op ())
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()
forall rep r op. Env rep r op -> ExpCompiler rep r op
envExpCompiler
  ec pat e

-- | Generate an expression that is true if the subexpressions match
-- the case pasttern.
caseMatch :: [SubExp] -> [Maybe PrimValue] -> Imp.TExp Bool
caseMatch :: [SubExp] -> [Maybe PrimValue] -> TExp Bool
caseMatch [SubExp]
ses [Maybe PrimValue]
vs = (TExp Bool -> TExp Bool -> TExp Bool)
-> TExp Bool -> [TExp Bool] -> TExp Bool
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) TExp Bool
forall v. TPrimExp Bool v
true ((SubExp -> Maybe PrimValue -> TExp Bool)
-> [SubExp] -> [Maybe PrimValue] -> [TExp Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SubExp -> Maybe PrimValue -> TExp Bool
forall {a}. ToExp a => a -> Maybe PrimValue -> TExp Bool
cmp [SubExp]
ses [Maybe PrimValue]
vs)
  where
    cmp :: a -> Maybe PrimValue -> TExp Bool
cmp a
se (Just (BoolValue Bool
True)) =
      Exp -> TExp Bool
forall v. PrimExp v -> TPrimExp Bool v
isBool (Exp -> TExp Bool) -> Exp -> TExp Bool
forall a b. (a -> b) -> a -> b
$ PrimType -> a -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
Bool a
se
    cmp a
se (Just PrimValue
v) =
      Exp -> TExp Bool
forall v. PrimExp v -> TPrimExp Bool v
isBool (Exp -> TExp Bool) -> Exp -> TExp Bool
forall a b. (a -> b) -> a -> b
$ PrimType -> a -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' (PrimValue -> PrimType
primValueType PrimValue
v) a
se Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
~==~ PrimValue -> Exp
forall v. PrimValue -> PrimExp v
ValueExp PrimValue
v
    cmp a
_ Maybe PrimValue
Nothing = TExp Bool
forall v. TPrimExp Bool v
true

defCompileExp ::
  (Mem rep inner) =>
  Pat (LetDec rep) ->
  Exp rep ->
  ImpM rep r op ()
defCompileExp :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()
defCompileExp Pat (LetDec rep)
pat (Match [SubExp]
ses [Case (Body rep)]
cases Body rep
defbody MatchDec (BranchType rep)
_) =
  (Case (Body rep) -> ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> [Case (Body rep)] -> ImpM rep r op ()
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Case (Body rep) -> ImpM rep r op () -> ImpM rep r op ()
f (Pat (LetDec rep) -> Body rep -> ImpM rep r op ()
forall rep r op. Pat (LetDec rep) -> Body rep -> ImpM rep r op ()
compileBody Pat (LetDec rep)
pat Body rep
defbody) [Case (Body rep)]
cases
  where
    f :: Case (Body rep) -> ImpM rep r op () -> ImpM rep r op ()
f (Case [Maybe PrimValue]
vs Body rep
body) = TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf ([SubExp] -> [Maybe PrimValue] -> TExp Bool
caseMatch [SubExp]
ses [Maybe PrimValue]
vs) (Pat (LetDec rep) -> Body rep -> ImpM rep r op ()
forall rep r op. Pat (LetDec rep) -> Body rep -> ImpM rep r op ()
compileBody Pat (LetDec rep)
pat Body rep
body)
defCompileExp Pat (LetDec rep)
pat (Apply Name
fname [(SubExp, Diet)]
args [(RetType rep, RetAls)]
_ Safety
_) = do
  dest <- Pat (LetDec rep) -> ImpM rep r op [ValueDestination]
forall rep r op.
Pat (LetDec rep) -> ImpM rep r op [ValueDestination]
destinationFromPat Pat (LetDec rep)
pat
  targets <- funcallTargets dest
  args' <- catMaybes <$> mapM compileArg args
  emit $ Imp.Call targets fname args'
  where
    compileArg :: (SubExp, b) -> m (Maybe Arg)
compileArg (SubExp
se, b
_) = do
      t <- SubExp -> m Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
se
      case (se, t) of
        (SubExp
_, Prim PrimType
pt) -> Maybe Arg -> m (Maybe Arg)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Arg -> m (Maybe Arg)) -> Maybe Arg -> m (Maybe Arg)
forall a b. (a -> b) -> a -> b
$ Arg -> Maybe Arg
forall a. a -> Maybe a
Just (Arg -> Maybe Arg) -> Arg -> Maybe Arg
forall a b. (a -> b) -> a -> b
$ Exp -> Arg
Imp.ExpArg (Exp -> Arg) -> Exp -> Arg
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
pt SubExp
se
        (Var VName
v, Mem {}) -> Maybe Arg -> m (Maybe Arg)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Arg -> m (Maybe Arg)) -> Maybe Arg -> m (Maybe Arg)
forall a b. (a -> b) -> a -> b
$ Arg -> Maybe Arg
forall a. a -> Maybe a
Just (Arg -> Maybe Arg) -> Arg -> Maybe Arg
forall a b. (a -> b) -> a -> b
$ VName -> Arg
Imp.MemArg VName
v
        (SubExp, Type)
_ -> Maybe Arg -> m (Maybe Arg)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe Arg
forall a. Maybe a
Nothing
defCompileExp Pat (LetDec rep)
pat (BasicOp BasicOp
op) = Pat (LetDec rep) -> BasicOp -> ImpM rep r op ()
forall rep (inner :: * -> *) r op.
Mem rep inner =>
Pat (LetDec rep) -> BasicOp -> ImpM rep r op ()
defCompileBasicOp Pat (LetDec rep)
pat BasicOp
op
defCompileExp Pat (LetDec rep)
pat (Loop [(FParam rep, SubExp)]
merge LoopForm
form Body rep
body) = do
  attrs <- ImpM rep r op Attrs
forall rep r op. ImpM rep r op Attrs
askAttrs
  when ("unroll" `inAttrs` attrs) $
    warn (noLoc :: SrcLoc) [] "#[unroll] on loop with unknown number of iterations." -- FIXME: no location.
  dFParams params
  forM_ merge $ \(Param FParamMem
p, SubExp
se) ->
    Bool -> ImpM rep r op () -> ImpM rep r op ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ((Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) (Int -> Bool) -> Int -> Bool
forall a b. (a -> b) -> a -> b
$ Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (Type -> Int) -> Type -> Int
forall a b. (a -> b) -> a -> b
$ Param FParamMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param FParamMem
p) (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
      VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
p) [] SubExp
se []

  let doBody = [Param FParamMem] -> Body rep -> ImpM rep r op ()
forall dec rep r op.
Typed dec =>
[Param dec] -> Body rep -> ImpM rep r op ()
compileLoopBody [Param FParamMem]
params Body rep
body

  case form of
    ForLoop VName
i IntType
_ SubExp
bound -> do
      bound' <- SubExp -> ImpM rep r op Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
forall rep r op. SubExp -> ImpM rep r op Exp
toExp SubExp
bound
      sFor' i bound' doBody
    WhileLoop VName
cond ->
      TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhile (Exp -> TExp Bool
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp Bool) -> Exp -> TExp Bool
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
cond PrimType
Bool) ImpM rep r op ()
doBody

  pat_dests <- destinationFromPat pat
  forM_ (zip pat_dests $ map (Var . paramName . fst) merge) $ \(ValueDestination
d, SubExp
r) ->
    ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIMDest ValueDestination
d [] SubExp
r []
  where
    params :: [Param FParamMem]
params = ((Param FParamMem, SubExp) -> Param FParamMem)
-> [(Param FParamMem, SubExp)] -> [Param FParamMem]
forall a b. (a -> b) -> [a] -> [b]
map (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst [(FParam rep, SubExp)]
[(Param FParamMem, SubExp)]
merge
defCompileExp Pat (LetDec rep)
pat (WithAcc [WithAccInput rep]
inputs Lambda rep
lam) = do
  [LParam rep] -> ImpM rep r op ()
forall rep (inner :: * -> *) r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams ([LParam rep] -> ImpM rep r op ())
-> [LParam rep] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [LParam rep]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam
  [(WithAccInput rep, Param LParamMem)]
-> ((WithAccInput rep, Param LParamMem) -> ImpM rep r op ())
-> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([WithAccInput rep]
-> [Param LParamMem] -> [(WithAccInput rep, Param LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [WithAccInput rep]
inputs ([Param LParamMem] -> [(WithAccInput rep, Param LParamMem)])
-> [Param LParamMem] -> [(WithAccInput rep, Param LParamMem)]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [LParam rep]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam) (((WithAccInput rep, Param LParamMem) -> ImpM rep r op ())
 -> ImpM rep r op ())
-> ((WithAccInput rep, Param LParamMem) -> ImpM rep r op ())
-> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \((Shape
_, [VName]
arrs, Maybe (Lambda rep, [SubExp])
op), Param LParamMem
p) ->
    (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ())
-> (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s ->
      ImpState rep r op
s {stateAccs = M.insert (paramName p) (arrs, op) $ stateAccs s}
  Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms (Body rep -> Stms rep) -> Body rep -> Stms rep
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam) (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ do
    let nonacc_res :: Result
nonacc_res = Int -> Result -> Result
forall a. Int -> [a] -> [a]
drop Int
num_accs (Body rep -> Result
forall rep. Body rep -> Result
bodyResult (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam))
        nonacc_pat_names :: [VName]
nonacc_pat_names = Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
takeLast (Result -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
nonacc_res) (Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat)
    [(VName, SubExpRes)]
-> ((VName, SubExpRes) -> ImpM rep r op ()) -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> Result -> [(VName, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
nonacc_pat_names Result
nonacc_res) (((VName, SubExpRes) -> ImpM rep r op ()) -> ImpM rep r op ())
-> ((VName, SubExpRes) -> ImpM rep r op ()) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \(VName
v, SubExpRes Certs
_ SubExp
se) ->
      VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM VName
v [] SubExp
se []
  where
    num_accs :: Int
num_accs = [WithAccInput rep] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput rep]
inputs
defCompileExp Pat (LetDec rep)
pat (Op Op rep
op) = do
  opc <- (Env rep r op
 -> Pat (LetDec rep) -> MemOp inner rep -> ImpM rep r op ())
-> ImpM
     rep r op (Pat (LetDec rep) -> MemOp inner rep -> ImpM rep r op ())
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> OpCompiler rep r op
Env rep r op
-> Pat (LetDec rep) -> MemOp inner rep -> ImpM rep r op ()
forall rep r op. Env rep r op -> OpCompiler rep r op
envOpCompiler
  opc pat op

tracePrim :: T.Text -> PrimType -> SubExp -> ImpM rep r op ()
tracePrim :: forall rep r op. Text -> PrimType -> SubExp -> ImpM rep r op ()
tracePrim Text
s PrimType
t SubExp
se =
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ())
-> (ErrorMsg Exp -> Code op) -> ErrorMsg Exp -> ImpM rep r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ErrorMsg Exp -> Code op
forall a. ErrorMsg Exp -> Code a
Imp.TracePrint (ErrorMsg Exp -> ImpM rep r op ())
-> ErrorMsg Exp -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
    [ErrorMsgPart Exp] -> ErrorMsg Exp
forall a. [ErrorMsgPart a] -> ErrorMsg a
ErrorMsg [Text -> ErrorMsgPart Exp
forall a. Text -> ErrorMsgPart a
ErrorString (Text
s Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
": "), PrimType -> Exp -> ErrorMsgPart Exp
forall a. PrimType -> a -> ErrorMsgPart a
ErrorVal PrimType
t (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
t SubExp
se), Text -> ErrorMsgPart Exp
forall a. Text -> ErrorMsgPart a
ErrorString Text
"\n"]

traceArray :: T.Text -> PrimType -> Shape -> SubExp -> ImpM rep r op ()
traceArray :: forall rep r op.
Text -> PrimType -> Shape -> SubExp -> ImpM rep r op ()
traceArray Text
s PrimType
t Shape
shape SubExp
se = do
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ())
-> (ErrorMsg Exp -> Code op) -> ErrorMsg Exp -> ImpM rep r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ErrorMsg Exp -> Code op
forall a. ErrorMsg Exp -> Code a
Imp.TracePrint (ErrorMsg Exp -> ImpM rep r op ())
-> ErrorMsg Exp -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ [ErrorMsgPart Exp] -> ErrorMsg Exp
forall a. [ErrorMsgPart a] -> ErrorMsg a
ErrorMsg [Text -> ErrorMsgPart Exp
forall a. Text -> ErrorMsgPart a
ErrorString (Text
s Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
": ")]
  Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
forall rep r op.
Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest Shape
shape (([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ())
-> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
is -> do
    arr_elem <- [Char] -> PrimType -> ImpM rep r op VName
forall rep r op. [Char] -> PrimType -> ImpM rep r op VName
dPrimS [Char]
"arr_elem" PrimType
t
    copyDWIMFix arr_elem [] se is
    emit . Imp.TracePrint $ ErrorMsg [ErrorVal t (toExp' t arr_elem), " "]
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ())
-> (ErrorMsg Exp -> Code op) -> ErrorMsg Exp -> ImpM rep r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ErrorMsg Exp -> Code op
forall a. ErrorMsg Exp -> Code a
Imp.TracePrint (ErrorMsg Exp -> ImpM rep r op ())
-> ErrorMsg Exp -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ [ErrorMsgPart Exp] -> ErrorMsg Exp
forall a. [ErrorMsgPart a] -> ErrorMsg a
ErrorMsg [ErrorMsgPart Exp
"\n"]

defCompileBasicOp ::
  (Mem rep inner) =>
  Pat (LetDec rep) ->
  BasicOp ->
  ImpM rep r op ()
defCompileBasicOp :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
Pat (LetDec rep) -> BasicOp -> ImpM rep r op ()
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (SubExp SubExp
se) =
  VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) [] SubExp
se []
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (Opaque OpaqueOp
op SubExp
se) = do
  VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) [] SubExp
se []
  case OpaqueOp
op of
    OpaqueOp
OpaqueNil -> () -> ImpM rep r op ()
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    OpaqueTrace Text
s -> Text -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment (Text
"Trace: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
s) (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ do
      se_t <- SubExp -> ImpM rep r op Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
se
      case se_t of
        Prim PrimType
t -> Text -> PrimType -> SubExp -> ImpM rep r op ()
forall rep r op. Text -> PrimType -> SubExp -> ImpM rep r op ()
tracePrim Text
s PrimType
t SubExp
se
        Array PrimType
t Shape
shape NoUniqueness
_ -> Text -> PrimType -> Shape -> SubExp -> ImpM rep r op ()
forall rep r op.
Text -> PrimType -> Shape -> SubExp -> ImpM rep r op ()
traceArray Text
s PrimType
t Shape
shape SubExp
se
        Type
_ ->
          [SrcLoc] -> [[SrcLoc]] -> Text -> ImpM rep r op ()
forall loc rep r op.
Located loc =>
loc -> [loc] -> Text -> ImpM rep r op ()
warn [SrcLoc
forall a. Monoid a => a
mempty :: SrcLoc] [[SrcLoc]]
forall a. Monoid a => a
mempty (Text -> ImpM rep r op ()) -> Text -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
            Text
s Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
": cannot trace value of this (core) type: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Type -> Text
forall a. Pretty a => a -> Text
prettyText Type
se_t
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (UnOp UnOp
op SubExp
e) = do
  e' <- SubExp -> ImpM rep r op Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
forall rep r op. SubExp -> ImpM rep r op Exp
toExp SubExp
e
  patElemName pe <~~ Imp.UnOpExp op e'
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (ConvOp ConvOp
conv SubExp
e) = do
  e' <- SubExp -> ImpM rep r op Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
forall rep r op. SubExp -> ImpM rep r op Exp
toExp SubExp
e
  patElemName pe <~~ Imp.ConvOpExp conv e'
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (BinOp BinOp
bop SubExp
x SubExp
y) = do
  x' <- SubExp -> ImpM rep r op Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
forall rep r op. SubExp -> ImpM rep r op Exp
toExp SubExp
x
  y' <- toExp y
  patElemName pe <~~ Imp.BinOpExp bop x' y'
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (CmpOp CmpOp
bop SubExp
x SubExp
y) = do
  x' <- SubExp -> ImpM rep r op Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
forall rep r op. SubExp -> ImpM rep r op Exp
toExp SubExp
x
  y' <- toExp y
  patElemName pe <~~ Imp.CmpOpExp bop x' y'
defCompileBasicOp Pat (LetDec rep)
_ (Assert SubExp
e ErrorMsg SubExp
msg) = do
  e' <- SubExp -> ImpM rep r op Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
forall rep r op. SubExp -> ImpM rep r op Exp
toExp SubExp
e
  msg' <- traverse toExp msg
  Imp.Provenance locs loc <- askProvenance
  emit $ Imp.Assert e' msg' (loc, reverse locs)

  attrs <- askAttrs
  when (AttrComp "warn" ["safety_checks"] `inAttrs` attrs) $
    warn loc (reverse locs) "Safety check required at run-time."
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (Index VName
src Slice SubExp
slice)
  | Just [SubExp]
idxs <- Slice SubExp -> Maybe [SubExp]
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
slice =
      VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) [] (VName -> SubExp
Var VName
src) ([DimIndex (TExp Int64)] -> ImpM rep r op ())
-> [DimIndex (TExp Int64)] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ (SubExp -> DimIndex (TExp Int64))
-> [SubExp] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix (TExp Int64 -> DimIndex (TExp Int64))
-> (SubExp -> TExp Int64) -> SubExp -> DimIndex (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TExp Int64
pe64) [SubExp]
idxs
defCompileBasicOp Pat (LetDec rep)
_ Index {} =
  () -> ImpM rep r op ()
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (Update Safety
safety VName
_ Slice SubExp
slice SubExp
se) =
  case Safety
safety of
    Safety
Unsafe -> ImpM rep r op ()
write
    Safety
Safe -> TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (Slice (TExp Int64) -> [TExp Int64] -> TExp Bool
inBounds Slice (TExp Int64)
slice' [TExp Int64]
dims) ImpM rep r op ()
write
  where
    slice' :: Slice (TExp Int64)
slice' = (SubExp -> TExp Int64) -> Slice SubExp -> Slice (TExp Int64)
forall a b. (a -> b) -> Slice a -> Slice b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
pe64 Slice SubExp
slice
    dims :: [TExp Int64]
dims = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp]) -> Type -> [SubExp]
forall a b. (a -> b) -> a -> b
$ PatElem (LetDec rep) -> Type
forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem (LetDec rep)
pe
    write :: ImpM rep r op ()
write = VName -> Slice (TExp Int64) -> SubExp -> ImpM rep r op ()
forall rep r op.
VName -> Slice (TExp Int64) -> SubExp -> ImpM rep r op ()
sUpdate (PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) Slice (TExp Int64)
slice' SubExp
se
defCompileBasicOp Pat (LetDec rep)
_ FlatIndex {} =
  () -> ImpM rep r op ()
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (FlatUpdate VName
_ FlatSlice SubExp
slice VName
v) = do
  pe_loc <- ArrayEntry -> MemLoc
entryArrayLoc (ArrayEntry -> MemLoc)
-> ImpM rep r op ArrayEntry -> ImpM rep r op MemLoc
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM rep r op ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray (PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe)
  v_loc <- entryArrayLoc <$> lookupArray v
  let pe_loc' = MemLoc -> FlatSlice (TExp Int64) -> MemLoc
flatSliceMemLoc MemLoc
pe_loc (FlatSlice (TExp Int64) -> MemLoc)
-> FlatSlice (TExp Int64) -> MemLoc
forall a b. (a -> b) -> a -> b
$ (SubExp -> TExp Int64)
-> FlatSlice SubExp -> FlatSlice (TExp Int64)
forall a b. (a -> b) -> FlatSlice a -> FlatSlice b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
pe64 FlatSlice SubExp
slice
  copy (elemType (patElemType pe)) pe_loc' v_loc
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (Replicate Shape
shape SubExp
se)
  | Acc {} <- PatElem (LetDec rep) -> Type
forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem (LetDec rep)
pe = () -> ImpM rep r op ()
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  | Shape
shape Shape -> Shape -> Bool
forall a. Eq a => a -> a -> Bool
== Shape
forall a. Monoid a => a
mempty =
      VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) [] SubExp
se []
  | Bool
otherwise =
      Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
forall rep r op.
Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest Shape
shape (([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ())
-> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
is -> VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) [TExp Int64]
is SubExp
se []
defCompileBasicOp Pat (LetDec rep)
_ Scratch {} =
  () -> ImpM rep r op ()
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (Iota SubExp
n SubExp
e SubExp
s IntType
it) = do
  e' <- SubExp -> ImpM rep r op Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
forall rep r op. SubExp -> ImpM rep r op Exp
toExp SubExp
e
  s' <- toExp s
  sFor "i" (pe64 n) $ \TExp Int64
i -> do
    let i' :: Exp
i' = IntType -> Exp -> Exp
forall v. IntType -> PrimExp v -> PrimExp v
sExt IntType
it (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int64
i
    x <-
      [Char] -> TExp (ZonkAny 5) -> ImpM rep r op (TV (ZonkAny 5))
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"x" (TExp (ZonkAny 5) -> ImpM rep r op (TV (ZonkAny 5)))
-> (Exp -> TExp (ZonkAny 5))
-> Exp
-> ImpM rep r op (TV (ZonkAny 5))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> TExp (ZonkAny 5)
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> ImpM rep r op (TV (ZonkAny 5)))
-> Exp -> ImpM rep r op (TV (ZonkAny 5))
forall a b. (a -> b) -> a -> b
$
        BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> Overflow -> BinOp
Add IntType
it Overflow
OverflowUndef) Exp
e' (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
          BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> Overflow -> BinOp
Mul IntType
it Overflow
OverflowUndef) Exp
i' Exp
s'
    copyDWIMFix (patElemName pe) [i] (Var (tvVar x)) []
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (Manifest VName
src [Int]
_) =
  VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) [] (VName -> SubExp
Var VName
src) []
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (Concat Int
i (VName
x :| [VName]
ys) SubExp
_) = do
  offs_glb <- [Char] -> TExp Int64 -> ImpM rep r op (TV Int64)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"tmp_offs" TExp Int64
0

  forM_ (x : ys) $ \VName
y -> do
    y_dims <- Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp]) -> ImpM rep r op Type -> ImpM rep r op [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM rep r op Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
y
    let rows = case Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop Int
i [SubExp]
y_dims of
          [] -> [Char] -> TExp Int64
forall a. HasCallStack => [Char] -> a
error ([Char] -> TExp Int64) -> [Char] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [Char]
"defCompileBasicOp Concat: empty array shape for " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString VName
y
          SubExp
r : [SubExp]
_ -> SubExp -> TExp Int64
pe64 SubExp
r
        skip_dims = Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
take Int
i [SubExp]
y_dims
        sliceAllDim d
d = d -> d -> d -> DimIndex d
forall d. d -> d -> d -> DimIndex d
DimSlice d
0 d
d d
1
        skip_slices = (SubExp -> DimIndex (TExp Int64))
-> [SubExp] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> DimIndex (TExp Int64)
forall {d}. Num d => d -> DimIndex d
sliceAllDim (TExp Int64 -> DimIndex (TExp Int64))
-> (SubExp -> TExp Int64) -> SubExp -> DimIndex (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TExp Int64
pe64) [SubExp]
skip_dims
        destslice = [DimIndex (TExp Int64)]
skip_slices [DimIndex (TExp Int64)]
-> [DimIndex (TExp Int64)] -> [DimIndex (TExp Int64)]
forall a. [a] -> [a] -> [a]
++ [TExp Int64 -> TExp Int64 -> TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> d -> d -> DimIndex d
DimSlice (TV Int64 -> TExp Int64
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
offs_glb) TExp Int64
rows TExp Int64
1]
    copyDWIM (patElemName pe) destslice (Var y) []
    offs_glb <-- tvExp offs_glb + rows
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (ArrayVal [PrimValue]
vs PrimType
t) = do
  dest_mem <- ArrayEntry -> MemLoc
entryArrayLoc (ArrayEntry -> MemLoc)
-> ImpM rep r op ArrayEntry -> ImpM rep r op MemLoc
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM rep r op ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray (PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe)
  static_array <- newVNameForFun "static_array"
  emit $ Imp.DeclareArray static_array t $ Imp.ArrayValues vs
  let static_src =
        VName -> [SubExp] -> LMAD (TExp Int64) -> MemLoc
MemLoc VName
static_array [IntType -> Integer -> SubExp
intConst IntType
Int64 (Integer -> SubExp) -> Integer -> SubExp
forall a b. (a -> b) -> a -> b
$ Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Integer) -> Int -> Integer
forall a b. (a -> b) -> a -> b
$ [PrimValue] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimValue]
vs] (LMAD (TExp Int64) -> MemLoc) -> LMAD (TExp Int64) -> MemLoc
forall a b. (a -> b) -> a -> b
$
          TExp Int64 -> [TExp Int64] -> LMAD (TExp Int64)
forall num. IntegralExp num => num -> [num] -> LMAD num
LMAD.iota TExp Int64
0 [Int -> TExp Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> TExp Int64) -> Int -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [PrimValue] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimValue]
vs]
  addVar static_array $ MemVar Nothing $ MemEntry DefaultSpace
  copy t dest_mem static_src
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (ArrayLit [SubExp]
es Type
_)
  | Just vs :: [PrimValue]
vs@(PrimValue
v : [PrimValue]
_) <- (SubExp -> Maybe PrimValue) -> [SubExp] -> Maybe [PrimValue]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> Maybe PrimValue
isLiteral [SubExp]
es = do
      let t :: PrimType
t = PrimValue -> PrimType
primValueType PrimValue
v
      Pat (LetDec rep) -> BasicOp -> ImpM rep r op ()
forall rep (inner :: * -> *) r op.
Mem rep inner =>
Pat (LetDec rep) -> BasicOp -> ImpM rep r op ()
defCompileBasicOp ([PatElem (LetDec rep)] -> Pat (LetDec rep)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)
pe]) ([PrimValue] -> PrimType -> BasicOp
ArrayVal [PrimValue]
vs PrimType
t)
  | Bool
otherwise =
      [(Integer, SubExp)]
-> ((Integer, SubExp) -> ImpM rep r op ()) -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Integer] -> [SubExp] -> [(Integer, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Integer
0 ..] [SubExp]
es) (((Integer, SubExp) -> ImpM rep r op ()) -> ImpM rep r op ())
-> ((Integer, SubExp) -> ImpM rep r op ()) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \(Integer
i, SubExp
e) ->
        VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) [Integer -> TExp Int64
forall a. Num a => Integer -> a
fromInteger Integer
i] SubExp
e []
  where
    isLiteral :: SubExp -> Maybe PrimValue
isLiteral (Constant PrimValue
v) = PrimValue -> Maybe PrimValue
forall a. a -> Maybe a
Just PrimValue
v
    isLiteral SubExp
_ = Maybe PrimValue
forall a. Maybe a
Nothing
defCompileBasicOp Pat (LetDec rep)
_ Rearrange {} =
  () -> ImpM rep r op ()
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
defCompileBasicOp Pat (LetDec rep)
_ Reshape {} =
  () -> ImpM rep r op ()
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
defCompileBasicOp Pat (LetDec rep)
_ (UpdateAcc Safety
safety VName
acc [SubExp]
is [SubExp]
vs) = Text -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"UpdateAcc" (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ do
  let is' :: [TExp Int64]
is' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
is

  -- We need to figure out whether we are updating a scatter-like
  -- accumulator or a generalised reduction.  This also binds the
  -- index parameters.
  (_, _, arrs, dims, op) <- VName
-> [TExp Int64]
-> ImpM
     rep r op (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep))
forall rep (inner :: * -> *) r op.
Mem rep inner =>
VName
-> [TExp Int64]
-> ImpM
     rep r op (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep))
lookupAcc VName
acc [TExp Int64]
is'

  let boundsCheck =
        case Safety
safety of
          Safety
Safe -> TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (Slice (TExp Int64) -> [TExp Int64] -> TExp Bool
inBounds ([DimIndex (TExp Int64)] -> Slice (TExp Int64)
forall d. [DimIndex d] -> Slice d
Slice ((TExp Int64 -> DimIndex (TExp Int64))
-> [TExp Int64] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix [TExp Int64]
is')) [TExp Int64]
dims)
          Safety
_ -> ImpM rep r op () -> ImpM rep r op ()
forall a. a -> a
id
  boundsCheck $
    case op of
      Maybe (Lambda rep)
Nothing ->
        -- Scatter-like.
        [(VName, SubExp)]
-> ((VName, SubExp) -> ImpM rep r op ()) -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
arrs [SubExp]
vs) (((VName, SubExp) -> ImpM rep r op ()) -> ImpM rep r op ())
-> ((VName, SubExp) -> ImpM rep r op ()) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \(VName
arr, SubExp
v) -> VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
arr [TExp Int64]
is' SubExp
v []
      Just Lambda rep
lam -> do
        -- Generalised reduction.
        [LParam rep] -> ImpM rep r op ()
forall rep (inner :: * -> *) r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams ([LParam rep] -> ImpM rep r op ())
-> [LParam rep] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [LParam rep]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam
        let ([VName]
x_params, [VName]
y_params) =
              Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs) ([VName] -> ([VName], [VName])) -> [VName] -> ([VName], [VName])
forall a b. (a -> b) -> a -> b
$ (Param LParamMem -> VName) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param LParamMem -> VName
forall dec. Param dec -> VName
paramName ([Param LParamMem] -> [VName]) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [LParam rep]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam

        [(VName, VName)]
-> ((VName, VName) -> ImpM rep r op ()) -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
x_params [VName]
arrs) (((VName, VName) -> ImpM rep r op ()) -> ImpM rep r op ())
-> ((VName, VName) -> ImpM rep r op ()) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \(VName
xp, VName
arr) ->
          VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
xp [] (VName -> SubExp
Var VName
arr) [TExp Int64]
is'

        [(VName, SubExp)]
-> ((VName, SubExp) -> ImpM rep r op ()) -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
y_params [SubExp]
vs) (((VName, SubExp) -> ImpM rep r op ()) -> ImpM rep r op ())
-> ((VName, SubExp) -> ImpM rep r op ()) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \(VName
yp, SubExp
v) ->
          VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM VName
yp [] SubExp
v []

        Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms (Body rep -> Stms rep) -> Body rep -> Stms rep
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam) (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
          [(VName, SubExpRes)]
-> ((VName, SubExpRes) -> ImpM rep r op ()) -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> Result -> [(VName, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
arrs (Body rep -> Result
forall rep. Body rep -> Result
bodyResult (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam))) (((VName, SubExpRes) -> ImpM rep r op ()) -> ImpM rep r op ())
-> ((VName, SubExpRes) -> ImpM rep r op ()) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \(VName
arr, SubExpRes Certs
_ SubExp
se) ->
            VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
arr [TExp Int64]
is' SubExp
se []
defCompileBasicOp Pat (LetDec rep)
pat BasicOp
e =
  [Char] -> ImpM rep r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op ()) -> [Char] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
    [Char]
"ImpGen.defCompileBasicOp: Invalid pattern\n  "
      [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Pat (LetDec rep) -> [Char]
forall a. Pretty a => a -> [Char]
prettyString Pat (LetDec rep)
pat
      [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"\nfor expression\n  "
      [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ BasicOp -> [Char]
forall a. Pretty a => a -> [Char]
prettyString BasicOp
e

-- | Note: a hack to be used only for functions.
addArrays :: [ArrayDecl] -> ImpM rep r op ()
addArrays :: forall rep r op. [ArrayDecl] -> ImpM rep r op ()
addArrays = (ArrayDecl -> ImpM rep r op ()) -> [ArrayDecl] -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ArrayDecl -> ImpM rep r op ()
forall {rep} {r} {op}. ArrayDecl -> ImpM rep r op ()
addArray
  where
    addArray :: ArrayDecl -> ImpM rep r op ()
addArray (ArrayDecl VName
name PrimType
bt MemLoc
location) =
      VName -> VarEntry rep -> ImpM rep r op ()
forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name (VarEntry rep -> ImpM rep r op ())
-> VarEntry rep -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
        Maybe (Exp rep) -> ArrayEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> ArrayEntry -> VarEntry rep
ArrayVar
          Maybe (Exp rep)
forall a. Maybe a
Nothing
          ArrayEntry
            { entryArrayLoc :: MemLoc
entryArrayLoc = MemLoc
location,
              entryArrayElemType :: PrimType
entryArrayElemType = PrimType
bt
            }

-- | Like 'dFParams', but does not create new declarations.
-- Note: a hack to be used only for functions.
addFParams :: (Mem rep inner) => [FParam rep] -> ImpM rep r op ()
addFParams :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
[FParam rep] -> ImpM rep r op ()
addFParams = (Param FParamMem -> ImpM rep r op ())
-> [Param FParamMem] -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Param FParamMem -> ImpM rep r op ()
forall {u} {rep} {r} {op}.
Param (MemInfo SubExp u MemBind) -> ImpM rep r op ()
addFParam
  where
    addFParam :: Param (MemInfo SubExp u MemBind) -> ImpM rep r op ()
addFParam Param (MemInfo SubExp u MemBind)
fparam =
      VName -> VarEntry rep -> ImpM rep r op ()
forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar (Param (MemInfo SubExp u MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp u MemBind)
fparam) (VarEntry rep -> ImpM rep r op ())
-> VarEntry rep -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
        Maybe (Exp rep) -> LParamMem -> VarEntry rep
forall rep. Maybe (Exp rep) -> LParamMem -> VarEntry rep
memBoundToVarEntry Maybe (Exp rep)
forall a. Maybe a
Nothing (LParamMem -> VarEntry rep) -> LParamMem -> VarEntry rep
forall a b. (a -> b) -> a -> b
$
          MemInfo SubExp u MemBind -> LParamMem
forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns (MemInfo SubExp u MemBind -> LParamMem)
-> MemInfo SubExp u MemBind -> LParamMem
forall a b. (a -> b) -> a -> b
$
            Param (MemInfo SubExp u MemBind) -> MemInfo SubExp u MemBind
forall dec. Param dec -> dec
paramDec Param (MemInfo SubExp u MemBind)
fparam

-- | Another hack.
addLoopVar :: VName -> IntType -> ImpM rep r op ()
addLoopVar :: forall rep r op. VName -> IntType -> ImpM rep r op ()
addLoopVar VName
i IntType
it = VName -> VarEntry rep -> ImpM rep r op ()
forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
i (VarEntry rep -> ImpM rep r op ())
-> VarEntry rep -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
ScalarVar Maybe (Exp rep)
forall a. Maybe a
Nothing (ScalarEntry -> VarEntry rep) -> ScalarEntry -> VarEntry rep
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry (PrimType -> ScalarEntry) -> PrimType -> ScalarEntry
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it

dVars ::
  (Mem rep inner) =>
  Maybe (Exp rep) ->
  [PatElem (LetDec rep)] ->
  ImpM rep r op ()
dVars :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
Maybe (Exp rep) -> [PatElem (LetDec rep)] -> ImpM rep r op ()
dVars Maybe (Exp rep)
e = (PatElem (LetDec rep) -> ImpM rep r op ())
-> [PatElem (LetDec rep)] -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ PatElem (LetDec rep) -> ImpM rep r op ()
dVar
  where
    dVar :: PatElem (LetDec rep) -> ImpM rep r op ()
dVar = Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
forall rep (inner :: * -> *) r op.
Mem rep inner =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope Maybe (Exp rep)
e (Scope rep -> ImpM rep r op ())
-> (PatElem (LetDec rep) -> Scope rep)
-> PatElem (LetDec rep)
-> ImpM rep r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem (LetDec rep) -> Scope rep
forall rep dec. (LetDec rep ~ dec) => PatElem dec -> Scope rep
scopeOfPatElem

dFParams :: (Mem rep inner) => [FParam rep] -> ImpM rep r op ()
dFParams :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
[FParam rep] -> ImpM rep r op ()
dFParams = Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
forall rep (inner :: * -> *) r op.
Mem rep inner =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope Maybe (Exp rep)
forall a. Maybe a
Nothing (Scope rep -> ImpM rep r op ())
-> ([Param FParamMem] -> Scope rep)
-> [Param FParamMem]
-> ImpM rep r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Param FParamMem] -> Scope rep
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams

dLParams :: (Mem rep inner) => [LParam rep] -> ImpM rep r op ()
dLParams :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams = Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
forall rep (inner :: * -> *) r op.
Mem rep inner =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope Maybe (Exp rep)
forall a. Maybe a
Nothing (Scope rep -> ImpM rep r op ())
-> ([Param LParamMem] -> Scope rep)
-> [Param LParamMem]
-> ImpM rep r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Param LParamMem] -> Scope rep
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams

dPrimVol :: String -> PrimType -> Imp.TExp t -> ImpM rep r op (TV t)
dPrimVol :: forall {k} (t :: k) rep r op.
[Char] -> PrimType -> TExp t -> ImpM rep r op (TV t)
dPrimVol [Char]
name PrimType
t TExp t
e = do
  name' <- [Char] -> ImpM rep r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
name
  emit $ Imp.DeclareScalar name' Imp.Volatile t
  addVar name' $ ScalarVar Nothing $ ScalarEntry t
  name' <~~ untyped e
  pure $ TV name' t

dPrim_ :: VName -> PrimType -> ImpM rep r op ()
dPrim_ :: forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
name PrimType
t = do
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
name Volatility
Imp.Nonvolatile PrimType
t
  VName -> VarEntry rep -> ImpM rep r op ()
forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name (VarEntry rep -> ImpM rep r op ())
-> VarEntry rep -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
ScalarVar Maybe (Exp rep)
forall a. Maybe a
Nothing (ScalarEntry -> VarEntry rep) -> ScalarEntry -> VarEntry rep
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry PrimType
t

-- | Create variable of some provided dynamic type. You'll need this
-- when you are compiling program code of Haskell-level unknown type.
-- For other things, use other functions.
dPrimS :: String -> PrimType -> ImpM rep r op VName
dPrimS :: forall rep r op. [Char] -> PrimType -> ImpM rep r op VName
dPrimS [Char]
name PrimType
t = do
  name' <- [Char] -> ImpM rep r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
name
  dPrim_ name' t
  pure name'

-- | Create 'TV' of some provided dynamic type. No guarantee that the
-- dynamic type matches the inferred type.
dPrimSV :: String -> PrimType -> ImpM rep r op (TV t)
dPrimSV :: forall {k} rep r op (t :: k).
[Char] -> PrimType -> ImpM rep r op (TV t)
dPrimSV [Char]
name PrimType
t = VName -> PrimType -> TV t
forall {k} (t :: k). VName -> PrimType -> TV t
TV (VName -> PrimType -> TV t)
-> ImpM rep r op VName -> ImpM rep r op (PrimType -> TV t)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char] -> PrimType -> ImpM rep r op VName
forall rep r op. [Char] -> PrimType -> ImpM rep r op VName
dPrimS [Char]
name PrimType
t ImpM rep r op (PrimType -> TV t)
-> ImpM rep r op PrimType -> ImpM rep r op (TV t)
forall a b.
ImpM rep r op (a -> b) -> ImpM rep r op a -> ImpM rep r op b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> PrimType -> ImpM rep r op PrimType
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure PrimType
t

-- | Create 'TV' of some fixed type.
dPrim :: (MkTV t) => String -> ImpM rep r op (TV t)
dPrim :: forall {k} (t :: k) rep r op.
MkTV t =>
[Char] -> ImpM rep r op (TV t)
dPrim [Char]
name = do
  name' <- [Char] -> ImpM rep r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
name
  let tv = VName -> TV t
forall {k} (t :: k). MkTV t => VName -> TV t
mkTV VName
name'
  dPrim_ name' $ tvType tv
  pure tv

dPrimV_ :: VName -> Imp.TExp t -> ImpM rep r op ()
dPrimV_ :: forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
name TExp t
e = do
  VName -> PrimType -> ImpM rep r op ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
name PrimType
t
  VName -> PrimType -> TV t
forall {k} (t :: k). VName -> PrimType -> TV t
TV VName
name PrimType
t TV t -> TExp t -> ImpM rep r op ()
forall {k} (t :: k) rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TExp t
e
  where
    t :: PrimType
t = Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType (Exp -> PrimType) -> Exp -> PrimType
forall a b. (a -> b) -> a -> b
$ TExp t -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp t
e

dPrimV :: String -> Imp.TExp t -> ImpM rep r op (TV t)
dPrimV :: forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
name TExp t
e = do
  name' <- [Char] -> PrimType -> ImpM rep r op VName
forall rep r op. [Char] -> PrimType -> ImpM rep r op VName
dPrimS [Char]
name PrimType
pt
  let tv = VName -> PrimType -> TV t
forall {k} (t :: k). VName -> PrimType -> TV t
TV VName
name' PrimType
pt
  tv <-- e
  pure tv
  where
    pt :: PrimType
pt = Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType (Exp -> PrimType) -> Exp -> PrimType
forall a b. (a -> b) -> a -> b
$ TExp t -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp t
e

dPrimVE :: String -> Imp.TExp t -> ImpM rep r op (Imp.TExp t)
dPrimVE :: forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
name TExp t
e = do
  name' <- [Char] -> PrimType -> ImpM rep r op VName
forall rep r op. [Char] -> PrimType -> ImpM rep r op VName
dPrimS [Char]
name PrimType
pt
  let tv = VName -> PrimType -> TV t
forall {k} (t :: k). VName -> PrimType -> TV t
TV VName
name' PrimType
pt
  tv <-- e
  pure $ tvExp tv
  where
    pt :: PrimType
pt = Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType (Exp -> PrimType) -> Exp -> PrimType
forall a b. (a -> b) -> a -> b
$ TExp t -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp t
e

memBoundToVarEntry ::
  Maybe (Exp rep) ->
  MemBound NoUniqueness ->
  VarEntry rep
memBoundToVarEntry :: forall rep. Maybe (Exp rep) -> LParamMem -> VarEntry rep
memBoundToVarEntry Maybe (Exp rep)
e (MemPrim PrimType
bt) =
  Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
ScalarVar Maybe (Exp rep)
e ScalarEntry {entryScalarType :: PrimType
entryScalarType = PrimType
bt}
memBoundToVarEntry Maybe (Exp rep)
e (MemMem Space
space) =
  Maybe (Exp rep) -> MemEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> MemEntry -> VarEntry rep
MemVar Maybe (Exp rep)
e (MemEntry -> VarEntry rep) -> MemEntry -> VarEntry rep
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
memBoundToVarEntry Maybe (Exp rep)
e (MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
_) =
  Maybe (Exp rep) -> (VName, Shape, [Type]) -> VarEntry rep
forall rep.
Maybe (Exp rep) -> (VName, Shape, [Type]) -> VarEntry rep
AccVar Maybe (Exp rep)
e (VName
acc, Shape
ispace, [Type]
ts)
memBoundToVarEntry Maybe (Exp rep)
e (MemArray PrimType
bt Shape
shape NoUniqueness
_ (ArrayIn VName
mem LMAD (TExp Int64)
lmad)) =
  let location :: MemLoc
location = VName -> [SubExp] -> LMAD (TExp Int64) -> MemLoc
MemLoc VName
mem (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape) LMAD (TExp Int64)
lmad
   in Maybe (Exp rep) -> ArrayEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> ArrayEntry -> VarEntry rep
ArrayVar
        Maybe (Exp rep)
e
        ArrayEntry
          { entryArrayLoc :: MemLoc
entryArrayLoc = MemLoc
location,
            entryArrayElemType :: PrimType
entryArrayElemType = PrimType
bt
          }

infoDec ::
  (Mem rep inner) =>
  NameInfo rep ->
  MemInfo SubExp NoUniqueness MemBind
infoDec :: forall rep (inner :: * -> *).
Mem rep inner =>
NameInfo rep -> LParamMem
infoDec (LetName LetDec rep
dec) = LetDec rep -> LParamMem
forall t. HasLetDecMem t => t -> LParamMem
letDecMem LetDec rep
dec
infoDec (FParamName FParamInfo rep
dec) = FParamMem -> LParamMem
forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns FParamInfo rep
FParamMem
dec
infoDec (LParamName LParamInfo rep
dec) = LParamInfo rep
LParamMem
dec
infoDec (IndexName IntType
it) = PrimType -> LParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim (PrimType -> LParamMem) -> PrimType -> LParamMem
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it

dInfo ::
  (Mem rep inner) =>
  Maybe (Exp rep) ->
  VName ->
  NameInfo rep ->
  ImpM rep r op ()
dInfo :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
Maybe (Exp rep) -> VName -> NameInfo rep -> ImpM rep r op ()
dInfo Maybe (Exp rep)
e VName
name NameInfo rep
info = do
  let entry :: VarEntry rep
entry = Maybe (Exp rep) -> LParamMem -> VarEntry rep
forall rep. Maybe (Exp rep) -> LParamMem -> VarEntry rep
memBoundToVarEntry Maybe (Exp rep)
e (LParamMem -> VarEntry rep) -> LParamMem -> VarEntry rep
forall a b. (a -> b) -> a -> b
$ NameInfo rep -> LParamMem
forall rep (inner :: * -> *).
Mem rep inner =>
NameInfo rep -> LParamMem
infoDec NameInfo rep
info
  case VarEntry rep
entry of
    MemVar Maybe (Exp rep)
_ MemEntry
entry' ->
      Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Code op
forall a. VName -> Space -> Code a
Imp.DeclareMem VName
name (Space -> Code op) -> Space -> Code op
forall a b. (a -> b) -> a -> b
$ MemEntry -> Space
entryMemSpace MemEntry
entry'
    ScalarVar Maybe (Exp rep)
_ ScalarEntry
entry' ->
      Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
name Volatility
Imp.Nonvolatile (PrimType -> Code op) -> PrimType -> Code op
forall a b. (a -> b) -> a -> b
$ ScalarEntry -> PrimType
entryScalarType ScalarEntry
entry'
    ArrayVar Maybe (Exp rep)
_ ArrayEntry
_ ->
      () -> ImpM rep r op ()
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    AccVar {} ->
      () -> ImpM rep r op ()
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  VName -> VarEntry rep -> ImpM rep r op ()
forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name VarEntry rep
entry

dScope ::
  (Mem rep inner) =>
  Maybe (Exp rep) ->
  Scope rep ->
  ImpM rep r op ()
dScope :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope Maybe (Exp rep)
e = ((VName, NameInfo rep) -> ImpM rep r op ())
-> [(VName, NameInfo rep)] -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((VName -> NameInfo rep -> ImpM rep r op ())
-> (VName, NameInfo rep) -> ImpM rep r op ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((VName -> NameInfo rep -> ImpM rep r op ())
 -> (VName, NameInfo rep) -> ImpM rep r op ())
-> (VName -> NameInfo rep -> ImpM rep r op ())
-> (VName, NameInfo rep)
-> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp rep) -> VName -> NameInfo rep -> ImpM rep r op ()
forall rep (inner :: * -> *) r op.
Mem rep inner =>
Maybe (Exp rep) -> VName -> NameInfo rep -> ImpM rep r op ()
dInfo Maybe (Exp rep)
e) ([(VName, NameInfo rep)] -> ImpM rep r op ())
-> (Scope rep -> [(VName, NameInfo rep)])
-> Scope rep
-> ImpM rep r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope rep -> [(VName, NameInfo rep)]
forall k a. Map k a -> [(k, a)]
M.toList

dArray :: VName -> PrimType -> ShapeBase SubExp -> VName -> LMAD -> ImpM rep r op ()
dArray :: forall rep r op.
VName
-> PrimType
-> Shape
-> VName
-> LMAD (TExp Int64)
-> ImpM rep r op ()
dArray VName
name PrimType
pt Shape
shape VName
mem LMAD (TExp Int64)
lmad =
  VName -> VarEntry rep -> ImpM rep r op ()
forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name (VarEntry rep -> ImpM rep r op ())
-> VarEntry rep -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp rep) -> ArrayEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> ArrayEntry -> VarEntry rep
ArrayVar Maybe (Exp rep)
forall a. Maybe a
Nothing (ArrayEntry -> VarEntry rep) -> ArrayEntry -> VarEntry rep
forall a b. (a -> b) -> a -> b
$ MemLoc -> PrimType -> ArrayEntry
ArrayEntry MemLoc
location PrimType
pt
  where
    location :: MemLoc
location = VName -> [SubExp] -> LMAD (TExp Int64) -> MemLoc
MemLoc VName
mem (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape) LMAD (TExp Int64)
lmad

everythingVolatile :: ImpM rep r op a -> ImpM rep r op a
everythingVolatile :: forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile = (Env rep r op -> Env rep r op)
-> ImpM rep r op a -> ImpM rep r op a
forall a.
(Env rep r op -> Env rep r op)
-> ImpM rep r op a -> ImpM rep r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env rep r op -> Env rep r op)
 -> ImpM rep r op a -> ImpM rep r op a)
-> (Env rep r op -> Env rep r op)
-> ImpM rep r op a
-> ImpM rep r op a
forall a b. (a -> b) -> a -> b
$ \Env rep r op
env -> Env rep r op
env {envVolatility = Imp.Volatile}

funcallTargets :: [ValueDestination] -> ImpM rep r op [VName]
funcallTargets :: forall rep r op. [ValueDestination] -> ImpM rep r op [VName]
funcallTargets [ValueDestination]
dests =
  [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[VName]] -> [VName])
-> ImpM rep r op [[VName]] -> ImpM rep r op [VName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ValueDestination -> ImpM rep r op [VName])
-> [ValueDestination] -> ImpM rep r op [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ValueDestination -> ImpM rep r op [VName]
forall {f :: * -> *}.
Applicative f =>
ValueDestination -> f [VName]
funcallTarget [ValueDestination]
dests
  where
    funcallTarget :: ValueDestination -> f [VName]
funcallTarget (ScalarDestination VName
name) =
      [VName] -> f [VName]
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName
name]
    funcallTarget (ArrayDestination Maybe MemLoc
_) =
      [VName] -> f [VName]
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
    funcallTarget (MemoryDestination VName
name) =
      [VName] -> f [VName]
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName
name]

-- | A typed variable, which we can turn into a typed expression, or
-- use as the target for an assignment.  This is used to aid in type
-- safety when doing code generation, by keeping the types straight.
-- It is still easy to cheat when you need to.
data TV t = TV VName PrimType

-- | A type class that helps ensuring that the type annotation in a
-- 'TV' is correct.
class MkTV t where
  -- | Create a typed variable from a name and a dynamic type.
  mkTV :: VName -> TV t

  -- | Extract type from a 'TV'.
  tvType :: TV t -> PrimType

instance MkTV Bool where
  mkTV :: VName -> TV Bool
mkTV VName
v = VName -> PrimType -> TV Bool
forall {k} (t :: k). VName -> PrimType -> TV t
TV VName
v PrimType
Bool
  tvType :: TV Bool -> PrimType
tvType TV Bool
_ = PrimType
Bool

instance MkTV Int8 where
  mkTV :: VName -> TV Int8
mkTV VName
v = VName -> PrimType -> TV Int8
forall {k} (t :: k). VName -> PrimType -> TV t
TV VName
v (IntType -> PrimType
IntType IntType
Int8)
  tvType :: TV Int8 -> PrimType
tvType TV Int8
_ = IntType -> PrimType
IntType IntType
Int8

instance MkTV Int16 where
  mkTV :: VName -> TV Int16
mkTV VName
v = VName -> PrimType -> TV Int16
forall {k} (t :: k). VName -> PrimType -> TV t
TV VName
v (IntType -> PrimType
IntType IntType
Int16)
  tvType :: TV Int16 -> PrimType
tvType TV Int16
_ = IntType -> PrimType
IntType IntType
Int16

instance MkTV Int32 where
  mkTV :: VName -> TV Int32
mkTV VName
v = VName -> PrimType -> TV Int32
forall {k} (t :: k). VName -> PrimType -> TV t
TV VName
v (IntType -> PrimType
IntType IntType
Int32)
  tvType :: TV Int32 -> PrimType
tvType TV Int32
_ = IntType -> PrimType
IntType IntType
Int32

instance MkTV Int64 where
  mkTV :: VName -> TV Int64
mkTV VName
v = VName -> PrimType -> TV Int64
forall {k} (t :: k). VName -> PrimType -> TV t
TV VName
v (IntType -> PrimType
IntType IntType
Int64)
  tvType :: TV Int64 -> PrimType
tvType TV Int64
_ = IntType -> PrimType
IntType IntType
Int64

instance MkTV Half where
  mkTV :: VName -> TV Half
mkTV VName
v = VName -> PrimType -> TV Half
forall {k} (t :: k). VName -> PrimType -> TV t
TV VName
v (FloatType -> PrimType
FloatType FloatType
Float16)
  tvType :: TV Half -> PrimType
tvType TV Half
_ = FloatType -> PrimType
FloatType FloatType
Float16

instance MkTV Float where
  mkTV :: VName -> TV Float
mkTV VName
v = VName -> PrimType -> TV Float
forall {k} (t :: k). VName -> PrimType -> TV t
TV VName
v (FloatType -> PrimType
FloatType FloatType
Float32)
  tvType :: TV Float -> PrimType
tvType TV Float
_ = FloatType -> PrimType
FloatType FloatType
Float32

instance MkTV Double where
  mkTV :: VName -> TV Double
mkTV VName
v = VName -> PrimType -> TV Double
forall {k} (t :: k). VName -> PrimType -> TV t
TV VName
v (FloatType -> PrimType
FloatType FloatType
Float64)
  tvType :: TV Double -> PrimType
tvType TV Double
_ = FloatType -> PrimType
FloatType FloatType
Float64

-- | Convert a typed variable to a size (a SubExp).
tvSize :: TV t -> Imp.DimSize
tvSize :: forall {k} (t :: k). TV t -> SubExp
tvSize = VName -> SubExp
Var (VName -> SubExp) -> (TV t -> VName) -> TV t -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TV t -> VName
forall {k} (t :: k). TV t -> VName
tvVar

-- | Convert a typed variable to a similarly typed expression.
tvExp :: TV t -> Imp.TExp t
tvExp :: forall {k} (t :: k). TV t -> TExp t
tvExp (TV VName
v PrimType
t) = Exp -> TPrimExp t VName
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
Imp.TPrimExp (Exp -> TPrimExp t VName) -> Exp -> TPrimExp t VName
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
v PrimType
t

-- | Extract the underlying variable name from a typed variable.
tvVar :: TV t -> VName
tvVar :: forall {k} (t :: k). TV t -> VName
tvVar (TV VName
v PrimType
_) = VName
v

-- | Compile things to 'Imp.Exp'.
class ToExp a where
  -- | Compile to an 'Imp.Exp', where the type (which must still be a
  -- primitive) is deduced monadically.
  toExp :: a -> ImpM rep r op Imp.Exp

  -- | Compile where we know the type in advance.
  toExp' :: PrimType -> a -> Imp.Exp

instance ToExp SubExp where
  toExp :: forall rep r op. SubExp -> ImpM rep r op Exp
toExp (Constant PrimValue
v) =
    Exp -> ImpM rep r op Exp
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> ImpM rep r op Exp) -> Exp -> ImpM rep r op Exp
forall a b. (a -> b) -> a -> b
$ PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
  toExp (Var VName
v) =
    VName -> ImpM rep r op (VarEntry rep)
forall rep r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
v ImpM rep r op (VarEntry rep)
-> (VarEntry rep -> ImpM rep r op Exp) -> ImpM rep r op Exp
forall a b.
ImpM rep r op a -> (a -> ImpM rep r op b) -> ImpM rep r op b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      ScalarVar Maybe (Exp rep)
_ (ScalarEntry PrimType
pt) ->
        Exp -> ImpM rep r op Exp
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> ImpM rep r op Exp) -> Exp -> ImpM rep r op Exp
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
v PrimType
pt
      VarEntry rep
_ -> [Char] -> ImpM rep r op Exp
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op Exp) -> [Char] -> ImpM rep r op Exp
forall a b. (a -> b) -> a -> b
$ [Char]
"toExp SubExp: SubExp is not a primitive type: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString VName
v

  toExp' :: PrimType -> SubExp -> Exp
toExp' PrimType
_ (Constant PrimValue
v) = PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
  toExp' PrimType
t (Var VName
v) = VName -> PrimType -> Exp
Imp.var VName
v PrimType
t

instance ToExp VName where
  toExp :: forall rep r op. VName -> ImpM rep r op Exp
toExp = SubExp -> ImpM rep r op Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
forall rep r op. SubExp -> ImpM rep r op Exp
toExp (SubExp -> ImpM rep r op Exp)
-> (VName -> SubExp) -> VName -> ImpM rep r op Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var
  toExp' :: PrimType -> VName -> Exp
toExp' PrimType
t = PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
t (SubExp -> Exp) -> (VName -> SubExp) -> VName -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var

instance ToExp (PrimExp VName) where
  toExp :: forall rep r op. Exp -> ImpM rep r op Exp
toExp = Exp -> ImpM rep r op Exp
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
  toExp' :: PrimType -> Exp -> Exp
toExp' PrimType
_ = Exp -> Exp
forall a. a -> a
id

addVar :: VName -> VarEntry rep -> ImpM rep r op ()
addVar :: forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name VarEntry rep
entry =
  (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ())
-> (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateVTable = M.insert name entry $ stateVTable s}

localDefaultSpace :: Imp.Space -> ImpM rep r op a -> ImpM rep r op a
localDefaultSpace :: forall rep r op a. Space -> ImpM rep r op a -> ImpM rep r op a
localDefaultSpace Space
space = (Env rep r op -> Env rep r op)
-> ImpM rep r op a -> ImpM rep r op a
forall a.
(Env rep r op -> Env rep r op)
-> ImpM rep r op a -> ImpM rep r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (\Env rep r op
env -> Env rep r op
env {envDefaultSpace = space})

askFunction :: ImpM rep r op (Maybe Name)
askFunction :: forall rep r op. ImpM rep r op (Maybe Name)
askFunction = (Env rep r op -> Maybe Name) -> ImpM rep r op (Maybe Name)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> Maybe Name
forall rep r op. Env rep r op -> Maybe Name
envFunction

-- | Generate a 'VName', prefixed with 'askFunction' if it exists.
newVNameForFun :: String -> ImpM rep r op VName
newVNameForFun :: forall rep r op. [Char] -> ImpM rep r op VName
newVNameForFun [Char]
s = do
  fname <- (Name -> [Char]) -> Maybe Name -> Maybe [Char]
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Name -> [Char]
nameToString (Maybe Name -> Maybe [Char])
-> ImpM rep r op (Maybe Name) -> ImpM rep r op (Maybe [Char])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM rep r op (Maybe Name)
forall rep r op. ImpM rep r op (Maybe Name)
askFunction
  newVName $ maybe "" (++ ".") fname ++ s

-- | Generate a 'Name', prefixed with 'askFunction' if it exists.
nameForFun :: String -> ImpM rep r op Name
nameForFun :: forall rep r op. [Char] -> ImpM rep r op Name
nameForFun [Char]
s = do
  fname <- ImpM rep r op (Maybe Name)
forall rep r op. ImpM rep r op (Maybe Name)
askFunction
  pure $ maybe "" (<> ".") fname <> nameFromString s

askEnv :: ImpM rep r op r
askEnv :: forall rep r op. ImpM rep r op r
askEnv = (Env rep r op -> r) -> ImpM rep r op r
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> r
forall rep r op. Env rep r op -> r
envEnv

localEnv :: (r -> r) -> ImpM rep r op a -> ImpM rep r op a
localEnv :: forall r rep op a. (r -> r) -> ImpM rep r op a -> ImpM rep r op a
localEnv r -> r
f = (Env rep r op -> Env rep r op)
-> ImpM rep r op a -> ImpM rep r op a
forall a.
(Env rep r op -> Env rep r op)
-> ImpM rep r op a -> ImpM rep r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env rep r op -> Env rep r op)
 -> ImpM rep r op a -> ImpM rep r op a)
-> (Env rep r op -> Env rep r op)
-> ImpM rep r op a
-> ImpM rep r op a
forall a b. (a -> b) -> a -> b
$ \Env rep r op
env -> Env rep r op
env {envEnv = f $ envEnv env}

-- | The active attributes, including those for the statement
-- currently being compiled.
askAttrs :: ImpM rep r op Attrs
askAttrs :: forall rep r op. ImpM rep r op Attrs
askAttrs = (Env rep r op -> Attrs) -> ImpM rep r op Attrs
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> Attrs
forall rep r op. Env rep r op -> Attrs
envAttrs

-- | Add more attributes to what is returning by 'askAttrs'.
localAttrs :: Attrs -> ImpM rep r op a -> ImpM rep r op a
localAttrs :: forall rep r op a. Attrs -> ImpM rep r op a -> ImpM rep r op a
localAttrs Attrs
attrs = (Env rep r op -> Env rep r op)
-> ImpM rep r op a -> ImpM rep r op a
forall a.
(Env rep r op -> Env rep r op)
-> ImpM rep r op a -> ImpM rep r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env rep r op -> Env rep r op)
 -> ImpM rep r op a -> ImpM rep r op a)
-> (Env rep r op -> Env rep r op)
-> ImpM rep r op a
-> ImpM rep r op a
forall a b. (a -> b) -> a -> b
$ \Env rep r op
env -> Env rep r op
env {envAttrs = attrs <> envAttrs env}

-- | The provenance of whatever we are currently generating code for.
askProvenance :: ImpM rep r op Provenance
askProvenance :: forall rep r op. ImpM rep r op Provenance
askProvenance = (Env rep r op -> Provenance) -> ImpM rep r op Provenance
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> Provenance
forall rep r op. Env rep r op -> Provenance
envProvenance

-- | Wrap any code emitted in the enclosed section with the current provenance,
-- if any.
withProvenance :: ImpM rep r op () -> ImpM rep r op ()
withProvenance :: forall rep r op. ImpM rep r op () -> ImpM rep r op ()
withProvenance ImpM rep r op ()
m = do
  p <- ImpM rep r op Provenance
forall rep r op. ImpM rep r op Provenance
askProvenance
  if p == mempty
    then m
    else do
      c <- collect m
      emit $ Imp.Meta (Imp.MetaProvenance p) <> c

-- | Replace (*not* extend) the provenance while executing some action.
localProvenance :: Provenance -> ImpM rep r op a -> ImpM rep r op a
localProvenance :: forall rep r op a. Provenance -> ImpM rep r op a -> ImpM rep r op a
localProvenance Provenance
p = (Env rep r op -> Env rep r op)
-> ImpM rep r op a -> ImpM rep r op a
forall a.
(Env rep r op -> Env rep r op)
-> ImpM rep r op a -> ImpM rep r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env rep r op -> Env rep r op)
 -> ImpM rep r op a -> ImpM rep r op a)
-> (Env rep r op -> Env rep r op)
-> ImpM rep r op a
-> ImpM rep r op a
forall a b. (a -> b) -> a -> b
$ \Env rep r op
env -> Env rep r op
env {envProvenance = p}

localOps :: Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps :: forall rep r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations rep r op
ops = (Env rep r op -> Env rep r op)
-> ImpM rep r op a -> ImpM rep r op a
forall a.
(Env rep r op -> Env rep r op)
-> ImpM rep r op a -> ImpM rep r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env rep r op -> Env rep r op)
 -> ImpM rep r op a -> ImpM rep r op a)
-> (Env rep r op -> Env rep r op)
-> ImpM rep r op a
-> ImpM rep r op a
forall a b. (a -> b) -> a -> b
$ \Env rep r op
env ->
  Env rep r op
env
    { envExpCompiler = opsExpCompiler ops,
      envStmsCompiler = opsStmsCompiler ops,
      envCopyCompiler = opsCopyCompiler ops,
      envOpCompiler = opsOpCompiler ops,
      envAllocCompilers = opsAllocCompilers ops
    }

-- | Get the current symbol table.
getVTable :: ImpM rep r op (VTable rep)
getVTable :: forall rep r op. ImpM rep r op (VTable rep)
getVTable = (ImpState rep r op -> VTable rep) -> ImpM rep r op (VTable rep)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState rep r op -> VTable rep
forall {k} rep (r :: k) op. ImpState rep r op -> VTable rep
stateVTable

putVTable :: VTable rep -> ImpM rep r op ()
putVTable :: forall rep r op. VTable rep -> ImpM rep r op ()
putVTable VTable rep
vtable = (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ())
-> (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateVTable = vtable}

-- | Run an action with a modified symbol table.  All changes to the
-- symbol table will be reverted once the action is done!
localVTable :: (VTable rep -> VTable rep) -> ImpM rep r op a -> ImpM rep r op a
localVTable :: forall rep r op a.
(VTable rep -> VTable rep) -> ImpM rep r op a -> ImpM rep r op a
localVTable VTable rep -> VTable rep
f ImpM rep r op a
m = do
  old_vtable <- ImpM rep r op (VTable rep)
forall rep r op. ImpM rep r op (VTable rep)
getVTable
  putVTable $ f old_vtable
  a <- m
  putVTable old_vtable
  pure a

lookupVar :: VName -> ImpM rep r op (VarEntry rep)
lookupVar :: forall rep r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
name = do
  res <- (ImpState rep r op -> Maybe (VarEntry rep))
-> ImpM rep r op (Maybe (VarEntry rep))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ImpState rep r op -> Maybe (VarEntry rep))
 -> ImpM rep r op (Maybe (VarEntry rep)))
-> (ImpState rep r op -> Maybe (VarEntry rep))
-> ImpM rep r op (Maybe (VarEntry rep))
forall a b. (a -> b) -> a -> b
$ VName -> Map VName (VarEntry rep) -> Maybe (VarEntry rep)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name (Map VName (VarEntry rep) -> Maybe (VarEntry rep))
-> (ImpState rep r op -> Map VName (VarEntry rep))
-> ImpState rep r op
-> Maybe (VarEntry rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpState rep r op -> Map VName (VarEntry rep)
forall {k} rep (r :: k) op. ImpState rep r op -> VTable rep
stateVTable
  case res of
    Just VarEntry rep
entry -> VarEntry rep -> ImpM rep r op (VarEntry rep)
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VarEntry rep
entry
    Maybe (VarEntry rep)
_ -> [Char] -> ImpM rep r op (VarEntry rep)
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op (VarEntry rep))
-> [Char] -> ImpM rep r op (VarEntry rep)
forall a b. (a -> b) -> a -> b
$ [Char]
"Unknown variable: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString VName
name

lookupArray :: VName -> ImpM rep r op ArrayEntry
lookupArray :: forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
name = do
  res <- VName -> ImpM rep r op (VarEntry rep)
forall rep r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
name
  case res of
    ArrayVar Maybe (Exp rep)
_ ArrayEntry
entry -> ArrayEntry -> ImpM rep r op ArrayEntry
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ArrayEntry
entry
    VarEntry rep
_ -> [Char] -> ImpM rep r op ArrayEntry
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op ArrayEntry)
-> [Char] -> ImpM rep r op ArrayEntry
forall a b. (a -> b) -> a -> b
$ [Char]
"ImpGen.lookupArray: not an array: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString VName
name

lookupMemory :: VName -> ImpM rep r op MemEntry
lookupMemory :: forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory VName
name = do
  res <- VName -> ImpM rep r op (VarEntry rep)
forall rep r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
name
  case res of
    MemVar Maybe (Exp rep)
_ MemEntry
entry -> MemEntry -> ImpM rep r op MemEntry
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure MemEntry
entry
    VarEntry rep
_ -> [Char] -> ImpM rep r op MemEntry
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op MemEntry)
-> [Char] -> ImpM rep r op MemEntry
forall a b. (a -> b) -> a -> b
$ [Char]
"Unknown memory block: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString VName
name

-- | In which memory space is this array allocated?
lookupArraySpace :: VName -> ImpM rep r op Space
lookupArraySpace :: forall rep r op. VName -> ImpM rep r op Space
lookupArraySpace =
  (MemEntry -> Space)
-> ImpM rep r op MemEntry -> ImpM rep r op Space
forall a b. (a -> b) -> ImpM rep r op a -> ImpM rep r op b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap MemEntry -> Space
entryMemSpace (ImpM rep r op MemEntry -> ImpM rep r op Space)
-> (VName -> ImpM rep r op MemEntry)
-> VName
-> ImpM rep r op Space
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> ImpM rep r op MemEntry
forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory
    (VName -> ImpM rep r op Space)
-> (VName -> ImpM rep r op VName) -> VName -> ImpM rep r op Space
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< (ArrayEntry -> VName)
-> ImpM rep r op ArrayEntry -> ImpM rep r op VName
forall a b. (a -> b) -> ImpM rep r op a -> ImpM rep r op b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (MemLoc -> VName
memLocName (MemLoc -> VName) -> (ArrayEntry -> MemLoc) -> ArrayEntry -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayEntry -> MemLoc
entryArrayLoc) (ImpM rep r op ArrayEntry -> ImpM rep r op VName)
-> (VName -> ImpM rep r op ArrayEntry)
-> VName
-> ImpM rep r op VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> ImpM rep r op ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray

-- | In the case of a histogram-like accumulator, also sets the index
-- parameters.
lookupAcc ::
  (Mem rep inner) =>
  VName ->
  [Imp.TExp Int64] ->
  ImpM rep r op (VName, Space, [VName], [Imp.TExp Int64], Maybe (Lambda rep))
lookupAcc :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
VName
-> [TExp Int64]
-> ImpM
     rep r op (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep))
lookupAcc VName
name [TExp Int64]
is = do
  res <- VName -> ImpM rep r op (VarEntry rep)
forall rep r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
name
  case res of
    AccVar Maybe (Exp rep)
_ (VName
acc, Shape
ispace, [Type]
_) -> do
      acc' <- (ImpState rep r op
 -> Maybe ([VName], Maybe (Lambda rep, [SubExp])))
-> ImpM rep r op (Maybe ([VName], Maybe (Lambda rep, [SubExp])))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ImpState rep r op
  -> Maybe ([VName], Maybe (Lambda rep, [SubExp])))
 -> ImpM rep r op (Maybe ([VName], Maybe (Lambda rep, [SubExp]))))
-> (ImpState rep r op
    -> Maybe ([VName], Maybe (Lambda rep, [SubExp])))
-> ImpM rep r op (Maybe ([VName], Maybe (Lambda rep, [SubExp])))
forall a b. (a -> b) -> a -> b
$ VName
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
-> Maybe ([VName], Maybe (Lambda rep, [SubExp]))
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
acc (Map VName ([VName], Maybe (Lambda rep, [SubExp]))
 -> Maybe ([VName], Maybe (Lambda rep, [SubExp])))
-> (ImpState rep r op
    -> Map VName ([VName], Maybe (Lambda rep, [SubExp])))
-> ImpState rep r op
-> Maybe ([VName], Maybe (Lambda rep, [SubExp]))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpState rep r op
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
forall {k} rep (r :: k) op.
ImpState rep r op
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
stateAccs
      case acc' of
        Just ([], Maybe (Lambda rep, [SubExp])
_) ->
          [Char]
-> ImpM
     rep r op (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep))
forall a. HasCallStack => [Char] -> a
error ([Char]
 -> ImpM
      rep r op (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep)))
-> [Char]
-> ImpM
     rep r op (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep))
forall a b. (a -> b) -> a -> b
$ [Char]
"Accumulator with no arrays: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString VName
name
        Just (arrs :: [VName]
arrs@(VName
arr : [VName]
_), Just (Lambda rep
op_orig, [SubExp]
_)) -> do
          space <- VName -> ImpM rep r op Space
forall rep r op. VName -> ImpM rep r op Space
lookupArraySpace VName
arr
          -- We must rename the lambda in order to avoid duplicate names in the
          -- likely case where the accumulator is used multiple times.
          op <- renameLambda op_orig
          let (i_params, ps) = splitAt (length is) $ lambdaParams op
          zipWithM_ dPrimV_ (map paramName i_params) is
          pure
            ( acc,
              space,
              arrs,
              map pe64 (shapeDims ispace),
              Just op {lambdaParams = ps}
            )
        Just (arrs :: [VName]
arrs@(VName
arr : [VName]
_), Maybe (Lambda rep, [SubExp])
Nothing) -> do
          space <- VName -> ImpM rep r op Space
forall rep r op. VName -> ImpM rep r op Space
lookupArraySpace VName
arr
          pure (acc, space, arrs, map pe64 (shapeDims ispace), Nothing)
        Maybe ([VName], Maybe (Lambda rep, [SubExp]))
Nothing ->
          [Char]
-> ImpM
     rep r op (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep))
forall a. HasCallStack => [Char] -> a
error ([Char]
 -> ImpM
      rep r op (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep)))
-> [Char]
-> ImpM
     rep r op (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep))
forall a b. (a -> b) -> a -> b
$ [Char]
"ImpGen.lookupAcc: unlisted accumulator: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString VName
name
    VarEntry rep
_ -> [Char]
-> ImpM
     rep r op (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep))
forall a. HasCallStack => [Char] -> a
error ([Char]
 -> ImpM
      rep r op (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep)))
-> [Char]
-> ImpM
     rep r op (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep))
forall a b. (a -> b) -> a -> b
$ [Char]
"ImpGen.lookupAcc: not an accumulator: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString VName
name

destinationFromPat :: Pat (LetDec rep) -> ImpM rep r op [ValueDestination]
destinationFromPat :: forall rep r op.
Pat (LetDec rep) -> ImpM rep r op [ValueDestination]
destinationFromPat = (PatElem (LetDec rep) -> ImpM rep r op ValueDestination)
-> [PatElem (LetDec rep)] -> ImpM rep r op [ValueDestination]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM PatElem (LetDec rep) -> ImpM rep r op ValueDestination
forall {dec} {rep} {r} {op}.
PatElem dec -> ImpM rep r op ValueDestination
inspect ([PatElem (LetDec rep)] -> ImpM rep r op [ValueDestination])
-> (Pat (LetDec rep) -> [PatElem (LetDec rep)])
-> Pat (LetDec rep)
-> ImpM rep r op [ValueDestination]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec rep) -> [PatElem (LetDec rep)]
forall dec. Pat dec -> [PatElem dec]
patElems
  where
    inspect :: PatElem dec -> ImpM rep r op ValueDestination
inspect PatElem dec
pe = do
      let name :: VName
name = PatElem dec -> VName
forall dec. PatElem dec -> VName
patElemName PatElem dec
pe
      entry <- VName -> ImpM rep r op (VarEntry rep)
forall rep r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
name
      case entry of
        ArrayVar Maybe (Exp rep)
_ (ArrayEntry MemLoc {} PrimType
_) ->
          ValueDestination -> ImpM rep r op ValueDestination
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ValueDestination -> ImpM rep r op ValueDestination)
-> ValueDestination -> ImpM rep r op ValueDestination
forall a b. (a -> b) -> a -> b
$ Maybe MemLoc -> ValueDestination
ArrayDestination Maybe MemLoc
forall a. Maybe a
Nothing
        MemVar {} ->
          ValueDestination -> ImpM rep r op ValueDestination
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ValueDestination -> ImpM rep r op ValueDestination)
-> ValueDestination -> ImpM rep r op ValueDestination
forall a b. (a -> b) -> a -> b
$ VName -> ValueDestination
MemoryDestination VName
name
        ScalarVar {} ->
          ValueDestination -> ImpM rep r op ValueDestination
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ValueDestination -> ImpM rep r op ValueDestination)
-> ValueDestination -> ImpM rep r op ValueDestination
forall a b. (a -> b) -> a -> b
$ VName -> ValueDestination
ScalarDestination VName
name
        AccVar {} ->
          ValueDestination -> ImpM rep r op ValueDestination
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ValueDestination -> ImpM rep r op ValueDestination)
-> ValueDestination -> ImpM rep r op ValueDestination
forall a b. (a -> b) -> a -> b
$ Maybe MemLoc -> ValueDestination
ArrayDestination Maybe MemLoc
forall a. Maybe a
Nothing

fullyIndexArray ::
  VName ->
  [Imp.TExp Int64] ->
  ImpM rep r op (VName, Imp.Space, Count Elements (Imp.TExp Int64))
fullyIndexArray :: forall rep r op.
VName
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray VName
name [TExp Int64]
indices = do
  arr <- VName -> ImpM rep r op ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
name
  fullyIndexArray' (entryArrayLoc arr) indices

fullyIndexArray' ::
  MemLoc ->
  [Imp.TExp Int64] ->
  ImpM rep r op (VName, Imp.Space, Count Elements (Imp.TExp Int64))
fullyIndexArray' :: forall rep r op.
MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' (MemLoc VName
mem [SubExp]
_ LMAD (TExp Int64)
lmad) [TExp Int64]
indices = do
  space <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM rep r op MemEntry -> ImpM rep r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM rep r op MemEntry
forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory VName
mem
  pure
    ( mem,
      space,
      elements $ LMAD.index lmad indices
    )

-- More complicated read/write operations that use index functions.

copy :: CopyCompiler rep r op
copy :: forall rep r op. CopyCompiler rep r op
copy
  PrimType
bt
  dst :: MemLoc
dst@(MemLoc VName
dst_name [SubExp]
_ dst_ixfn :: LMAD (TExp Int64)
dst_ixfn@LMAD (TExp Int64)
dst_lmad)
  src :: MemLoc
src@(MemLoc VName
src_name [SubExp]
_ src_ixfn :: LMAD (TExp Int64)
src_ixfn@LMAD (TExp Int64)
src_lmad) = do
    -- If we can statically determine that the two index-functions
    -- are equivalent, don't do anything
    Bool -> ImpM rep r op () -> ImpM rep r op ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (VName
dst_name VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
src_name Bool -> Bool -> Bool
&& LMAD (TExp Int64)
dst_ixfn LMAD (TExp Int64) -> LMAD (TExp Int64) -> Bool
forall num. Eq num => LMAD num -> LMAD num -> Bool
`LMAD.equivalent` LMAD (TExp Int64)
src_ixfn)
      (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
      -- It's also possible that we can dynamically determine that the two
      -- index-functions are equivalent.
      TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless
        ( Bool -> TExp Bool
forall v. Bool -> TPrimExp Bool v
fromBool (VName
dst_name VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
src_name)
            TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. LMAD (TExp Int64) -> LMAD (TExp Int64) -> TExp Bool
forall {k} num (t :: k).
Eq num =>
LMAD (TPrimExp t num) -> LMAD (TPrimExp t num) -> TPrimExp Bool num
LMAD.dynamicEqualsLMAD LMAD (TExp Int64)
dst_lmad LMAD (TExp Int64)
src_lmad
        )
      (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ do
        -- If none of the above is true, actually do the copy
        cc <- (Env rep r op -> CopyCompiler rep r op)
-> ImpM rep r op (CopyCompiler rep r op)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> CopyCompiler rep r op
forall rep r op. Env rep r op -> CopyCompiler rep r op
envCopyCompiler
        cc bt dst src

lmadCopy :: CopyCompiler rep r op
lmadCopy :: forall rep r op. CopyCompiler rep r op
lmadCopy PrimType
t MemLoc
dstloc MemLoc
srcloc = do
  let dstmem :: VName
dstmem = MemLoc -> VName
memLocName MemLoc
dstloc
      srcmem :: VName
srcmem = MemLoc -> VName
memLocName MemLoc
srcloc
      dstlmad :: LMAD (TExp Int64)
dstlmad = MemLoc -> LMAD (TExp Int64)
memLocLMAD MemLoc
dstloc
      srclmad :: LMAD (TExp Int64)
srclmad = MemLoc -> LMAD (TExp Int64)
memLocLMAD MemLoc
srcloc
  srcspace <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM rep r op MemEntry -> ImpM rep r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM rep r op MemEntry
forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory VName
srcmem
  dstspace <- entryMemSpace <$> lookupMemory dstmem
  withProvenance . emit $
    Imp.Copy
      t
      (elements <$> LMAD.shape dstlmad)
      (dstmem, dstspace)
      ( LMAD.offset $ elements <$> dstlmad,
        map LMAD.ldStride $ LMAD.dims $ elements <$> dstlmad
      )
      (srcmem, srcspace)
      ( LMAD.offset $ elements <$> srclmad,
        map LMAD.ldStride $ LMAD.dims $ elements <$> srclmad
      )

-- | Copy from here to there; both destination and source may be
-- indexeded.
copyArrayDWIM ::
  PrimType ->
  MemLoc ->
  [DimIndex (Imp.TExp Int64)] ->
  MemLoc ->
  [DimIndex (Imp.TExp Int64)] ->
  ImpM rep r op (Imp.Code op)
copyArrayDWIM :: forall rep r op.
PrimType
-> MemLoc
-> [DimIndex (TExp Int64)]
-> MemLoc
-> [DimIndex (TExp Int64)]
-> ImpM rep r op (Code op)
copyArrayDWIM
  PrimType
bt
  destlocation :: MemLoc
destlocation@(MemLoc VName
_ [SubExp]
destshape LMAD (TExp Int64)
_)
  [DimIndex (TExp Int64)]
destslice
  srclocation :: MemLoc
srclocation@(MemLoc VName
_ [SubExp]
srcshape LMAD (TExp Int64)
_)
  [DimIndex (TExp Int64)]
srcslice
    | Just [TExp Int64]
destis <- (DimIndex (TExp Int64) -> Maybe (TExp Int64))
-> [DimIndex (TExp Int64)] -> Maybe [TExp Int64]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM DimIndex (TExp Int64) -> Maybe (TExp Int64)
forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
destslice,
      Just [TExp Int64]
srcis <- (DimIndex (TExp Int64) -> Maybe (TExp Int64))
-> [DimIndex (TExp Int64)] -> Maybe [TExp Int64]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM DimIndex (TExp Int64) -> Maybe (TExp Int64)
forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
srcslice,
      [TExp Int64] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TExp Int64]
srcis Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
srcshape,
      [TExp Int64] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TExp Int64]
destis Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
destshape = do
        (targetmem, destspace, targetoffset) <-
          MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
forall rep r op.
MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLoc
destlocation [TExp Int64]
destis
        (srcmem, srcspace, srcoffset) <-
          fullyIndexArray' srclocation srcis
        vol <- asks envVolatility
        collect $ do
          tmp <- dPrimS "tmp" bt
          emit $ Imp.Read tmp srcmem srcoffset bt srcspace vol
          emit $ Imp.Write targetmem targetoffset bt destspace vol $ Imp.var tmp bt
    | Bool
otherwise = do
        let destslice' :: Slice (TExp Int64)
destslice' = [TExp Int64] -> [DimIndex (TExp Int64)] -> Slice (TExp Int64)
forall d. Num d => [d] -> [DimIndex d] -> Slice d
fullSliceNum ((SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
destshape) [DimIndex (TExp Int64)]
destslice
            srcslice' :: Slice (TExp Int64)
srcslice' = [TExp Int64] -> [DimIndex (TExp Int64)] -> Slice (TExp Int64)
forall d. Num d => [d] -> [DimIndex d] -> Slice d
fullSliceNum ((SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
srcshape) [DimIndex (TExp Int64)]
srcslice
            destrank :: Int
destrank = [TExp Int64] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([TExp Int64] -> Int) -> [TExp Int64] -> Int
forall a b. (a -> b) -> a -> b
$ Slice (TExp Int64) -> [TExp Int64]
forall d. Slice d -> [d]
sliceDims Slice (TExp Int64)
destslice'
            srcrank :: Int
srcrank = [TExp Int64] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([TExp Int64] -> Int) -> [TExp Int64] -> Int
forall a b. (a -> b) -> a -> b
$ Slice (TExp Int64) -> [TExp Int64]
forall d. Slice d -> [d]
sliceDims Slice (TExp Int64)
srcslice'
            destlocation' :: MemLoc
destlocation' = MemLoc -> Slice (TExp Int64) -> MemLoc
sliceMemLoc MemLoc
destlocation Slice (TExp Int64)
destslice'
            srclocation' :: MemLoc
srclocation' = MemLoc -> Slice (TExp Int64) -> MemLoc
sliceMemLoc MemLoc
srclocation Slice (TExp Int64)
srcslice'
        if Int
destrank Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
srcrank
          then
            [Char] -> ImpM rep r op (Code op)
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op (Code op))
-> [Char] -> ImpM rep r op (Code op)
forall a b. (a -> b) -> a -> b
$
              [Char]
"copyArrayDWIM: cannot copy to "
                [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString (MemLoc -> VName
memLocName MemLoc
destlocation)
                [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
" from "
                [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString (MemLoc -> VName
memLocName MemLoc
srclocation)
                [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
" because ranks do not match ("
                [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Pretty a => a -> [Char]
prettyString Int
destrank
                [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
" vs "
                [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Pretty a => a -> [Char]
prettyString Int
srcrank
                [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
")"
          else
            if MemLoc
destlocation' MemLoc -> MemLoc -> Bool
forall a. Eq a => a -> a -> Bool
== MemLoc
srclocation'
              then Code op -> ImpM rep r op (Code op)
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Code op
forall a. Monoid a => a
mempty -- Copy would be no-op.
              else ImpM rep r op () -> ImpM rep r op (Code op)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect (ImpM rep r op () -> ImpM rep r op (Code op))
-> ImpM rep r op () -> ImpM rep r op (Code op)
forall a b. (a -> b) -> a -> b
$ CopyCompiler rep r op
forall rep r op. CopyCompiler rep r op
copy PrimType
bt MemLoc
destlocation' MemLoc
srclocation'

-- Like 'copyDWIM', but the target is a 'ValueDestination' instead of
-- a variable name.
copyDWIMDest ::
  ValueDestination ->
  [DimIndex (Imp.TExp Int64)] ->
  SubExp ->
  [DimIndex (Imp.TExp Int64)] ->
  ImpM rep r op ()
copyDWIMDest :: forall rep r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIMDest ValueDestination
_ [DimIndex (TExp Int64)]
_ (Constant PrimValue
v) (DimIndex (TExp Int64)
_ : [DimIndex (TExp Int64)]
_) =
  [Char] -> ImpM rep r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op ()) -> [Char] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
    [[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: constant source", PrimValue -> [Char]
forall a. Pretty a => a -> [Char]
prettyString PrimValue
v, [Char]
"cannot be indexed."]
copyDWIMDest ValueDestination
pat [DimIndex (TExp Int64)]
dest_slice (Constant PrimValue
v) [] =
  case (DimIndex (TExp Int64) -> Maybe (TExp Int64))
-> [DimIndex (TExp Int64)] -> Maybe [TExp Int64]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM DimIndex (TExp Int64) -> Maybe (TExp Int64)
forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
dest_slice of
    Maybe [TExp Int64]
Nothing ->
      [Char] -> ImpM rep r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op ()) -> [Char] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
        [[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: constant source", PrimValue -> [Char]
forall a. Pretty a => a -> [Char]
prettyString PrimValue
v, [Char]
"with slice destination."]
    Just [TExp Int64]
dest_is ->
      case ValueDestination
pat of
        ScalarDestination VName
name ->
          Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
name (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
        MemoryDestination {} ->
          [Char] -> ImpM rep r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op ()) -> [Char] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
            [[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: constant source", PrimValue -> [Char]
forall a. Pretty a => a -> [Char]
prettyString PrimValue
v, [Char]
"cannot be written to memory destination."]
        ArrayDestination (Just MemLoc
dest_loc) -> do
          (dest_mem, dest_space, dest_i) <-
            MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
forall rep r op.
MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLoc
dest_loc [TExp Int64]
dest_is
          vol <- asks envVolatility
          emit $ Imp.Write dest_mem dest_i bt dest_space vol $ Imp.ValueExp v
        ArrayDestination Maybe MemLoc
Nothing ->
          [Char] -> ImpM rep r op ()
forall a. HasCallStack => [Char] -> a
error [Char]
"copyDWIMDest: ArrayDestination Nothing"
  where
    bt :: PrimType
bt = PrimValue -> PrimType
primValueType PrimValue
v
copyDWIMDest ValueDestination
dest [DimIndex (TExp Int64)]
dest_slice (Var VName
src) [DimIndex (TExp Int64)]
src_slice = do
  src_entry <- VName -> ImpM rep r op (VarEntry rep)
forall rep r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
src
  case (dest, src_entry) of
    (MemoryDestination VName
mem, MemVar Maybe (Exp rep)
_ (MemEntry Space
space)) ->
      Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> Space -> Code op
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
mem VName
src Space
space
    (MemoryDestination {}, VarEntry rep
_) ->
      [Char] -> ImpM rep r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op ()) -> [Char] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
        [[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: cannot write", VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString VName
src, [Char]
"to memory destination."]
    (ValueDestination
_, MemVar {}) ->
      [Char] -> ImpM rep r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op ()) -> [Char] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
        [[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: source", VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString VName
src, [Char]
"is a memory block."]
    (ValueDestination
_, ScalarVar Maybe (Exp rep)
_ (ScalarEntry PrimType
_))
      | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [DimIndex (TExp Int64)]
src_slice ->
          [Char] -> ImpM rep r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op ()) -> [Char] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
            [[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: prim-typed source", VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString VName
src, [Char]
"with slice", [DimIndex (TExp Int64)] -> [Char]
forall a. Pretty a => a -> [Char]
prettyString [DimIndex (TExp Int64)]
src_slice]
    (ScalarDestination VName
name, VarEntry rep
_)
      | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [DimIndex (TExp Int64)]
dest_slice ->
          [Char] -> ImpM rep r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op ()) -> [Char] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
            [[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: prim-typed target", VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString VName
name, [Char]
"with slice", [DimIndex (TExp Int64)] -> [Char]
forall a. Pretty a => a -> [Char]
prettyString [DimIndex (TExp Int64)]
dest_slice]
    (ScalarDestination VName
name, ScalarVar Maybe (Exp rep)
_ (ScalarEntry PrimType
pt)) ->
      Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
name (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
src PrimType
pt
    (ScalarDestination VName
name, ArrayVar Maybe (Exp rep)
_ ArrayEntry
arr)
      | Just [TExp Int64]
src_is <- (DimIndex (TExp Int64) -> Maybe (TExp Int64))
-> [DimIndex (TExp Int64)] -> Maybe [TExp Int64]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM DimIndex (TExp Int64) -> Maybe (TExp Int64)
forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
src_slice,
        [DimIndex (TExp Int64)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex (TExp Int64)]
src_slice Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (ArrayEntry -> [SubExp]
entryArrayShape ArrayEntry
arr) -> do
          let bt :: PrimType
bt = ArrayEntry -> PrimType
entryArrayElemType ArrayEntry
arr
          (mem, space, i) <-
            MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
forall rep r op.
MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' (ArrayEntry -> MemLoc
entryArrayLoc ArrayEntry
arr) [TExp Int64]
src_is
          vol <- asks envVolatility
          emit $ Imp.Read name mem i bt space vol
      | Bool
otherwise ->
          [Char] -> ImpM rep r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op ()) -> [Char] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
            [[Char]] -> [Char]
unwords
              [ [Char]
"copyDWIMDest: prim-typed target",
                VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString VName
name,
                [Char]
"and array-typed source",
                VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString VName
src,
                [Char]
"of shape",
                [SubExp] -> [Char]
forall a. Pretty a => a -> [Char]
prettyString (ArrayEntry -> [SubExp]
entryArrayShape ArrayEntry
arr),
                [Char]
"sliced with",
                [DimIndex (TExp Int64)] -> [Char]
forall a. Pretty a => a -> [Char]
prettyString [DimIndex (TExp Int64)]
src_slice
              ]
    (ArrayDestination (Just MemLoc
dest_loc), ArrayVar Maybe (Exp rep)
_ ArrayEntry
src_arr) -> do
      let src_loc :: MemLoc
src_loc = ArrayEntry -> MemLoc
entryArrayLoc ArrayEntry
src_arr
          bt :: PrimType
bt = ArrayEntry -> PrimType
entryArrayElemType ArrayEntry
src_arr
      Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ())
-> ImpM rep r op (Code op) -> ImpM rep r op ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PrimType
-> MemLoc
-> [DimIndex (TExp Int64)]
-> MemLoc
-> [DimIndex (TExp Int64)]
-> ImpM rep r op (Code op)
forall rep r op.
PrimType
-> MemLoc
-> [DimIndex (TExp Int64)]
-> MemLoc
-> [DimIndex (TExp Int64)]
-> ImpM rep r op (Code op)
copyArrayDWIM PrimType
bt MemLoc
dest_loc [DimIndex (TExp Int64)]
dest_slice MemLoc
src_loc [DimIndex (TExp Int64)]
src_slice
    (ArrayDestination (Just MemLoc
dest_loc), ScalarVar Maybe (Exp rep)
_ (ScalarEntry PrimType
bt))
      | Just [TExp Int64]
dest_is <- (DimIndex (TExp Int64) -> Maybe (TExp Int64))
-> [DimIndex (TExp Int64)] -> Maybe [TExp Int64]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM DimIndex (TExp Int64) -> Maybe (TExp Int64)
forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
dest_slice,
        [TExp Int64] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TExp Int64]
dest_is Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (MemLoc -> [SubExp]
memLocShape MemLoc
dest_loc) -> do
          (dest_mem, dest_space, dest_i) <- MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
forall rep r op.
MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLoc
dest_loc [TExp Int64]
dest_is
          vol <- asks envVolatility
          emit $ Imp.Write dest_mem dest_i bt dest_space vol (Imp.var src bt)
      | Bool
otherwise ->
          [Char] -> ImpM rep r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op ()) -> [Char] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
            [[Char]] -> [Char]
unwords
              [ [Char]
"copyDWIMDest: array-typed target and prim-typed source",
                VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString VName
src,
                [Char]
"with slice",
                [DimIndex (TExp Int64)] -> [Char]
forall a. Pretty a => a -> [Char]
prettyString [DimIndex (TExp Int64)]
dest_slice
              ]
    (ArrayDestination Maybe MemLoc
Nothing, VarEntry rep
_) ->
      () -> ImpM rep r op ()
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure () -- Nothing to do; something else set some memory
      -- somewhere.
    (ValueDestination
_, AccVar {}) ->
      () -> ImpM rep r op ()
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure () -- Nothing to do; accumulators are phantoms.

-- | Copy from here to there; both destination and source be
-- indexeded.  If so, they better be arrays of enough dimensions.
-- This function will generally just Do What I Mean, and Do The Right
-- Thing.  Both destination and source must be in scope.
copyDWIM ::
  VName ->
  [DimIndex (Imp.TExp Int64)] ->
  SubExp ->
  [DimIndex (Imp.TExp Int64)] ->
  ImpM rep r op ()
copyDWIM :: forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM VName
dest [DimIndex (TExp Int64)]
dest_slice SubExp
src [DimIndex (TExp Int64)]
src_slice = do
  dest_entry <- VName -> ImpM rep r op (VarEntry rep)
forall rep r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
dest
  let dest_target =
        case VarEntry rep
dest_entry of
          ScalarVar Maybe (Exp rep)
_ ScalarEntry
_ ->
            VName -> ValueDestination
ScalarDestination VName
dest
          ArrayVar Maybe (Exp rep)
_ (ArrayEntry (MemLoc VName
mem [SubExp]
shape LMAD (TExp Int64)
lmad) PrimType
_) ->
            Maybe MemLoc -> ValueDestination
ArrayDestination (Maybe MemLoc -> ValueDestination)
-> Maybe MemLoc -> ValueDestination
forall a b. (a -> b) -> a -> b
$ MemLoc -> Maybe MemLoc
forall a. a -> Maybe a
Just (MemLoc -> Maybe MemLoc) -> MemLoc -> Maybe MemLoc
forall a b. (a -> b) -> a -> b
$ VName -> [SubExp] -> LMAD (TExp Int64) -> MemLoc
MemLoc VName
mem [SubExp]
shape LMAD (TExp Int64)
lmad
          MemVar Maybe (Exp rep)
_ MemEntry
_ ->
            VName -> ValueDestination
MemoryDestination VName
dest
          AccVar {} ->
            -- Does not matter; accumulators are phantoms.
            Maybe MemLoc -> ValueDestination
ArrayDestination Maybe MemLoc
forall a. Maybe a
Nothing
  copyDWIMDest dest_target dest_slice src src_slice

-- | As 'copyDWIM', but implicitly 'DimFix'es the indexes.
copyDWIMFix ::
  VName ->
  [Imp.TExp Int64] ->
  SubExp ->
  [Imp.TExp Int64] ->
  ImpM rep r op ()
copyDWIMFix :: forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
dest [TExp Int64]
dest_is SubExp
src [TExp Int64]
src_is =
  VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM VName
dest ((TExp Int64 -> DimIndex (TExp Int64))
-> [TExp Int64] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix [TExp Int64]
dest_is) SubExp
src ((TExp Int64 -> DimIndex (TExp Int64))
-> [TExp Int64] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix [TExp Int64]
src_is)

-- | @compileAlloc pat size space@ allocates @n@ bytes of memory in
-- @space@, writing the result to @pat@, which must contain a single
-- memory-typed element.
compileAlloc ::
  (Mem rep inner) => Pat (LetDec rep) -> SubExp -> Space -> ImpM rep r op ()
compileAlloc :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
Pat (LetDec rep) -> SubExp -> Space -> ImpM rep r op ()
compileAlloc (Pat [PatElem (LetDec rep)
mem]) SubExp
e Space
space = do
  let e' :: Count Bytes (TExp Int64)
e' = TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
Imp.bytes (TExp Int64 -> Count Bytes (TExp Int64))
-> TExp Int64 -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$ SubExp -> TExp Int64
pe64 SubExp
e
  allocator <- (Env rep r op -> Maybe (AllocCompiler rep r op))
-> ImpM rep r op (Maybe (AllocCompiler rep r op))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((Env rep r op -> Maybe (AllocCompiler rep r op))
 -> ImpM rep r op (Maybe (AllocCompiler rep r op)))
-> (Env rep r op -> Maybe (AllocCompiler rep r op))
-> ImpM rep r op (Maybe (AllocCompiler rep r op))
forall a b. (a -> b) -> a -> b
$ Space
-> Map Space (AllocCompiler rep r op)
-> Maybe (AllocCompiler rep r op)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Space
space (Map Space (AllocCompiler rep r op)
 -> Maybe (AllocCompiler rep r op))
-> (Env rep r op -> Map Space (AllocCompiler rep r op))
-> Env rep r op
-> Maybe (AllocCompiler rep r op)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env rep r op -> Map Space (AllocCompiler rep r op)
forall rep r op. Env rep r op -> Map Space (AllocCompiler rep r op)
envAllocCompilers
  case allocator of
    Maybe (AllocCompiler rep r op)
Nothing -> Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Count Bytes (TExp Int64) -> Space -> Code op
forall a. VName -> Count Bytes (TExp Int64) -> Space -> Code a
Imp.Allocate (PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
mem) Count Bytes (TExp Int64)
e' Space
space
    Just AllocCompiler rep r op
allocator' -> AllocCompiler rep r op
allocator' (PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
mem) Count Bytes (TExp Int64)
e'
compileAlloc Pat (LetDec rep)
pat SubExp
_ Space
_ =
  [Char] -> ImpM rep r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op ()) -> [Char] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ [Char]
"compileAlloc: Invalid pattern: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Pat (LetDec rep) -> [Char]
forall a. Pretty a => a -> [Char]
prettyString Pat (LetDec rep)
pat

-- | The number of bytes needed to represent the array in a
-- straightforward contiguous format, as an t'Int64' expression.
typeSize :: Type -> Count Bytes (Imp.TExp Int64)
typeSize :: Type -> Count Bytes (TExp Int64)
typeSize Type
t =
  TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
Imp.bytes (TExp Int64 -> Count Bytes (TExp Int64))
-> TExp Int64 -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$ PrimType -> TExp Int64
forall a. Num a => PrimType -> a
primByteSize (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* [TExp Int64] -> TExp Int64
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ((SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t))

-- | Is this indexing in-bounds for an array of the given shape?  This
-- is useful for things like scatter, which ignores out-of-bounds
-- writes.
inBounds :: Slice (Imp.TExp Int64) -> [Imp.TExp Int64] -> Imp.TExp Bool
inBounds :: Slice (TExp Int64) -> [TExp Int64] -> TExp Bool
inBounds (Slice [DimIndex (TExp Int64)]
slice) [TExp Int64]
dims =
  let condInBounds :: DimIndex (TPrimExp t v) -> TPrimExp t v -> TPrimExp Bool v
condInBounds (DimFix TPrimExp t v
i) TPrimExp t v
d =
        TPrimExp t v
0 TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TPrimExp t v
i TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp t v
i TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp t v
d
      condInBounds (DimSlice TPrimExp t v
i TPrimExp t v
n TPrimExp t v
s) TPrimExp t v
d =
        TPrimExp t v
0 TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TPrimExp t v
i TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp t v
i TPrimExp t v -> TPrimExp t v -> TPrimExp t v
forall a. Num a => a -> a -> a
+ (TPrimExp t v
n TPrimExp t v -> TPrimExp t v -> TPrimExp t v
forall a. Num a => a -> a -> a
- TPrimExp t v
1) TPrimExp t v -> TPrimExp t v -> TPrimExp t v
forall a. Num a => a -> a -> a
* TPrimExp t v
s TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp t v
d
   in (TExp Bool -> TExp Bool -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall a. (a -> a -> a) -> [a] -> a
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) ([TExp Bool] -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall a b. (a -> b) -> a -> b
$ (DimIndex (TExp Int64) -> TExp Int64 -> TExp Bool)
-> [DimIndex (TExp Int64)] -> [TExp Int64] -> [TExp Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith DimIndex (TExp Int64) -> TExp Int64 -> TExp Bool
forall {k} {v} {t :: k}.
(Eq v, NumExp t, Pretty v) =>
DimIndex (TPrimExp t v) -> TPrimExp t v -> TPrimExp Bool v
condInBounds [DimIndex (TExp Int64)]
slice [TExp Int64]
dims

--- Building blocks for constructing code.

sFor' :: VName -> Imp.Exp -> ImpM rep r op () -> ImpM rep r op ()
sFor' :: forall rep r op.
VName -> Exp -> ImpM rep r op () -> ImpM rep r op ()
sFor' VName
i Exp
bound ImpM rep r op ()
body = do
  let it :: IntType
it = case Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType Exp
bound of
        IntType IntType
bound_t -> IntType
bound_t
        PrimType
t -> [Char] -> IntType
forall a. HasCallStack => [Char] -> a
error ([Char] -> IntType) -> [Char] -> IntType
forall a b. (a -> b) -> a -> b
$ [Char]
"sFor': bound " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Exp -> [Char]
forall a. Pretty a => a -> [Char]
prettyString Exp
bound [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
" is of type " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ PrimType -> [Char]
forall a. Pretty a => a -> [Char]
prettyString PrimType
t
  VName -> IntType -> ImpM rep r op ()
forall rep r op. VName -> IntType -> ImpM rep r op ()
addLoopVar VName
i IntType
it
  body' <- ImpM rep r op () -> ImpM rep r op (Code op)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect ImpM rep r op ()
body
  emit $ Imp.For i bound body'

sFor :: String -> Imp.TExp t -> (Imp.TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor :: forall {k} (t :: k) rep r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
i TExp t
bound TExp t -> ImpM rep r op ()
body = do
  i' <- [Char] -> ImpM rep r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
i
  sFor' i' (untyped bound) $
    body $
      TPrimExp $
        Imp.var i' $
          primExpType $
            untyped bound

sWhile :: Imp.TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhile :: forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhile TExp Bool
cond ImpM rep r op ()
body = do
  body' <- ImpM rep r op () -> ImpM rep r op (Code op)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect ImpM rep r op ()
body
  emit $ Imp.While cond body'

-- | Execute a code generation action, wrapping the generated code
-- within a 'Imp.MetaComment' with the given description.
sComment :: T.Text -> ImpM rep r op () -> ImpM rep r op ()
sComment :: forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
s ImpM rep r op ()
code = do
  code' <- ImpM rep r op () -> ImpM rep r op (Code op)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect ImpM rep r op ()
code
  emit $ Imp.Meta (Imp.MetaComment s) <> code'

sIf :: Imp.TExp Bool -> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf :: forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf TExp Bool
cond ImpM rep r op ()
tbranch ImpM rep r op ()
fbranch = do
  tbranch' <- ImpM rep r op () -> ImpM rep r op (Code op)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect ImpM rep r op ()
tbranch
  fbranch' <- collect fbranch
  -- Avoid generating branch if the condition is known statically.
  emit $
    if cond == true
      then tbranch'
      else
        if cond == false
          then fbranch'
          else Imp.If cond tbranch' fbranch'

sWhen :: Imp.TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen :: forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
cond ImpM rep r op ()
tbranch = TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf TExp Bool
cond ImpM rep r op ()
tbranch (() -> ImpM rep r op ()
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ())

sUnless :: Imp.TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless :: forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless TExp Bool
cond = TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf TExp Bool
cond (() -> ImpM rep r op ()
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ())

sOp :: op -> ImpM rep r op ()
sOp :: forall op rep r. op -> ImpM rep r op ()
sOp = Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ())
-> (op -> Code op) -> op -> ImpM rep r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. op -> Code op
forall a. a -> Code a
Imp.Op

sDeclareMem :: String -> Space -> ImpM rep r op VName
sDeclareMem :: forall rep r op. [Char] -> Space -> ImpM rep r op VName
sDeclareMem [Char]
name Space
space = do
  name' <- [Char] -> ImpM rep r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
name
  emit $ Imp.DeclareMem name' space
  addVar name' $ MemVar Nothing $ MemEntry space
  pure name'

sAlloc_ :: VName -> Count Bytes (Imp.TExp Int64) -> Space -> ImpM rep r op ()
sAlloc_ :: forall rep r op.
VName -> Count Bytes (TExp Int64) -> Space -> ImpM rep r op ()
sAlloc_ VName
name' Count Bytes (TExp Int64)
size' Space
space = do
  allocator <- (Env rep r op -> Maybe (AllocCompiler rep r op))
-> ImpM rep r op (Maybe (AllocCompiler rep r op))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((Env rep r op -> Maybe (AllocCompiler rep r op))
 -> ImpM rep r op (Maybe (AllocCompiler rep r op)))
-> (Env rep r op -> Maybe (AllocCompiler rep r op))
-> ImpM rep r op (Maybe (AllocCompiler rep r op))
forall a b. (a -> b) -> a -> b
$ Space
-> Map Space (AllocCompiler rep r op)
-> Maybe (AllocCompiler rep r op)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Space
space (Map Space (AllocCompiler rep r op)
 -> Maybe (AllocCompiler rep r op))
-> (Env rep r op -> Map Space (AllocCompiler rep r op))
-> Env rep r op
-> Maybe (AllocCompiler rep r op)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env rep r op -> Map Space (AllocCompiler rep r op)
forall rep r op. Env rep r op -> Map Space (AllocCompiler rep r op)
envAllocCompilers
  case allocator of
    Maybe (AllocCompiler rep r op)
Nothing -> Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Count Bytes (TExp Int64) -> Space -> Code op
forall a. VName -> Count Bytes (TExp Int64) -> Space -> Code a
Imp.Allocate VName
name' Count Bytes (TExp Int64)
size' Space
space
    Just AllocCompiler rep r op
allocator' -> AllocCompiler rep r op
allocator' VName
name' Count Bytes (TExp Int64)
size'

sAlloc :: String -> Count Bytes (Imp.TExp Int64) -> Space -> ImpM rep r op VName
sAlloc :: forall rep r op.
[Char] -> Count Bytes (TExp Int64) -> Space -> ImpM rep r op VName
sAlloc [Char]
name Count Bytes (TExp Int64)
size Space
space = do
  name' <- [Char] -> Space -> ImpM rep r op VName
forall rep r op. [Char] -> Space -> ImpM rep r op VName
sDeclareMem [Char]
name Space
space
  sAlloc_ name' size space
  pure name'

sArray :: String -> PrimType -> ShapeBase SubExp -> VName -> LMAD -> ImpM rep r op VName
sArray :: forall rep r op.
[Char]
-> PrimType
-> Shape
-> VName
-> LMAD (TExp Int64)
-> ImpM rep r op VName
sArray [Char]
name PrimType
bt Shape
shape VName
mem LMAD (TExp Int64)
lmad = do
  name' <- [Char] -> ImpM rep r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
name
  dArray name' bt shape mem lmad
  pure name'

-- | Declare an array in row-major order in the given memory block.
sArrayInMem :: String -> PrimType -> ShapeBase SubExp -> VName -> ImpM rep r op VName
sArrayInMem :: forall rep r op.
[Char] -> PrimType -> Shape -> VName -> ImpM rep r op VName
sArrayInMem [Char]
name PrimType
pt Shape
shape VName
mem =
  [Char]
-> PrimType
-> Shape
-> VName
-> LMAD (TExp Int64)
-> ImpM rep r op VName
forall rep r op.
[Char]
-> PrimType
-> Shape
-> VName
-> LMAD (TExp Int64)
-> ImpM rep r op VName
sArray [Char]
name PrimType
pt Shape
shape VName
mem (LMAD (TExp Int64) -> ImpM rep r op VName)
-> LMAD (TExp Int64) -> ImpM rep r op VName
forall a b. (a -> b) -> a -> b
$
    TExp Int64 -> [TExp Int64] -> LMAD (TExp Int64)
forall num. IntegralExp num => num -> [num] -> LMAD num
LMAD.iota TExp Int64
0 ([TExp Int64] -> LMAD (TExp Int64))
-> [TExp Int64] -> LMAD (TExp Int64)
forall a b. (a -> b) -> a -> b
$
      (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (Exp -> TExp Int64
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (Exp -> TExp Int64) -> (SubExp -> Exp) -> SubExp -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> SubExp -> Exp
primExpFromSubExp PrimType
int64) ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$
        Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape

-- | Like 'sAllocArray', but permute the in-memory representation of the indices as specified.
sAllocArrayPerm :: String -> PrimType -> ShapeBase SubExp -> Space -> [Int] -> ImpM rep r op VName
sAllocArrayPerm :: forall rep r op.
[Char]
-> PrimType -> Shape -> Space -> [Int] -> ImpM rep r op VName
sAllocArrayPerm [Char]
name PrimType
pt Shape
shape Space
space [Int]
perm = do
  let permuted_dims :: [SubExp]
permuted_dims = [Int] -> [SubExp] -> [SubExp]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape
  mem <- [Char] -> Count Bytes (TExp Int64) -> Space -> ImpM rep r op VName
forall rep r op.
[Char] -> Count Bytes (TExp Int64) -> Space -> ImpM rep r op VName
sAlloc ([Char]
name [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"_mem") (Type -> Count Bytes (TExp Int64)
typeSize (PrimType -> Shape -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt Shape
shape NoUniqueness
NoUniqueness)) Space
space
  let iota_lmad = TExp Int64 -> [TExp Int64] -> LMAD (TExp Int64)
forall num. IntegralExp num => num -> [num] -> LMAD num
LMAD.iota TExp Int64
0 ([TExp Int64] -> LMAD (TExp Int64))
-> [TExp Int64] -> LMAD (TExp Int64)
forall a b. (a -> b) -> a -> b
$ (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (Exp -> TExp Int64
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (Exp -> TExp Int64) -> (SubExp -> Exp) -> SubExp -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> SubExp -> Exp
primExpFromSubExp PrimType
int64) [SubExp]
permuted_dims
  sArray name pt shape mem $
    LMAD.permute iota_lmad $
      rearrangeInverse perm

-- | Uses linear/iota index function.
sAllocArray :: String -> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName
sAllocArray :: forall rep r op.
[Char] -> PrimType -> Shape -> Space -> ImpM rep r op VName
sAllocArray [Char]
name PrimType
pt Shape
shape Space
space =
  [Char]
-> PrimType -> Shape -> Space -> [Int] -> ImpM rep r op VName
forall rep r op.
[Char]
-> PrimType -> Shape -> Space -> [Int] -> ImpM rep r op VName
sAllocArrayPerm [Char]
name PrimType
pt Shape
shape Space
space [Int
0 .. Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
shape Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]

-- | Uses linear/iota index function.
sStaticArray :: String -> PrimType -> Imp.ArrayContents -> ImpM rep r op VName
sStaticArray :: forall rep r op.
[Char] -> PrimType -> ArrayContents -> ImpM rep r op VName
sStaticArray [Char]
name PrimType
pt ArrayContents
vs = do
  let num_elems :: Int
num_elems = case ArrayContents
vs of
        Imp.ArrayValues [PrimValue]
vs' -> [PrimValue] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimValue]
vs'
        Imp.ArrayZeros Int
n -> Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n
      shape :: Shape
shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [IntType -> Integer -> SubExp
intConst IntType
Int64 (Integer -> SubExp) -> Integer -> SubExp
forall a b. (a -> b) -> a -> b
$ Int -> Integer
forall a. Integral a => a -> Integer
toInteger Int
num_elems]
  mem <- [Char] -> ImpM rep r op VName
forall rep r op. [Char] -> ImpM rep r op VName
newVNameForFun ([Char] -> ImpM rep r op VName) -> [Char] -> ImpM rep r op VName
forall a b. (a -> b) -> a -> b
$ [Char]
name [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"_mem"
  emit $ Imp.DeclareArray mem pt vs
  addVar mem $ MemVar Nothing $ MemEntry DefaultSpace
  sArray name pt shape mem $ LMAD.iota 0 [fromIntegral num_elems]

sWrite :: VName -> [Imp.TExp Int64] -> Imp.Exp -> ImpM rep r op ()
sWrite :: forall rep r op. VName -> [TExp Int64] -> Exp -> ImpM rep r op ()
sWrite VName
arr [TExp Int64]
is Exp
v = do
  (mem, space, offset) <- VName
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
forall rep r op.
VName
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray VName
arr [TExp Int64]
is
  vol <- asks envVolatility
  emit $ Imp.Write mem offset (primExpType v) space vol v

sUpdate :: VName -> Slice (Imp.TExp Int64) -> SubExp -> ImpM rep r op ()
sUpdate :: forall rep r op.
VName -> Slice (TExp Int64) -> SubExp -> ImpM rep r op ()
sUpdate VName
arr Slice (TExp Int64)
slice SubExp
v = VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM VName
arr (Slice (TExp Int64) -> [DimIndex (TExp Int64)]
forall d. Slice d -> [DimIndex d]
unSlice Slice (TExp Int64)
slice) SubExp
v []

-- | Create a sequential 'Imp.For' loop covering a space of the given
-- shape.  The function is calling with the indexes for a given
-- iteration.
sLoopSpace ::
  [Imp.TExp t] ->
  ([Imp.TExp t] -> ImpM rep r op ()) ->
  ImpM rep r op ()
sLoopSpace :: forall {k} (t :: k) rep r op.
[TExp t] -> ([TExp t] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopSpace = [TExp t]
-> [TExp t] -> ([TExp t] -> ImpM rep r op ()) -> ImpM rep r op ()
forall {k} {t :: k} {rep} {r} {op}.
[TExp t]
-> [TExp t] -> ([TExp t] -> ImpM rep r op ()) -> ImpM rep r op ()
nest []
  where
    nest :: [TExp t]
-> [TExp t] -> ([TExp t] -> ImpM rep r op ()) -> ImpM rep r op ()
nest [TExp t]
is [] [TExp t] -> ImpM rep r op ()
f = [TExp t] -> ImpM rep r op ()
f ([TExp t] -> ImpM rep r op ()) -> [TExp t] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ [TExp t] -> [TExp t]
forall a. [a] -> [a]
reverse [TExp t]
is
    nest [TExp t]
is (TExp t
d : [TExp t]
ds) [TExp t] -> ImpM rep r op ()
f = [Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
forall {k} (t :: k) rep r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"nest_i" TExp t
d ((TExp t -> ImpM rep r op ()) -> ImpM rep r op ())
-> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \TExp t
i -> [TExp t]
-> [TExp t] -> ([TExp t] -> ImpM rep r op ()) -> ImpM rep r op ()
nest (TExp t
i TExp t -> [TExp t] -> [TExp t]
forall a. a -> [a] -> [a]
: [TExp t]
is) [TExp t]
ds [TExp t] -> ImpM rep r op ()
f

sLoopNest ::
  Shape ->
  ([Imp.TExp Int64] -> ImpM rep r op ()) ->
  ImpM rep r op ()
sLoopNest :: forall rep r op.
Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest = [TExp Int64]
-> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
forall {k} (t :: k) rep r op.
[TExp t] -> ([TExp t] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopSpace ([TExp Int64]
 -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ())
-> (Shape -> [TExp Int64])
-> Shape
-> ([TExp Int64] -> ImpM rep r op ())
-> ImpM rep r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 ([SubExp] -> [TExp Int64])
-> (Shape -> [SubExp]) -> Shape -> [TExp Int64]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims

-- | Untyped assignment.
(<~~) :: VName -> Imp.Exp -> ImpM rep r op ()
VName
x <~~ :: forall rep r op. VName -> Exp -> ImpM rep r op ()
<~~ Exp
e = Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
x Exp
e

infixl 3 <~~

-- | Typed assignment.
(<--) :: TV t -> Imp.TExp t -> ImpM rep r op ()
TV VName
x PrimType
_ <-- :: forall {k} (t :: k) rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TExp t
e = Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
x (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ TExp t -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp t
e

infixl 3 <--

-- | Constructing an ad-hoc function that does not
-- correspond to any of the IR functions in the input program.
function ::
  Name ->
  [Imp.Param] ->
  [Imp.Param] ->
  ImpM rep r op () ->
  ImpM rep r op ()
function :: forall rep r op.
Name -> [Param] -> [Param] -> ImpM rep r op () -> ImpM rep r op ()
function Name
fname [Param]
outputs [Param]
inputs ImpM rep r op ()
m = (Env rep r op -> Env rep r op)
-> ImpM rep r op () -> ImpM rep r op ()
forall a.
(Env rep r op -> Env rep r op)
-> ImpM rep r op a -> ImpM rep r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local Env rep r op -> Env rep r op
newFunction (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ do
  body <- ImpM rep r op () -> ImpM rep r op (Code op)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect (ImpM rep r op () -> ImpM rep r op (Code op))
-> ImpM rep r op () -> ImpM rep r op (Code op)
forall a b. (a -> b) -> a -> b
$ do
    (Param -> ImpM rep r op ()) -> [Param] -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Param -> ImpM rep r op ()
forall {rep} {r} {op}. Param -> ImpM rep r op ()
addParam ([Param] -> ImpM rep r op ()) -> [Param] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ [Param]
outputs [Param] -> [Param] -> [Param]
forall a. [a] -> [a] -> [a]
++ [Param]
inputs
    ImpM rep r op ()
m
  emitFunction fname $ Imp.Function Nothing outputs inputs body
  where
    addParam :: Param -> ImpM rep r op ()
addParam (Imp.MemParam VName
name Space
space) =
      VName -> VarEntry rep -> ImpM rep r op ()
forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name (VarEntry rep -> ImpM rep r op ())
-> VarEntry rep -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp rep) -> MemEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> MemEntry -> VarEntry rep
MemVar Maybe (Exp rep)
forall a. Maybe a
Nothing (MemEntry -> VarEntry rep) -> MemEntry -> VarEntry rep
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
    addParam (Imp.ScalarParam VName
name PrimType
bt) =
      VName -> VarEntry rep -> ImpM rep r op ()
forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name (VarEntry rep -> ImpM rep r op ())
-> VarEntry rep -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
ScalarVar Maybe (Exp rep)
forall a. Maybe a
Nothing (ScalarEntry -> VarEntry rep) -> ScalarEntry -> VarEntry rep
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry PrimType
bt
    newFunction :: Env rep r op -> Env rep r op
newFunction Env rep r op
env = Env rep r op
env {envFunction = Just fname}

-- Fish out those top-level declarations in the constant
-- initialisation code that are free in the functions.
constParams :: Names -> Imp.Code a -> (DL.DList Imp.Param, Imp.Code a)
constParams :: forall a. Names -> Code a -> (DList Param, Code a)
constParams Names
used (Code a
x Imp.:>>: Code a
y) =
  Names -> Code a -> (DList Param, Code a)
forall a. Names -> Code a -> (DList Param, Code a)
constParams Names
used Code a
x (DList Param, Code a)
-> (DList Param, Code a) -> (DList Param, Code a)
forall a. Semigroup a => a -> a -> a
<> Names -> Code a -> (DList Param, Code a)
forall a. Names -> Code a -> (DList Param, Code a)
constParams Names
used Code a
y
constParams Names
used (Imp.DeclareMem VName
name Space
space)
  | VName
name VName -> Names -> Bool
`nameIn` Names
used =
      ( Param -> DList Param
forall a. a -> DList a
DL.singleton (Param -> DList Param) -> Param -> DList Param
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Param
Imp.MemParam VName
name Space
space,
        Code a
forall a. Monoid a => a
mempty
      )
constParams Names
used (Imp.DeclareScalar VName
name Volatility
_ PrimType
t)
  | VName
name VName -> Names -> Bool
`nameIn` Names
used =
      ( Param -> DList Param
forall a. a -> DList a
DL.singleton (Param -> DList Param) -> Param -> DList Param
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Param
Imp.ScalarParam VName
name PrimType
t,
        Code a
forall a. Monoid a => a
mempty
      )
constParams Names
used s :: Code a
s@(Imp.DeclareArray VName
name PrimType
_ ArrayContents
_)
  | VName
name VName -> Names -> Bool
`nameIn` Names
used =
      ( Param -> DList Param
forall a. a -> DList a
DL.singleton (Param -> DList Param) -> Param -> DList Param
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Param
Imp.MemParam VName
name Space
DefaultSpace,
        Code a
s
      )
constParams Names
_ Code a
s =
  (DList Param
forall a. Monoid a => a
mempty, Code a
s)

-- | Generate constants that get put outside of all functions.  Will
-- be executed at program startup.  Action must return the names that
-- should should be made available.  This one has real sharp edges. Do
-- not use inside 'subImpM'.  Do not use any variable from the context.
genConstants :: ImpM rep r op (Names, a) -> ImpM rep r op a
genConstants :: forall rep r op a. ImpM rep r op (Names, a) -> ImpM rep r op a
genConstants ImpM rep r op (Names, a)
m = do
  ((avail, a), code) <- ImpM rep r op (Names, a) -> ImpM rep r op ((Names, a), Code op)
forall rep r op a. ImpM rep r op a -> ImpM rep r op (a, Code op)
collect' ImpM rep r op (Names, a)
m
  let consts = ([Param] -> Code op -> Constants op)
-> ([Param], Code op) -> Constants op
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry [Param] -> Code op -> Constants op
forall a. [Param] -> Code a -> Constants a
Imp.Constants (([Param], Code op) -> Constants op)
-> ([Param], Code op) -> Constants op
forall a b. (a -> b) -> a -> b
$ (DList Param -> [Param])
-> (DList Param, Code op) -> ([Param], Code op)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first DList Param -> [Param]
forall a. DList a -> [a]
DL.toList ((DList Param, Code op) -> ([Param], Code op))
-> (DList Param, Code op) -> ([Param], Code op)
forall a b. (a -> b) -> a -> b
$ Names -> Code op -> (DList Param, Code op)
forall a. Names -> Code a -> (DList Param, Code a)
constParams Names
avail Code op
code
  modify $ \ImpState rep r op
s -> ImpState rep r op
s {stateConstants = stateConstants s <> consts}
  pure a

dSlices :: [Imp.TExp Int64] -> ImpM rep r op [Imp.TExp Int64]
dSlices :: forall rep r op. [TExp Int64] -> ImpM rep r op [TExp Int64]
dSlices = ((TExp Int64, [TExp Int64]) -> [TExp Int64])
-> ImpM rep r op (TExp Int64, [TExp Int64])
-> ImpM rep r op [TExp Int64]
forall a b. (a -> b) -> ImpM rep r op a -> ImpM rep r op b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> [TExp Int64] -> [TExp Int64]
forall a. Int -> [a] -> [a]
drop Int
1 ([TExp Int64] -> [TExp Int64])
-> ((TExp Int64, [TExp Int64]) -> [TExp Int64])
-> (TExp Int64, [TExp Int64])
-> [TExp Int64]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TExp Int64, [TExp Int64]) -> [TExp Int64]
forall a b. (a, b) -> b
snd) (ImpM rep r op (TExp Int64, [TExp Int64])
 -> ImpM rep r op [TExp Int64])
-> ([TExp Int64] -> ImpM rep r op (TExp Int64, [TExp Int64]))
-> [TExp Int64]
-> ImpM rep r op [TExp Int64]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [TExp Int64] -> ImpM rep r op (TExp Int64, [TExp Int64])
forall {k} {t :: k} {rep} {r} {op}.
NumExp t =>
[TExp t] -> ImpM rep r op (TExp t, [TExp t])
dSlices'
  where
    dSlices' :: [TExp t] -> ImpM rep r op (TExp t, [TExp t])
dSlices' [] = (TExp t, [TExp t]) -> ImpM rep r op (TExp t, [TExp t])
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TExp t
1, [TExp t
1])
    dSlices' (TExp t
n : [TExp t]
ns) = do
      (prod, ns') <- [TExp t] -> ImpM rep r op (TExp t, [TExp t])
dSlices' [TExp t]
ns
      n' <- dPrimVE "slice" $ n * prod
      pure (n', n' : ns')

-- | @dIndexSpace f dims i@ computes a list of indices into an
-- array with dimension @dims@ given the flat index @i@.  The
-- resulting list will have the same size as @dims@.  Intermediate
-- results are passed to @f@.
dIndexSpace ::
  [(VName, Imp.TExp Int64)] ->
  Imp.TExp Int64 ->
  ImpM rep r op ()
dIndexSpace :: forall rep r op.
[(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
dIndexSpace [(VName, TExp Int64)]
vs_ds TExp Int64
j = do
  slices <- [TExp Int64] -> ImpM rep r op [TExp Int64]
forall rep r op. [TExp Int64] -> ImpM rep r op [TExp Int64]
dSlices (((VName, TExp Int64) -> TExp Int64)
-> [(VName, TExp Int64)] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (VName, TExp Int64) -> TExp Int64
forall a b. (a, b) -> b
snd [(VName, TExp Int64)]
vs_ds)
  loop (zip (map fst vs_ds) slices) j
  where
    loop :: [(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
loop ((VName
v, TExp Int64
size) : [(VName, TExp Int64)]
rest) TExp Int64
i = do
      VName -> TExp Int64 -> ImpM rep r op ()
forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
v (TExp Int64
i TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int64
size)
      i' <- [Char] -> TExp Int64 -> ImpM rep r op (TExp Int64)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"remnant" (TExp Int64 -> ImpM rep r op (TExp Int64))
-> TExp Int64 -> ImpM rep r op (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64
i TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
Imp.le64 VName
v TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
size
      loop rest i'
    loop [(VName, TExp Int64)]
_ TExp Int64
_ = () -> ImpM rep r op ()
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- | Like 'dIndexSpace', but invent some new names for the indexes
-- based on the given template.
dIndexSpace' ::
  String ->
  [Imp.TExp Int64] ->
  Imp.TExp Int64 ->
  ImpM rep r op [Imp.TExp Int64]
dIndexSpace' :: forall rep r op.
[Char] -> [TExp Int64] -> TExp Int64 -> ImpM rep r op [TExp Int64]
dIndexSpace' [Char]
desc [TExp Int64]
ds TExp Int64
j = do
  ivs <- Int -> ImpM rep r op VName -> ImpM rep r op [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM ([TExp Int64] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TExp Int64]
ds) ([Char] -> ImpM rep r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
desc)
  dIndexSpace (zip ivs ds) j
  pure $ map Imp.le64 ivs