{-# LANGUAGE TypeFamilies #-}
module Futhark.CodeGen.ImpGen.GPU.SegScan.TwoPass (compileSegScan) where
import Control.Monad
import Control.Monad.State
import Data.List qualified as L
import Data.Maybe
import Futhark.CodeGen.ImpCode.GPU qualified as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.GPU.Base
import Futhark.IR.GPUMem
import Futhark.IR.Mem.LMAD qualified as LMAD
import Futhark.Transform.Rename
import Futhark.Util (takeLast)
import Futhark.Util.IntegralExp (divUp, quot, rem)
import Prelude hiding (quot, rem)
makeLocalArrays ::
Count BlockSize SubExp ->
SubExp ->
[SegBinOp GPUMem] ->
InKernelGen [[VName]]
makeLocalArrays :: Count BlockSize SubExp
-> SubExp -> [SegBinOp GPUMem] -> InKernelGen [[VName]]
makeLocalArrays (Count SubExp
tblock_size) SubExp
num_threads [SegBinOp GPUMem]
scans = do
(arrs, mems_and_sizes) <- StateT
[([Count Bytes (TPrimExp Int64 VName)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
[[VName]]
-> [([Count Bytes (TPrimExp Int64 VName)], VName)]
-> ImpM
GPUMem
KernelEnv
KernelOp
([[VName]], [([Count Bytes (TPrimExp Int64 VName)], VName)])
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT ((SegBinOp GPUMem
-> StateT
[([Count Bytes (TPrimExp Int64 VName)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
[VName])
-> [SegBinOp GPUMem]
-> StateT
[([Count Bytes (TPrimExp Int64 VName)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
[[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SegBinOp GPUMem
-> StateT
[([Count Bytes (TPrimExp Int64 VName)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
[VName]
onScan [SegBinOp GPUMem]
scans) [([Count Bytes (TPrimExp Int64 VName)], VName)]
forall a. Monoid a => a
mempty
let maxSize [Count u (TPrimExp Int64 v)]
sizes = TPrimExp Int64 v -> Count Bytes (TPrimExp Int64 v)
forall a. a -> Count Bytes a
Imp.bytes (TPrimExp Int64 v -> Count Bytes (TPrimExp Int64 v))
-> TPrimExp Int64 v -> Count Bytes (TPrimExp Int64 v)
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v)
-> TPrimExp Int64 v -> [TPrimExp Int64 v] -> TPrimExp Int64 v
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
L.foldl' TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMax64 TPrimExp Int64 v
1 ([TPrimExp Int64 v] -> TPrimExp Int64 v)
-> [TPrimExp Int64 v] -> TPrimExp Int64 v
forall a b. (a -> b) -> a -> b
$ (Count u (TPrimExp Int64 v) -> TPrimExp Int64 v)
-> [Count u (TPrimExp Int64 v)] -> [TPrimExp Int64 v]
forall a b. (a -> b) -> [a] -> [b]
map Count u (TPrimExp Int64 v) -> TPrimExp Int64 v
forall {k} (u :: k) e. Count u e -> e
Imp.unCount [Count u (TPrimExp Int64 v)]
sizes
forM_ mems_and_sizes $ \([Count Bytes (TPrimExp Int64 VName)]
sizes, VName
mem) ->
VName
-> Count Bytes (TPrimExp Int64 VName)
-> Space
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> Count Bytes (TPrimExp Int64 VName) -> Space -> ImpM rep r op ()
sAlloc_ VName
mem ([Count Bytes (TPrimExp Int64 VName)]
-> Count Bytes (TPrimExp Int64 VName)
forall {k} {v} {u :: k}.
Pretty v =>
[Count u (TPrimExp Int64 v)] -> Count Bytes (TPrimExp Int64 v)
maxSize [Count Bytes (TPrimExp Int64 VName)]
sizes) (SpaceId -> Space
Space SpaceId
"shared")
pure arrs
where
onScan :: SegBinOp GPUMem
-> StateT
[([Count Bytes (TPrimExp Int64 VName)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
[VName]
onScan (SegBinOp Commutativity
_ Lambda GPUMem
scan_op [SubExp]
nes Shape
_) = do
let ([Param LParamMem]
scan_x_params, [Param LParamMem]
_scan_y_params) =
Int -> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Param LParamMem] -> ([Param LParamMem], [Param LParamMem]))
-> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
scan_op
(arrs, used_mems) <- ([(VName, [([Count Bytes (TPrimExp Int64 VName)], VName)])]
-> ([VName], [[([Count Bytes (TPrimExp Int64 VName)], VName)]]))
-> StateT
[([Count Bytes (TPrimExp Int64 VName)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
[(VName, [([Count Bytes (TPrimExp Int64 VName)], VName)])]
-> StateT
[([Count Bytes (TPrimExp Int64 VName)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
([VName], [[([Count Bytes (TPrimExp Int64 VName)], VName)]])
forall a b.
(a -> b)
-> StateT
[([Count Bytes (TPrimExp Int64 VName)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
a
-> StateT
[([Count Bytes (TPrimExp Int64 VName)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(VName, [([Count Bytes (TPrimExp Int64 VName)], VName)])]
-> ([VName], [[([Count Bytes (TPrimExp Int64 VName)], VName)]])
forall a b. [(a, b)] -> ([a], [b])
unzip (StateT
[([Count Bytes (TPrimExp Int64 VName)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
[(VName, [([Count Bytes (TPrimExp Int64 VName)], VName)])]
-> StateT
[([Count Bytes (TPrimExp Int64 VName)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
([VName], [[([Count Bytes (TPrimExp Int64 VName)], VName)]]))
-> StateT
[([Count Bytes (TPrimExp Int64 VName)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
[(VName, [([Count Bytes (TPrimExp Int64 VName)], VName)])]
-> StateT
[([Count Bytes (TPrimExp Int64 VName)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
([VName], [[([Count Bytes (TPrimExp Int64 VName)], VName)]])
forall a b. (a -> b) -> a -> b
$
[Param LParamMem]
-> (Param LParamMem
-> StateT
[([Count Bytes (TPrimExp Int64 VName)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
(VName, [([Count Bytes (TPrimExp Int64 VName)], VName)]))
-> StateT
[([Count Bytes (TPrimExp Int64 VName)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
[(VName, [([Count Bytes (TPrimExp Int64 VName)], VName)])]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Param LParamMem]
scan_x_params ((Param LParamMem
-> StateT
[([Count Bytes (TPrimExp Int64 VName)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
(VName, [([Count Bytes (TPrimExp Int64 VName)], VName)]))
-> StateT
[([Count Bytes (TPrimExp Int64 VName)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
[(VName, [([Count Bytes (TPrimExp Int64 VName)], VName)])])
-> (Param LParamMem
-> StateT
[([Count Bytes (TPrimExp Int64 VName)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
(VName, [([Count Bytes (TPrimExp Int64 VName)], VName)]))
-> StateT
[([Count Bytes (TPrimExp Int64 VName)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
[(VName, [([Count Bytes (TPrimExp Int64 VName)], VName)])]
forall a b. (a -> b) -> a -> b
$ \Param LParamMem
p ->
case Param LParamMem -> LParamMem
forall dec. Param dec -> dec
paramDec Param LParamMem
p of
MemArray PrimType
pt Shape
shape NoUniqueness
_ (ArrayIn VName
mem LMAD
_) -> do
let shape' :: Shape
shape' = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
num_threads] Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape
shape
arr <-
ImpM GPUMem KernelEnv KernelOp VName
-> StateT
[([Count Bytes (TPrimExp Int64 VName)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
VName
forall (m :: * -> *) a.
Monad m =>
m a -> StateT [([Count Bytes (TPrimExp Int64 VName)], VName)] m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ImpM GPUMem KernelEnv KernelOp VName
-> StateT
[([Count Bytes (TPrimExp Int64 VName)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
VName)
-> (LMAD -> ImpM GPUMem KernelEnv KernelOp VName)
-> LMAD
-> StateT
[([Count Bytes (TPrimExp Int64 VName)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SpaceId
-> PrimType
-> Shape
-> VName
-> LMAD
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
SpaceId
-> PrimType -> Shape -> VName -> LMAD -> ImpM rep r op VName
sArray SpaceId
"scan_arr" PrimType
pt Shape
shape' VName
mem (LMAD
-> StateT
[([Count Bytes (TPrimExp Int64 VName)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
VName)
-> LMAD
-> StateT
[([Count Bytes (TPrimExp Int64 VName)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
VName
forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName -> [TPrimExp Int64 VName] -> LMAD
forall num. IntegralExp num => num -> [num] -> LMAD num
LMAD.iota TPrimExp Int64 VName
0 ((SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape')
pure (arr, [])
LParamMem
_ -> do
let pt :: PrimType
pt = TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType (TypeBase Shape NoUniqueness -> PrimType)
-> TypeBase Shape NoUniqueness -> PrimType
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LParamMem
p
shape :: Shape
shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
tblock_size]
(sizes, mem') <- PrimType
-> Shape
-> StateT
[([Count Bytes (TPrimExp Int64 VName)], VName)]
(ImpM GPUMem KernelEnv KernelOp)
([Count Bytes (TPrimExp Int64 VName)], VName)
forall {t :: (* -> *) -> * -> *} {rep} {r} {op}.
(MonadState
[([Count Bytes (TPrimExp Int64 VName)], VName)]
(t (ImpM rep r op)),
MonadTrans t) =>
PrimType
-> Shape
-> t (ImpM rep r op) ([Count Bytes (TPrimExp Int64 VName)], VName)
getMem PrimType
pt Shape
shape
arr <- lift $ sArrayInMem "scan_arr" pt shape mem'
pure (arr, [(sizes, mem')])
modify (<> concat used_mems)
pure arrs
getMem :: PrimType
-> Shape
-> t (ImpM rep r op) ([Count Bytes (TPrimExp Int64 VName)], VName)
getMem PrimType
pt Shape
shape = do
let size :: Count Bytes (TPrimExp Int64 VName)
size = TypeBase Shape NoUniqueness -> Count Bytes (TPrimExp Int64 VName)
typeSize (TypeBase Shape NoUniqueness -> Count Bytes (TPrimExp Int64 VName))
-> TypeBase Shape NoUniqueness
-> Count Bytes (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ PrimType -> Shape -> NoUniqueness -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt Shape
shape NoUniqueness
NoUniqueness
mems <- t (ImpM rep r op) [([Count Bytes (TPrimExp Int64 VName)], VName)]
forall s (m :: * -> *). MonadState s m => m s
get
case (L.find ((size `elem`) . fst) mems, mems) of
(Just ([Count Bytes (TPrimExp Int64 VName)], VName)
mem, [([Count Bytes (TPrimExp Int64 VName)], VName)]
_) -> do
([([Count Bytes (TPrimExp Int64 VName)], VName)]
-> [([Count Bytes (TPrimExp Int64 VName)], VName)])
-> t (ImpM rep r op) ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (([([Count Bytes (TPrimExp Int64 VName)], VName)]
-> [([Count Bytes (TPrimExp Int64 VName)], VName)])
-> t (ImpM rep r op) ())
-> ([([Count Bytes (TPrimExp Int64 VName)], VName)]
-> [([Count Bytes (TPrimExp Int64 VName)], VName)])
-> t (ImpM rep r op) ()
forall a b. (a -> b) -> a -> b
$ ([Count Bytes (TPrimExp Int64 VName)], VName)
-> [([Count Bytes (TPrimExp Int64 VName)], VName)]
-> [([Count Bytes (TPrimExp Int64 VName)], VName)]
forall a. Eq a => a -> [a] -> [a]
L.delete ([Count Bytes (TPrimExp Int64 VName)], VName)
mem
([Count Bytes (TPrimExp Int64 VName)], VName)
-> t (ImpM rep r op) ([Count Bytes (TPrimExp Int64 VName)], VName)
forall a. a -> t (ImpM rep r op) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Count Bytes (TPrimExp Int64 VName)], VName)
mem
(Maybe ([Count Bytes (TPrimExp Int64 VName)], VName)
Nothing, ([Count Bytes (TPrimExp Int64 VName)]
size', VName
mem) : [([Count Bytes (TPrimExp Int64 VName)], VName)]
mems') -> do
[([Count Bytes (TPrimExp Int64 VName)], VName)]
-> t (ImpM rep r op) ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put [([Count Bytes (TPrimExp Int64 VName)], VName)]
mems'
([Count Bytes (TPrimExp Int64 VName)], VName)
-> t (ImpM rep r op) ([Count Bytes (TPrimExp Int64 VName)], VName)
forall a. a -> t (ImpM rep r op) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Count Bytes (TPrimExp Int64 VName)
size Count Bytes (TPrimExp Int64 VName)
-> [Count Bytes (TPrimExp Int64 VName)]
-> [Count Bytes (TPrimExp Int64 VName)]
forall a. a -> [a] -> [a]
: [Count Bytes (TPrimExp Int64 VName)]
size', VName
mem)
(Maybe ([Count Bytes (TPrimExp Int64 VName)], VName)
Nothing, []) -> do
mem <- ImpM rep r op VName -> t (ImpM rep r op) VName
forall (m :: * -> *) a. Monad m => m a -> t m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ImpM rep r op VName -> t (ImpM rep r op) VName)
-> ImpM rep r op VName -> t (ImpM rep r op) VName
forall a b. (a -> b) -> a -> b
$ SpaceId -> Space -> ImpM rep r op VName
forall rep r op. SpaceId -> Space -> ImpM rep r op VName
sDeclareMem SpaceId
"scan_arr_mem" (Space -> ImpM rep r op VName) -> Space -> ImpM rep r op VName
forall a b. (a -> b) -> a -> b
$ SpaceId -> Space
Space SpaceId
"shared"
pure ([size], mem)
type CrossesSegment = Maybe (Imp.TExp Int64 -> Imp.TExp Int64 -> Imp.TExp Bool)
localArrayIndex :: KernelConstants -> Type -> Imp.TExp Int64
localArrayIndex :: KernelConstants
-> TypeBase Shape NoUniqueness -> TPrimExp Int64 VName
localArrayIndex KernelConstants
constants TypeBase Shape NoUniqueness
t =
if TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType TypeBase Shape NoUniqueness
t
then TPrimExp Int32 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int32 VName
kernelLocalThreadId KernelConstants
constants)
else TPrimExp Int32 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int32 VName
kernelGlobalThreadId KernelConstants
constants)
barrierFor :: Lambda GPUMem -> (Bool, Imp.Fence, InKernelGen ())
barrierFor :: Lambda GPUMem -> (Bool, Fence, ImpM GPUMem KernelEnv KernelOp ())
barrierFor Lambda GPUMem
scan_op = (Bool
array_scan, Fence
fence, KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> ImpM GPUMem KernelEnv KernelOp ())
-> KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
fence)
where
array_scan :: Bool
array_scan = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (TypeBase Shape NoUniqueness -> Bool)
-> [TypeBase Shape NoUniqueness] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([TypeBase Shape NoUniqueness] -> Bool)
-> [TypeBase Shape NoUniqueness] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda GPUMem
scan_op
fence :: Fence
fence
| Bool
array_scan = Fence
Imp.FenceGlobal
| Bool
otherwise = Fence
Imp.FenceLocal
xParams, yParams :: SegBinOp GPUMem -> [LParam GPUMem]
xParams :: SegBinOp GPUMem -> [LParam GPUMem]
xParams SegBinOp GPUMem
scan =
Int -> [Param LParamMem] -> [Param LParamMem]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scan)) (Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams (SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scan))
yParams :: SegBinOp GPUMem -> [LParam GPUMem]
yParams SegBinOp GPUMem
scan =
Int -> [Param LParamMem] -> [Param LParamMem]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scan)) (Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams (SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scan))
writeToScanValues ::
[VName] ->
([PatElem LetDecMem], SegBinOp GPUMem, [KernelResult]) ->
InKernelGen ()
writeToScanValues :: [VName]
-> ([PatElem LParamMem], SegBinOp GPUMem, [KernelResult])
-> ImpM GPUMem KernelEnv KernelOp ()
writeToScanValues [VName]
gtids ([PatElem LParamMem]
pes, SegBinOp GPUMem
scan, [KernelResult]
scan_res)
| Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (SegBinOp GPUMem -> Shape
forall rep. SegBinOp rep -> Shape
segBinOpShape SegBinOp GPUMem
scan) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 =
[(PatElem LParamMem, KernelResult)]
-> ((PatElem LParamMem, KernelResult)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem LParamMem]
-> [KernelResult] -> [(PatElem LParamMem, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
pes [KernelResult]
scan_res) (((PatElem LParamMem, KernelResult)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((PatElem LParamMem, KernelResult)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, KernelResult
res) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix
(PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)
((VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gtids)
(KernelResult -> SubExp
kernelResultSubExp KernelResult
res)
[]
| Bool
otherwise =
[(Param LParamMem, KernelResult)]
-> ((Param LParamMem, KernelResult)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [KernelResult] -> [(Param LParamMem, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp GPUMem -> [LParam GPUMem]
yParams SegBinOp GPUMem
scan) [KernelResult]
scan_res) (((Param LParamMem, KernelResult)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((Param LParamMem, KernelResult)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, KernelResult
res) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (KernelResult -> SubExp
kernelResultSubExp KernelResult
res) []
readToScanValues ::
[Imp.TExp Int64] ->
[PatElem LetDecMem] ->
SegBinOp GPUMem ->
InKernelGen ()
readToScanValues :: [TPrimExp Int64 VName]
-> [PatElem LParamMem]
-> SegBinOp GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
readToScanValues [TPrimExp Int64 VName]
is [PatElem LParamMem]
pes SegBinOp GPUMem
scan
| Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (SegBinOp GPUMem -> Shape
forall rep. SegBinOp rep -> Shape
segBinOpShape SegBinOp GPUMem
scan) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 =
[(Param LParamMem, PatElem LParamMem)]
-> ((Param LParamMem, PatElem LParamMem)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [PatElem LParamMem] -> [(Param LParamMem, PatElem LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp GPUMem -> [LParam GPUMem]
yParams SegBinOp GPUMem
scan) [PatElem LParamMem]
pes) (((Param LParamMem, PatElem LParamMem)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((Param LParamMem, PatElem LParamMem)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, PatElem LParamMem
pe) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var (PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)) [TPrimExp Int64 VName]
is
| Bool
otherwise =
() -> ImpM GPUMem KernelEnv KernelOp ()
forall a. a -> ImpM GPUMem KernelEnv KernelOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
readCarries ::
Imp.TExp Int64 ->
Imp.TExp Int64 ->
[Imp.TExp Int64] ->
[Imp.TExp Int64] ->
[PatElem LetDecMem] ->
SegBinOp GPUMem ->
InKernelGen ()
readCarries :: TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
-> [PatElem LParamMem]
-> SegBinOp GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
readCarries TPrimExp Int64 VName
chunk_id TPrimExp Int64 VName
chunk_offset [TPrimExp Int64 VName]
dims' [TPrimExp Int64 VName]
vec_is [PatElem LParamMem]
pes SegBinOp GPUMem
scan
| Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (SegBinOp GPUMem -> Shape
forall rep. SegBinOp rep -> Shape
segBinOpShape SegBinOp GPUMem
scan) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 = do
ltid <- KernelConstants -> TPrimExp Int32 VName
kernelLocalThreadId (KernelConstants -> TPrimExp Int32 VName)
-> (KernelEnv -> KernelConstants)
-> KernelEnv
-> TPrimExp Int32 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> TPrimExp Int32 VName)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int32 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
sIf
(chunk_id .>. 0 .&&. ltid .==. 0)
( do
let is = [TPrimExp Int64 VName]
-> TPrimExp Int64 VName -> [TPrimExp Int64 VName]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TPrimExp Int64 VName]
dims' (TPrimExp Int64 VName -> [TPrimExp Int64 VName])
-> TPrimExp Int64 VName -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
chunk_offset TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1
forM_ (zip (xParams scan) pes) $ \(Param LParamMem
p, PatElem LParamMem
pe) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var (PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)) ([TPrimExp Int64 VName]
is [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
vec_is)
)
( forM_ (zip (xParams scan) (segBinOpNeutral scan)) $ \(Param LParamMem
p, SubExp
ne) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
ne []
)
| Bool
otherwise =
() -> ImpM GPUMem KernelEnv KernelOp ()
forall a. a -> ImpM GPUMem KernelEnv KernelOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
scanStage1 ::
Pat LetDecMem ->
Count NumBlocks SubExp ->
Count BlockSize SubExp ->
SegSpace ->
[SegBinOp GPUMem] ->
KernelBody GPUMem ->
CallKernelGen (TV Int32, Imp.TExp Int64, CrossesSegment)
scanStage1 :: Pat LParamMem
-> Count NumBlocks SubExp
-> Count BlockSize SubExp
-> SegSpace
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> CallKernelGen (TV Int32, TPrimExp Int64 VName, CrossesSegment)
scanStage1 (Pat [PatElem LParamMem]
all_pes) Count NumBlocks SubExp
num_tblocks Count BlockSize SubExp
tblock_size SegSpace
space [SegBinOp GPUMem]
scans KernelBody GPUMem
kbody = do
let num_tblocks' :: Count NumBlocks (TPrimExp Int64 VName)
num_tblocks' = (SubExp -> TPrimExp Int64 VName)
-> Count NumBlocks SubExp -> Count NumBlocks (TPrimExp Int64 VName)
forall a b. (a -> b) -> Count NumBlocks a -> Count NumBlocks b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 Count NumBlocks SubExp
num_tblocks
tblock_size' :: Count BlockSize (TPrimExp Int64 VName)
tblock_size' = (SubExp -> TPrimExp Int64 VName)
-> Count BlockSize SubExp -> Count BlockSize (TPrimExp Int64 VName)
forall a b. (a -> b) -> Count BlockSize a -> Count BlockSize b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 Count BlockSize SubExp
tblock_size
num_threads <- SpaceId
-> TPrimExp Int32 VName -> ImpM GPUMem HostEnv HostOp (TV Int32)
forall {k} (t :: k) rep r op.
SpaceId -> TExp t -> ImpM rep r op (TV t)
dPrimV SpaceId
"num_threads" (TPrimExp Int32 VName -> ImpM GPUMem HostEnv HostOp (TV Int32))
-> TPrimExp Int32 VName -> ImpM GPUMem HostEnv HostOp (TV Int32)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> TPrimExp Int32 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TPrimExp Int32 VName)
-> TPrimExp Int64 VName -> TPrimExp Int32 VName
forall a b. (a -> b) -> a -> b
$ Count NumBlocks (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall {k} (u :: k) e. Count u e -> e
unCount Count NumBlocks (TPrimExp Int64 VName)
num_tblocks' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* Count BlockSize (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall {k} (u :: k) e. Count u e -> e
unCount Count BlockSize (TPrimExp Int64 VName)
tblock_size'
let (gtids, dims) = unzip $ unSegSpace space
dims' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims
let num_elements = [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims'
elems_per_thread = TPrimExp Int64 VName
num_elements TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`divUp` TPrimExp Int32 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TV Int32 -> TPrimExp Int32 VName
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
num_threads)
elems_per_group = Count BlockSize (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall {k} (u :: k) e. Count u e -> e
unCount Count BlockSize (TPrimExp Int64 VName)
tblock_size' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
elems_per_thread
let crossesSegment =
case [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a]
reverse [TPrimExp Int64 VName]
dims' of
TPrimExp Int64 VName
segment_size : TPrimExp Int64 VName
_ : [TPrimExp Int64 VName]
_ -> (TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool)
-> CrossesSegment
forall a. a -> Maybe a
Just ((TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool)
-> CrossesSegment)
-> (TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool)
-> CrossesSegment
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
from TPrimExp Int64 VName
to ->
(TPrimExp Int64 VName
to TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
from) TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. (TPrimExp Int64 VName
to TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`rem` TPrimExp Int64 VName
segment_size)
[TPrimExp Int64 VName]
_ -> CrossesSegment
forall a. Maybe a
Nothing
sKernelThread "scan_stage1" (segFlat space) (defKernelAttrs num_tblocks tblock_size) $ do
constants <- kernelConstants <$> askEnv
all_local_arrs <- makeLocalArrays tblock_size (tvSize num_threads) scans
forM_ scans $ \SegBinOp GPUMem
scan -> do
Maybe (Exp GPUMem)
-> Scope GPUMem -> ImpM GPUMem KernelEnv KernelOp ()
forall rep (inner :: * -> *) r op.
Mem rep inner =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope Maybe (Exp GPUMem)
forall a. Maybe a
Nothing (Scope GPUMem -> ImpM GPUMem KernelEnv KernelOp ())
-> Scope GPUMem -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ [Param LParamMem] -> Scope GPUMem
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams ([Param LParamMem] -> Scope GPUMem)
-> [Param LParamMem] -> Scope GPUMem
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams (Lambda GPUMem -> [LParam GPUMem])
-> Lambda GPUMem -> [LParam GPUMem]
forall a b. (a -> b) -> a -> b
$ SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scan
[(Param LParamMem, SubExp)]
-> ((Param LParamMem, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [SubExp] -> [(Param LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp GPUMem -> [LParam GPUMem]
xParams SegBinOp GPUMem
scan) (SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scan)) (((Param LParamMem, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((Param LParamMem, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
ne) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
ne []
sFor "j" elems_per_thread $ \TPrimExp Int64 VName
j -> do
chunk_offset <-
SpaceId
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall {k} (t :: k) rep r op.
SpaceId -> TExp t -> ImpM rep r op (TV t)
dPrimV SpaceId
"chunk_offset" (TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int64 VName
kernelBlockSize KernelConstants
constants) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
j
TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int32 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int32 VName
kernelBlockId KernelConstants
constants) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
elems_per_group
flat_idx <-
dPrimV "flat_idx" $
tvExp chunk_offset + sExt64 (kernelLocalThreadId constants)
zipWithM_ dPrimV_ gtids $ unflattenIndex dims' $ tvExp flat_idx
let per_scan_pes = [SegBinOp GPUMem] -> [PatElem LParamMem] -> [[PatElem LParamMem]]
forall rep a. [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks [SegBinOp GPUMem]
scans [PatElem LParamMem]
all_pes
in_bounds =
(TExp Bool -> TExp Bool -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall a. (a -> a -> a) -> [a] -> a
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) ([TExp Bool] -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool)
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName] -> [TExp Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
(.<.) ((VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gtids) [TPrimExp Int64 VName]
dims'
when_in_bounds = Names
-> Stms GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
let ([KernelResult]
all_scan_res, [KernelResult]
map_res) =
Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp GPUMem] -> Int
forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp GPUMem]
scans) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
kbody
per_scan_res :: [[KernelResult]]
per_scan_res =
[SegBinOp GPUMem] -> [KernelResult] -> [[KernelResult]]
forall rep a. [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks [SegBinOp GPUMem]
scans [KernelResult]
all_scan_res
Text
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"write to-scan values to parameters" (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
(([PatElem LParamMem], SegBinOp GPUMem, [KernelResult])
-> ImpM GPUMem KernelEnv KernelOp ())
-> [([PatElem LParamMem], SegBinOp GPUMem, [KernelResult])]
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([VName]
-> ([PatElem LParamMem], SegBinOp GPUMem, [KernelResult])
-> ImpM GPUMem KernelEnv KernelOp ()
writeToScanValues [VName]
gtids) ([([PatElem LParamMem], SegBinOp GPUMem, [KernelResult])]
-> ImpM GPUMem KernelEnv KernelOp ())
-> [([PatElem LParamMem], SegBinOp GPUMem, [KernelResult])]
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
[[PatElem LParamMem]]
-> [SegBinOp GPUMem]
-> [[KernelResult]]
-> [([PatElem LParamMem], SegBinOp GPUMem, [KernelResult])]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [[PatElem LParamMem]]
per_scan_pes [SegBinOp GPUMem]
scans [[KernelResult]]
per_scan_res
Text
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"write mapped values results to global memory" (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
[(PatElem LParamMem, KernelResult)]
-> ((PatElem LParamMem, KernelResult)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem LParamMem]
-> [KernelResult] -> [(PatElem LParamMem, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Int -> [PatElem LParamMem] -> [PatElem LParamMem]
forall a. Int -> [a] -> [a]
takeLast ([KernelResult] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
map_res) [PatElem LParamMem]
all_pes) [KernelResult]
map_res) (((PatElem LParamMem, KernelResult)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((PatElem LParamMem, KernelResult)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, KernelResult
se) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix
(PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)
((VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gtids)
(KernelResult -> SubExp
kernelResultSubExp KernelResult
se)
[]
sComment "threads in bounds read input" $
sWhen in_bounds when_in_bounds
unless (all (null . segBinOpShape) scans) $
sOp $
Imp.Barrier Imp.FenceGlobal
forM_ (zip3 per_scan_pes scans all_local_arrs) $
\([PatElem LParamMem]
pes, scan :: SegBinOp GPUMem
scan@(SegBinOp Commutativity
_ Lambda GPUMem
scan_op [SubExp]
nes Shape
vec_shape), [VName]
local_arrs) ->
Text
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"do one intra-group scan operation" (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
let rets :: [TypeBase Shape NoUniqueness]
rets = Lambda GPUMem -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda GPUMem
scan_op
scan_x_params :: [LParam GPUMem]
scan_x_params = SegBinOp GPUMem -> [LParam GPUMem]
xParams SegBinOp GPUMem
scan
(Bool
array_scan, Fence
fence, ImpM GPUMem KernelEnv KernelOp ()
barrier) = Lambda GPUMem -> (Bool, Fence, ImpM GPUMem KernelEnv KernelOp ())
barrierFor Lambda GPUMem
scan_op
Bool
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan ImpM GPUMem KernelEnv KernelOp ()
barrier
Shape
-> ([TPrimExp Int64 VName] -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
Shape
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest Shape
vec_shape (([TPrimExp Int64 VName] -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ([TPrimExp Int64 VName] -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
vec_is -> do
Text
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"maybe restore some to-scan values to parameters, or read neutral" (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
TExp Bool
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
TExp Bool
in_bounds
( do
[TPrimExp Int64 VName]
-> [PatElem LParamMem]
-> SegBinOp GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
readToScanValues ((VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gtids [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
vec_is) [PatElem LParamMem]
pes SegBinOp GPUMem
scan
TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
-> [PatElem LParamMem]
-> SegBinOp GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
readCarries TPrimExp Int64 VName
j (TV Int64 -> TPrimExp Int64 VName
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
chunk_offset) [TPrimExp Int64 VName]
dims' [TPrimExp Int64 VName]
vec_is [PatElem LParamMem]
pes SegBinOp GPUMem
scan
)
( [(Param LParamMem, SubExp)]
-> ((Param LParamMem, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [SubExp] -> [(Param LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegBinOp GPUMem -> [LParam GPUMem]
yParams SegBinOp GPUMem
scan) (SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scan)) (((Param LParamMem, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((Param LParamMem, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
ne) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
ne []
)
Text
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"combine with carry and write to shared memory" (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
Names
-> Stms GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (Body GPUMem -> Stms GPUMem
forall rep. Body rep -> Stms rep
bodyStms (Body GPUMem -> Stms GPUMem) -> Body GPUMem -> Stms GPUMem
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
scan_op) (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
[(TypeBase Shape NoUniqueness, VName, SubExp)]
-> ((TypeBase Shape NoUniqueness, VName, SubExp)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TypeBase Shape NoUniqueness]
-> [VName]
-> [SubExp]
-> [(TypeBase Shape NoUniqueness, VName, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TypeBase Shape NoUniqueness]
rets [VName]
local_arrs ([SubExp] -> [(TypeBase Shape NoUniqueness, VName, SubExp)])
-> [SubExp] -> [(TypeBase Shape NoUniqueness, VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ([SubExpRes] -> [SubExp]) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Body GPUMem -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult (Body GPUMem -> [SubExpRes]) -> Body GPUMem -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
scan_op) (((TypeBase Shape NoUniqueness, VName, SubExp)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((TypeBase Shape NoUniqueness, VName, SubExp)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
\(TypeBase Shape NoUniqueness
t, VName
arr, SubExp
se) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
arr [KernelConstants
-> TypeBase Shape NoUniqueness -> TPrimExp Int64 VName
localArrayIndex KernelConstants
constants TypeBase Shape NoUniqueness
t] SubExp
se []
let crossesSegment' :: Maybe (TPrimExp Int32 VName -> TPrimExp Int32 VName -> TExp Bool)
crossesSegment' = do
f <- CrossesSegment
crossesSegment
Just $ \TPrimExp Int32 VName
from TPrimExp Int32 VName
to ->
let from' :: TPrimExp Int64 VName
from' = TPrimExp Int32 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
from TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TV Int64 -> TPrimExp Int64 VName
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
chunk_offset
to' :: TPrimExp Int64 VName
to' = TPrimExp Int32 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
to TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TV Int64 -> TPrimExp Int64 VName
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
chunk_offset
in TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
f TPrimExp Int64 VName
from' TPrimExp Int64 VName
to'
KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> ImpM GPUMem KernelEnv KernelOp ())
-> KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
fence
scan_op_renamed <- Lambda GPUMem -> ImpM GPUMem KernelEnv KernelOp (Lambda GPUMem)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda GPUMem
scan_op
blockScan
crossesSegment'
(sExt64 $ tvExp num_threads)
(sExt64 $ kernelBlockSize constants)
scan_op_renamed
local_arrs
sComment "threads in bounds write partial scan result" $
sWhen in_bounds $
forM_ (zip3 rets pes local_arrs) $ \(TypeBase Shape NoUniqueness
t, PatElem LParamMem
pe, VName
arr) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix
(PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)
((VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gtids [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
vec_is)
(VName -> SubExp
Var VName
arr)
[KernelConstants
-> TypeBase Shape NoUniqueness -> TPrimExp Int64 VName
localArrayIndex KernelConstants
constants TypeBase Shape NoUniqueness
t]
barrier
let load_carry =
[(VName, Param LParamMem)]
-> ((VName, Param LParamMem) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [Param LParamMem] -> [(VName, Param LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
local_arrs [LParam GPUMem]
[Param LParamMem]
scan_x_params) (((VName, Param LParamMem) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, Param LParamMem) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
arr, Param LParamMem
p) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix
(Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p)
[]
(VName -> SubExp
Var VName
arr)
[ if TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase Shape NoUniqueness -> Bool)
-> TypeBase Shape NoUniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LParamMem
p
then TPrimExp Int64 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int64 VName
kernelBlockSize KernelConstants
constants) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1
else
(TPrimExp Int32 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int32 VName
kernelBlockId KernelConstants
constants) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1)
TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int64 VName
kernelBlockSize KernelConstants
constants)
TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1
]
load_neutral =
[(SubExp, Param LParamMem)]
-> ((SubExp, Param LParamMem) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SubExp] -> [Param LParamMem] -> [(SubExp, Param LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
nes [LParam GPUMem]
[Param LParamMem]
scan_x_params) (((SubExp, Param LParamMem) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((SubExp, Param LParamMem) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(SubExp
ne, Param LParamMem
p) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
ne []
sComment "first thread reads last element as carry-in for next iteration" $ do
crosses_segment <- dPrimVE "crosses_segment" $
case crossesSegment of
CrossesSegment
Nothing -> TExp Bool
forall v. TPrimExp Bool v
false
Just TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
f ->
TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
f
( TV Int64 -> TPrimExp Int64 VName
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
chunk_offset
TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int64 VName
kernelBlockSize KernelConstants
constants)
TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1
)
( TV Int64 -> TPrimExp Int64 VName
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
chunk_offset
TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TPrimExp Int64 VName
kernelBlockSize KernelConstants
constants)
)
should_load_carry <-
dPrimVE "should_load_carry" $
kernelLocalThreadId constants .==. 0 .&&. bNot crosses_segment
sWhen should_load_carry load_carry
when array_scan barrier
sUnless should_load_carry load_neutral
barrier
pure (num_threads, elems_per_group, crossesSegment)
scanStage2 ::
Pat LetDecMem ->
TV Int32 ->
Imp.TExp Int64 ->
Count NumBlocks SubExp ->
CrossesSegment ->
SegSpace ->
[SegBinOp GPUMem] ->
CallKernelGen ()
scanStage2 :: Pat LParamMem
-> TV Int32
-> TPrimExp Int64 VName
-> Count NumBlocks SubExp
-> CrossesSegment
-> SegSpace
-> [SegBinOp GPUMem]
-> CallKernelGen ()
scanStage2 (Pat [PatElem LParamMem]
all_pes) TV Int32
stage1_num_threads TPrimExp Int64 VName
elems_per_group Count NumBlocks SubExp
num_tblocks CrossesSegment
crossesSegment SegSpace
space [SegBinOp GPUMem]
scans = do
let ([VName]
gtids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
dims' :: [TPrimExp Int64 VName]
dims' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims
let tblock_size :: Count BlockSize SubExp
tblock_size = SubExp -> Count BlockSize SubExp
forall {k} (u :: k) e. e -> Count u e
Count (SubExp -> Count BlockSize SubExp)
-> SubExp -> Count BlockSize SubExp
forall a b. (a -> b) -> a -> b
$ Count NumBlocks SubExp -> SubExp
forall {k} (u :: k) e. Count u e -> e
unCount Count NumBlocks SubExp
num_tblocks
let crossesSegment' :: Maybe (TPrimExp Int32 VName -> TPrimExp Int32 VName -> TExp Bool)
crossesSegment' = do
f <- CrossesSegment
crossesSegment
Just $ \TPrimExp Int32 VName
from TPrimExp Int32 VName
to ->
TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
f
((TPrimExp Int32 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
from TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
elems_per_group TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)
((TPrimExp Int32 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
to TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
elems_per_group TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)
SpaceId
-> VName
-> KernelAttrs
-> ImpM GPUMem KernelEnv KernelOp ()
-> CallKernelGen ()
sKernelThread SpaceId
"scan_stage2" (SegSpace -> VName
segFlat SegSpace
space) (Count NumBlocks SubExp -> Count BlockSize SubExp -> KernelAttrs
defKernelAttrs (SubExp -> Count NumBlocks SubExp
forall {k} (u :: k) e. e -> Count u e
Count (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)) Count BlockSize SubExp
tblock_size) (ImpM GPUMem KernelEnv KernelOp () -> CallKernelGen ())
-> ImpM GPUMem KernelEnv KernelOp () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
per_scan_local_arrs <- makeLocalArrays tblock_size (tvSize stage1_num_threads) scans
let per_scan_rets = (SegBinOp GPUMem -> [TypeBase Shape NoUniqueness])
-> [SegBinOp GPUMem] -> [[TypeBase Shape NoUniqueness]]
forall a b. (a -> b) -> [a] -> [b]
map (Lambda GPUMem -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType (Lambda GPUMem -> [TypeBase Shape NoUniqueness])
-> (SegBinOp GPUMem -> Lambda GPUMem)
-> SegBinOp GPUMem
-> [TypeBase Shape NoUniqueness]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda) [SegBinOp GPUMem]
scans
per_scan_pes = [SegBinOp GPUMem] -> [PatElem LParamMem] -> [[PatElem LParamMem]]
forall rep a. [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks [SegBinOp GPUMem]
scans [PatElem LParamMem]
all_pes
flat_idx <-
dPrimV "flat_idx" $
(sExt64 (kernelLocalThreadId constants) + 1) * elems_per_group - 1
zipWithM_ dPrimV_ gtids $ unflattenIndex dims' $ tvExp flat_idx
forM_ (L.zip4 scans per_scan_local_arrs per_scan_rets per_scan_pes) $
\(SegBinOp Commutativity
_ Lambda GPUMem
scan_op [SubExp]
nes Shape
vec_shape, [VName]
local_arrs, [TypeBase Shape NoUniqueness]
rets, [PatElem LParamMem]
pes) ->
Shape
-> ([TPrimExp Int64 VName] -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
Shape
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest Shape
vec_shape (([TPrimExp Int64 VName] -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ([TPrimExp Int64 VName] -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
vec_is -> do
let glob_is :: [TPrimExp Int64 VName]
glob_is = (VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gtids [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
vec_is
in_bounds :: TExp Bool
in_bounds =
(TExp Bool -> TExp Bool -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall a. (a -> a -> a) -> [a] -> a
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) ([TExp Bool] -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool)
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName] -> [TExp Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
(.<.) ((VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gtids) [TPrimExp Int64 VName]
dims'
when_in_bounds :: ImpM GPUMem KernelEnv KernelOp ()
when_in_bounds = [(TypeBase Shape NoUniqueness, VName, PatElem LParamMem)]
-> ((TypeBase Shape NoUniqueness, VName, PatElem LParamMem)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TypeBase Shape NoUniqueness]
-> [VName]
-> [PatElem LParamMem]
-> [(TypeBase Shape NoUniqueness, VName, PatElem LParamMem)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TypeBase Shape NoUniqueness]
rets [VName]
local_arrs [PatElem LParamMem]
pes) (((TypeBase Shape NoUniqueness, VName, PatElem LParamMem)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((TypeBase Shape NoUniqueness, VName, PatElem LParamMem)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(TypeBase Shape NoUniqueness
t, VName
arr, PatElem LParamMem
pe) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix
VName
arr
[KernelConstants
-> TypeBase Shape NoUniqueness -> TPrimExp Int64 VName
localArrayIndex KernelConstants
constants TypeBase Shape NoUniqueness
t]
(VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)
[TPrimExp Int64 VName]
glob_is
when_out_of_bounds :: ImpM GPUMem KernelEnv KernelOp ()
when_out_of_bounds = [(TypeBase Shape NoUniqueness, VName, SubExp)]
-> ((TypeBase Shape NoUniqueness, VName, SubExp)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TypeBase Shape NoUniqueness]
-> [VName]
-> [SubExp]
-> [(TypeBase Shape NoUniqueness, VName, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TypeBase Shape NoUniqueness]
rets [VName]
local_arrs [SubExp]
nes) (((TypeBase Shape NoUniqueness, VName, SubExp)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((TypeBase Shape NoUniqueness, VName, SubExp)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(TypeBase Shape NoUniqueness
t, VName
arr, SubExp
ne) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
arr [KernelConstants
-> TypeBase Shape NoUniqueness -> TPrimExp Int64 VName
localArrayIndex KernelConstants
constants TypeBase Shape NoUniqueness
t] SubExp
ne []
(Bool
_, Fence
_, ImpM GPUMem KernelEnv KernelOp ()
barrier) =
Lambda GPUMem -> (Bool, Fence, ImpM GPUMem KernelEnv KernelOp ())
barrierFor Lambda GPUMem
scan_op
Text
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"threads in bound read carries; others get neutral element" (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
TExp Bool
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf TExp Bool
in_bounds ImpM GPUMem KernelEnv KernelOp ()
when_in_bounds ImpM GPUMem KernelEnv KernelOp ()
when_out_of_bounds
ImpM GPUMem KernelEnv KernelOp ()
barrier
Maybe (TPrimExp Int32 VName -> TPrimExp Int32 VName -> TExp Bool)
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> Lambda GPUMem
-> [VName]
-> ImpM GPUMem KernelEnv KernelOp ()
blockScan
Maybe (TPrimExp Int32 VName -> TPrimExp Int32 VName -> TExp Bool)
crossesSegment'
(TPrimExp Int32 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TPrimExp Int32 VName -> TPrimExp Int64 VName)
-> TPrimExp Int32 VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TPrimExp Int32 VName
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
stage1_num_threads)
(TPrimExp Int64 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int64 VName
kernelBlockSize KernelConstants
constants)
Lambda GPUMem
scan_op
[VName]
local_arrs
Text
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"threads in bounds write scanned carries" (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
TExp Bool
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
in_bounds (ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
[(TypeBase Shape NoUniqueness, PatElem LParamMem, VName)]
-> ((TypeBase Shape NoUniqueness, PatElem LParamMem, VName)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([TypeBase Shape NoUniqueness]
-> [PatElem LParamMem]
-> [VName]
-> [(TypeBase Shape NoUniqueness, PatElem LParamMem, VName)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [TypeBase Shape NoUniqueness]
rets [PatElem LParamMem]
pes [VName]
local_arrs) (((TypeBase Shape NoUniqueness, PatElem LParamMem, VName)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((TypeBase Shape NoUniqueness, PatElem LParamMem, VName)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(TypeBase Shape NoUniqueness
t, PatElem LParamMem
pe, VName
arr) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix
(PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)
[TPrimExp Int64 VName]
glob_is
(VName -> SubExp
Var VName
arr)
[KernelConstants
-> TypeBase Shape NoUniqueness -> TPrimExp Int64 VName
localArrayIndex KernelConstants
constants TypeBase Shape NoUniqueness
t]
scanStage3 ::
Pat LetDecMem ->
Count NumBlocks SubExp ->
Count BlockSize SubExp ->
Imp.TExp Int64 ->
CrossesSegment ->
SegSpace ->
[SegBinOp GPUMem] ->
CallKernelGen ()
scanStage3 :: Pat LParamMem
-> Count NumBlocks SubExp
-> Count BlockSize SubExp
-> TPrimExp Int64 VName
-> CrossesSegment
-> SegSpace
-> [SegBinOp GPUMem]
-> CallKernelGen ()
scanStage3 (Pat [PatElem LParamMem]
all_pes) Count NumBlocks SubExp
num_tblocks Count BlockSize SubExp
tblock_size TPrimExp Int64 VName
elems_per_group CrossesSegment
crossesSegment SegSpace
space [SegBinOp GPUMem]
scans = do
let tblock_size' :: Count BlockSize (TPrimExp Int64 VName)
tblock_size' = (SubExp -> TPrimExp Int64 VName)
-> Count BlockSize SubExp -> Count BlockSize (TPrimExp Int64 VName)
forall a b. (a -> b) -> Count BlockSize a -> Count BlockSize b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 Count BlockSize SubExp
tblock_size
([VName]
gtids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
dims' :: [TPrimExp Int64 VName]
dims' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims
required_groups <-
SpaceId
-> TPrimExp Int32 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int32 VName)
forall {k} (t :: k) rep r op.
SpaceId -> TExp t -> ImpM rep r op (TExp t)
dPrimVE SpaceId
"required_groups" (TPrimExp Int32 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int32 VName))
-> TPrimExp Int32 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int32 VName)
forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName -> TPrimExp Int32 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TPrimExp Int32 VName)
-> TPrimExp Int64 VName -> TPrimExp Int32 VName
forall a b. (a -> b) -> a -> b
$
[TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`divUp` TPrimExp Int64 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (Count BlockSize (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall {k} (u :: k) e. Count u e -> e
unCount Count BlockSize (TPrimExp Int64 VName)
tblock_size')
sKernelThread "scan_stage3" (segFlat space) (defKernelAttrs num_tblocks tblock_size) $
virtualiseBlocks SegVirt required_groups $ \TPrimExp Int32 VName
virt_tblock_id -> do
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
flat_idx <-
dPrimVE "flat_idx" $
sExt64 virt_tblock_id * sExt64 (unCount tblock_size')
+ sExt64 (kernelLocalThreadId constants)
zipWithM_ dPrimV_ gtids $ unflattenIndex dims' flat_idx
orig_group <- dPrimV "orig_group" $ flat_idx `quot` elems_per_group
carry_in_flat_idx <-
dPrimV "carry_in_flat_idx" $
tvExp orig_group * elems_per_group - 1
let carry_in_idx = [TPrimExp Int64 VName]
-> TPrimExp Int64 VName -> [TPrimExp Int64 VName]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TPrimExp Int64 VName]
dims' (TPrimExp Int64 VName -> [TPrimExp Int64 VName])
-> TPrimExp Int64 VName -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TPrimExp Int64 VName
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
carry_in_flat_idx
let in_bounds =
(TExp Bool -> TExp Bool -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall a. (a -> a -> a) -> [a] -> a
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) ([TExp Bool] -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool)
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName] -> [TExp Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
(.<.) ((VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gtids) [TPrimExp Int64 VName]
dims'
crosses_segment =
TExp Bool -> Maybe (TExp Bool) -> TExp Bool
forall a. a -> Maybe a -> a
fromMaybe TExp Bool
forall v. TPrimExp Bool v
false (Maybe (TExp Bool) -> TExp Bool) -> Maybe (TExp Bool) -> TExp Bool
forall a b. (a -> b) -> a -> b
$
CrossesSegment
crossesSegment
CrossesSegment
-> Maybe (TPrimExp Int64 VName)
-> Maybe (TPrimExp Int64 VName -> TExp Bool)
forall a b. Maybe (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TPrimExp Int64 VName -> Maybe (TPrimExp Int64 VName)
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TV Int64 -> TPrimExp Int64 VName
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
carry_in_flat_idx)
Maybe (TPrimExp Int64 VName -> TExp Bool)
-> Maybe (TPrimExp Int64 VName) -> Maybe (TExp Bool)
forall a b. Maybe (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TPrimExp Int64 VName -> Maybe (TPrimExp Int64 VName)
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure TPrimExp Int64 VName
flat_idx
is_a_carry = TPrimExp Int64 VName
flat_idx TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. (TV Int64 -> TPrimExp Int64 VName
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
orig_group TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
elems_per_group TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1
no_carry_in = TV Int64 -> TPrimExp Int64 VName
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
orig_group TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0 TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.||. TExp Bool
is_a_carry TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.||. TExp Bool
crosses_segment
let per_scan_pes = [SegBinOp GPUMem] -> [PatElem LParamMem] -> [[PatElem LParamMem]]
forall rep a. [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks [SegBinOp GPUMem]
scans [PatElem LParamMem]
all_pes
sWhen in_bounds $
sUnless no_carry_in $
forM_ (zip per_scan_pes scans) $
\([PatElem LParamMem]
pes, SegBinOp Commutativity
_ Lambda GPUMem
scan_op [SubExp]
nes Shape
vec_shape) -> do
Maybe (Exp GPUMem)
-> Scope GPUMem -> ImpM GPUMem KernelEnv KernelOp ()
forall rep (inner :: * -> *) r op.
Mem rep inner =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope Maybe (Exp GPUMem)
forall a. Maybe a
Nothing (Scope GPUMem -> ImpM GPUMem KernelEnv KernelOp ())
-> Scope GPUMem -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ [Param LParamMem] -> Scope GPUMem
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams ([Param LParamMem] -> Scope GPUMem)
-> [Param LParamMem] -> Scope GPUMem
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
scan_op
let ([Param LParamMem]
scan_x_params, [Param LParamMem]
scan_y_params) =
Int -> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Param LParamMem] -> ([Param LParamMem], [Param LParamMem]))
-> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
scan_op
Shape
-> ([TPrimExp Int64 VName] -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
Shape
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest Shape
vec_shape (([TPrimExp Int64 VName] -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ([TPrimExp Int64 VName] -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
vec_is -> do
[(Param LParamMem, PatElem LParamMem)]
-> ((Param LParamMem, PatElem LParamMem)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [PatElem LParamMem] -> [(Param LParamMem, PatElem LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
scan_x_params [PatElem LParamMem]
pes) (((Param LParamMem, PatElem LParamMem)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((Param LParamMem, PatElem LParamMem)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, PatElem LParamMem
pe) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix
(Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p)
[]
(VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)
([TPrimExp Int64 VName]
carry_in_idx [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
vec_is)
[(Param LParamMem, PatElem LParamMem)]
-> ((Param LParamMem, PatElem LParamMem)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [PatElem LParamMem] -> [(Param LParamMem, PatElem LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
scan_y_params [PatElem LParamMem]
pes) (((Param LParamMem, PatElem LParamMem)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((Param LParamMem, PatElem LParamMem)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, PatElem LParamMem
pe) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix
(Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p)
[]
(VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)
((VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gtids [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
vec_is)
[Param LParamMem]
-> Body GPUMem -> ImpM GPUMem KernelEnv KernelOp ()
forall dec rep r op. [Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [Param LParamMem]
scan_x_params (Body GPUMem -> ImpM GPUMem KernelEnv KernelOp ())
-> Body GPUMem -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
scan_op
[(Param LParamMem, PatElem LParamMem)]
-> ((Param LParamMem, PatElem LParamMem)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [PatElem LParamMem] -> [(Param LParamMem, PatElem LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
scan_x_params [PatElem LParamMem]
pes) (((Param LParamMem, PatElem LParamMem)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ())
-> ((Param LParamMem, PatElem LParamMem)
-> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, PatElem LParamMem
pe) ->
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix
(PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)
((VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gtids [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
vec_is)
(VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p)
[]
compileSegScan ::
Pat LetDecMem ->
SegLevel ->
SegSpace ->
[SegBinOp GPUMem] ->
KernelBody GPUMem ->
CallKernelGen ()
compileSegScan :: Pat LParamMem
-> SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> CallKernelGen ()
compileSegScan Pat LParamMem
pat SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
scans KernelBody GPUMem
kbody = do
attrs <- SegLevel -> CallKernelGen KernelAttrs
lvlKernelAttrs SegLevel
lvl
stage1_max_num_tblocks <- dPrim "stage1_max_num_tblocks"
sOp $ Imp.GetSizeMax (tvVar stage1_max_num_tblocks) SizeThreadBlock
stage1_num_tblocks <-
fmap (Imp.Count . tvSize) $
dPrimV "stage1_num_tblocks" $
sMin64 (tvExp stage1_max_num_tblocks) $
pe64 . Imp.unCount . kAttrNumBlocks $
attrs
(stage1_num_threads, elems_per_group, crossesSegment) <-
scanStage1 pat stage1_num_tblocks (kAttrBlockSize attrs) space scans kbody
emit $ Imp.DebugPrint "elems_per_group" $ Just $ untyped elems_per_group
scanStage2 pat stage1_num_threads elems_per_group stage1_num_tblocks crossesSegment space scans
scanStage3 pat (kAttrNumBlocks attrs) (kAttrBlockSize attrs) elems_per_group crossesSegment space scans