module Futhark.CodeGen.ImpGen.Multicore.SegHist
( compileSegHist,
)
where
import Control.Monad
import Data.List (zip4)
import Data.Maybe (listToMaybe)
import Futhark.CodeGen.ImpCode.Multicore qualified as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.Multicore.Base
import Futhark.CodeGen.ImpGen.Multicore.SegRed (compileSegRed')
import Futhark.IR.MCMem
import Futhark.Transform.Rename (renameLambda)
import Futhark.Util (chunks, splitFromEnd, takeLast)
import Futhark.Util.IntegralExp (rem)
import Prelude hiding (quot, rem)
compileSegHist ::
Pat LetDecMem ->
SegSpace ->
[HistOp MCMem] ->
KernelBody MCMem ->
TV Int32 ->
MulticoreGen Imp.MCCode
compileSegHist :: Pat LParamMem
-> SegSpace
-> [HistOp MCMem]
-> KernelBody MCMem
-> TV Int32
-> MulticoreGen (Code Multicore)
compileSegHist Pat LParamMem
pat SegSpace
space [HistOp MCMem]
histops KernelBody MCMem
kbody TV Int32
nsubtasks
| [(VName, SubExp)
_] <- SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space =
Pat LParamMem
-> SegSpace
-> [HistOp MCMem]
-> KernelBody MCMem
-> TV Int32
-> MulticoreGen (Code Multicore)
nonsegmentedHist Pat LParamMem
pat SegSpace
space [HistOp MCMem]
histops KernelBody MCMem
kbody TV Int32
nsubtasks
| Bool
otherwise =
Pat LParamMem
-> SegSpace
-> [HistOp MCMem]
-> KernelBody MCMem
-> MulticoreGen (Code Multicore)
segmentedHist Pat LParamMem
pat SegSpace
space [HistOp MCMem]
histops KernelBody MCMem
kbody
segHistOpChunks :: [HistOp rep] -> [a] -> [[a]]
segHistOpChunks :: forall rep a. [HistOp rep] -> [a] -> [[a]]
segHistOpChunks = [Int] -> [a] -> [[a]]
forall a. [Int] -> [a] -> [[a]]
chunks ([Int] -> [a] -> [[a]])
-> ([HistOp rep] -> [Int]) -> [HistOp rep] -> [a] -> [[a]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HistOp rep -> Int) -> [HistOp rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int) -> (HistOp rep -> [SubExp]) -> HistOp rep -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp rep -> [SubExp]
forall rep. HistOp rep -> [SubExp]
histNeutral)
histSize :: HistOp MCMem -> Imp.TExp Int64
histSize :: HistOp MCMem -> TExp Int64
histSize = [TExp Int64] -> TExp Int64
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TExp Int64] -> TExp Int64)
-> (HistOp MCMem -> [TExp Int64]) -> HistOp MCMem -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 ([SubExp] -> [TExp Int64])
-> (HistOp MCMem -> [SubExp]) -> HistOp MCMem -> [TExp Int64]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (ShapeBase SubExp -> [SubExp])
-> (HistOp MCMem -> ShapeBase SubExp) -> HistOp MCMem -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp MCMem -> ShapeBase SubExp
forall rep. HistOp rep -> ShapeBase SubExp
histShape
genHistOpParams :: HistOp MCMem -> MulticoreGen ()
genHistOpParams :: HistOp MCMem -> MulticoreGen ()
genHistOpParams HistOp MCMem
histops =
Maybe (Exp MCMem) -> Scope MCMem -> MulticoreGen ()
forall rep (inner :: * -> *) r op.
Mem rep inner =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope Maybe (Exp MCMem)
forall a. Maybe a
Nothing (Scope MCMem -> MulticoreGen ()) -> Scope MCMem -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ [Param LParamMem] -> Scope MCMem
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams ([Param LParamMem] -> Scope MCMem)
-> [Param LParamMem] -> Scope MCMem
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> [LParam MCMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams (Lambda MCMem -> [LParam MCMem]) -> Lambda MCMem -> [LParam MCMem]
forall a b. (a -> b) -> a -> b
$ HistOp MCMem -> Lambda MCMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp MCMem
histops
renameHistop :: HistOp MCMem -> MulticoreGen (HistOp MCMem)
renameHistop :: HistOp MCMem -> MulticoreGen (HistOp MCMem)
renameHistop HistOp MCMem
histop = do
let op :: Lambda MCMem
op = HistOp MCMem -> Lambda MCMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp MCMem
histop
lambda' <- Lambda MCMem -> ImpM MCMem HostEnv Multicore (Lambda MCMem)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda MCMem
op
pure histop {histOp = lambda'}
nonsegmentedHist ::
Pat LetDecMem ->
SegSpace ->
[HistOp MCMem] ->
KernelBody MCMem ->
TV Int32 ->
MulticoreGen Imp.MCCode
nonsegmentedHist :: Pat LParamMem
-> SegSpace
-> [HistOp MCMem]
-> KernelBody MCMem
-> TV Int32
-> MulticoreGen (Code Multicore)
nonsegmentedHist Pat LParamMem
pat SegSpace
space [HistOp MCMem]
histops KernelBody MCMem
kbody TV Int32
num_histos = do
let ns :: [SubExp]
ns = ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd ([(VName, SubExp)] -> [SubExp]) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
ns_64 :: [TExp Int64]
ns_64 = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
ns
num_histos' :: TExp Int32
num_histos' = TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
num_histos
hist_width :: TExp Int64
hist_width = TExp Int64
-> (HistOp MCMem -> TExp Int64)
-> Maybe (HistOp MCMem)
-> TExp Int64
forall b a. b -> (a -> b) -> Maybe a -> b
maybe TExp Int64
0 HistOp MCMem -> TExp Int64
histSize (Maybe (HistOp MCMem) -> TExp Int64)
-> Maybe (HistOp MCMem) -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [HistOp MCMem] -> Maybe (HistOp MCMem)
forall a. [a] -> Maybe a
listToMaybe [HistOp MCMem]
histops
use_subhistogram :: TPrimExp Bool VName
use_subhistogram = TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
num_histos' TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
hist_width TExp Int64 -> TExp Int64 -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. [TExp Int64] -> TExp Int64
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp Int64]
ns_64
histops' <- [HistOp MCMem] -> MulticoreGen [HistOp MCMem]
renameHistOpLambda [HistOp MCMem]
histops
collect $
sUnless (product ns_64 .==. 0) $ do
sIf
use_subhistogram
(subHistogram pat space histops num_histos kbody)
(atomicHistogram pat space histops' kbody)
onOpAtomic :: HistOp MCMem -> MulticoreGen ([VName] -> [Imp.TExp Int64] -> MulticoreGen ())
onOpAtomic :: HistOp MCMem
-> MulticoreGen ([VName] -> [TExp Int64] -> MulticoreGen ())
onOpAtomic HistOp MCMem
op = do
atomics <- HostEnv -> AtomicBinOp
hostAtomics (HostEnv -> AtomicBinOp)
-> ImpM MCMem HostEnv Multicore HostEnv
-> ImpM MCMem HostEnv Multicore AtomicBinOp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM MCMem HostEnv Multicore HostEnv
forall rep r op. ImpM rep r op r
askEnv
let lambda = HistOp MCMem -> Lambda MCMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp MCMem
op
do_op = AtomicBinOp -> Lambda MCMem -> AtomicUpdate MCMem ()
atomicUpdateLocking AtomicBinOp
atomics Lambda MCMem
lambda
case do_op of
AtomicPrim [VName] -> [TExp Int64] -> MulticoreGen ()
f -> ([VName] -> [TExp Int64] -> MulticoreGen ())
-> MulticoreGen ([VName] -> [TExp Int64] -> MulticoreGen ())
forall a. a -> ImpM MCMem HostEnv Multicore a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName] -> [TExp Int64] -> MulticoreGen ()
f
AtomicCAS [VName] -> [TExp Int64] -> MulticoreGen ()
f -> ([VName] -> [TExp Int64] -> MulticoreGen ())
-> MulticoreGen ([VName] -> [TExp Int64] -> MulticoreGen ())
forall a. a -> ImpM MCMem HostEnv Multicore a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName] -> [TExp Int64] -> MulticoreGen ()
f
AtomicLocking Locking -> [VName] -> [TExp Int64] -> MulticoreGen ()
f -> do
let num_locks :: Int
num_locks = Int
100151
dims :: [TExp Int64]
dims = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp MCMem -> ShapeBase SubExp
forall rep. HistOp rep -> ShapeBase SubExp
histOpShape HistOp MCMem
op ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> HistOp MCMem -> ShapeBase SubExp
forall rep. HistOp rep -> ShapeBase SubExp
histShape HistOp MCMem
op)
locks <-
String
-> PrimType -> ArrayContents -> ImpM MCMem HostEnv Multicore VName
forall rep r op.
String -> PrimType -> ArrayContents -> ImpM rep r op VName
sStaticArray String
"hist_locks" PrimType
int32 (ArrayContents -> ImpM MCMem HostEnv Multicore VName)
-> ArrayContents -> ImpM MCMem HostEnv Multicore VName
forall a b. (a -> b) -> a -> b
$
Int -> ArrayContents
Imp.ArrayZeros Int
num_locks
let l' = VName
-> TExp Int32
-> TExp Int32
-> TExp Int32
-> ([TExp Int64] -> [TExp Int64])
-> Locking
Locking VName
locks TExp Int32
0 TExp Int32
1 TExp Int32
0 (TExp Int64 -> [TExp Int64]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TExp Int64 -> [TExp Int64])
-> ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> [TExp Int64]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`rem` Int -> TExp Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num_locks) (TExp Int64 -> TExp Int64)
-> ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [TExp Int64] -> [TExp Int64] -> TExp Int64
forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [TExp Int64]
dims)
pure $ f l'
atomicHistogram ::
Pat LetDecMem ->
SegSpace ->
[HistOp MCMem] ->
KernelBody MCMem ->
MulticoreGen ()
atomicHistogram :: Pat LParamMem
-> SegSpace
-> [HistOp MCMem]
-> KernelBody MCMem
-> MulticoreGen ()
atomicHistogram Pat LParamMem
pat SegSpace
space [HistOp MCMem]
histops KernelBody MCMem
kbody = do
let ([VName]
is, [SubExp]
ns) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
ns_64 :: [TExp Int64]
ns_64 = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
ns
let num_red_res :: Int
num_red_res = [HistOp MCMem] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp MCMem]
histops Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((HistOp MCMem -> Int) -> [HistOp MCMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (HistOp MCMem -> [SubExp]) -> HistOp MCMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp MCMem -> [SubExp]
forall rep. HistOp rep -> [SubExp]
histNeutral) [HistOp MCMem]
histops)
([PatElem LParamMem]
all_red_pes, [PatElem LParamMem]
map_pes) = Int
-> [PatElem LParamMem]
-> ([PatElem LParamMem], [PatElem LParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_red_res ([PatElem LParamMem] -> ([PatElem LParamMem], [PatElem LParamMem]))
-> [PatElem LParamMem]
-> ([PatElem LParamMem], [PatElem LParamMem])
forall a b. (a -> b) -> a -> b
$ Pat LParamMem -> [PatElem LParamMem]
forall dec. Pat dec -> [PatElem dec]
patElems Pat LParamMem
pat
atomicOps <- (HistOp MCMem
-> MulticoreGen ([VName] -> [TExp Int64] -> MulticoreGen ()))
-> [HistOp MCMem]
-> ImpM
MCMem
HostEnv
Multicore
[[VName] -> [TExp Int64] -> MulticoreGen ()]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM HistOp MCMem
-> MulticoreGen ([VName] -> [TExp Int64] -> MulticoreGen ())
onOpAtomic [HistOp MCMem]
histops
body <- collect $ do
dPrim_ (segFlat space) int64
sOp $ Imp.GetTaskId (segFlat space)
generateChunkLoop "SegHist" Scalar $ \TExp Int64
flat_idx -> do
(VName -> TExp Int64 -> MulticoreGen ())
-> [VName] -> [TExp Int64] -> MulticoreGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TExp Int64 -> MulticoreGen ()
forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
is ([TExp Int64] -> MulticoreGen ())
-> [TExp Int64] -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
ns_64 TExp Int64
flat_idx
Names -> Stms MCMem -> MulticoreGen () -> MulticoreGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody MCMem -> Stms MCMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody MCMem
kbody) (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ do
let ([KernelResult]
red_res, [KernelResult]
map_res) =
Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([PatElem LParamMem] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElem LParamMem]
map_pes) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody MCMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody MCMem
kbody
red_res_split :: [([SubExp], [SubExp])]
red_res_split = [HistOp MCMem] -> [SubExp] -> [([SubExp], [SubExp])]
forall rep. [HistOp rep] -> [SubExp] -> [([SubExp], [SubExp])]
splitHistResults [HistOp MCMem]
histops ([SubExp] -> [([SubExp], [SubExp])])
-> [SubExp] -> [([SubExp], [SubExp])]
forall a b. (a -> b) -> a -> b
$ (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
red_res
let pes_per_op :: [[PatElem LParamMem]]
pes_per_op = [Int] -> [PatElem LParamMem] -> [[PatElem LParamMem]]
forall a. [Int] -> [a] -> [[a]]
chunks ((HistOp MCMem -> Int) -> [HistOp MCMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int)
-> (HistOp MCMem -> [VName]) -> HistOp MCMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp MCMem -> [VName]
forall rep. HistOp rep -> [VName]
histDest) [HistOp MCMem]
histops) [PatElem LParamMem]
all_red_pes
[(HistOp MCMem, ([SubExp], [SubExp]),
[VName] -> [TExp Int64] -> MulticoreGen (), [PatElem LParamMem])]
-> ((HistOp MCMem, ([SubExp], [SubExp]),
[VName] -> [TExp Int64] -> MulticoreGen (), [PatElem LParamMem])
-> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([HistOp MCMem]
-> [([SubExp], [SubExp])]
-> [[VName] -> [TExp Int64] -> MulticoreGen ()]
-> [[PatElem LParamMem]]
-> [(HistOp MCMem, ([SubExp], [SubExp]),
[VName] -> [TExp Int64] -> MulticoreGen (), [PatElem LParamMem])]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [HistOp MCMem]
histops [([SubExp], [SubExp])]
red_res_split [[VName] -> [TExp Int64] -> MulticoreGen ()]
atomicOps [[PatElem LParamMem]]
pes_per_op) (((HistOp MCMem, ([SubExp], [SubExp]),
[VName] -> [TExp Int64] -> MulticoreGen (), [PatElem LParamMem])
-> MulticoreGen ())
-> MulticoreGen ())
-> ((HistOp MCMem, ([SubExp], [SubExp]),
[VName] -> [TExp Int64] -> MulticoreGen (), [PatElem LParamMem])
-> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
\(HistOp ShapeBase SubExp
dest_shape SubExp
_ [VName]
_ [SubExp]
_ ShapeBase SubExp
shape Lambda MCMem
lam, ([SubExp]
bucket, [SubExp]
vs'), [VName] -> [TExp Int64] -> MulticoreGen ()
do_op, [PatElem LParamMem]
dest_res) -> do
let ([Param LParamMem]
_is_params, [Param LParamMem]
vs_params) = Int -> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs') ([Param LParamMem] -> ([Param LParamMem], [Param LParamMem]))
-> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> [LParam MCMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda MCMem
lam
dest_shape' :: [TExp Int64]
dest_shape' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
dest_shape
bucket' :: [TExp Int64]
bucket' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
bucket
bucket_in_bounds :: TPrimExp Bool VName
bucket_in_bounds = Slice (TExp Int64) -> [TExp Int64] -> TPrimExp Bool VName
inBounds ([DimIndex (TExp Int64)] -> Slice (TExp Int64)
forall d. [DimIndex d] -> Slice d
Slice ((TExp Int64 -> DimIndex (TExp Int64))
-> [TExp Int64] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix [TExp Int64]
bucket')) [TExp Int64]
dest_shape'
Text -> MulticoreGen () -> MulticoreGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"save map-out results" (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
[(PatElem LParamMem, KernelResult)]
-> ((PatElem LParamMem, KernelResult) -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem LParamMem]
-> [KernelResult] -> [(PatElem LParamMem, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
map_pes [KernelResult]
map_res) (((PatElem LParamMem, KernelResult) -> MulticoreGen ())
-> MulticoreGen ())
-> ((PatElem LParamMem, KernelResult) -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, KernelResult
res) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> MulticoreGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
is) (KernelResult -> SubExp
kernelResultSubExp KernelResult
res) []
Text -> MulticoreGen () -> MulticoreGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"perform updates" (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
TPrimExp Bool VName -> MulticoreGen () -> MulticoreGen ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen TPrimExp Bool VName
bucket_in_bounds (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ do
let bucket_is :: [TExp Int64]
bucket_is = (VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
Imp.le64 ([VName] -> [VName]
forall a. HasCallStack => [a] -> [a]
init [VName]
is) [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
bucket'
[LParam MCMem] -> MulticoreGen ()
forall rep (inner :: * -> *) r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams ([LParam MCMem] -> MulticoreGen ())
-> [LParam MCMem] -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> [LParam MCMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda MCMem
lam
ShapeBase SubExp
-> ([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ()
forall rep r op.
ShapeBase SubExp
-> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest ShapeBase SubExp
shape (([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ())
-> ([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
is' -> do
[(Param LParamMem, SubExp)]
-> ((Param LParamMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [SubExp] -> [(Param LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
vs_params [SubExp]
vs') (((Param LParamMem, SubExp) -> MulticoreGen ()) -> MulticoreGen ())
-> ((Param LParamMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
res) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> MulticoreGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
res [TExp Int64]
is'
[VName] -> [TExp Int64] -> MulticoreGen ()
do_op ((PatElem LParamMem -> VName) -> [PatElem LParamMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName [PatElem LParamMem]
dest_res) ([TExp Int64]
bucket_is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
is')
free_params <- freeParams body
emit $ Imp.Op $ Imp.ParLoop "atomic_seg_hist" body free_params
updateHisto ::
HistOp MCMem ->
[VName] ->
[Imp.TExp Int64] ->
Imp.TExp Int64 ->
[Param LParamMem] ->
MulticoreGen ()
updateHisto :: HistOp MCMem
-> [VName]
-> [TExp Int64]
-> TExp Int64
-> [Param LParamMem]
-> MulticoreGen ()
updateHisto HistOp MCMem
op [VName]
arrs [TExp Int64]
bucket TExp Int64
j [Param LParamMem]
uni_acc = do
let bind_acc_params :: ImpM rep r op ()
bind_acc_params =
[(Param LParamMem, VName)]
-> ((Param LParamMem, VName) -> ImpM rep r op ())
-> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [VName] -> [(Param LParamMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
uni_acc [VName]
arrs) (((Param LParamMem, VName) -> ImpM rep r op ())
-> ImpM rep r op ())
-> ((Param LParamMem, VName) -> ImpM rep r op ())
-> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
acc_u, VName
arr) -> do
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
acc_u) [] (VName -> SubExp
Var VName
arr) [TExp Int64]
bucket
op_body :: ImpM MCMem r op ()
op_body = [Param (ZonkAny 0)] -> Body MCMem -> ImpM MCMem r op ()
forall dec rep r op. [Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [] (Body MCMem -> ImpM MCMem r op ())
-> Body MCMem -> ImpM MCMem r op ()
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> Body MCMem
forall rep. Lambda rep -> Body rep
lambdaBody (Lambda MCMem -> Body MCMem) -> Lambda MCMem -> Body MCMem
forall a b. (a -> b) -> a -> b
$ HistOp MCMem -> Lambda MCMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp MCMem
op
writeArray :: VName -> SubExp -> MulticoreGen ()
writeArray VName
arr SubExp
val = TExp Int64 -> MulticoreGen (Code Multicore) -> MulticoreGen ()
extractVectorLane TExp Int64
j (MulticoreGen (Code Multicore) -> MulticoreGen ())
-> MulticoreGen (Code Multicore) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ MulticoreGen () -> MulticoreGen (Code Multicore)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect (MulticoreGen () -> MulticoreGen (Code Multicore))
-> MulticoreGen () -> MulticoreGen (Code Multicore)
forall a b. (a -> b) -> a -> b
$ VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> MulticoreGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
arr [TExp Int64]
bucket SubExp
val []
do_hist :: MulticoreGen ()
do_hist = (VName -> SubExp -> MulticoreGen ())
-> [VName] -> [SubExp] -> MulticoreGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> SubExp -> MulticoreGen ()
writeArray [VName]
arrs ([SubExp] -> MulticoreGen ()) -> [SubExp] -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ([SubExpRes] -> [SubExp]) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Body MCMem -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult (Body MCMem -> [SubExpRes]) -> Body MCMem -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> Body MCMem
forall rep. Lambda rep -> Body rep
lambdaBody (Lambda MCMem -> Body MCMem) -> Lambda MCMem -> Body MCMem
forall a b. (a -> b) -> a -> b
$ HistOp MCMem -> Lambda MCMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp MCMem
op
Text -> MulticoreGen () -> MulticoreGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Start of body" (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ do
MulticoreGen ()
forall {rep} {r} {op}. ImpM rep r op ()
bind_acc_params
MulticoreGen ()
forall {r} {op}. ImpM MCMem r op ()
op_body
MulticoreGen ()
do_hist
subHistogram ::
Pat LetDecMem ->
SegSpace ->
[HistOp MCMem] ->
TV Int32 ->
KernelBody MCMem ->
MulticoreGen ()
subHistogram :: Pat LParamMem
-> SegSpace
-> [HistOp MCMem]
-> TV Int32
-> KernelBody MCMem
-> MulticoreGen ()
subHistogram Pat LParamMem
pat SegSpace
space [HistOp MCMem]
histops TV Int32
num_histos KernelBody MCMem
kbody = do
Code Multicore -> MulticoreGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code Multicore -> MulticoreGen ())
-> Code Multicore -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code Multicore
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"subHistogram segHist" Maybe Exp
forall a. Maybe a
Nothing
let ([VName]
is, [SubExp]
ns) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
ns_64 :: [TExp Int64]
ns_64 = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
ns
let pes :: [PatElem LParamMem]
pes = Pat LParamMem -> [PatElem LParamMem]
forall dec. Pat dec -> [PatElem dec]
patElems Pat LParamMem
pat
num_red_res :: Int
num_red_res = [HistOp MCMem] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp MCMem]
histops Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((HistOp MCMem -> Int) -> [HistOp MCMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (HistOp MCMem -> [SubExp]) -> HistOp MCMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp MCMem -> [SubExp]
forall rep. HistOp rep -> [SubExp]
histNeutral) [HistOp MCMem]
histops)
map_pes :: [PatElem LParamMem]
map_pes = Int -> [PatElem LParamMem] -> [PatElem LParamMem]
forall a. Int -> [a] -> [a]
drop Int
num_red_res [PatElem LParamMem]
pes
per_red_pes :: [[PatElem LParamMem]]
per_red_pes = [HistOp MCMem] -> [PatElem LParamMem] -> [[PatElem LParamMem]]
forall rep a. [HistOp rep] -> [a] -> [[a]]
segHistOpChunks [HistOp MCMem]
histops ([PatElem LParamMem] -> [[PatElem LParamMem]])
-> [PatElem LParamMem] -> [[PatElem LParamMem]]
forall a b. (a -> b) -> a -> b
$ Pat LParamMem -> [PatElem LParamMem]
forall dec. Pat dec -> [PatElem dec]
patElems Pat LParamMem
pat
global_subhistograms <- [HistOp MCMem]
-> (HistOp MCMem -> ImpM MCMem HostEnv Multicore [VName])
-> ImpM MCMem HostEnv Multicore [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp MCMem]
histops ((HistOp MCMem -> ImpM MCMem HostEnv Multicore [VName])
-> ImpM MCMem HostEnv Multicore [[VName]])
-> (HistOp MCMem -> ImpM MCMem HostEnv Multicore [VName])
-> ImpM MCMem HostEnv Multicore [[VName]]
forall a b. (a -> b) -> a -> b
$ \HistOp MCMem
histop ->
[Type]
-> (Type -> ImpM MCMem HostEnv Multicore VName)
-> ImpM MCMem HostEnv Multicore [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (HistOp MCMem -> [Type]
forall rep. HistOp rep -> [Type]
histType HistOp MCMem
histop) ((Type -> ImpM MCMem HostEnv Multicore VName)
-> ImpM MCMem HostEnv Multicore [VName])
-> (Type -> ImpM MCMem HostEnv Multicore VName)
-> ImpM MCMem HostEnv Multicore [VName]
forall a b. (a -> b) -> a -> b
$ \Type
t -> do
let shape :: ShapeBase SubExp
shape = [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [TV Int32 -> SubExp
forall {k} (t :: k). TV t -> SubExp
tvSize TV Int32
num_histos] ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> Type -> ShapeBase SubExp
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t
String
-> PrimType
-> ShapeBase SubExp
-> Space
-> ImpM MCMem HostEnv Multicore VName
forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName
sAllocArray String
"subhistogram" (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) ShapeBase SubExp
shape Space
DefaultSpace
let tid' = VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
Imp.le64 (VName -> TExp Int64) -> VName -> TExp Int64
forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space
body <- collect $ do
dPrim_ (segFlat space) int64
sOp $ Imp.GetTaskId (segFlat space)
local_subhistograms <- forM (zip per_red_pes histops) $ \([PatElem LParamMem]
pes', HistOp MCMem
histop) -> do
op_local_subhistograms <- [Type]
-> (Type -> ImpM MCMem HostEnv Multicore VName)
-> ImpM MCMem HostEnv Multicore [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (HistOp MCMem -> [Type]
forall rep. HistOp rep -> [Type]
histType HistOp MCMem
histop) ((Type -> ImpM MCMem HostEnv Multicore VName)
-> ImpM MCMem HostEnv Multicore [VName])
-> (Type -> ImpM MCMem HostEnv Multicore VName)
-> ImpM MCMem HostEnv Multicore [VName]
forall a b. (a -> b) -> a -> b
$ \Type
t ->
String
-> PrimType
-> ShapeBase SubExp
-> Space
-> ImpM MCMem HostEnv Multicore VName
forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName
sAllocArray String
"subhistogram" (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) (Type -> ShapeBase SubExp
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t) Space
DefaultSpace
forM_ (zip3 pes' op_local_subhistograms (histNeutral histop)) $ \(PatElem LParamMem
pe, VName
hist, SubExp
ne) ->
TPrimExp Bool VName
-> MulticoreGen () -> MulticoreGen () -> MulticoreGen ()
forall rep r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
(TExp Int64
tid' TExp Int64 -> TExp Int64 -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64
0)
(VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> MulticoreGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
hist [] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) [])
( ShapeBase SubExp
-> ([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ()
forall rep r op.
ShapeBase SubExp
-> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (HistOp MCMem -> ShapeBase SubExp
forall rep. HistOp rep -> ShapeBase SubExp
histShape HistOp MCMem
histop) (([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ())
-> ([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
shape_is ->
ShapeBase SubExp
-> ([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ()
forall rep r op.
ShapeBase SubExp
-> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest (HistOp MCMem -> ShapeBase SubExp
forall rep. HistOp rep -> ShapeBase SubExp
histOpShape HistOp MCMem
histop) (([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ())
-> ([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> MulticoreGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
hist ([TExp Int64]
shape_is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. Semigroup a => a -> a -> a
<> [TExp Int64]
vec_is) SubExp
ne []
)
pure op_local_subhistograms
inISPC $
generateChunkLoop "SegRed" Vectorized $ \TExp Int64
i -> do
(VName -> TExp Int64 -> MulticoreGen ())
-> [VName] -> [TExp Int64] -> MulticoreGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TExp Int64 -> MulticoreGen ()
forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
is ([TExp Int64] -> MulticoreGen ())
-> [TExp Int64] -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
ns_64 TExp Int64
i
Names -> Stms MCMem -> MulticoreGen () -> MulticoreGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody MCMem -> Stms MCMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody MCMem
kbody) (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ do
let ([SubExp]
red_res, [SubExp]
map_res) =
Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([PatElem LParamMem] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElem LParamMem]
map_pes) ([SubExp] -> ([SubExp], [SubExp]))
-> [SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$
(KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp ([KernelResult] -> [SubExp]) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> a -> b
$
KernelBody MCMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody MCMem
kbody
Text -> MulticoreGen () -> MulticoreGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"save map-out results" (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
[(PatElem LParamMem, SubExp)]
-> ((PatElem LParamMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem LParamMem] -> [SubExp] -> [(PatElem LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
map_pes [SubExp]
map_res) (((PatElem LParamMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ())
-> ((PatElem LParamMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, SubExp
res) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> MulticoreGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
is) SubExp
res []
[(HistOp MCMem, [VName], ([SubExp], [SubExp]))]
-> ((HistOp MCMem, [VName], ([SubExp], [SubExp]))
-> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([HistOp MCMem]
-> [[VName]]
-> [([SubExp], [SubExp])]
-> [(HistOp MCMem, [VName], ([SubExp], [SubExp]))]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [HistOp MCMem]
histops [[VName]]
local_subhistograms ([HistOp MCMem] -> [SubExp] -> [([SubExp], [SubExp])]
forall rep. [HistOp rep] -> [SubExp] -> [([SubExp], [SubExp])]
splitHistResults [HistOp MCMem]
histops [SubExp]
red_res)) (((HistOp MCMem, [VName], ([SubExp], [SubExp])) -> MulticoreGen ())
-> MulticoreGen ())
-> ((HistOp MCMem, [VName], ([SubExp], [SubExp]))
-> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
\( histop :: HistOp MCMem
histop@(HistOp ShapeBase SubExp
dest_shape SubExp
_ [VName]
_ [SubExp]
_ ShapeBase SubExp
shape Lambda MCMem
_),
[VName]
histop_subhistograms,
([SubExp]
bucket, [SubExp]
vs')
) -> do
histop' <- HistOp MCMem -> MulticoreGen (HistOp MCMem)
renameHistop HistOp MCMem
histop
let bucket' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
bucket
dest_shape' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
dest_shape
acc_params' = (Lambda MCMem -> [LParam MCMem]
Lambda MCMem -> [Param LParamMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams (Lambda MCMem -> [Param LParamMem])
-> (HistOp MCMem -> Lambda MCMem)
-> HistOp MCMem
-> [Param LParamMem]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp MCMem -> Lambda MCMem
forall rep. HistOp rep -> Lambda rep
histOp) HistOp MCMem
histop'
vs_params' = Int -> [Param LParamMem] -> [Param LParamMem]
forall a. Int -> [a] -> [a]
takeLast ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs') ([Param LParamMem] -> [Param LParamMem])
-> [Param LParamMem] -> [Param LParamMem]
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> [LParam MCMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams (Lambda MCMem -> [LParam MCMem]) -> Lambda MCMem -> [LParam MCMem]
forall a b. (a -> b) -> a -> b
$ HistOp MCMem -> Lambda MCMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp MCMem
histop'
generateUniformizeLoop $ \TExp Int64
j ->
Text -> MulticoreGen () -> MulticoreGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"perform updates" (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ do
extract_buckets <- (TExp Int64 -> ImpM MCMem HostEnv Multicore (TV Int64))
-> [TExp Int64] -> ImpM MCMem HostEnv Multicore [TV Int64]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (String -> PrimType -> ImpM MCMem HostEnv Multicore (TV Int64)
forall {k} rep r op (t :: k).
String -> PrimType -> ImpM rep r op (TV t)
dPrimSV String
"extract_bucket" (PrimType -> ImpM MCMem HostEnv Multicore (TV Int64))
-> (TExp Int64 -> PrimType)
-> TExp Int64
-> ImpM MCMem HostEnv Multicore (TV Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType (Exp -> PrimType) -> (TExp Int64 -> Exp) -> TExp Int64 -> PrimType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped)) [TExp Int64]
bucket'
forM_ (zip extract_buckets bucket') $ \(TV Int64
x, TExp Int64
y) ->
Code Multicore -> MulticoreGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code Multicore -> MulticoreGen ())
-> Code Multicore -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ Multicore -> Code Multicore
forall a. a -> Code a
Imp.Op (Multicore -> Code Multicore) -> Multicore -> Code Multicore
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Exp -> Multicore
Imp.ExtractLane (TV Int64 -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV Int64
x) (TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int64
y) (TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int64
j)
let bucket'' = (TV Int64 -> TExp Int64) -> [TV Int64] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map TV Int64 -> TExp Int64
forall {k} (t :: k). TV t -> TExp t
tvExp [TV Int64]
extract_buckets
bucket_in_bounds =
Slice (TExp Int64) -> [TExp Int64] -> TPrimExp Bool VName
inBounds ([DimIndex (TExp Int64)] -> Slice (TExp Int64)
forall d. [DimIndex d] -> Slice d
Slice ((TExp Int64 -> DimIndex (TExp Int64))
-> [TExp Int64] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix [TExp Int64]
bucket'')) [TExp Int64]
dest_shape'
sWhen bucket_in_bounds $ do
genHistOpParams histop'
sLoopNest shape $ \[TExp Int64]
is' -> do
[(Param LParamMem, SubExp)]
-> ((Param LParamMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [SubExp] -> [(Param LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
vs_params' [SubExp]
vs') (((Param LParamMem, SubExp) -> MulticoreGen ()) -> MulticoreGen ())
-> ((Param LParamMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
res) ->
Type -> (PrimType -> MulticoreGen ()) -> MulticoreGen ()
forall {f :: * -> *} {shape} {u}.
Applicative f =>
TypeBase shape u -> (PrimType -> f ()) -> f ()
ifPrimType (Param LParamMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param LParamMem
p) ((PrimType -> MulticoreGen ()) -> MulticoreGen ())
-> (PrimType -> MulticoreGen ()) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \PrimType
pt -> do
tmp <- String -> PrimType -> ImpM MCMem HostEnv Multicore VName
forall rep r op. String -> PrimType -> ImpM rep r op VName
dPrimS String
"tmp" PrimType
pt
copyDWIMFix tmp [] res is'
extractVectorLane j . pure $
Imp.SetScalar (paramName p) (toExp' pt tmp)
HistOp MCMem
-> [VName]
-> [TExp Int64]
-> TExp Int64
-> [Param LParamMem]
-> MulticoreGen ()
updateHisto HistOp MCMem
histop' [VName]
histop_subhistograms ([TExp Int64]
bucket'' [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
is') TExp Int64
j [Param LParamMem]
acc_params'
forM_ (zip (concat global_subhistograms) (concat local_subhistograms)) $
\(VName
global, VName
local) -> VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> MulticoreGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
global [TExp Int64
tid'] (VName -> SubExp
Var VName
local) []
free_params <- freeParams body
emit $ Imp.Op $ Imp.ParLoop "seghist_stage_1" body free_params
forM_ (zip3 per_red_pes global_subhistograms histops) $ \([PatElem LParamMem]
red_pes, [VName]
hists, HistOp MCMem
op) -> do
bucket_ids <-
Int
-> ImpM MCMem HostEnv Multicore VName
-> ImpM MCMem HostEnv Multicore [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (ShapeBase SubExp -> Int
forall a. ArrayShape a => a -> Int
shapeRank (HistOp MCMem -> ShapeBase SubExp
forall rep. HistOp rep -> ShapeBase SubExp
histShape HistOp MCMem
op)) (String -> ImpM MCMem HostEnv Multicore VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"bucket_id")
subhistogram_id <- newVName "subhistogram_id"
let segred_space =
VName -> [(VName, SubExp)] -> SegSpace
SegSpace (SegSpace -> VName
segFlat SegSpace
space) ([(VName, SubExp)] -> SegSpace) -> [(VName, SubExp)] -> SegSpace
forall a b. (a -> b) -> a -> b
$
[(VName, SubExp)]
segment_dims
[(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
bucket_ids (ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp MCMem -> ShapeBase SubExp
forall rep. HistOp rep -> ShapeBase SubExp
histShape HistOp MCMem
op))
[(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(VName
subhistogram_id, TV Int32 -> SubExp
forall {k} (t :: k). TV t -> SubExp
tvSize TV Int32
num_histos)]
segred_op = Commutativity
-> Lambda MCMem -> [SubExp] -> ShapeBase SubExp -> SegBinOp MCMem
forall rep.
Commutativity
-> Lambda rep -> [SubExp] -> ShapeBase SubExp -> SegBinOp rep
SegBinOp Commutativity
Noncommutative (HistOp MCMem -> Lambda MCMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp MCMem
op) (HistOp MCMem -> [SubExp]
forall rep. HistOp rep -> [SubExp]
histNeutral HistOp MCMem
op) (HistOp MCMem -> ShapeBase SubExp
forall rep. HistOp rep -> ShapeBase SubExp
histOpShape HistOp MCMem
op)
red_code <- collect $ do
nsubtasks <- dPrim "nsubtasks"
sOp $ Imp.GetNumTasks $ tvVar nsubtasks
emit <=< compileSegRed' (Pat red_pes) segred_space [segred_op] nsubtasks $ \[[(SubExp, [TExp Int64])]] -> MulticoreGen ()
red_cont ->
[[(SubExp, [TExp Int64])]] -> MulticoreGen ()
red_cont ([[(SubExp, [TExp Int64])]] -> MulticoreGen ())
-> [[(SubExp, [TExp Int64])]] -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
[SegBinOp MCMem]
-> [(SubExp, [TExp Int64])] -> [[(SubExp, [TExp Int64])]]
forall rep a. [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks [SegBinOp MCMem
segred_op] ([(SubExp, [TExp Int64])] -> [[(SubExp, [TExp Int64])]])
-> [(SubExp, [TExp Int64])] -> [[(SubExp, [TExp Int64])]]
forall a b. (a -> b) -> a -> b
$
((VName -> (SubExp, [TExp Int64]))
-> [VName] -> [(SubExp, [TExp Int64])])
-> [VName]
-> (VName -> (SubExp, [TExp Int64]))
-> [(SubExp, [TExp Int64])]
forall a b c. (a -> b -> c) -> b -> a -> c
flip (VName -> (SubExp, [TExp Int64]))
-> [VName] -> [(SubExp, [TExp Int64])]
forall a b. (a -> b) -> [a] -> [b]
map [VName]
hists ((VName -> (SubExp, [TExp Int64])) -> [(SubExp, [TExp Int64])])
-> (VName -> (SubExp, [TExp Int64])) -> [(SubExp, [TExp Int64])]
forall a b. (a -> b) -> a -> b
$ \VName
subhisto ->
( VName -> SubExp
Var VName
subhisto,
(VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
Imp.le64 ([VName] -> [TExp Int64]) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$
((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst [(VName, SubExp)]
segment_dims [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName
subhistogram_id] [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
bucket_ids
)
let ns_red = ((VName, SubExp) -> TExp Int64)
-> [(VName, SubExp)] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TExp Int64
pe64 (SubExp -> TExp Int64)
-> ((VName, SubExp) -> SubExp) -> (VName, SubExp) -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) ([(VName, SubExp)] -> [TExp Int64])
-> [(VName, SubExp)] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
segred_space
iterations = [TExp Int64] -> TExp Int64
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> [TExp Int64]
forall a. HasCallStack => [a] -> [a]
init [TExp Int64]
ns_red
scheduler_info = Exp -> Scheduling -> SchedulerInfo
Imp.SchedulerInfo (TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int64
iterations) Scheduling
Imp.Static
red_task = Code Multicore -> ParallelTask
Imp.ParallelTask Code Multicore
red_code
free_params_red <- freeParams red_code
emit $ Imp.Op $ Imp.SegOp "seghist_red" free_params_red red_task Nothing mempty scheduler_info
where
segment_dims :: [(VName, SubExp)]
segment_dims = [(VName, SubExp)] -> [(VName, SubExp)]
forall a. HasCallStack => [a] -> [a]
init ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
ifPrimType :: TypeBase shape u -> (PrimType -> f ()) -> f ()
ifPrimType (Prim PrimType
pt) PrimType -> f ()
f = PrimType -> f ()
f PrimType
pt
ifPrimType TypeBase shape u
_ PrimType -> f ()
_ = () -> f ()
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
segmentedHist ::
Pat LetDecMem ->
SegSpace ->
[HistOp MCMem] ->
KernelBody MCMem ->
MulticoreGen Imp.MCCode
segmentedHist :: Pat LParamMem
-> SegSpace
-> [HistOp MCMem]
-> KernelBody MCMem
-> MulticoreGen (Code Multicore)
segmentedHist Pat LParamMem
pat SegSpace
space [HistOp MCMem]
histops KernelBody MCMem
kbody = do
Code Multicore -> MulticoreGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code Multicore -> MulticoreGen ())
-> Code Multicore -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code Multicore
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"Segmented segHist" Maybe Exp
forall a. Maybe a
Nothing
MulticoreGen () -> MulticoreGen (Code Multicore)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect (MulticoreGen () -> MulticoreGen (Code Multicore))
-> MulticoreGen () -> MulticoreGen (Code Multicore)
forall a b. (a -> b) -> a -> b
$ do
body <- Pat LParamMem
-> SegSpace
-> [HistOp MCMem]
-> KernelBody MCMem
-> MulticoreGen (Code Multicore)
compileSegHistBody Pat LParamMem
pat SegSpace
space [HistOp MCMem]
histops KernelBody MCMem
kbody
free_params <- freeParams body
emit $ Imp.Op $ Imp.ParLoop "segmented_hist" body free_params
compileSegHistBody ::
Pat LetDecMem ->
SegSpace ->
[HistOp MCMem] ->
KernelBody MCMem ->
MulticoreGen Imp.MCCode
compileSegHistBody :: Pat LParamMem
-> SegSpace
-> [HistOp MCMem]
-> KernelBody MCMem
-> MulticoreGen (Code Multicore)
compileSegHistBody Pat LParamMem
pat SegSpace
space [HistOp MCMem]
histops KernelBody MCMem
kbody = MulticoreGen () -> MulticoreGen (Code Multicore)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect (MulticoreGen () -> MulticoreGen (Code Multicore))
-> MulticoreGen () -> MulticoreGen (Code Multicore)
forall a b. (a -> b) -> a -> b
$ do
let ([VName]
is, [SubExp]
ns) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
ns_64 :: [TExp Int64]
ns_64 = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
ns
let num_red_res :: Int
num_red_res = [HistOp MCMem] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp MCMem]
histops Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((HistOp MCMem -> Int) -> [HistOp MCMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (HistOp MCMem -> [SubExp]) -> HistOp MCMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp MCMem -> [SubExp]
forall rep. HistOp rep -> [SubExp]
histNeutral) [HistOp MCMem]
histops)
map_pes :: [PatElem LParamMem]
map_pes = Int -> [PatElem LParamMem] -> [PatElem LParamMem]
forall a. Int -> [a] -> [a]
drop Int
num_red_res ([PatElem LParamMem] -> [PatElem LParamMem])
-> [PatElem LParamMem] -> [PatElem LParamMem]
forall a b. (a -> b) -> a -> b
$ Pat LParamMem -> [PatElem LParamMem]
forall dec. Pat dec -> [PatElem dec]
patElems Pat LParamMem
pat
per_red_pes :: [[PatElem LParamMem]]
per_red_pes = [HistOp MCMem] -> [PatElem LParamMem] -> [[PatElem LParamMem]]
forall rep a. [HistOp rep] -> [a] -> [[a]]
segHistOpChunks [HistOp MCMem]
histops ([PatElem LParamMem] -> [[PatElem LParamMem]])
-> [PatElem LParamMem] -> [[PatElem LParamMem]]
forall a b. (a -> b) -> a -> b
$ Pat LParamMem -> [PatElem LParamMem]
forall dec. Pat dec -> [PatElem dec]
patElems Pat LParamMem
pat
VName -> PrimType -> MulticoreGen ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ (SegSpace -> VName
segFlat SegSpace
space) PrimType
int64
Multicore -> MulticoreGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (Multicore -> MulticoreGen ()) -> Multicore -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ VName -> Multicore
Imp.GetTaskId (SegSpace -> VName
segFlat SegSpace
space)
String
-> ChunkLoopVectorization
-> (TExp Int64 -> MulticoreGen ())
-> MulticoreGen ()
generateChunkLoop String
"SegHist" ChunkLoopVectorization
Scalar ((TExp Int64 -> MulticoreGen ()) -> MulticoreGen ())
-> (TExp Int64 -> MulticoreGen ()) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
idx -> do
let inner_bound :: TExp Int64
inner_bound = [TExp Int64] -> TExp Int64
forall a. HasCallStack => [a] -> a
last [TExp Int64]
ns_64
String
-> TExp Int64 -> (TExp Int64 -> MulticoreGen ()) -> MulticoreGen ()
forall {k} (t :: k) rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TExp Int64
inner_bound ((TExp Int64 -> MulticoreGen ()) -> MulticoreGen ())
-> (TExp Int64 -> MulticoreGen ()) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
(VName -> TExp Int64 -> MulticoreGen ())
-> [VName] -> [TExp Int64] -> MulticoreGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TExp Int64 -> MulticoreGen ()
forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ ([VName] -> [VName]
forall a. HasCallStack => [a] -> [a]
init [VName]
is) ([TExp Int64] -> MulticoreGen ())
-> [TExp Int64] -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex ([TExp Int64] -> [TExp Int64]
forall a. HasCallStack => [a] -> [a]
init [TExp Int64]
ns_64) TExp Int64
idx
VName -> TExp Int64 -> MulticoreGen ()
forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ ([VName] -> VName
forall a. HasCallStack => [a] -> a
last [VName]
is) TExp Int64
i
Names -> Stms MCMem -> MulticoreGen () -> MulticoreGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody MCMem -> Stms MCMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody MCMem
kbody) (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ do
let ([SubExp]
red_res, [SubExp]
map_res) =
Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([PatElem LParamMem] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElem LParamMem]
map_pes) ([SubExp] -> ([SubExp], [SubExp]))
-> [SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$
(KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp ([KernelResult] -> [SubExp]) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> a -> b
$
KernelBody MCMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody MCMem
kbody
[([PatElem LParamMem], HistOp MCMem, ([SubExp], [SubExp]))]
-> (([PatElem LParamMem], HistOp MCMem, ([SubExp], [SubExp]))
-> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([[PatElem LParamMem]]
-> [HistOp MCMem]
-> [([SubExp], [SubExp])]
-> [([PatElem LParamMem], HistOp MCMem, ([SubExp], [SubExp]))]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [[PatElem LParamMem]]
per_red_pes [HistOp MCMem]
histops ([HistOp MCMem] -> [SubExp] -> [([SubExp], [SubExp])]
forall rep. [HistOp rep] -> [SubExp] -> [([SubExp], [SubExp])]
splitHistResults [HistOp MCMem]
histops [SubExp]
red_res)) ((([PatElem LParamMem], HistOp MCMem, ([SubExp], [SubExp]))
-> MulticoreGen ())
-> MulticoreGen ())
-> (([PatElem LParamMem], HistOp MCMem, ([SubExp], [SubExp]))
-> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
\([PatElem LParamMem]
red_pes, HistOp ShapeBase SubExp
dest_shape SubExp
_ [VName]
_ [SubExp]
_ ShapeBase SubExp
shape Lambda MCMem
lam, ([SubExp]
bucket, [SubExp]
vs')) -> do
let ([Param LParamMem]
is_params, [Param LParamMem]
vs_params) = Int -> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs') ([Param LParamMem] -> ([Param LParamMem], [Param LParamMem]))
-> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> [LParam MCMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda MCMem
lam
bucket' :: [TExp Int64]
bucket' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
bucket
dest_shape' :: [TExp Int64]
dest_shape' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
dest_shape
bucket_in_bounds :: TPrimExp Bool VName
bucket_in_bounds = Slice (TExp Int64) -> [TExp Int64] -> TPrimExp Bool VName
inBounds ([DimIndex (TExp Int64)] -> Slice (TExp Int64)
forall d. [DimIndex d] -> Slice d
Slice ((TExp Int64 -> DimIndex (TExp Int64))
-> [TExp Int64] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix [TExp Int64]
bucket')) [TExp Int64]
dest_shape'
Text -> MulticoreGen () -> MulticoreGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"save map-out results" (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
[(PatElem LParamMem, SubExp)]
-> ((PatElem LParamMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem LParamMem] -> [SubExp] -> [(PatElem LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
map_pes [SubExp]
map_res) (((PatElem LParamMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ())
-> ((PatElem LParamMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, SubExp
res) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> MulticoreGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
is) SubExp
res []
Text -> MulticoreGen () -> MulticoreGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"perform updates" (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
TPrimExp Bool VName -> MulticoreGen () -> MulticoreGen ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen TPrimExp Bool VName
bucket_in_bounds (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ do
[LParam MCMem] -> MulticoreGen ()
forall rep (inner :: * -> *) r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams ([LParam MCMem] -> MulticoreGen ())
-> [LParam MCMem] -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> [LParam MCMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda MCMem
lam
ShapeBase SubExp
-> ([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ()
forall rep r op.
ShapeBase SubExp
-> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest ShapeBase SubExp
shape (([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ())
-> ([TExp Int64] -> MulticoreGen ()) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
vec_is -> do
[(PatElem LParamMem, Param LParamMem)]
-> ((PatElem LParamMem, Param LParamMem) -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem LParamMem]
-> [Param LParamMem] -> [(PatElem LParamMem, Param LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
red_pes [Param LParamMem]
is_params) (((PatElem LParamMem, Param LParamMem) -> MulticoreGen ())
-> MulticoreGen ())
-> ((PatElem LParamMem, Param LParamMem) -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, Param LParamMem
p) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> MulticoreGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix
(Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p)
[]
(VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)
((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
Imp.le64 ([VName] -> [VName]
forall a. HasCallStack => [a] -> [a]
init [VName]
is) [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
bucket' [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is)
[(Param LParamMem, SubExp)]
-> ((Param LParamMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [SubExp] -> [(Param LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
vs_params [SubExp]
vs') (((Param LParamMem, SubExp) -> MulticoreGen ()) -> MulticoreGen ())
-> ((Param LParamMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
v) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> MulticoreGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
v [TExp Int64]
vec_is
Names -> Stms MCMem -> MulticoreGen () -> MulticoreGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (Body MCMem -> Stms MCMem
forall rep. Body rep -> Stms rep
bodyStms (Body MCMem -> Stms MCMem) -> Body MCMem -> Stms MCMem
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> Body MCMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda MCMem
lam) (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
[(PatElem LParamMem, SubExp)]
-> ((PatElem LParamMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem LParamMem] -> [SubExp] -> [(PatElem LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
red_pes ([SubExp] -> [(PatElem LParamMem, SubExp)])
-> [SubExp] -> [(PatElem LParamMem, SubExp)]
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ([SubExpRes] -> [SubExp]) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Body MCMem -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult (Body MCMem -> [SubExpRes]) -> Body MCMem -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> Body MCMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda MCMem
lam) (((PatElem LParamMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ())
-> ((PatElem LParamMem, SubExp) -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
\(PatElem LParamMem
pe, SubExp
se) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> MulticoreGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix
(PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)
((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
Imp.le64 ([VName] -> [VName]
forall a. HasCallStack => [a] -> [a]
init [VName]
is) [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
bucket' [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
vec_is)
SubExp
se
[]