{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.AD.Rev.Hist
( diffMinMaxHist,
diffMulHist,
diffAddHist,
diffVecHist,
diffHist,
)
where
import Control.Monad
import Futhark.AD.Rev.Monad
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.SOACS
import Futhark.Tools
import Futhark.Transform.Rename
getBinOpPlus :: PrimType -> BinOp
getBinOpPlus :: PrimType -> BinOp
getBinOpPlus (IntType IntType
x) = IntType -> Overflow -> BinOp
Add IntType
x Overflow
OverflowUndef
getBinOpPlus (FloatType FloatType
f) = FloatType -> BinOp
FAdd FloatType
f
getBinOpPlus PrimType
_ = String -> BinOp
forall a. HasCallStack => String -> a
error String
"In getBinOpMul, Hist.hs: input not supported"
getBinOpDiv :: PrimType -> BinOp
getBinOpDiv :: PrimType -> BinOp
getBinOpDiv (IntType IntType
t) = IntType -> Safety -> BinOp
SDiv IntType
t Safety
Unsafe
getBinOpDiv (FloatType FloatType
t) = FloatType -> BinOp
FDiv FloatType
t
getBinOpDiv PrimType
_ = String -> BinOp
forall a. HasCallStack => String -> a
error String
"In getBinOpDiv, Hist.hs: input not supported"
withinBounds :: [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds :: [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds [] = PrimExp VName -> TPrimExp Bool VName
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp VName -> TPrimExp Bool VName)
-> PrimExp VName -> TPrimExp Bool VName
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimExp VName
forall v. PrimValue -> PrimExp v
ValueExp (Bool -> PrimValue
BoolValue Bool
True)
withinBounds [(SubExp
q, VName
i)] = (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
q) TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. (SubExp -> TPrimExp Int64 VName
pe64 (IntType -> Integer -> SubExp
intConst IntType
Int64 (-Integer
1)) TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i)
withinBounds ((SubExp, VName)
qi : [(SubExp, VName)]
qis) = [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds [(SubExp, VName)
qi] TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds [(SubExp, VName)]
qis
elseIf ::
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
PrimType ->
[(m (Exp (Rep m)), m (Exp (Rep m)))] ->
[m (Body (Rep m))] ->
m (Exp (Rep m))
elseIf :: forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
PrimType
-> [(m (Exp (Rep m)), m (Exp (Rep m)))]
-> [m (Body (Rep m))]
-> m (Exp (Rep m))
elseIf PrimType
t [(m (Exp (Rep m))
c1, m (Exp (Rep m))
c2)] [m (Body (Rep m))
bt, m (Body (Rep m))
bf] =
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
t) m (Exp (Rep m))
c1 m (Exp (Rep m))
c2)
m (Body (Rep m))
bt
m (Body (Rep m))
bf
elseIf PrimType
t ((m (Exp (Rep m))
c1, m (Exp (Rep m))
c2) : [(m (Exp (Rep m)), m (Exp (Rep m)))]
cs) (m (Body (Rep m))
bt : [m (Body (Rep m))]
bs) =
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
t) m (Exp (Rep m))
c1 m (Exp (Rep m))
c2)
m (Body (Rep m))
bt
(m (Body (Rep m)) -> m (Exp (Rep m)))
-> m (Body (Rep m)) -> m (Exp (Rep m))
forall a b. (a -> b) -> a -> b
$ [m (Exp (Rep m))] -> m (Body (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody
([m (Exp (Rep m))] -> m (Body (Rep m)))
-> [m (Exp (Rep m))] -> m (Body (Rep m))
forall a b. (a -> b) -> a -> b
$ m (Exp (Rep m)) -> [m (Exp (Rep m))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure
(m (Exp (Rep m)) -> [m (Exp (Rep m))])
-> m (Exp (Rep m)) -> [m (Exp (Rep m))]
forall a b. (a -> b) -> a -> b
$ PrimType
-> [(m (Exp (Rep m)), m (Exp (Rep m)))]
-> [m (Body (Rep m))]
-> m (Exp (Rep m))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
PrimType
-> [(m (Exp (Rep m)), m (Exp (Rep m)))]
-> [m (Body (Rep m))]
-> m (Exp (Rep m))
elseIf PrimType
t [(m (Exp (Rep m)), m (Exp (Rep m)))]
cs [m (Body (Rep m))]
bs
elseIf PrimType
_ [(m (Exp (Rep m)), m (Exp (Rep m)))]
_ [m (Body (Rep m))]
_ = String -> m (Exp (Rep m))
forall a. HasCallStack => String -> a
error String
"In elseIf, Hist.hs: input not supported"
bindSubExpRes :: (MonadBuilder m) => String -> [SubExpRes] -> m [VName]
bindSubExpRes :: forall (m :: * -> *).
MonadBuilder m =>
String -> [SubExpRes] -> m [VName]
bindSubExpRes String
s =
(SubExpRes -> m VName) -> [SubExpRes] -> m [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse
( \(SubExpRes Certs
cs SubExp
se) -> do
bn <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
s
certifying cs $ letBindNames [bn] $ BasicOp $ SubExp se
pure bn
)
nestedmap :: [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap :: [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [] [PrimType]
_ Lambda SOACS
lam = Lambda SOACS -> ADM (Lambda SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda SOACS
lam
nestedmap s :: [SubExp]
s@(SubExp
h : [SubExp]
r) [PrimType]
pt Lambda SOACS
lam = do
params <- (PrimType -> ADM (Param Type)) -> [PrimType] -> ADM [Param Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (\PrimType
tp -> String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"x" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Shape -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
tp ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
s) NoUniqueness
NoUniqueness) [PrimType]
pt
body <- nestedmap r pt lam
mkLambda params $
fmap varsRes . letTupExp "res" . Op $
Screma h (map paramName params) (mapSOAC body)
mkF' :: Lambda SOACS -> [Type] -> SubExp -> ADM ([VName], [VName], Lambda SOACS)
mkF' :: Lambda SOACS
-> [Type] -> SubExp -> ADM ([VName], [VName], Lambda SOACS)
mkF' Lambda SOACS
lam [Type]
tps SubExp
n = do
lam' <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam
ds_params <- traverse (newParam "ds_param") tps
hs_params <- traverse (newParam "hs_param") tps
let ds_pars = (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
ds_params
let hs_pars = (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
hs_params
lam_map <-
mkLambda (ds_params <> hs_params) $
fmap varsRes . letTupExp "map_f'" . Op $
Screma n (ds_pars <> hs_pars) (mapSOAC lam')
pure (ds_pars, hs_pars, lam_map)
mkF :: Lambda SOACS -> [Type] -> SubExp -> ADM ([VName], Lambda SOACS)
mkF :: Lambda SOACS -> [Type] -> SubExp -> ADM ([VName], Lambda SOACS)
mkF Lambda SOACS
lam [Type]
tps SubExp
n = do
lam_l <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam
lam_r <- renameLambda lam
let q = [Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Type] -> Int) -> [Type] -> Int
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
lam
(lps, aps) = splitAt q $ lambdaParams lam_l
(ips, rps) = splitAt q $ lambdaParams lam_r
lam' <- mkLambda (lps <> aps <> rps) $ do
lam_l_res <- bodyBind $ lambdaBody lam_l
forM_ (zip ips lam_l_res) $ \(Param Type
ip, SubExpRes Certs
cs SubExp
se) ->
Certs -> ADM () -> ADM ()
forall a. Certs -> ADM a -> ADM a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
ip] (Exp (Rep ADM) -> ADM ()) -> Exp (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
bodyBind $ lambdaBody lam_r
ls_params <- traverse (newParam "ls_param") tps
as_params <- traverse (newParam "as_param") tps
rs_params <- traverse (newParam "rs_param") tps
let map_params = [Param Type]
ls_params [Param Type] -> [Param Type] -> [Param Type]
forall a. Semigroup a => a -> a -> a
<> [Param Type]
as_params [Param Type] -> [Param Type] -> [Param Type]
forall a. Semigroup a => a -> a -> a
<> [Param Type]
rs_params
lam_map <-
mkLambda map_params $
fmap varsRes . letTupExp "map_f" $
Op $
Screma n (map paramName map_params) $
mapSOAC lam'
pure (map paramName as_params, lam_map)
mapout :: VName -> SubExp -> SubExp -> ADM VName
mapout :: VName -> SubExp -> SubExp -> ADM VName
mapout VName
is SubExp
n SubExp
w = do
par_is <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"is" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
is'_lam <-
mkLambda [par_is] $
fmap varsRes . letTupExp "is'"
=<< eIf
(toExp $ withinBounds $ pure (w, paramName par_is))
(eBody $ pure $ eParam par_is)
(eBody $ pure $ eSubExp w)
letExp "is'" $ Op $ Screma n (pure is) $ mapSOAC is'_lam
multiScatter :: SubExp -> [VName] -> VName -> [VName] -> ADM [VName]
multiScatter :: SubExp -> [VName] -> VName -> [VName] -> ADM [VName]
multiScatter SubExp
n [VName]
dst VName
is [VName]
vs = do
tps <- (VName -> ADM Type) -> [VName] -> ADM [Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
vs
par_i <- newParam "i" $ Prim int64
scatter_params <- traverse (newParam "scatter_param" . rowType) tps
scatter_lam <-
mkLambda (par_i : scatter_params) $
fmap subExpsRes . mapM (letSubExp "scatter_map_res") =<< do
p1 <- replicateM (length scatter_params) $ eParam par_i
p2 <- traverse eParam scatter_params
pure $ p1 <> p2
let spec = (Type -> VName -> (Shape, Int, VName))
-> [Type] -> [VName] -> [(Shape, Int, VName)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Type
t -> (,,) ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$ SubExp -> [SubExp]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> [SubExp]) -> SubExp -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
t) Int
1) [Type]
tps [VName]
dst
letTupExp "scatter_res" . Op $ Scatter n (is : vs) spec scatter_lam
multiIndex :: [VName] -> [DimIndex SubExp] -> ADM [VName]
multiIndex :: [VName] -> [DimIndex SubExp] -> ADM [VName]
multiIndex [VName]
vs [DimIndex SubExp]
s = do
(VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse
( \VName
x -> do
t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
x
letExp "sorted" $ BasicOp $ Index x (fullSlice t s)
)
[VName]
vs
diffMinMaxHist ::
VjpOps -> VName -> StmAux () -> SubExp -> BinOp -> SubExp -> VName -> VName -> SubExp -> SubExp -> VName -> ADM () -> ADM ()
diffMinMaxHist :: VjpOps
-> VName
-> StmAux ()
-> SubExp
-> BinOp
-> SubExp
-> VName
-> VName
-> SubExp
-> SubExp
-> VName
-> ADM ()
-> ADM ()
diffMinMaxHist VjpOps
_ops VName
x StmAux ()
aux SubExp
n BinOp
minmax SubExp
ne VName
is VName
vs SubExp
w SubExp
rf VName
dst ADM ()
m = do
let t :: PrimType
t = BinOp -> PrimType
binOpType BinOp
minmax
vs_type <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
vs
let vs_elm_type = Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
vs_type
let vs_dims = Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
vs_type
let inner_dims = [SubExp] -> [SubExp]
forall a. HasCallStack => [a] -> [a]
tail [SubExp]
vs_dims
let nr_dims = [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs_dims
dst_type <- lookupType dst
let dst_dims = Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
dst_type
dst_cpy <-
letExp (baseString dst <> "_copy") . BasicOp $
Replicate mempty (Var dst)
acc_v_p <- newParam "acc_v" $ Prim t
acc_i_p <- newParam "acc_i" $ Prim int64
v_p <- newParam "v" $ Prim t
i_p <- newParam "i" $ Prim int64
hist_lam_inner <-
mkLambda [acc_v_p, acc_i_p, v_p, i_p] $
fmap varsRes . letTupExp "idx_res"
=<< eIf
(eCmpOp (CmpEq t) (eParam acc_v_p) (eParam v_p))
( eBody
[ eParam acc_v_p,
eBinOp (SMin Int64) (eParam acc_i_p) (eParam i_p)
]
)
( eBody
[ eIf
( eCmpOp
(CmpEq t)
(eParam acc_v_p)
(eBinOp minmax (eParam acc_v_p) (eParam v_p))
)
(eBody [eParam acc_v_p, eParam acc_i_p])
(eBody [eParam v_p, eParam i_p])
]
)
hist_lam <- nestedmap inner_dims [vs_elm_type, int64, vs_elm_type, int64] hist_lam_inner
dst_minus_ones <-
letExp "minus_ones" . BasicOp $
Replicate (Shape dst_dims) (intConst Int64 (-1))
ne_minus_ones <-
letSubExp "minus_ones" . BasicOp $
Replicate (Shape inner_dims) (intConst Int64 (-1))
iota_n <-
letExp "red_iota" . BasicOp $
Iota n (intConst Int64 0) (intConst Int64 1) Int64
inp_iota <- do
if nr_dims == 1
then pure iota_n
else do
i <- newParam "i" $ Prim int64
lam <-
mkLambda [i] $
fmap varsRes . letTupExp "res" =<< do
pure $ BasicOp $ Replicate (Shape inner_dims) $ Var $ paramName i
letExp "res" $ Op $ Screma n [iota_n] $ mapSOAC lam
let hist_op = Shape
-> SubExp -> [VName] -> [SubExp] -> Lambda SOACS -> HistOp SOACS
forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
rf [VName
dst_cpy, VName
dst_minus_ones] [SubExp
ne, if Int
nr_dims Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 then IntType -> Integer -> SubExp
intConst IntType
Int64 (-Integer
1) else SubExp
ne_minus_ones] Lambda SOACS
hist_lam
f' <- mkIdentityLambda [Prim int64, rowType vs_type, rowType $ Array int64 (Shape vs_dims) NoUniqueness]
x_inds <- newVName (baseString x <> "_inds")
auxing aux $
letBindNames [x, x_inds] $
Op $
Hist n [is, vs, inp_iota] [hist_op] f'
m
x_bar <- lookupAdjVal x
x_ind_dst <- newParam (baseString x <> "_ind_param") $ Prim int64
x_bar_dst <- newParam (baseString x <> "_bar_param") $ Prim t
dst_lam_inner <-
mkLambda [x_ind_dst, x_bar_dst] $
fmap varsRes . letTupExp "dst_bar"
=<< eIf
(toExp $ le64 (paramName x_ind_dst) .==. -1)
(eBody $ pure $ eParam x_bar_dst)
(eBody $ pure $ eSubExp $ Constant $ blankPrimValue t)
dst_lam <- nestedmap inner_dims [int64, vs_elm_type] dst_lam_inner
dst_bar <-
letExp (baseString dst <> "_bar") . Op $
Screma w [x_inds, x_bar] (mapSOAC dst_lam)
updateAdj dst dst_bar
vs_bar <- lookupAdjVal vs
inds' <- traverse (letExp "inds" . BasicOp . Replicate (Shape [w]) . Var) =<< mk_indices inner_dims []
let inds = VName
x_inds VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
inds'
par_x_ind_vs <- replicateM nr_dims $ newParam (baseString x <> "_ind_param") $ Prim int64
par_x_bar_vs <- newParam (baseString x <> "_bar_param") $ Prim t
vs_lam_inner <-
mkLambda (par_x_bar_vs : par_x_ind_vs) $
fmap varsRes . letTupExp "res"
=<< eIf
(toExp $ le64 (paramName $ head par_x_ind_vs) .==. -1)
(eBody $ pure $ eSubExp $ Constant $ blankPrimValue t)
( eBody $
pure $ do
vs_bar_i <-
letSubExp (baseString vs_bar <> "_el") . BasicOp $
Index vs_bar . Slice $
fmap (DimFix . Var . paramName) par_x_ind_vs
eBinOp (getBinOpPlus t) (eParam par_x_bar_vs) (eSubExp vs_bar_i)
)
vs_lam <- nestedmap inner_dims (vs_elm_type : replicate nr_dims int64) vs_lam_inner
vs_bar_p <-
letExp (baseString vs <> "_partial") . Op $
Screma w (x_bar : inds) (mapSOAC vs_lam)
q <-
letSubExp "q"
=<< foldBinOp (Mul Int64 OverflowUndef) (intConst Int64 1) dst_dims
scatter_inps <- do
forM (inds ++ [vs_bar_p]) $ \VName
v -> do
v_t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
letExp "flat" . BasicOp . Reshape v $
reshapeAll (arrayShape v_t) (Shape [q])
f'' <- mkIdentityLambda $ replicate nr_dims (Prim int64) ++ [Prim t]
vs_bar' <-
letExp (baseString vs <> "_bar") . Op $
Scatter q scatter_inps [(Shape vs_dims, 1, vs_bar)] f''
insAdj vs vs_bar'
where
mk_indices :: [SubExp] -> [SubExp] -> ADM [VName]
mk_indices :: [SubExp] -> [SubExp] -> ADM [VName]
mk_indices [] [SubExp]
_ = [VName] -> ADM [VName]
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
mk_indices [SubExp
d] [SubExp]
iotas = do
reps <- (SubExp -> ADM VName) -> [SubExp] -> ADM [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"rep" (Exp SOACS -> ADM VName)
-> (SubExp -> Exp SOACS) -> SubExp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS)
-> (SubExp -> BasicOp) -> SubExp -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
d])) [SubExp]
iotas
iota_d <-
letExp "red_iota" . BasicOp $
Iota d (intConst Int64 0) (intConst Int64 1) Int64
pure $ reps ++ [iota_d]
mk_indices (SubExp
d : [SubExp]
dims) [SubExp]
iotas = do
iota_d <-
String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"red_iota" (Exp SOACS -> ADM VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM VName) -> BasicOp -> ADM VName
forall a b. (a -> b) -> a -> b
$
SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
d (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
i_param <- newParam "i" $ Prim int64
lam <-
mkLambda [i_param] $
fmap varsRes $
mk_indices dims $
iotas ++ [Var $ paramName i_param]
letTupExp "res" $ Op $ Screma d [iota_d] $ mapSOAC lam
diffMulHist ::
VjpOps -> VName -> StmAux () -> SubExp -> BinOp -> SubExp -> VName -> VName -> SubExp -> SubExp -> VName -> ADM () -> ADM ()
diffMulHist :: VjpOps
-> VName
-> StmAux ()
-> SubExp
-> BinOp
-> SubExp
-> VName
-> VName
-> SubExp
-> SubExp
-> VName
-> ADM ()
-> ADM ()
diffMulHist VjpOps
_ops VName
x StmAux ()
aux SubExp
n BinOp
mul SubExp
ne VName
is VName
vs SubExp
w SubExp
rf VName
dst ADM ()
m = do
let t :: PrimType
t = BinOp -> PrimType
binOpType BinOp
mul
vs_type <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
vs
let vs_dims = Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
vs_type
let vs_elm_type = Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
vs_type
dst_type <- lookupType dst
let dst_dims = Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
dst_type
let inner_dims = [SubExp] -> [SubExp]
forall a. HasCallStack => [a] -> [a]
tail [SubExp]
vs_dims
v_param <- newParam "v" $ Prim t
lam_ps_zs_inner <-
mkLambda [v_param] $
fmap varsRes . letTupExp "map_res"
=<< eIf
(eCmpOp (CmpEq t) (eParam v_param) (eSubExp $ Constant $ blankPrimValue t))
(eBody $ fmap eSubExp [Constant $ onePrimValue t, intConst Int64 1])
(eBody [eParam v_param, eSubExp $ intConst Int64 0])
lam_ps_zs <- nestedmap vs_dims [vs_elm_type] lam_ps_zs_inner
ps_zs_res <- eLambda lam_ps_zs [eSubExp $ Var vs]
ps_zs <- bindSubExpRes "ps_zs" ps_zs_res
let [ps, zs] = ps_zs
lam_mul_inner <- binOpLambda mul t
lam_mul <- nestedmap inner_dims [vs_elm_type, vs_elm_type] lam_mul_inner
nz_prods0 <- letExp "nz_prd" $ BasicOp $ Replicate (Shape [w]) ne
let hist_nzp = Shape
-> SubExp -> [VName] -> [SubExp] -> Lambda SOACS -> HistOp SOACS
forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
rf [VName
nz_prods0] [SubExp
ne] Lambda SOACS
lam_mul
lam_add_inner <- binOpLambda (Add Int64 OverflowUndef) int64
lam_add <- nestedmap inner_dims [int64, int64] lam_add_inner
zr_counts0 <- letExp "zr_cts" $ BasicOp $ Replicate (Shape dst_dims) (intConst Int64 0)
zrn_ne <- letSubExp "zr_ne" $ BasicOp $ Replicate (Shape inner_dims) (intConst Int64 0)
let hist_zrn = Shape
-> SubExp -> [VName] -> [SubExp] -> Lambda SOACS -> HistOp SOACS
forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
rf [VName
zr_counts0] [if [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs_dims Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 then IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0 else SubExp
zrn_ne] Lambda SOACS
lam_add
f' <- mkIdentityLambda [Prim int64, Prim int64, rowType vs_type, rowType $ Array int64 (Shape vs_dims) NoUniqueness]
nz_prods <- newVName "non_zero_prod"
zr_counts <- newVName "zero_count"
auxing aux $
letBindNames [nz_prods, zr_counts] $
Op $
Hist n [is, is, ps, zs] [hist_nzp, hist_zrn] f'
p_param <- newParam "prod" $ Prim t
c_param <- newParam "count" $ Prim int64
lam_h_part_inner <-
mkLambda [p_param, c_param] $
fmap varsRes . letTupExp "h_part"
=<< eIf
(toExp $ 0 .==. le64 (paramName c_param))
(eBody $ pure $ eParam p_param)
(eBody $ pure $ eSubExp $ Constant $ blankPrimValue t)
lam_h_part <- nestedmap dst_dims [vs_elm_type, int64] lam_h_part_inner
h_part_res <- eLambda lam_h_part $ map (eSubExp . Var) [nz_prods, zr_counts]
h_part' <- bindSubExpRes "h_part" h_part_res
let [h_part] = h_part'
lam_mul_inner' <- binOpLambda mul t
lam_mul' <- nestedmap dst_dims [vs_elm_type, vs_elm_type] lam_mul_inner'
x_res <- eLambda lam_mul' $ map (eSubExp . Var) [dst, h_part]
x' <- bindSubExpRes "x" x_res
auxing aux $ letBindNames [x] $ BasicOp $ SubExp $ Var $ head x'
m
x_bar <- lookupAdjVal x
lam_mul'' <- renameLambda lam_mul'
dst_bar_res <- eLambda lam_mul'' $ map (eSubExp . Var) [h_part, x_bar]
dst_bar <- bindSubExpRes (baseString dst <> "_bar") dst_bar_res
updateAdj dst $ head dst_bar
lam_mul''' <- renameLambda lam_mul'
part_bar_res <- eLambda lam_mul''' $ map (eSubExp . Var) [dst, x_bar]
part_bar' <- bindSubExpRes "part_bar" part_bar_res
let [part_bar] = part_bar'
inner_params <- zipWithM newParam ["zr_cts", "pr_bar", "nz_prd", "a"] $ map Prim [int64, t, t, t]
let [zr_cts, pr_bar, nz_prd, a_param] = inner_params
lam_vsbar_inner <-
mkLambda inner_params $
fmap varsRes . letTupExp "vs_bar" =<< do
eIf
(eCmpOp (CmpEq int64) (eSubExp $ intConst Int64 0) (eParam zr_cts))
(eBody $ pure $ eBinOp mul (eParam pr_bar) $ eBinOp (getBinOpDiv t) (eParam nz_prd) $ eParam a_param)
( eBody $
pure $
eIf
( eBinOp
LogAnd
(eCmpOp (CmpEq int64) (eSubExp $ intConst Int64 1) (eParam zr_cts))
(eCmpOp (CmpEq t) (eSubExp $ Constant $ blankPrimValue t) $ eParam a_param)
)
(eBody $ pure $ eBinOp mul (eParam nz_prd) (eParam pr_bar))
(eBody $ pure $ eSubExp $ Constant $ blankPrimValue t)
)
lam_vsbar_middle <- nestedmap inner_dims [int64, t, t, t] lam_vsbar_inner
i_param <- newParam "i" $ Prim int64
a_param' <- newParam "a" $ rowType vs_type
lam_vsbar <-
mkLambda [i_param, a_param'] $
fmap varsRes . letTupExp "vs_bar"
=<< eIf
(toExp $ withinBounds $ pure (w, paramName i_param))
( buildBody_ $ do
let i = Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
vs_type [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
i_param]
names <- traverse newVName ["zr_cts", "pr_bar", "nz_prd"]
zipWithM_ (\VName
name -> [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
name] (Exp SOACS -> ADM ()) -> (VName -> Exp SOACS) -> VName -> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> (VName -> BasicOp) -> VName -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Slice SubExp -> BasicOp)
-> Slice SubExp -> VName -> BasicOp
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> Slice SubExp -> BasicOp
Index Slice SubExp
i) names [zr_counts, part_bar, nz_prods]
eLambda lam_vsbar_middle $ map (eSubExp . Var) names <> [eParam a_param']
)
(eBody $ pure $ pure $ zeroExp $ rowType dst_type)
vs_bar <-
letExp (baseString vs <> "_bar") $ Op $ Screma n [is, vs] $ mapSOAC lam_vsbar
updateAdj vs vs_bar
diffAddHist ::
VjpOps -> VName -> StmAux () -> SubExp -> Lambda SOACS -> SubExp -> VName -> VName -> SubExp -> SubExp -> VName -> ADM () -> ADM ()
diffAddHist :: VjpOps
-> VName
-> StmAux ()
-> SubExp
-> Lambda SOACS
-> SubExp
-> VName
-> VName
-> SubExp
-> SubExp
-> VName
-> ADM ()
-> ADM ()
diffAddHist VjpOps
_ops VName
x StmAux ()
aux SubExp
n Lambda SOACS
add SubExp
ne VName
is VName
vs SubExp
w SubExp
rf VName
dst ADM ()
m = do
let t :: Type
t = Param Type -> Type
forall dec. Param dec -> dec
paramDec (Param Type -> Type) -> Param Type -> Type
forall a b. (a -> b) -> a -> b
$ [Param Type] -> Param Type
forall a. HasCallStack => [a] -> a
head ([Param Type] -> Param Type) -> [Param Type] -> Param Type
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
add
dst_cpy <-
String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
dst String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_copy") (Exp SOACS -> ADM VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM VName) -> BasicOp -> ADM VName
forall a b. (a -> b) -> a -> b
$
Shape -> SubExp -> BasicOp
Replicate Shape
forall a. Monoid a => a
mempty (VName -> SubExp
Var VName
dst)
f <- mkIdentityLambda [Prim int64, t]
auxing aux . letBindNames [x] . Op $
Hist n [is, vs] [HistOp (Shape [w]) rf [dst_cpy] [ne] add] f
m
x_bar <- lookupAdjVal x
updateAdj dst x_bar
x_type <- lookupType x
i_param <- newParam (baseString vs <> "_i") $ Prim int64
let i = Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
i_param
lam_vsbar <-
mkLambda [i_param] $
fmap varsRes . letTupExp "vs_bar"
=<< eIf
(toExp $ withinBounds $ pure (w, i))
(eBody $ pure $ pure $ BasicOp $ Index x_bar $ fullSlice x_type [DimFix $ Var i])
(eBody $ pure $ eSubExp ne)
vs_bar <- letExp (baseString vs <> "_bar") $ Op $ Screma n [is] $ mapSOAC lam_vsbar
updateAdj vs vs_bar
diffVecHist ::
VjpOps ->
VName ->
StmAux () ->
SubExp ->
Lambda SOACS ->
VName ->
VName ->
VName ->
SubExp ->
SubExp ->
VName ->
ADM () ->
ADM ()
diffVecHist :: VjpOps
-> VName
-> StmAux ()
-> SubExp
-> Lambda SOACS
-> VName
-> VName
-> VName
-> SubExp
-> SubExp
-> VName
-> ADM ()
-> ADM ()
diffVecHist VjpOps
ops VName
x StmAux ()
aux SubExp
n Lambda SOACS
op VName
nes VName
is VName
vss SubExp
w SubExp
rf VName
dst ADM ()
m = do
stms <- ADM () -> ADM (Stms (Rep ADM))
forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ (ADM () -> ADM (Stms (Rep ADM))) -> ADM () -> ADM (Stms (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
rank <- Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (Type -> Int) -> ADM Type -> ADM Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
vss
let dims = [Int
1, Int
0] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
drop Int
2 [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
dstT <- letExp "dstT" $ BasicOp $ Rearrange dst dims
vssT <- letExp "vssT" $ BasicOp $ Rearrange vss dims
t_dstT <- lookupType dstT
t_vssT <- lookupType vssT
t_nes <- lookupType nes
dst_col <- newParam "dst_col" $ rowType t_dstT
vss_col <- newParam "vss_col" $ rowType t_vssT
ne <- newParam "ne" $ rowType t_nes
f <- mkIdentityLambda (Prim int64 : lambdaReturnType op)
map_lam <-
mkLambda [dst_col, vss_col, ne] $ do
dst_col_cpy <-
letExp "dst_col_cpy" . BasicOp $
Replicate mempty (Var $ paramName dst_col)
fmap (varsRes . pure) . letExp "col_res" $
Op $
Hist
n
[is, paramName vss_col]
[HistOp (Shape [w]) rf [dst_col_cpy] [Var $ paramName ne] op]
f
histT <-
letExp "histT" . Op $
Screma (arraySize 0 t_dstT) [dstT, vssT, nes] $
mapSOAC map_lam
auxing aux . letBindNames [x] . BasicOp $ Rearrange histT dims
foldr (vjpStm ops) m stms
radixSortStep :: [VName] -> [Type] -> SubExp -> SubExp -> SubExp -> ADM [VName]
radixSortStep :: [VName] -> [Type] -> SubExp -> SubExp -> SubExp -> ADM [VName]
radixSortStep [VName]
xs [Type]
tps SubExp
bit SubExp
n SubExp
w = do
is <- VName -> SubExp -> SubExp -> ADM VName
mapout ([VName] -> VName
forall a. HasCallStack => [a] -> a
head [VName]
xs) SubExp
n SubExp
w
num_param <- newParam "num" $ Prim int64
num_lam <-
mkLambda [num_param] $
fmap varsRes . letTupExp "num_res"
=<< eBinOp
(Add Int64 OverflowUndef)
( eBinOp
(And Int64)
(eBinOp (AShr Int64) (eParam num_param) (eSubExp bit))
(iConst 1)
)
( eBinOp
(Mul Int64 OverflowUndef)
(iConst 2)
( eBinOp
(And Int64)
(eBinOp (AShr Int64) (eParam num_param) (eBinOp (Add Int64 OverflowUndef) (eSubExp bit) (iConst 1)))
(iConst 1)
)
)
bins <- letExp "bins" $ Op $ Screma n [is] $ mapSOAC num_lam
flag_param <- newParam "flag" $ Prim int64
flag_lam <-
mkLambda [flag_param] $
fmap varsRes . letTupExp "flag_res"
=<< elseIf
int64
(map ((,) (eParam flag_param) . iConst) [0 .. 2])
(map (eBody . fmap iConst . (\Integer
i -> (Integer -> Integer) -> [Integer] -> [Integer]
forall a b. (a -> b) -> [a] -> [b]
map (\Integer
j -> if Integer
i Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
j then Integer
1 else Integer
0) [Integer
0 .. Integer
3])) ([0 .. 3] :: [Integer]))
flags <- letTupExp "flags" $ Op $ Screma n [bins] $ mapSOAC flag_lam
scan_params <- traverse (flip newParam $ Prim int64) ["a1", "b1", "c1", "d1", "a2", "b2", "c2", "d2"]
scan_lam <-
mkLambda scan_params $
fmap subExpsRes . mapM (letSubExp "scan_res") =<< do
uncurry (zipWithM (eBinOp $ Add Int64 OverflowUndef)) $ splitAt 4 $ map eParam scan_params
scan <- scanSOAC $ pure $ Scan scan_lam $ map (intConst Int64) [0, 0, 0, 0]
offsets <- letTupExp "offsets" $ Op $ Screma n flags scan
ind <- letSubExp "ind_last" =<< eBinOp (Sub Int64 OverflowUndef) (eSubExp n) (iConst 1)
let i = [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
ind]
nabcd <- traverse newVName ["na", "nb", "nc", "nd"]
zipWithM_ (\VName
abcd -> [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
abcd] (Exp SOACS -> ADM ()) -> (VName -> Exp SOACS) -> VName -> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> (VName -> BasicOp) -> VName -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Slice SubExp -> BasicOp)
-> Slice SubExp -> VName -> BasicOp
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> Slice SubExp -> BasicOp
Index Slice SubExp
i) nabcd offsets
let vars = (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
nabcd
map_params <- traverse (flip newParam $ Prim int64) ["bin", "a", "b", "c", "d"]
map_lam <-
mkLambda map_params $
fmap varsRes . letTupExp "map_res"
=<< elseIf
int64
(map ((,) (eParam $ head map_params) . iConst) [0 .. 2])
( zipWith
( \Int
j Param Type
p ->
[ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$
ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ do
t <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"t" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (IntType -> Overflow -> BinOp
Sub IntType
Int64 Overflow
OverflowUndef) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
p) (Integer -> ADM (Exp (Rep ADM))
forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst Integer
1)
foldBinOp (Add Int64 OverflowUndef) (intConst Int64 0) (t : take j vars)
)
[0 .. 3]
(tail map_params)
)
nis <- letExp "nis" $ Op $ Screma n (bins : offsets) $ mapSOAC map_lam
scatter_dst <- traverse (\Type
t -> String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"scatter_dst" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ PrimType -> [SubExp] -> BasicOp
Scratch (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t)) tps
multiScatter n scatter_dst nis xs
where
iConst :: Integer -> m (Exp (Rep m))
iConst Integer
c = SubExp -> m (Exp (Rep m))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> m (Exp (Rep m))) -> SubExp -> m (Exp (Rep m))
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
c
radixSort :: [VName] -> SubExp -> SubExp -> ADM [VName]
radixSort :: [VName] -> SubExp -> SubExp -> ADM [VName]
radixSort [VName]
xs SubExp
n SubExp
w = do
logw <- SubExp -> ADM SubExp
log2 (SubExp -> ADM SubExp) -> ADM SubExp -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"w1" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
w TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1)
iters <- letSubExp "iters" =<< toExp (untyped (pe64 logw + 1) ~/~ untyped (pe64 (intConst Int64 2)))
types <- traverse lookupType xs
params <- zipWithM (\VName
x -> String
-> TypeBase Shape Uniqueness
-> ADM (Param (TypeBase Shape Uniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (VName -> String
baseString VName
x) (TypeBase Shape Uniqueness
-> ADM (Param (TypeBase Shape Uniqueness)))
-> (Type -> TypeBase Shape Uniqueness)
-> Type
-> ADM (Param (TypeBase Shape Uniqueness))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Uniqueness -> TypeBase Shape Uniqueness)
-> Uniqueness -> Type -> TypeBase Shape Uniqueness
forall a b c. (a -> b -> c) -> b -> a -> c
flip Type -> Uniqueness -> TypeBase Shape Uniqueness
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Uniqueness
Nonunique) xs types
i <- newVName "i"
loopbody <- buildBody_ . localScope (scopeOfFParams params) $
fmap varsRes $ do
bit <- letSubExp "bit" =<< toExp (le64 i * 2)
radixSortStep (map paramName params) types bit n w
letTupExp "sorted" $
Loop
(zip params $ map Var xs)
(ForLoop i Int64 iters)
loopbody
where
log2 :: SubExp -> ADM SubExp
log2 :: SubExp -> ADM SubExp
log2 SubExp
m = do
params <- (String
-> TypeBase Shape Uniqueness
-> ADM (Param (TypeBase Shape Uniqueness)))
-> [String]
-> [TypeBase Shape Uniqueness]
-> ADM [Param (TypeBase Shape Uniqueness)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM String
-> TypeBase Shape Uniqueness
-> ADM (Param (TypeBase Shape Uniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam [String
"cond", String
"r", String
"i"] ([TypeBase Shape Uniqueness]
-> ADM [Param (TypeBase Shape Uniqueness)])
-> [TypeBase Shape Uniqueness]
-> ADM [Param (TypeBase Shape Uniqueness)]
forall a b. (a -> b) -> a -> b
$ (PrimType -> TypeBase Shape Uniqueness)
-> [PrimType] -> [TypeBase Shape Uniqueness]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> TypeBase Shape Uniqueness
forall shape u. PrimType -> TypeBase shape u
Prim [PrimType
Bool, PrimType
int64, PrimType
int64]
let [cond, r, i] = params
body <- buildBody_ . localScope (scopeOfFParams params) $ do
r' <- letSubExp "r'" =<< toExp (le64 (paramName r) .>>. 1)
cond' <- letSubExp "cond'" =<< toExp (bNot $ pe64 r' .==. 0)
i' <- letSubExp "i'" =<< toExp (le64 (paramName i) + 1)
pure $ subExpsRes [cond', r', i']
cond_init <- letSubExp "test" =<< toExp (bNot $ pe64 m .==. 0)
l <-
letTupExp' "log2res" $
Loop
(zip params [cond_init, m, Constant $ blankPrimValue int64])
(WhileLoop $ paramName cond)
body
let [_, _, res] = l
pure res
radixSort' :: [VName] -> SubExp -> SubExp -> ADM [VName]
radixSort' :: [VName] -> SubExp -> SubExp -> ADM [VName]
radixSort' [VName]
xs SubExp
n SubExp
w = do
iota_n <-
String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"red_iota" (Exp SOACS -> ADM VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM VName) -> BasicOp -> ADM VName
forall a b. (a -> b) -> a -> b
$
SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
n (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
radres <- radixSort [head xs, iota_n] n w
let [is', iota'] = radres
i_param <- newParam "i" $ Prim int64
let slice = [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
i_param]
map_lam <- mkLambda [i_param] $ varsRes <$> multiIndex (tail xs) slice
sorted <- letTupExp "sorted" $ Op $ Screma n [iota'] $ mapSOAC map_lam
pure $ iota' : is' : sorted
diffHist :: VjpOps -> [VName] -> StmAux () -> SubExp -> Lambda SOACS -> [SubExp] -> [VName] -> [SubExp] -> SubExp -> [VName] -> ADM () -> ADM ()
diffHist :: VjpOps
-> [VName]
-> StmAux ()
-> SubExp
-> Lambda SOACS
-> [SubExp]
-> [VName]
-> [SubExp]
-> SubExp
-> [VName]
-> ADM ()
-> ADM ()
diffHist VjpOps
ops [VName]
xs StmAux ()
aux SubExp
n Lambda SOACS
lam0 [SubExp]
ne [VName]
as [SubExp]
w SubExp
rf [VName]
dst ADM ()
m = do
as_type <- (VName -> ADM Type) -> [VName] -> ADM [Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType ([VName] -> ADM [Type]) -> [VName] -> ADM [Type]
forall a b. (a -> b) -> a -> b
$ [VName] -> [VName]
forall a. HasCallStack => [a] -> [a]
tail [VName]
as
dst_type <- traverse lookupType dst
nes <- traverse (letExp "new_dst" . BasicOp . Replicate (Shape $ pure $ head w)) ne
h_map <- mkIdentityLambda $ Prim int64 : map rowType as_type
h_part <- traverse (newVName . flip (<>) "_h_part" . baseString) xs
auxing aux . letBindNames h_part . Op $
Hist n as [HistOp (Shape w) rf nes ne lam0] h_map
lam0' <- renameLambda lam0
auxing aux . letBindNames xs . Op $
Screma (head w) (dst <> h_part) (mapSOAC lam0')
m
xs_bar <- traverse lookupAdjVal xs
(dst_params, hp_params, f') <- mkF' lam0 dst_type $ head w
f'_adj_dst <- vjpLambda ops (map adjFromVar xs_bar) dst_params f'
f'_adj_hp <- vjpLambda ops (map adjFromVar xs_bar) hp_params f'
dst_bar' <- eLambda f'_adj_dst $ map (eSubExp . Var) $ dst <> h_part
dst_bar <- bindSubExpRes "dst_bar" dst_bar'
zipWithM_ updateAdj dst dst_bar
h_part_bar' <- eLambda f'_adj_hp $ map (eSubExp . Var) $ dst <> h_part
h_part_bar <- bindSubExpRes "h_part_bar" h_part_bar'
lam <- renameLambda lam0
lam' <- renameLambda lam0
sorted <- radixSort' as n $ head w
let siota = [VName] -> VName
forall a. HasCallStack => [a] -> a
head [VName]
sorted
let sis = [VName] -> VName
forall a. HasCallStack => [a] -> a
head ([VName] -> VName) -> [VName] -> VName
forall a b. (a -> b) -> a -> b
$ [VName] -> [VName]
forall a. HasCallStack => [a] -> [a]
tail [VName]
sorted
let sas = Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
drop Int
2 [VName]
sorted
iota_n <-
letExp "iota" $ BasicOp $ Iota n (intConst Int64 0) (intConst Int64 1) Int64
par_i <- newParam "i" $ Prim int64
flag_lam <- mkFlagLam par_i sis
flag <- letExp "flag" $ Op $ Screma n [iota_n] $ mapSOAC flag_lam
par_i' <- newParam "i" $ Prim int64
let i' = Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
par_i'
g_lam <-
mkLambda [par_i'] $
fmap subExpsRes . mapM (letSubExp "scan_inps") =<< do
im1 <- letSubExp "i_1" =<< toExp (le64 i' - 1)
nmi <- letSubExp "n_i" =<< toExp (pe64 n - le64 i')
let s1 = [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
im1]
let s2 = [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
nmi]
f1 <- letSubExp "f1" $ BasicOp $ Index flag $ Slice [DimFix $ Var i']
r1 <-
letTupExp' "r1"
=<< eIf
(eSubExp f1)
(eBody $ fmap eSubExp ne)
(eBody . fmap (eSubExp . Var) =<< multiIndex sas s1)
r2 <-
letTupExp' "r2"
=<< eIf
(toExp $ le64 i' .==. 0)
(eBody $ fmap eSubExp $ Constant (onePrimValue Bool) : ne)
( eBody $
pure $ do
eIf
(pure $ BasicOp $ Index flag $ Slice s2)
(eBody $ fmap eSubExp $ Constant (onePrimValue Bool) : ne)
( eBody . fmap eSubExp . (Constant (blankPrimValue Bool) :) . fmap Var
=<< multiIndex sas s2
)
)
traverse eSubExp $ f1 : r1 ++ r2
scan_lams <-
traverse
( \Lambda SOACS
l -> do
f1 <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"f1" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Bool
f2 <- newParam "f2" $ Prim Bool
ps <- lambdaParams <$> renameLambda lam0
let (p1, p2) = splitAt (length ne) ps
mkLambda (f1 : p1 ++ f2 : p2) $
fmap varsRes . letTupExp "scan_res" =<< do
let f = BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp BinOp
LogOr (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
f1) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
f2)
eIf
(eParam f2)
(eBody $ f : fmap eParam p2)
( eBody . (f :) . fmap (eSubExp . Var)
=<< bindSubExpRes "gres"
=<< eLambda l (fmap eParam ps)
)
)
[lam, lam']
let ne' = PrimValue -> SubExp
Constant (Bool -> PrimValue
BoolValue Bool
False) SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: [SubExp]
ne
scansres <-
letTupExp "adj_ctrb_scan" . Op $
Screma n [iota_n] (scanomapSOAC (map (`Scan` ne') scan_lams) g_lam)
let (_ : ls_arr, _ : rs_arr_rev) = splitAt (length ne + 1) scansres
par_i'' <- newParam "i" $ Prim int64
let i'' = Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
par_i''
map_lam <-
mkLambda [par_i''] $
fmap varsRes . letTupExp "scan_res"
=<< eIf
(toExp $ withinBounds $ pure (head w, i''))
(eBody . fmap (eSubExp . Var) =<< multiIndex h_part_bar [DimFix $ Var i''])
( eBody $ do
map (\Type
t -> Exp SOACS -> ADM (Exp SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp SOACS -> ADM (Exp SOACS)) -> Exp SOACS -> ADM (Exp SOACS)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [SubExp]
forall a. HasCallStack => [a] -> [a]
tail ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t) (PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue (PrimType -> PrimValue) -> PrimType -> PrimValue
forall a b. (a -> b) -> a -> b
$ Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t)) as_type
)
f_bar <- letTupExp "f_bar" $ Op $ Screma n [sis] $ mapSOAC map_lam
(as_params, f) <- mkF lam0 as_type n
f_adj <- vjpLambda ops (map adjFromVar f_bar) as_params f
par_i''' <- newParam "i" $ Prim int64
let i''' = Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
par_i'''
rev_lam <- mkLambda [par_i'''] $ do
nmim1 <- letSubExp "n_i_1" =<< toExp (pe64 n - le64 i''' - 1)
varsRes <$> multiIndex rs_arr_rev [DimFix nmim1]
rs_arr <- letTupExp "rs_arr" $ Op $ Screma n [iota_n] $ mapSOAC rev_lam
sas_bar <-
bindSubExpRes "sas_bar"
=<< eLambda f_adj (map (eSubExp . Var) $ ls_arr <> sas <> rs_arr)
scatter_dst <- traverse (\Type
t -> String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"scatter_dst" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ PrimType -> [SubExp] -> BasicOp
Scratch (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t)) as_type
as_bar <- multiScatter n scatter_dst siota sas_bar
zipWithM_ updateAdj (tail as) as_bar
where
mkFlagLam :: LParam SOACS -> VName -> ADM (Lambda SOACS)
mkFlagLam :: LParam SOACS -> VName -> ADM (Lambda SOACS)
mkFlagLam LParam SOACS
par_i VName
sis =
[LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [LParam (Rep ADM)
LParam SOACS
par_i] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"flag" (Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< do
let i :: VName
i = Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
LParam SOACS
par_i
ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0))
([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp (Rep ADM))) -> SubExp -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
onePrimValue PrimType
Bool)
( [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$
ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ do
i_p <- String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"i_p" (Exp SOACS -> ADM VName) -> ADM (Exp SOACS) -> ADM VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)
vs <- traverse (letExp "vs" . BasicOp . Index sis . Slice . pure . DimFix . Var) [i, i_p]
let [vs_i, vs_p] = vs
toExp $ bNot $ le64 vs_i .==. le64 vs_p
)