{-# LANGUAGE TypeFamilies #-}

-- | Code generation for segmented and non-segmented scans.  Uses a
-- fast single-pass algorithm, but which only works on NVIDIA GPUs and
-- with some constraints on the operator.  We use this when we can.
module Futhark.CodeGen.ImpGen.GPU.SegScan.SinglePass (compileSegScan) where

import Control.Monad
import Data.List (zip4, zip7)
import Data.Map qualified as M
import Data.Maybe
import Futhark.CodeGen.ImpCode.GPU qualified as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.GPU.Base
import Futhark.IR.GPUMem
import Futhark.IR.Mem.LMAD qualified as LMAD
import Futhark.Transform.Rename
import Futhark.Util (mapAccumLM, takeLast)
import Futhark.Util.IntegralExp (IntegralExp (mod, rem), divUp, nextMul, quot)
import Prelude hiding (mod, quot, rem)

xParams, yParams :: SegBinOp GPUMem -> [LParam GPUMem]
xParams :: SegBinOp GPUMem -> [LParam GPUMem]
xParams SegBinOp GPUMem
scan =
  Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scan)) (Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams (SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scan))
yParams :: SegBinOp GPUMem -> [LParam GPUMem]
yParams SegBinOp GPUMem
scan =
  Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scan)) (Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams (SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scan))

createLocalArrays ::
  Count BlockSize SubExp ->
  SubExp ->
  [PrimType] ->
  InKernelGen (VName, [VName], [VName], VName, [VName])
createLocalArrays :: Count BlockSize SubExp
-> SubExp
-> [PrimType]
-> InKernelGen (VName, [VName], [VName], VName, [VName])
createLocalArrays (Count SubExp
block_size) SubExp
chunk [PrimType]
types = do
  let block_sizeE :: TExp Int64
block_sizeE = SubExp -> TExp Int64
pe64 SubExp
block_size
      workSize :: TExp Int64
workSize = SubExp -> TExp Int64
pe64 SubExp
chunk TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
block_sizeE
      prefixArraysSize :: TExp Int64
prefixArraysSize =
        (TExp Int64 -> TExp Int64 -> TExp Int64)
-> TExp Int64 -> [TExp Int64] -> TExp Int64
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\TExp Int64
acc TExp Int64
tySize -> TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
nextMul TExp Int64
acc TExp Int64
tySize TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
tySize TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
block_sizeE) TExp Int64
0 ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$
          (PrimType -> TExp Int64) -> [PrimType] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> TExp Int64
forall a. Num a => PrimType -> a
primByteSize [PrimType]
types
      maxTransposedArraySize :: TExp Int64
maxTransposedArraySize =
        (TExp Int64 -> TExp Int64 -> TExp Int64)
-> [TExp Int64] -> TExp Int64
forall a. (a -> a -> a) -> [a] -> a
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 TExp Int64 -> TExp Int64 -> TExp Int64
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMax64 ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ (PrimType -> TExp Int64) -> [PrimType] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (\PrimType
ty -> TExp Int64
workSize TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* PrimType -> TExp Int64
forall a. Num a => PrimType -> a
primByteSize PrimType
ty) [PrimType]
types

      warp_size :: (Num a) => a
      warp_size :: forall a. Num a => a
warp_size = a
32
      maxWarpExchangeSize :: TExp Int64
maxWarpExchangeSize =
        (TExp Int64 -> TExp Int64 -> TExp Int64)
-> TExp Int64 -> [TExp Int64] -> TExp Int64
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\TExp Int64
acc TExp Int64
tySize -> TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
nextMul TExp Int64
acc TExp Int64
tySize TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
tySize TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* Integer -> TExp Int64
forall a. Num a => Integer -> a
fromInteger Integer
forall a. Num a => a
warp_size) TExp Int64
0 ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$
          (PrimType -> TExp Int64) -> [PrimType] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> TExp Int64
forall a. Num a => PrimType -> a
primByteSize [PrimType]
types
      maxLookbackSize :: TExp Int64
maxLookbackSize = TExp Int64
maxWarpExchangeSize TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
forall a. Num a => a
warp_size
      size :: Count Bytes (TExp Int64)
size = TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
Imp.bytes (TExp Int64 -> Count Bytes (TExp Int64))
-> TExp Int64 -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64
maxLookbackSize TExp Int64 -> TExp Int64 -> TExp Int64
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
`sMax64` TExp Int64
prefixArraysSize TExp Int64 -> TExp Int64 -> TExp Int64
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
`sMax64` TExp Int64
maxTransposedArraySize

  (_, byteOffsets) <-
    (TExp Int64
 -> TExp Int64
 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64, TExp Int64))
-> TExp Int64
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp (TExp Int64, [TExp Int64])
forall (m :: * -> *) (t :: * -> *) acc x y.
(Monad m, Traversable t) =>
(acc -> x -> m (acc, y)) -> acc -> t x -> m (acc, t y)
mapAccumLM
      ( \TExp Int64
off TExp Int64
tySize -> do
          off' <- String -> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64)
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"byte_offsets" (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
nextMul TExp Int64
off TExp Int64
tySize TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ SubExp -> TExp Int64
pe64 SubExp
block_size TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
tySize
          pure (off', off)
      )
      TExp Int64
0
      ([TExp Int64]
 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64, [TExp Int64]))
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp (TExp Int64, [TExp Int64])
forall a b. (a -> b) -> a -> b
$ (PrimType -> TExp Int64) -> [PrimType] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> TExp Int64
forall a. Num a => PrimType -> a
primByteSize [PrimType]
types

  (_, warpByteOffsets) <-
    mapAccumLM
      ( \TExp Int64
off TExp Int64
tySize -> do
          off' <- String -> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64)
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"warp_byte_offset" (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
nextMul TExp Int64
off TExp Int64
tySize TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
forall a. Num a => a
warp_size TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
tySize
          pure (off', off)
      )
      warp_size
      $ map primByteSize types

  sComment "Allocate reusable shared memory" $ pure ()

  localMem <- sAlloc "local_mem" size (Space "shared")
  transposeArrayLength <- dPrimV "trans_arr_len" workSize

  sharedId <- sArrayInMem "shared_id" int32 (Shape [constant (1 :: Int32)]) localMem

  transposedArrays <-
    forM types $ \PrimType
ty ->
      String
-> PrimType
-> ShapeBase SubExp
-> VName
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> VName -> ImpM rep r op VName
sArrayInMem
        String
"local_transpose_arr"
        PrimType
ty
        ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [TV Int64 -> SubExp
forall {k} (t :: k). TV t -> SubExp
tvSize TV Int64
transposeArrayLength])
        VName
localMem

  prefixArrays <-
    forM (zip byteOffsets types) $ \(TExp Int64
off, PrimType
ty) -> do
      let off' :: TExp Int64
off' = TExp Int64
off TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`quot` PrimType -> TExp Int64
forall a. Num a => PrimType -> a
primByteSize PrimType
ty
      String
-> PrimType
-> ShapeBase SubExp
-> VName
-> LMAD
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String
-> PrimType
-> ShapeBase SubExp
-> VName
-> LMAD
-> ImpM rep r op VName
sArray
        String
"local_prefix_arr"
        PrimType
ty
        ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
block_size])
        VName
localMem
        (LMAD -> ImpM GPUMem KernelEnv KernelOp VName)
-> LMAD -> ImpM GPUMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> [TExp Int64] -> LMAD
forall num. IntegralExp num => num -> [num] -> LMAD num
LMAD.iota TExp Int64
off' [SubExp -> TExp Int64
pe64 SubExp
block_size]

  warpscan <- sArrayInMem "warpscan" int8 (Shape [constant (warp_size :: Int64)]) localMem
  warpExchanges <-
    forM (zip warpByteOffsets types) $ \(TExp Int64
off, PrimType
ty) -> do
      let off' :: TExp Int64
off' = TExp Int64
off TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`quot` PrimType -> TExp Int64
forall a. Num a => PrimType -> a
primByteSize PrimType
ty
      String
-> PrimType
-> ShapeBase SubExp
-> VName
-> LMAD
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String
-> PrimType
-> ShapeBase SubExp
-> VName
-> LMAD
-> ImpM rep r op VName
sArray
        String
"warp_exchange"
        PrimType
ty
        ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
forall a. Num a => a
warp_size :: Int64)])
        VName
localMem
        (LMAD -> ImpM GPUMem KernelEnv KernelOp VName)
-> LMAD -> ImpM GPUMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> [TExp Int64] -> LMAD
forall num. IntegralExp num => num -> [num] -> LMAD num
LMAD.iota TExp Int64
off' [TExp Int64
forall a. Num a => a
warp_size]

  pure (sharedId, transposedArrays, prefixArrays, warpscan, warpExchanges)

statusX, statusA, statusP :: (Num a) => a
statusX :: forall a. Num a => a
statusX = a
0
statusA :: forall a. Num a => a
statusA = a
1
statusP :: forall a. Num a => a
statusP = a
2

inBlockScanLookback ::
  KernelConstants ->
  Imp.TExp Int64 ->
  VName ->
  [VName] ->
  Lambda GPUMem ->
  InKernelGen ()
inBlockScanLookback :: KernelConstants
-> TExp Int64
-> VName
-> [VName]
-> Lambda GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
inBlockScanLookback KernelConstants
constants TExp Int64
arrs_full_size VName
flag_arr [VName]
arrs Lambda GPUMem
scan_lam = ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
  flg_x :: TV Int8 <- String -> ImpM GPUMem KernelEnv KernelOp (TV Int8)
forall {k} (t :: k) rep r op.
MkTV t =>
String -> ImpM rep r op (TV t)
dPrim String
"flg_x"
  flg_y :: TV Int8 <- dPrim "flg_y"
  let flg_param_x = Attrs
-> VName
-> MemInfo SubExp NoUniqueness MemBind
-> Param (MemInfo SubExp NoUniqueness MemBind)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty (TV Int8 -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV Int8
flg_x) (PrimType -> MemInfo SubExp NoUniqueness MemBind
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
p_int8)
      flg_param_y = Attrs
-> VName
-> MemInfo SubExp NoUniqueness MemBind
-> Param (MemInfo SubExp NoUniqueness MemBind)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty (TV Int8 -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV Int8
flg_y) (PrimType -> MemInfo SubExp NoUniqueness MemBind
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
p_int8)
      flg_y_exp = TV Int8 -> TPrimExp Int8 VName
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int8
flg_y
      statusP_e = TPrimExp Int8 VName
forall a. Num a => a
statusP :: Imp.TExp Int8
      statusX_e = TPrimExp Int8 VName
forall a. Num a => a
statusX :: Imp.TExp Int8

  dLParams (lambdaParams scan_lam)

  skip_threads <- dPrim "skip_threads"
  let in_block_thread_active =
        TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
skip_threads TExp Int32 -> TExp Int32 -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TExp Int32
in_block_id
      actual_params = Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
scan_lam
      (x_params, y_params) =
        splitAt (length actual_params `div` 2) actual_params
      y_to_x =
        [(Param (MemInfo SubExp NoUniqueness MemBind),
  Param (MemInfo SubExp NoUniqueness MemBind))]
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
     Param (MemInfo SubExp NoUniqueness MemBind))
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [(Param (MemInfo SubExp NoUniqueness MemBind),
     Param (MemInfo SubExp NoUniqueness MemBind))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params [Param (MemInfo SubExp NoUniqueness MemBind)]
y_params) (((Param (MemInfo SubExp NoUniqueness MemBind),
   Param (MemInfo SubExp NoUniqueness MemBind))
  -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
     Param (MemInfo SubExp NoUniqueness MemBind))
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
x, Param (MemInfo SubExp NoUniqueness MemBind)
y) ->
          Bool
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Param (MemInfo SubExp NoUniqueness MemBind)
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
Param dec -> TypeBase (ShapeBase SubExp) NoUniqueness
paramType Param (MemInfo SubExp NoUniqueness MemBind)
x)) (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
            VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
x) [] (VName -> SubExp
Var (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
y)) []
      y_to_x_flg =
        VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (TV Int8 -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV Int8
flg_x) [] (VName -> SubExp
Var (TV Int8 -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV Int8
flg_y)) []

  -- Set initial y values
  sComment "read input for in-block scan" $ do
    zipWithM_ readInitial (flg_param_y : y_params) (flag_arr : arrs)
    -- Since the final result is expected to be in x_params, we may
    -- need to copy it there for the first thread in the block.
    sWhen (in_block_id .==. 0) $ do
      y_to_x
      y_to_x_flg

  when array_scan barrier

  let op_to_x = do
        TPrimExp Bool VName
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool VName
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
          (TPrimExp Int8 VName
flg_y_exp TPrimExp Int8 VName -> TPrimExp Int8 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int8 VName
statusP_e TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.||. TPrimExp Int8 VName
flg_y_exp TPrimExp Int8 VName -> TPrimExp Int8 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int8 VName
statusX_e)
          ( do
              ImpM GPUMem KernelEnv KernelOp ()
y_to_x_flg
              ImpM GPUMem KernelEnv KernelOp ()
y_to_x
          )
          ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> Body GPUMem -> ImpM GPUMem KernelEnv KernelOp ()
forall dec rep r op. [Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params (Body GPUMem -> ImpM GPUMem KernelEnv KernelOp ())
-> Body GPUMem -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
scan_lam)

  sComment "in-block scan (hopefully no barriers needed)" $ do
    skip_threads <-- 1

    sWhile (tvExp skip_threads .<. block_size) $ do
      sWhen in_block_thread_active $ do
        sComment "read operands" $
          zipWithM_
            (readParam (sExt64 $ tvExp skip_threads))
            (flg_param_x : x_params)
            (flag_arr : arrs)
        sComment "perform operation" op_to_x

        sComment "write result" $
          sequence_ $
            zipWith3
              writeResult
              (flg_param_x : x_params)
              (flg_param_y : y_params)
              (flag_arr : arrs)

      skip_threads <-- tvExp skip_threads * 2
  where
    p_int8 :: PrimType
p_int8 = IntType -> PrimType
IntType IntType
Int8
    block_size :: TExp Int32
block_size = TExp Int32
32
    block_id :: TExp Int32
block_id = TExp Int32
ltid32 TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int32
block_size
    in_block_id :: TExp Int32
in_block_id = TExp Int32
ltid32 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
block_id TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
block_size
    ltid32 :: TExp Int32
ltid32 = KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
    ltid :: TExp Int64
ltid = TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
ltid32
    gtid :: TExp Int64
gtid = TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelGlobalThreadId KernelConstants
constants
    array_scan :: Bool
array_scan = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (TypeBase (ShapeBase SubExp) NoUniqueness -> Bool)
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([TypeBase (ShapeBase SubExp) NoUniqueness] -> Bool)
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall rep.
Lambda rep -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType Lambda GPUMem
scan_lam
    barrier :: ImpM GPUMem KernelEnv KernelOp ()
barrier
      | Bool
array_scan =
          KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> ImpM GPUMem KernelEnv KernelOp ())
-> KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceGlobal
      | Bool
otherwise =
          KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> ImpM GPUMem KernelEnv KernelOp ())
-> KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

    readInitial :: Param (MemInfo SubExp NoUniqueness MemBind)
-> VName -> ImpM GPUMem KernelEnv KernelOp ()
readInitial Param (MemInfo SubExp NoUniqueness MemBind)
p VName
arr
      | TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase (ShapeBase SubExp) NoUniqueness -> Bool)
-> TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind)
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
Param dec -> TypeBase (ShapeBase SubExp) NoUniqueness
paramType Param (MemInfo SubExp NoUniqueness MemBind)
p =
          VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var VName
arr) [TExp Int64
ltid]
      | Bool
otherwise =
          VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var VName
arr) [TExp Int64
gtid]
    readParam :: TExp Int64
-> Param (MemInfo SubExp NoUniqueness MemBind)
-> VName
-> ImpM GPUMem KernelEnv KernelOp ()
readParam TExp Int64
behind Param (MemInfo SubExp NoUniqueness MemBind)
p VName
arr
      | TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase (ShapeBase SubExp) NoUniqueness -> Bool)
-> TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind)
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
Param dec -> TypeBase (ShapeBase SubExp) NoUniqueness
paramType Param (MemInfo SubExp NoUniqueness MemBind)
p =
          VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var VName
arr) [TExp Int64
ltid TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
behind]
      | Bool
otherwise =
          VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var VName
arr) [TExp Int64
gtid TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
behind TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
arrs_full_size]

    writeResult :: Param (MemInfo SubExp NoUniqueness MemBind)
-> Param (MemInfo SubExp NoUniqueness MemBind)
-> VName
-> ImpM GPUMem KernelEnv KernelOp ()
writeResult Param (MemInfo SubExp NoUniqueness MemBind)
x Param (MemInfo SubExp NoUniqueness MemBind)
y VName
arr = do
      Bool
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Param (MemInfo SubExp NoUniqueness MemBind) -> Bool
forall p. Typed p => Param p -> Bool
isPrimParam Param (MemInfo SubExp NoUniqueness MemBind)
x) (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
        VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
arr [TExp Int64
ltid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
x) []
      VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
y) [] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
x) []

-- | Compile 'SegScan' instance to host-level code with calls to a
-- single-pass kernel.
compileSegScan ::
  Pat LetDecMem ->
  SegLevel ->
  SegSpace ->
  SegBinOp GPUMem ->
  KernelBody GPUMem ->
  CallKernelGen ()
compileSegScan :: Pat (MemInfo SubExp NoUniqueness MemBind)
-> SegLevel
-> SegSpace
-> SegBinOp GPUMem
-> KernelBody GPUMem
-> CallKernelGen ()
compileSegScan Pat (MemInfo SubExp NoUniqueness MemBind)
pat SegLevel
lvl SegSpace
space SegBinOp GPUMem
scan_op KernelBody GPUMem
map_kbody = do
  attrs <- SegLevel -> CallKernelGen KernelAttrs
lvlKernelAttrs SegLevel
lvl
  let Pat all_pes = pat

      scanop_nes = SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scan_op

      n = [TExp Int64] -> TExp Int64
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space

      tys' = Lambda GPUMem -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall rep.
Lambda rep -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType (Lambda GPUMem -> [TypeBase (ShapeBase SubExp) NoUniqueness])
-> Lambda GPUMem -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a b. (a -> b) -> a -> b
$ SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scan_op

      tys = (TypeBase (ShapeBase SubExp) NoUniqueness -> PrimType)
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> [PrimType]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase (ShapeBase SubExp) NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType [TypeBase (ShapeBase SubExp) NoUniqueness]
tys'

      tblock_size_e = SubExp -> TExp Int64
pe64 (SubExp -> TExp Int64) -> SubExp -> TExp Int64
forall a b. (a -> b) -> a -> b
$ Count BlockSize SubExp -> SubExp
forall {k} (u :: k) e. Count u e -> e
unCount (Count BlockSize SubExp -> SubExp)
-> Count BlockSize SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ KernelAttrs -> Count BlockSize SubExp
kAttrBlockSize KernelAttrs
attrs
      num_phys_blocks_e = SubExp -> TExp Int64
pe64 (SubExp -> TExp Int64) -> SubExp -> TExp Int64
forall a b. (a -> b) -> a -> b
$ Count NumBlocks SubExp -> SubExp
forall {k} (u :: k) e. Count u e -> e
unCount (Count NumBlocks SubExp -> SubExp)
-> Count NumBlocks SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ KernelAttrs -> Count NumBlocks SubExp
kAttrNumBlocks KernelAttrs
attrs

  let chunk_const = [TypeBase (ShapeBase SubExp) NoUniqueness] -> KernelConstExp
getChunkSize [TypeBase (ShapeBase SubExp) NoUniqueness]
tys'
  chunk_v <- dPrimV "chunk_size" . isInt64 =<< kernelConstToExp chunk_const
  let chunk = TV Int64 -> TExp Int64
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
chunk_v

  num_virt_blocks <-
    tvSize <$> dPrimV "num_virt_blocks" (n `divUp` (tblock_size_e * chunk))
  let num_virt_blocks_e = SubExp -> TExp Int64
pe64 SubExp
num_virt_blocks

  num_virt_threads <-
    dPrimVE "num_virt_threads" $ num_virt_blocks_e * tblock_size_e

  let (gtids, dims) = unzip $ unSegSpace space
      dims' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
dims
      segmented = [TExp Int64] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TExp Int64]
dims' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1
      not_segmented_e = Bool -> TPrimExp Bool VName
forall v. Bool -> TPrimExp Bool v
fromBool (Bool -> TPrimExp Bool VName) -> Bool -> TPrimExp Bool VName
forall a b. (a -> b) -> a -> b
$ Bool -> Bool
not Bool
segmented
      segment_size = [TExp Int64] -> TExp Int64
forall a. HasCallStack => [a] -> a
last [TExp Int64]
dims'

  emit $ Imp.DebugPrint "Sequential elements per thread (chunk)" $ Just $ untyped chunk

  statusFlags <- sAllocArray "status_flags" int8 (Shape [num_virt_blocks]) (Space "device")
  sReplicate statusFlags $ intConst Int8 statusX

  (aggregateArrays, incprefixArrays) <-
    fmap unzip $
      forM tys $ \PrimType
ty ->
        (,)
          (VName -> VName -> (VName, VName))
-> ImpM GPUMem HostEnv HostOp VName
-> ImpM GPUMem HostEnv HostOp (VName -> (VName, VName))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String
-> PrimType
-> ShapeBase SubExp
-> Space
-> ImpM GPUMem HostEnv HostOp VName
forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName
sAllocArray String
"aggregates" PrimType
ty ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
num_virt_blocks]) (String -> Space
Space String
"device")
          ImpM GPUMem HostEnv HostOp (VName -> (VName, VName))
-> ImpM GPUMem HostEnv HostOp VName
-> ImpM GPUMem HostEnv HostOp (VName, VName)
forall a b.
ImpM GPUMem HostEnv HostOp (a -> b)
-> ImpM GPUMem HostEnv HostOp a -> ImpM GPUMem HostEnv HostOp b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> String
-> PrimType
-> ShapeBase SubExp
-> Space
-> ImpM GPUMem HostEnv HostOp VName
forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName
sAllocArray String
"incprefixes" PrimType
ty ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
num_virt_blocks]) (String -> Space
Space String
"device")

  global_id <- genZeroes "global_dynid" 1

  let attrs' = KernelAttrs
attrs {kAttrConstExps = M.singleton (tvVar chunk_v) chunk_const}

  sKernelThread "segscan" (segFlat space) attrs' $ do
    chunk32 <- dPrimVE "chunk_size_32b" $ sExt32 $ tvExp chunk_v

    constants <- kernelConstants <$> askEnv

    let ltid32 = KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
        ltid = TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
ltid32

    (sharedId, transposedArrays, prefixArrays, warpscan, exchanges) <-
      createLocalArrays (kAttrBlockSize attrs) (tvSize chunk_v) tys

    -- We wrap the entire kernel body in a virtualisation loop to
    -- handle the case where we do not have enough thread blocks to
    -- cover the iteration space. Dynamic block indexing has no
    -- implication on this, since each block simply fetches a new
    -- dynamic ID upon entry into the virtualisation loop.
    --
    -- We could use virtualiseBlocks, but this introduces a barrier which is
    -- redundant in this case, and also we don't need to base virtual block IDs
    -- on the loop variable, but rather on the dynamic IDs.
    phys_block_id <- dPrim "phys_block_id"
    sOp $ Imp.GetBlockId (tvVar phys_block_id) 0
    iters <-
      dPrimVE "virtloop_bound" $
        (num_virt_blocks_e - tvExp phys_block_id)
          `divUp` num_phys_blocks_e

    sFor "virtloop_i" iters $ const $ do
      dyn_id <- dPrim "dynamic_id"
      sComment "First thread in block fetches this block's dynamic_id" $ do
        sWhen (ltid32 .==. 0) $ do
          (globalIdMem, _, globalIdOff) <- fullyIndexArray global_id [0]
          sOp $
            Imp.Atomic DefaultSpace $
              Imp.AtomicAdd
                Int32
                (tvVar dyn_id)
                globalIdMem
                (Count $ unCount globalIdOff)
                (untyped (1 :: Imp.TExp Int32))
          sComment "Set dynamic id for this block" $ do
            copyDWIMFix sharedId [0] (tvSize dyn_id) []

          sComment "First thread in last (virtual) block resets global dynamic_id" $ do
            sWhen (tvExp dyn_id .==. num_virt_blocks_e - 1) $
              copyDWIMFix global_id [0] (intConst Int32 0) []

      let local_barrier = Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
          local_fence = Fence -> KernelOp
Imp.MemFence Fence
Imp.FenceLocal
          global_fence = Fence -> KernelOp
Imp.MemFence Fence
Imp.FenceGlobal

      sOp local_barrier
      copyDWIMFix (tvVar dyn_id) [] (Var sharedId) [0]
      sOp local_barrier

      block_offset <-
        dPrimVE "block_offset" $
          sExt64 (tvExp dyn_id) * chunk * tblock_size_e
      sgm_idx <- dPrimVE "sgm_idx" $ block_offset `mod` segment_size
      boundary <-
        dPrimVE "boundary" $
          sExt32 $
            sMin64 (chunk * tblock_size_e) (segment_size - sgm_idx)
      segsize_compact <-
        dPrimVE "segsize_compact" $
          sExt32 $
            sMin64 (chunk * tblock_size_e) segment_size
      private_chunks <-
        forM tys $ \PrimType
ty ->
          String
-> PrimType
-> ShapeBase SubExp
-> Space
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String
-> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName
sAllocArray
            String
"private"
            PrimType
ty
            ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [TV Int64 -> SubExp
forall {k} (t :: k). TV t -> SubExp
tvSize TV Int64
chunk_v])
            ([SubExp] -> PrimType -> Space
ScalarSpace [TV Int64 -> SubExp
forall {k} (t :: k). TV t -> SubExp
tvSize TV Int64
chunk_v] PrimType
ty)

      thd_offset <- dPrimVE "thd_offset" $ block_offset + ltid

      sComment "Load and map" $
        sFor "i" chunk $ \TExp Int64
i -> do
          -- The map's input index
          virt_tid <- String -> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64)
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"virt_tid" (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64
thd_offset TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
i TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
tblock_size_e
          dIndexSpace (zip gtids dims') virt_tid
          -- Perform the map
          let in_bounds =
                Names
-> Stms GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
map_kbody) (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
                  let ([KernelResult]
all_scan_res, [KernelResult]
map_res) =
                        Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp GPUMem] -> Int
forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp GPUMem
scan_op]) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
map_kbody

                  -- Write map results to their global memory destinations
                  [(PatElem (MemInfo SubExp NoUniqueness MemBind), KernelResult)]
-> ((PatElem (MemInfo SubExp NoUniqueness MemBind), KernelResult)
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> [KernelResult]
-> [(PatElem (MemInfo SubExp NoUniqueness MemBind), KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Int
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
forall a. Int -> [a] -> [a]
takeLast ([KernelResult] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
map_res) [PatElem (MemInfo SubExp NoUniqueness MemBind)]
all_pes) [KernelResult]
map_res) (((PatElem (MemInfo SubExp NoUniqueness MemBind), KernelResult)
  -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((PatElem (MemInfo SubExp NoUniqueness MemBind), KernelResult)
    -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(PatElem (MemInfo SubExp NoUniqueness MemBind)
dest, KernelResult
src) ->
                    VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (PatElem (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (MemInfo SubExp NoUniqueness MemBind)
dest) ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gtids) (KernelResult -> SubExp
kernelResultSubExp KernelResult
src) []

                  -- Write to-scan results to private memory.
                  [(VName, SubExp)]
-> ((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
private_chunks ([SubExp] -> [(VName, SubExp)]) -> [SubExp] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
all_scan_res) (((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
dest, SubExp
src) ->
                    VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
dest [TExp Int64
i] SubExp
src []

              out_of_bounds =
                [(VName, SubExp)]
-> ((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
private_chunks [SubExp]
scanop_nes) (((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
dest, SubExp
ne) ->
                  VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
dest [TExp Int64
i] SubExp
ne []

          sIf (virt_tid .<. n) in_bounds out_of_bounds

      sOp $ Imp.ErrorSync Imp.FenceLocal
      sComment "Transpose scan inputs" $ do
        forM_ (zip transposedArrays private_chunks) $ \(VName
trans, VName
priv) -> do
          String
-> TExp Int64
-> (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall {k} (t :: k) rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TExp Int64
chunk ((TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
            sharedIdx <- String -> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64)
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"sharedIdx" (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64
ltid TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
i TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
tblock_size_e
            copyDWIMFix trans [sharedIdx] (Var priv) [i]
          KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
local_barrier
          String
-> TExp Int64
-> (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall {k} (t :: k) rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TExp Int64
chunk ((TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
            sharedIdx <- String -> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"sharedIdx" (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64
ltid TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
chunk TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
i
            copyDWIMFix priv [sExt64 i] (Var trans) [sExt64 $ tvExp sharedIdx]
          KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
local_barrier

      sComment "Per thread scan" $ do
        -- We don't need to touch the first element, so only m-1
        -- iterations here.
        sFor "i" (chunk - 1) $ \TExp Int64
i -> do
          let xs :: [VName]
xs = (LParam GPUMem -> VName) -> [LParam GPUMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map LParam GPUMem -> VName
Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName ([LParam GPUMem] -> [VName]) -> [LParam GPUMem] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegBinOp GPUMem -> [LParam GPUMem]
xParams SegBinOp GPUMem
scan_op
              ys :: [VName]
ys = (LParam GPUMem -> VName) -> [LParam GPUMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map LParam GPUMem -> VName
Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName ([LParam GPUMem] -> [VName]) -> [LParam GPUMem] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegBinOp GPUMem -> [LParam GPUMem]
yParams SegBinOp GPUMem
scan_op
          -- determine if start of segment
          new_sgm <-
            if Bool
segmented
              then do
                gidx <- String -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"gidx" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ (TExp Int32
ltid32 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
chunk32) TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
1
                dPrimVE "new_sgm" $ (gidx + sExt32 i - boundary) `mod` segsize_compact .==. 0
              else TPrimExp Bool VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Bool VName)
forall a. a -> ImpM GPUMem KernelEnv KernelOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure TPrimExp Bool VName
forall v. TPrimExp Bool v
false
          -- skip scan of first element in segment
          sUnless new_sgm $ do
            forM_ (zip4 private_chunks xs ys tys) $ \(VName
src, VName
x, VName
y, PrimType
ty) -> do
              VName -> PrimType -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
x PrimType
ty
              VName -> PrimType -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
y PrimType
ty
              VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
x [] (VName -> SubExp
Var VName
src) [TExp Int64
i]
              VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
y [] (VName -> SubExp
Var VName
src) [TExp Int64
i TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
1]

            compileStms mempty (bodyStms $ lambdaBody $ segBinOpLambda scan_op) $
              forM_ (zip private_chunks $ map resSubExp $ bodyResult $ lambdaBody $ segBinOpLambda scan_op) $ \(VName
dest, SubExp
res) ->
                VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
dest [TExp Int64
i TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
1] SubExp
res []

      sComment "Publish results in shared memory" $ do
        forM_ (zip prefixArrays private_chunks) $ \(VName
dest, VName
src) ->
          VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
dest [TExp Int64
ltid] (VName -> SubExp
Var VName
src) [TExp Int64
chunk TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1]
        sOp local_barrier

      let crossesSegment = do
            Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard Bool
segmented
            (TExp Int32 -> TExp Int32 -> TPrimExp Bool VName)
-> Maybe (TExp Int32 -> TExp Int32 -> TPrimExp Bool VName)
forall a. a -> Maybe a
Just ((TExp Int32 -> TExp Int32 -> TPrimExp Bool VName)
 -> Maybe (TExp Int32 -> TExp Int32 -> TPrimExp Bool VName))
-> (TExp Int32 -> TExp Int32 -> TPrimExp Bool VName)
-> Maybe (TExp Int32 -> TExp Int32 -> TPrimExp Bool VName)
forall a b. (a -> b) -> a -> b
$ \TExp Int32
from TExp Int32
to ->
              let from' :: TExp Int32
from' = (TExp Int32
from TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
1) TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
chunk32 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
1
                  to' :: TExp Int32
to' = (TExp Int32
to TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
1) TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
chunk32 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
1
               in (TExp Int32
to' TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
from') TExp Int32 -> TExp Int32 -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. (TExp Int32
to' TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
segsize_compact TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
boundary) TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`mod` TExp Int32
segsize_compact

      scan_op1 <- renameLambda $ segBinOpLambda scan_op

      accs <- mapM (dPrimSV "acc") tys
      sComment "Scan results (with warp scan)" $ do
        blockScan
          crossesSegment
          tblock_size_e
          num_virt_threads
          scan_op1
          prefixArrays

        sOp $ Imp.ErrorSync Imp.FenceLocal

        let firstThread TV (ZonkAny 1)
acc VName
prefixes =
              VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (TV (ZonkAny 1) -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV (ZonkAny 1)
acc) [] (VName -> SubExp
Var VName
prefixes) [TExp Int64 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int64
tblock_size_e TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1]
            notFirstThread TV (ZonkAny 1)
acc VName
prefixes =
              VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (TV (ZonkAny 1) -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV (ZonkAny 1)
acc) [] (VName -> SubExp
Var VName
prefixes) [TExp Int64
ltid TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1]
        sIf
          (ltid32 .==. 0)
          (zipWithM_ firstThread accs prefixArrays)
          (zipWithM_ notFirstThread accs prefixArrays)

        sOp local_barrier

      prefixes <- forM (zip scanop_nes tys) $ \(SubExp
ne, PrimType
ty) ->
        String
-> TExp (ZonkAny 3)
-> ImpM GPUMem KernelEnv KernelOp (TV (ZonkAny 3))
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"prefix" (TExp (ZonkAny 3)
 -> ImpM GPUMem KernelEnv KernelOp (TV (ZonkAny 3)))
-> TExp (ZonkAny 3)
-> ImpM GPUMem KernelEnv KernelOp (TV (ZonkAny 3))
forall a b. (a -> b) -> a -> b
$ PrimExp VName -> TExp (ZonkAny 3)
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp VName -> TExp (ZonkAny 3))
-> PrimExp VName -> TExp (ZonkAny 3)
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> PrimExp VName
forall a. ToExp a => PrimType -> a -> PrimExp VName
toExp' PrimType
ty SubExp
ne
      blockNewSgm <- dPrimVE "block_new_sgm" $ sgm_idx .==. 0
      sComment "Perform lookback" $ do
        sWhen (blockNewSgm .&&. ltid32 .==. 0) $ do
          everythingVolatile $
            forM_ (zip accs incprefixArrays) $ \(TV (ZonkAny 1)
acc, VName
incprefixArray) ->
              VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
incprefixArray [TV Int64 -> TExp Int64
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dyn_id] (TV (ZonkAny 1) -> SubExp
forall {k} (t :: k). TV t -> SubExp
tvSize TV (ZonkAny 1)
acc) []
          sOp global_fence
          everythingVolatile $
            copyDWIMFix statusFlags [tvExp dyn_id] (intConst Int8 statusP) []
          forM_ (zip scanop_nes accs) $ \(SubExp
ne, TV (ZonkAny 1)
acc) ->
            VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (TV (ZonkAny 1) -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV (ZonkAny 1)
acc) [] SubExp
ne []
        -- end sWhen

        let warp_size = KernelConstants -> TExp Int32
kernelWaveSize KernelConstants
constants
        sWhen (bNot blockNewSgm .&&. ltid32 .<. warp_size) $ do
          sWhen (ltid32 .==. 0) $ do
            sIf
              (not_segmented_e .||. boundary .==. sExt32 (tblock_size_e * chunk))
              ( do
                  everythingVolatile $
                    forM_ (zip aggregateArrays accs) $ \(VName
aggregateArray, TV (ZonkAny 1)
acc) ->
                      VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
aggregateArray [TV Int64 -> TExp Int64
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dyn_id] (TV (ZonkAny 1) -> SubExp
forall {k} (t :: k). TV t -> SubExp
tvSize TV (ZonkAny 1)
acc) []
                  sOp global_fence
                  everythingVolatile $
                    copyDWIMFix statusFlags [tvExp dyn_id] (intConst Int8 statusA) []
              )
              ( do
                  everythingVolatile $
                    forM_ (zip incprefixArrays accs) $ \(VName
incprefixArray, TV (ZonkAny 1)
acc) ->
                      VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
incprefixArray [TV Int64 -> TExp Int64
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dyn_id] (TV (ZonkAny 1) -> SubExp
forall {k} (t :: k). TV t -> SubExp
tvSize TV (ZonkAny 1)
acc) []
                  sOp global_fence
                  everythingVolatile $
                    copyDWIMFix statusFlags [tvExp dyn_id] (intConst Int8 statusP) []
              )
            everythingVolatile $
              copyDWIMFix warpscan [0] (Var statusFlags) [tvExp dyn_id - 1]
          -- sWhen
          sOp local_fence

          status :: TV Int8 <- dPrim "status"
          copyDWIMFix (tvVar status) [] (Var warpscan) [0]

          sIf
            (tvExp status .==. statusP)
            ( sWhen (ltid32 .==. 0) $
                everythingVolatile $
                  forM_ (zip prefixes incprefixArrays) $ \(TV (ZonkAny 3)
prefix, VName
incprefixArray) ->
                    VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (TV (ZonkAny 3) -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV (ZonkAny 3)
prefix) [] (VName -> SubExp
Var VName
incprefixArray) [TV Int64 -> TExp Int64
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dyn_id TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1]
            )
            ( do
                readOffset <-
                  dPrimV "readOffset" $
                    sExt32 $
                      tvExp dyn_id - sExt64 (kernelWaveSize constants)
                let loopStop = TExp Int32
warp_size TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* (-TExp Int32
1)
                    sameSegment TV Int32
readIdx
                      | Bool
segmented =
                          let startIdx :: TExp Int64
startIdx = TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
readIdx TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
1) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
tblock_size_e TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
chunk TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1
                           in TExp Int64
block_offset TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
startIdx TExp Int64 -> TExp Int64 -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TExp Int64
sgm_idx
                      | Bool
otherwise = TPrimExp Bool VName
forall v. TPrimExp Bool v
true
                sWhile (tvExp readOffset .>. loopStop) $ do
                  readI <- dPrimV "read_i" $ tvExp readOffset + ltid32
                  aggrs <- forM (zip scanop_nes tys) $ \(SubExp
ne, PrimType
ty) ->
                    String
-> TExp (ZonkAny 5)
-> ImpM GPUMem KernelEnv KernelOp (TV (ZonkAny 5))
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"aggr" (TExp (ZonkAny 5)
 -> ImpM GPUMem KernelEnv KernelOp (TV (ZonkAny 5)))
-> TExp (ZonkAny 5)
-> ImpM GPUMem KernelEnv KernelOp (TV (ZonkAny 5))
forall a b. (a -> b) -> a -> b
$ PrimExp VName -> TExp (ZonkAny 5)
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp VName -> TExp (ZonkAny 5))
-> PrimExp VName -> TExp (ZonkAny 5)
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> PrimExp VName
forall a. ToExp a => PrimType -> a -> PrimExp VName
toExp' PrimType
ty SubExp
ne
                  flag <- dPrimV "flag" (statusX :: Imp.TExp Int8)
                  everythingVolatile . sWhen (tvExp readI .>=. 0) $ do
                    sIf
                      (sameSegment readI)
                      ( do
                          copyDWIMFix (tvVar flag) [] (Var statusFlags) [sExt64 $ tvExp readI]
                          sIf
                            (tvExp flag .==. statusP)
                            ( forM_ (zip incprefixArrays aggrs) $ \(VName
incprefix, TV (ZonkAny 5)
aggr) ->
                                VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (TV (ZonkAny 5) -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV (ZonkAny 5)
aggr) [] (VName -> SubExp
Var VName
incprefix) [TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
readI]
                            )
                            ( sWhen (tvExp flag .==. statusA) $ do
                                forM_ (zip aggrs aggregateArrays) $ \(TV (ZonkAny 5)
aggr, VName
aggregate) ->
                                  VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (TV (ZonkAny 5) -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV (ZonkAny 5)
aggr) [] (VName -> SubExp
Var VName
aggregate) [TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
readI]
                            )
                      )
                      (copyDWIMFix (tvVar flag) [] (intConst Int8 statusP) [])
                  -- end sIf
                  -- end sWhen

                  forM_ (zip exchanges aggrs) $ \(VName
exchange, TV (ZonkAny 5)
aggr) ->
                    VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
exchange [TExp Int64
ltid] (TV (ZonkAny 5) -> SubExp
forall {k} (t :: k). TV t -> SubExp
tvSize TV (ZonkAny 5)
aggr) []
                  copyDWIMFix warpscan [ltid] (tvSize flag) []

                  -- execute warp-parallel reduction but only if the last read flag in not STATUS_P
                  copyDWIMFix (tvVar flag) [] (Var warpscan) [sExt64 warp_size - 1]
                  sWhen (tvExp flag .<. statusP) $ do
                    lam' <- renameLambda scan_op1
                    inBlockScanLookback
                      constants
                      num_virt_threads
                      warpscan
                      exchanges
                      lam'

                  -- all threads of the warp read the result of reduction
                  copyDWIMFix (tvVar flag) [] (Var warpscan) [sExt64 warp_size - 1]
                  forM_ (zip aggrs exchanges) $ \(TV (ZonkAny 5)
aggr, VName
exchange) ->
                    VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (TV (ZonkAny 5) -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV (ZonkAny 5)
aggr) [] (VName -> SubExp
Var VName
exchange) [TExp Int32 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
warp_size TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1]
                  -- update read offset
                  sIf
                    (tvExp flag .==. statusP)
                    (readOffset <-- loopStop)
                    ( sWhen (tvExp flag .==. statusA) $ do
                        readOffset <-- tvExp readOffset - zExt32 warp_size
                    )

                  -- update prefix if flag different than STATUS_X:
                  sWhen (tvExp flag .>. statusX) $ do
                    lam <- renameLambda scan_op1
                    let (xs, ys) = splitAt (length tys) $ map paramName $ lambdaParams lam
                    forM_ (zip xs aggrs) $ \(VName
x, TV (ZonkAny 5)
aggr) -> VName -> TExp (ZonkAny 5) -> ImpM GPUMem KernelEnv KernelOp ()
forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
x (TV (ZonkAny 5) -> TExp (ZonkAny 5)
forall {k} (t :: k). TV t -> TExp t
tvExp TV (ZonkAny 5)
aggr)
                    forM_ (zip ys prefixes) $ \(VName
y, TV (ZonkAny 3)
prefix) -> VName -> TExp (ZonkAny 3) -> ImpM GPUMem KernelEnv KernelOp ()
forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
y (TV (ZonkAny 3) -> TExp (ZonkAny 3)
forall {k} (t :: k). TV t -> TExp t
tvExp TV (ZonkAny 3)
prefix)
                    compileStms mempty (bodyStms $ lambdaBody lam) $
                      forM_ (zip3 prefixes tys $ map resSubExp $ bodyResult $ lambdaBody lam) $
                        \(TV (ZonkAny 3)
prefix, PrimType
ty, SubExp
res) -> TV (ZonkAny 3)
prefix TV (ZonkAny 3)
-> TExp (ZonkAny 3) -> ImpM GPUMem KernelEnv KernelOp ()
forall {k} (t :: k) rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- PrimExp VName -> TExp (ZonkAny 3)
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp (PrimType -> SubExp -> PrimExp VName
forall a. ToExp a => PrimType -> a -> PrimExp VName
toExp' PrimType
ty SubExp
res)
                  sOp local_fence
            )

          -- end sWhile
          -- end sIf
          sWhen (ltid32 .==. 0) $ do
            scan_op2 <- renameLambda scan_op1
            let xs = (Param (MemInfo SubExp NoUniqueness MemBind) -> VName)
-> [Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName ([Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName])
-> [Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName]
forall a b. (a -> b) -> a -> b
$ Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a. Int -> [a] -> [a]
take ([PrimType] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimType]
tys) ([Param (MemInfo SubExp NoUniqueness MemBind)]
 -> [Param (MemInfo SubExp NoUniqueness MemBind)])
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
scan_op2
                ys = (Param (MemInfo SubExp NoUniqueness MemBind) -> VName)
-> [Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName ([Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName])
-> [Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName]
forall a b. (a -> b) -> a -> b
$ Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a. Int -> [a] -> [a]
drop ([PrimType] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimType]
tys) ([Param (MemInfo SubExp NoUniqueness MemBind)]
 -> [Param (MemInfo SubExp NoUniqueness MemBind)])
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
scan_op2
            sWhen (boundary .==. sExt32 (tblock_size_e * chunk)) $ do
              forM_ (zip xs prefixes) $ \(VName
x, TV (ZonkAny 3)
prefix) -> VName -> TExp (ZonkAny 3) -> ImpM GPUMem KernelEnv KernelOp ()
forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
x (TExp (ZonkAny 3) -> ImpM GPUMem KernelEnv KernelOp ())
-> TExp (ZonkAny 3) -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ TV (ZonkAny 3) -> TExp (ZonkAny 3)
forall {k} (t :: k). TV t -> TExp t
tvExp TV (ZonkAny 3)
prefix
              forM_ (zip ys accs) $ \(VName
y, TV (ZonkAny 1)
acc) -> VName -> TExp (ZonkAny 1) -> ImpM GPUMem KernelEnv KernelOp ()
forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
y (TExp (ZonkAny 1) -> ImpM GPUMem KernelEnv KernelOp ())
-> TExp (ZonkAny 1) -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ TV (ZonkAny 1) -> TExp (ZonkAny 1)
forall {k} (t :: k). TV t -> TExp t
tvExp TV (ZonkAny 1)
acc
              compileStms mempty (bodyStms $ lambdaBody scan_op2) $
                everythingVolatile $
                  forM_ (zip incprefixArrays $ map resSubExp $ bodyResult $ lambdaBody scan_op2) $
                    \(VName
incprefixArray, SubExp
res) -> VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
incprefixArray [TV Int64 -> TExp Int64
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dyn_id] SubExp
res []
              sOp global_fence
              everythingVolatile $ copyDWIMFix statusFlags [tvExp dyn_id] (intConst Int8 statusP) []
            forM_ (zip exchanges prefixes) $ \(VName
exchange, TV (ZonkAny 3)
prefix) ->
              VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
exchange [TExp Int64
0] (TV (ZonkAny 3) -> SubExp
forall {k} (t :: k). TV t -> SubExp
tvSize TV (ZonkAny 3)
prefix) []
            forM_ (zip3 accs tys scanop_nes) $ \(TV (ZonkAny 1)
acc, PrimType
ty, SubExp
ne) ->
              TV (ZonkAny 1) -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV (ZonkAny 1)
acc VName -> PrimExp VName -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. VName -> PrimExp VName -> ImpM rep r op ()
<~~ PrimType -> SubExp -> PrimExp VName
forall a. ToExp a => PrimType -> a -> PrimExp VName
toExp' PrimType
ty SubExp
ne
        -- end sWhen
        -- end sWhen

        sWhen (bNot $ tvExp dyn_id .==. 0) $ do
          sOp local_barrier
          forM_ (zip exchanges prefixes) $ \(VName
exchange, TV (ZonkAny 3)
prefix) ->
            VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (TV (ZonkAny 3) -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV (ZonkAny 3)
prefix) [] (VName -> SubExp
Var VName
exchange) [TExp Int64
0]
          sOp local_barrier
      -- end sWhen
      -- end sComment

      scan_op3 <- renameLambda scan_op1
      scan_op4 <- renameLambda scan_op1

      sComment "Distribute results" $ do
        let (xs, ys) = splitAt (length tys) $ map paramName $ lambdaParams scan_op3
            (xs', ys') = splitAt (length tys) $ map paramName $ lambdaParams scan_op4

        forM_ (zip7 prefixes accs xs xs' ys ys' tys) $
          \(TV (ZonkAny 3)
prefix, TV (ZonkAny 1)
acc, VName
x, VName
x', VName
y, VName
y', PrimType
ty) -> do
            VName -> PrimType -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
x PrimType
ty
            VName -> PrimType -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
y PrimType
ty
            VName -> TExp (ZonkAny 3) -> ImpM GPUMem KernelEnv KernelOp ()
forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
x' (TExp (ZonkAny 3) -> ImpM GPUMem KernelEnv KernelOp ())
-> TExp (ZonkAny 3) -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ TV (ZonkAny 3) -> TExp (ZonkAny 3)
forall {k} (t :: k). TV t -> TExp t
tvExp TV (ZonkAny 3)
prefix
            VName -> TExp (ZonkAny 1) -> ImpM GPUMem KernelEnv KernelOp ()
forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
y' (TExp (ZonkAny 1) -> ImpM GPUMem KernelEnv KernelOp ())
-> TExp (ZonkAny 1) -> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ TV (ZonkAny 1) -> TExp (ZonkAny 1)
forall {k} (t :: k). TV t -> TExp t
tvExp TV (ZonkAny 1)
acc

        sIf
          (ltid32 * chunk32 .<. boundary .&&. bNot blockNewSgm)
          ( compileStms mempty (bodyStms $ lambdaBody scan_op4) $
              forM_ (zip3 xs tys $ map resSubExp $ bodyResult $ lambdaBody scan_op4) $
                \(VName
x, PrimType
ty, SubExp
res) -> VName
x VName -> PrimExp VName -> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op. VName -> PrimExp VName -> ImpM rep r op ()
<~~ PrimType -> SubExp -> PrimExp VName
forall a. ToExp a => PrimType -> a -> PrimExp VName
toExp' PrimType
ty SubExp
res
          )
          (forM_ (zip xs accs) $ \(VName
x, TV (ZonkAny 1)
acc) -> VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
x [] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ TV (ZonkAny 1) -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV (ZonkAny 1)
acc) [])
        -- calculate where previous thread stopped, to determine number of
        -- elements left before new segment.
        stop <-
          dPrimVE "stopping_point" $
            segsize_compact - (ltid32 * chunk32 - 1 + segsize_compact - boundary) `rem` segsize_compact
        sFor "i" chunk $ \TExp Int64
i -> do
          TPrimExp Bool VName
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
TPrimExp Bool VName -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int64 -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
i TExp Int32 -> TExp Int32 -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
stop TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
1) (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ do
            [(VName, VName)]
-> ((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
private_chunks [VName]
ys) (((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, VName) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \(VName
src, VName
y) ->
              -- only include prefix for the first segment part per thread
              VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
y [] (VName -> SubExp
Var VName
src) [TExp Int64
i]
            Names
-> Stms GPUMem
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (Body GPUMem -> Stms GPUMem
forall rep. Body rep -> Stms rep
bodyStms (Body GPUMem -> Stms GPUMem) -> Body GPUMem -> Stms GPUMem
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
scan_op3) (ImpM GPUMem KernelEnv KernelOp ()
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
              [(VName, SubExp)]
-> ((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
private_chunks ([SubExp] -> [(VName, SubExp)]) -> [SubExp] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ([SubExpRes] -> [SubExp]) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Body GPUMem -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult (Body GPUMem -> [SubExpRes]) -> Body GPUMem -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
scan_op3) (((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> ((VName, SubExp) -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$
                \(VName
dest, SubExp
res) ->
                  VName
-> [TExp Int64]
-> SubExp
-> [TExp Int64]
-> ImpM GPUMem KernelEnv KernelOp ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
dest [TExp Int64
i] SubExp
res []

      sComment "Transpose scan output and Write it to global memory in coalesced fashion" $ do
        forM_ (zip3 transposedArrays private_chunks $ map patElemName all_pes) $ \(VName
locmem, VName
priv, VName
dest) -> do
          -- sOp local_barrier
          String
-> TExp Int64
-> (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall {k} (t :: k) rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TExp Int64
chunk ((TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
            sharedIdx <-
              String -> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"sharedIdx" (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$
                TExp Int64 -> TExp Int64
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int64
ltid TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
chunk) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
i
            copyDWIMFix locmem [tvExp sharedIdx] (Var priv) [i]
          KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
local_barrier
          String
-> TExp Int64
-> (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall {k} (t :: k) rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TExp Int64
chunk ((TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
 -> ImpM GPUMem KernelEnv KernelOp ())
-> (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp ())
-> ImpM GPUMem KernelEnv KernelOp ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
            flat_idx <- String -> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64)
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"flat_idx" (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64
thd_offset TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
i TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
tblock_size_e
            dIndexSpace (zip gtids dims') flat_idx
            sWhen (flat_idx .<. n) $ do
              copyDWIMFix
                dest
                (map Imp.le64 gtids)
                (Var locmem)
                [sExt64 $ flat_idx - block_offset]
          KernelOp -> ImpM GPUMem KernelEnv KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp KernelOp
local_barrier
{-# NOINLINE compileSegScan #-}