{-# LANGUAGE TypeFamilies #-}
module Futhark.AD.Rev (revVJP) where
import Control.Monad
import Data.List ((\\))
import Data.List.NonEmpty (NonEmpty (..))
import Data.Map qualified as M
import Futhark.AD.Derivatives
import Futhark.AD.Rev.Loop
import Futhark.AD.Rev.Monad
import Futhark.AD.Rev.SOAC
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.SOACS
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util (takeLast)
patName :: Pat Type -> ADM VName
patName :: Pat Type -> ADM VName
patName (Pat [PatElem Type
pe]) = VName -> ADM VName
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> ADM VName) -> VName -> ADM VName
forall a b. (a -> b) -> a -> b
$ PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName PatElem Type
pe
patName Pat Type
pat = String -> ADM VName
forall a. HasCallStack => String -> a
error (String -> ADM VName) -> String -> ADM VName
forall a b. (a -> b) -> a -> b
$ String
"Expected single-element pattern: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Pat Type -> String
forall a. Pretty a => a -> String
prettyString Pat Type
pat
copyIfArray :: VName -> ADM VName
copyIfArray :: VName -> ADM VName
copyIfArray VName
v = do
v_t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
case v_t of
Array {} ->
String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
v String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_copy") (Exp SOACS -> ADM VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM VName) -> BasicOp -> ADM VName
forall a b. (a -> b) -> a -> b
$
Shape -> SubExp -> BasicOp
Replicate Shape
forall a. Monoid a => a
mempty (VName -> SubExp
Var VName
v)
Type
_ -> VName -> ADM VName
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v
commonBasicOp :: Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp :: Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
op ADM ()
m = do
Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep ADM) -> ADM ()) -> Stm (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep ADM))
-> StmAux (ExpDec (Rep ADM)) -> Exp (Rep ADM) -> Stm (Rep ADM)
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec (Rep ADM))
pat StmAux ()
StmAux (ExpDec (Rep ADM))
aux (Exp (Rep ADM) -> Stm (Rep ADM)) -> Exp (Rep ADM) -> Stm (Rep ADM)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp BasicOp
op
ADM ()
m
pat_v <- Pat Type -> ADM VName
patName Pat Type
pat
pat_adj <- lookupAdjVal pat_v
pure (pat_v, pat_adj)
diffBasicOp :: Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM ()
diffBasicOp :: Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM ()
diffBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m =
case BasicOp
e of
CmpOp {} ->
ADM (VName, VName) -> ADM ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ADM (VName, VName) -> ADM ()) -> ADM (VName, VName) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
ConvOp ConvOp
op SubExp
x -> do
(_pat_v, pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
returnSweepCode $ do
contrib <-
letExp "contrib" $ BasicOp $ ConvOp (flipConvOp op) $ Var pat_adj
updateSubExpAdj x contrib
UnOp UnOp
op SubExp
x -> do
(_pat_v, pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
returnSweepCode $ do
let t = UnOp -> PrimType
unOpType UnOp
op
contrib <- do
let x_pe = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t SubExp
x
pat_adj' = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t (VName -> SubExp
Var VName
pat_adj)
dx = UnOp -> PrimExp VName -> PrimExp VName
pdUnOp UnOp
op PrimExp VName
x_pe
letExp "contrib" <=< toExp $ pat_adj' ~*~ dx
updateSubExpAdj x contrib
BinOp BinOp
op SubExp
x SubExp
y -> do
(_pat_v, pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
returnSweepCode $ do
let t = BinOp -> PrimType
binOpType BinOp
op
(wrt_x, wrt_y) =
pdBinOp op (primExpFromSubExp t x) (primExpFromSubExp t y)
pat_adj' = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t (SubExp -> PrimExp VName) -> SubExp -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
pat_adj
adj_x <- letExp "binop_x_adj" <=< toExp $ pat_adj' ~*~ wrt_x
adj_y <- letExp "binop_y_adj" <=< toExp $ pat_adj' ~*~ wrt_y
updateSubExpAdj x adj_x
updateSubExpAdj y adj_y
SubExp SubExp
se -> do
(_pat_v, pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
returnSweepCode $ updateSubExpAdj se pat_adj
Assert {} ->
ADM (VName, VName) -> ADM ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ADM (VName, VName) -> ADM ()) -> ADM (VName, VName) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
ArrayVal {} ->
ADM (VName, VName) -> ADM ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ADM (VName, VName) -> ADM ()) -> ADM (VName, VName) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
ArrayLit [SubExp]
elems Type
_ -> do
(_pat_v, pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
t <- lookupType pat_adj
returnSweepCode $ do
forM_ (zip [(0 :: Int64) ..] elems) $ \(Int64
i, SubExp
se) -> do
let slice :: Slice SubExp
slice = Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant Int64
i)]
SubExp -> VName -> ADM ()
updateSubExpAdj SubExp
se (VName -> ADM ())
-> (Exp SOACS -> ADM VName) -> Exp SOACS -> ADM ()
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"elem_adj" (Exp SOACS -> ADM ()) -> Exp SOACS -> ADM ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
pat_adj Slice SubExp
slice
Index VName
arr Slice SubExp
slice -> do
(_pat_v, pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
returnSweepCode $ void $ updateAdjSlice slice arr pat_adj
FlatIndex {} -> String -> ADM ()
forall a. HasCallStack => String -> a
error String
"FlatIndex not handled by AD yet."
FlatUpdate {} -> String -> ADM ()
forall a. HasCallStack => String -> a
error String
"FlatUpdate not handled by AD yet."
Opaque OpaqueOp
_ SubExp
se -> do
(_pat_v, pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
returnSweepCode $ updateSubExpAdj se pat_adj
Reshape VName
arr NewShape SubExp
newshape -> do
(_pat_v, pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
returnSweepCode $ do
arr_shape <- arrayShape <$> lookupType arr
void $
updateAdj arr <=< letExp "adj_reshape" . BasicOp $
Reshape pat_adj (reshapeAll (newShape newshape) arr_shape)
Rearrange VName
arr [Int]
perm -> do
(_pat_v, pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
returnSweepCode $
void $
updateAdj arr <=< letExp "adj_rearrange" . BasicOp $
Rearrange pat_adj (rearrangeInverse perm)
Replicate (Shape []) (Var VName
se) -> do
(_pat_v, pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
returnSweepCode $ void $ updateAdj se pat_adj
Replicate (Shape [SubExp]
ns) SubExp
x -> do
(_pat_v, pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
returnSweepCode $ do
x_t <- subExpType x
lam <- addLambda x_t
ne <- letSubExp "zero" $ zeroExp x_t
n <- letSubExp "rep_size" =<< foldBinOp (Mul Int64 OverflowUndef) (intConst Int64 1) ns
pat_adj_flat <-
letExp (baseString pat_adj <> "_flat") . BasicOp $
Reshape pat_adj (reshapeAll (Shape ns) (Shape $ n : arrayDims x_t))
reduce <- reduceSOAC [Reduce Commutative lam [ne]]
updateSubExpAdj x
=<< letExp "rep_contrib" (Op $ Screma n [pat_adj_flat] reduce)
Concat Int
d (VName
arr :| [VName]
arrs) SubExp
_ -> do
(_pat_v, pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
returnSweepCode $ do
let sliceAdj SubExp
_ [] = [VName] -> ADM [VName]
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
sliceAdj SubExp
start (VName
v : [VName]
vs) = do
v_t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
let w = Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
v_t
slice = SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
start SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)
pat_adj_slice <-
letExp (baseString pat_adj <> "_slice") $
BasicOp $
Index pat_adj (sliceAt v_t d [slice])
start' <- letSubExp "start" $ BasicOp $ BinOp (Add Int64 OverflowUndef) start w
slices <- sliceAdj start' vs
pure $ pat_adj_slice : slices
slices <- sliceAdj (intConst Int64 0) $ arr : arrs
zipWithM_ updateAdj (arr : arrs) slices
Manifest VName
se [Int]
_ -> do
(_pat_v, pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
returnSweepCode $ void $ updateAdj se pat_adj
Scratch {} ->
ADM (VName, VName) -> ADM ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ADM (VName, VName) -> ADM ()) -> ADM (VName, VName) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
Iota SubExp
n SubExp
_ SubExp
_ IntType
t -> do
(_pat_v, pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
returnSweepCode $ do
ne <- letSubExp "zero" $ zeroExp $ Prim $ IntType t
lam <- addLambda $ Prim $ IntType t
reduce <- reduceSOAC [Reduce Commutative lam [ne]]
updateSubExpAdj n
=<< letExp "iota_contrib" (Op $ Screma n [pat_adj] reduce)
Update Safety
safety VName
arr Slice SubExp
slice SubExp
v -> do
(_pat_v, pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
returnSweepCode $ do
v_adj <- letExp "update_val_adj" $ BasicOp $ Index pat_adj slice
v_adj_copy <- copyIfArray v_adj
updateSubExpAdj v v_adj_copy
zeroes <- letSubExp "update_zero" . zeroExp =<< subExpType v
void $
updateAdj arr
=<< letExp "update_src_adj" (BasicOp $ Update safety pat_adj slice zeroes)
UpdateAcc Safety
_ VName
_ [SubExp]
is [SubExp]
vs -> do
Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep ADM) -> ADM ()) -> Stm (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep ADM))
-> StmAux (ExpDec (Rep ADM)) -> Exp (Rep ADM) -> Stm (Rep ADM)
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec (Rep ADM))
pat StmAux ()
StmAux (ExpDec (Rep ADM))
aux (Exp (Rep ADM) -> Stm (Rep ADM)) -> Exp (Rep ADM) -> Stm (Rep ADM)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp BasicOp
e
ADM ()
m
pat_adjs <- (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM VName
lookupAdjVal (Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
pat)
returnSweepCode $ do
forM_ (zip pat_adjs vs) $ \(VName
adj, SubExp
v) -> do
adj_i <- String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"updateacc_val_adj" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
adj (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ (SubExp -> DimIndex SubExp) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix [SubExp]
is
updateSubExpAdj v adj_i
vjpOps :: VjpOps
vjpOps :: VjpOps
vjpOps =
VjpOps
{ vjpLambda :: [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda = [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
diffLambda,
vjpStm :: Stm SOACS -> ADM () -> ADM ()
vjpStm = Stm SOACS -> ADM () -> ADM ()
diffStm
}
diffStm :: Stm SOACS -> ADM () -> ADM ()
diffStm :: Stm SOACS -> ADM () -> ADM ()
diffStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (BasicOp BasicOp
e)) ADM ()
m =
Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM ()
diffBasicOp Pat Type
Pat (LetDec SOACS)
pat StmAux ()
StmAux (ExpDec SOACS)
aux BasicOp
e ADM ()
m
diffStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
_ (Apply Name
f [(SubExp, Diet)]
args [(RetType SOACS, RetAls)]
_ Safety
_)) ADM ()
m
| Just (PrimType
ret, [PrimType]
argts) <- Name
-> Map Name (PrimType, [PrimType]) -> Maybe (PrimType, [PrimType])
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Name
f Map Name (PrimType, [PrimType])
builtInFunctions = do
Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stm (Rep ADM)
Stm SOACS
stm
ADM ()
m
pat_adj <- VName -> ADM VName
lookupAdjVal (VName -> ADM VName) -> ADM VName -> ADM VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Pat Type -> ADM VName
patName Pat Type
Pat (LetDec SOACS)
pat
let arg_pes = (PrimType -> SubExp -> PrimExp VName)
-> [PrimType] -> [SubExp] -> [PrimExp VName]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PrimType -> SubExp -> PrimExp VName
primExpFromSubExp [PrimType]
argts (((SubExp, Diet) -> SubExp) -> [(SubExp, Diet)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, Diet) -> SubExp
forall a b. (a, b) -> a
fst [(SubExp, Diet)]
args)
pat_adj' = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
ret (VName -> SubExp
Var VName
pat_adj)
convert PrimType
ft PrimType
tt
| PrimType
ft PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
tt = PrimExp VName -> PrimExp VName
forall a. a -> a
id
convert (IntType IntType
ft) (IntType IntType
tt) = ConvOp -> PrimExp VName -> PrimExp VName
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (IntType -> IntType -> ConvOp
SExt IntType
ft IntType
tt)
convert (FloatType FloatType
ft) (FloatType FloatType
tt) = ConvOp -> PrimExp VName -> PrimExp VName
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (FloatType -> FloatType -> ConvOp
FPConv FloatType
ft FloatType
tt)
convert PrimType
Bool (FloatType FloatType
tt) = ConvOp -> PrimExp VName -> PrimExp VName
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (FloatType -> ConvOp
BToF FloatType
tt)
convert (FloatType FloatType
ft) PrimType
Bool = ConvOp -> PrimExp VName -> PrimExp VName
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (FloatType -> ConvOp
FToB FloatType
ft)
convert PrimType
ft PrimType
tt = String -> PrimExp VName -> PrimExp VName
forall a. HasCallStack => String -> a
error (String -> PrimExp VName -> PrimExp VName)
-> String -> PrimExp VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ String
"diffStm.convert: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ (Name, PrimType, PrimType) -> String
forall a. Pretty a => a -> String
prettyString (Name
f, PrimType
ft, PrimType
tt)
contribs <-
case pdBuiltin f arg_pes of
Maybe [PrimExp VName]
Nothing ->
String -> ADM [VName]
forall a. HasCallStack => String -> a
error (String -> ADM [VName]) -> String -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ String
"No partial derivative defined for builtin function: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Pretty a => a -> String
prettyString Name
f
Just [PrimExp VName]
derivs ->
[(PrimExp VName, PrimType)]
-> ((PrimExp VName, PrimType) -> ADM VName) -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([PrimExp VName] -> [PrimType] -> [(PrimExp VName, PrimType)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PrimExp VName]
derivs [PrimType]
argts) (((PrimExp VName, PrimType) -> ADM VName) -> ADM [VName])
-> ((PrimExp VName, PrimType) -> ADM VName) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ \(PrimExp VName
deriv, PrimType
argt) ->
String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"contrib" (Exp SOACS -> ADM VName)
-> (PrimExp VName -> ADM (Exp SOACS)) -> PrimExp VName -> ADM VName
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< PrimExp VName -> ADM (Exp (Rep ADM))
PrimExp VName -> ADM (Exp SOACS)
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
PrimExp VName -> m (Exp (Rep m))
toExp (PrimExp VName -> ADM (Exp SOACS))
-> (PrimExp VName -> PrimExp VName)
-> PrimExp VName
-> ADM (Exp SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> PrimType -> PrimExp VName -> PrimExp VName
convert PrimType
ret PrimType
argt (PrimExp VName -> ADM VName) -> PrimExp VName -> ADM VName
forall a b. (a -> b) -> a -> b
$ PrimExp VName
pat_adj' PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
~*~ PrimExp VName
deriv
zipWithM_ updateSubExpAdj (map fst args) contribs
diffStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
_ (Match [SubExp]
ses [Case (Body SOACS)]
cases Body SOACS
defbody MatchDec (BranchType SOACS)
_)) ADM ()
m = do
Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stm (Rep ADM)
Stm SOACS
stm
ADM ()
m
ADM () -> ADM ()
forall a. ADM a -> ADM a
returnSweepCode (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ do
let cases_free :: [Names]
cases_free = (Case (Body SOACS) -> Names) -> [Case (Body SOACS)] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map Case (Body SOACS) -> Names
forall a. FreeIn a => a -> Names
freeIn [Case (Body SOACS)]
cases
defbody_free :: Names
defbody_free = Body SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn Body SOACS
defbody
branches_free :: [VName]
branches_free = Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ Names
defbody_free Names -> [Names] -> [Names]
forall a. a -> [a] -> [a]
: [Names]
cases_free
adjs <- (VName -> ADM Adj) -> [VName] -> ADM [Adj]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM Adj
lookupAdj ([VName] -> ADM [Adj]) -> [VName] -> ADM [Adj]
forall a b. (a -> b) -> a -> b
$ Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
Pat (LetDec SOACS)
pat
branches_free_adj <-
( pure . takeLast (length branches_free)
<=< letTupExp "branch_adj"
<=< renameExp
)
=<< eMatch
ses
(map (fmap $ diffBody adjs branches_free) cases)
(diffBody adjs branches_free defbody)
forM_ (zip branches_free branches_free_adj) $ \(VName
v, VName
v_adj) ->
VName -> VName -> ADM ()
insAdj VName
v (VName -> ADM ()) -> ADM VName -> ADM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> ADM VName
copyIfArray VName
v_adj
diffStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op OpC SOACS SOACS
soac)) ADM ()
m =
(Stms (Rep ADM) -> Stms (Rep ADM)) -> ADM () -> ADM ()
forall (m :: * -> *) a.
MonadBuilder m =>
(Stms (Rep m) -> Stms (Rep m)) -> m a -> m a
censorStms ((Stm SOACS -> Stm SOACS) -> Stms SOACS -> Stms SOACS
forall a b. (a -> b) -> Seq a -> Seq b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Stm SOACS -> Stm SOACS
addAttrs) (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ VjpOps -> Pat Type -> StmAux () -> SOAC SOACS -> ADM () -> ADM ()
vjpSOAC VjpOps
vjpOps Pat Type
Pat (LetDec SOACS)
pat StmAux ()
StmAux (ExpDec SOACS)
aux OpC SOACS SOACS
SOAC SOACS
soac ADM ()
m
where
addAttrs :: Stm SOACS -> Stm SOACS
addAttrs Stm SOACS
stm
| Op OpC SOACS SOACS
_ <- Stm SOACS -> Exp SOACS
forall rep. Stm rep -> Exp rep
stmExp Stm SOACS
stm =
Attrs -> Stm SOACS -> Stm SOACS
forall rep. Attrs -> Stm rep -> Stm rep
attribute (StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
StmAux (ExpDec SOACS)
aux) Stm SOACS
stm
| Bool
otherwise = Stm SOACS
stm
diffStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux loop :: Exp SOACS
loop@Loop {}) ADM ()
m =
(Stms SOACS -> ADM ())
-> Pat Type -> StmAux () -> Exp SOACS -> ADM () -> ADM ()
diffLoop Stms SOACS -> ADM ()
diffStms Pat Type
Pat (LetDec SOACS)
pat StmAux ()
StmAux (ExpDec SOACS)
aux Exp SOACS
loop ADM ()
m
diffStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
_aux (WithAcc [WithAccInput SOACS]
inputs Lambda SOACS
lam)) ADM ()
m = do
Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stm (Rep ADM)
Stm SOACS
stm
ADM ()
m
ADM () -> ADM ()
forall a. ADM a -> ADM a
returnSweepCode (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ do
adjs <- (VName -> ADM Adj) -> [VName] -> ADM [Adj]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM Adj
lookupAdj ([VName] -> ADM [Adj]) -> [VName] -> ADM [Adj]
forall a b. (a -> b) -> a -> b
$ Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
Pat (LetDec SOACS)
pat
lam' <- renameLambda lam
free_vars <- filterM isActive $ namesToList $ freeIn lam'
free_accs <- filterM (fmap isAcc . lookupType) free_vars
let free_vars' = [VName]
free_vars [VName] -> [VName] -> [VName]
forall a. Eq a => [a] -> [a] -> [a]
\\ [VName]
free_accs
lam'' <- diffLambda' adjs free_vars' lam'
inputs' <- mapM renameInputLambda inputs
free_adjs <- letTupExp "with_acc_contrib" $ WithAcc inputs' lam''
zipWithM_ insAdj (arrs <> free_vars') free_adjs
where
arrs :: [VName]
arrs = (WithAccInput SOACS -> [VName]) -> [WithAccInput SOACS] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\(Shape
_, [VName]
as, Maybe (Lambda SOACS, [SubExp])
_) -> [VName]
as) [WithAccInput SOACS]
inputs
renameInputLambda :: (a, b, Maybe (Lambda rep, b)) -> m (a, b, Maybe (Lambda rep, b))
renameInputLambda (a
shape, b
as, Just (Lambda rep
f, b
nes)) = do
f' <- Lambda rep -> m (Lambda rep)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda rep
f
pure (shape, as, Just (f', nes))
renameInputLambda (a, b, Maybe (Lambda rep, b))
input = (a, b, Maybe (Lambda rep, b)) -> m (a, b, Maybe (Lambda rep, b))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a, b, Maybe (Lambda rep, b))
input
diffLambda' :: [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
diffLambda' [Adj]
res_adjs [VName]
get_adjs_for (Lambda [LParam SOACS]
params [Type]
ts Body SOACS
body) =
Scope SOACS -> ADM (Lambda SOACS) -> ADM (Lambda SOACS)
forall a. Scope SOACS -> ADM a -> ADM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([LParam SOACS] -> Scope SOACS
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [LParam SOACS]
params) (ADM (Lambda SOACS) -> ADM (Lambda SOACS))
-> ADM (Lambda SOACS) -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ do
Body () stms res <- [Adj] -> [VName] -> Body SOACS -> ADM (Body SOACS)
diffBody [Adj]
res_adjs [VName]
get_adjs_for Body SOACS
body
let body' = BodyDec SOACS -> Stms SOACS -> Result -> Body SOACS
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms SOACS
stms (Result -> Body SOACS) -> Result -> Body SOACS
forall a b. (a -> b) -> a -> b
$ Int -> Result -> Result
forall a. Int -> [a] -> [a]
take ([WithAccInput SOACS] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput SOACS]
inputs) Result
res Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> Int -> Result -> Result
forall a. Int -> [a] -> [a]
takeLast ([VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
get_adjs_for) Result
res
ts' <- mapM lookupType get_adjs_for
pure $ Lambda params (take (length inputs) ts <> ts') body'
diffStm Stm SOACS
stm ADM ()
_ = String -> ADM ()
forall a. HasCallStack => String -> a
error (String -> ADM ()) -> String -> ADM ()
forall a b. (a -> b) -> a -> b
$ String
"diffStm unhandled:\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Stm SOACS -> String
forall a. Pretty a => a -> String
prettyString Stm SOACS
stm
diffStms :: Stms SOACS -> ADM ()
diffStms :: Stms SOACS -> ADM ()
diffStms Stms SOACS
all_stms
| Just (Stm SOACS
stm, Stms SOACS
stms) <- Stms SOACS -> Maybe (Stm SOACS, Stms SOACS)
forall rep. Stms rep -> Maybe (Stm rep, Stms rep)
stmsHead Stms SOACS
all_stms = do
(subst, copy_stms) <- Stm SOACS -> ADM (Substitutions, Stms SOACS)
copyConsumedArrsInStm Stm SOACS
stm
let (stm', stms') = substituteNames subst (stm, stms)
diffStms copy_stms >> diffStm stm' (diffStms stms')
forM_ (M.toList subst) $ \(VName
from, VName
to) ->
VName -> Adj -> ADM ()
setAdj VName
from (Adj -> ADM ()) -> ADM Adj -> ADM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> ADM Adj
lookupAdj VName
to
| Bool
otherwise =
() -> ADM ()
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
preprocess :: Stms SOACS -> ADM (Stms SOACS)
preprocess :: Stms SOACS -> ADM (Stms SOACS)
preprocess = Stms SOACS -> ADM (Stms SOACS)
stripmineStms
diffBody :: [Adj] -> [VName] -> Body SOACS -> ADM (Body SOACS)
diffBody :: [Adj] -> [VName] -> Body SOACS -> ADM (Body SOACS)
diffBody [Adj]
res_adjs [VName]
get_adjs_for (Body () Stms SOACS
stms Result
res) = ADM (Body SOACS) -> ADM (Body SOACS)
forall a. ADM a -> ADM a
subAD (ADM (Body SOACS) -> ADM (Body SOACS))
-> ADM (Body SOACS) -> ADM (Body SOACS)
forall a b. (a -> b) -> a -> b
$
ADM (Body SOACS) -> ADM (Body SOACS)
forall a. ADM a -> ADM a
subSubsts (ADM (Body SOACS) -> ADM (Body SOACS))
-> ADM (Body SOACS) -> ADM (Body SOACS)
forall a b. (a -> b) -> a -> b
$ do
let onResult :: SubExpRes -> Adj -> ADM ()
onResult (SubExpRes Certs
_ (Constant PrimValue
_)) Adj
_ = () -> ADM ()
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
onResult (SubExpRes Certs
_ (Var VName
v)) Adj
v_adj = ADM () -> ADM ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> ADM ()
updateAdj VName
v (VName -> ADM ()) -> ADM VName -> ADM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Adj -> ADM VName
adjVal Adj
v_adj
(adjs, stms') <- ADM [VName] -> ADM ([VName], Stms (Rep ADM))
forall a. ADM a -> ADM (a, Stms (Rep ADM))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms (ADM [VName] -> ADM ([VName], Stms (Rep ADM)))
-> ADM [VName] -> ADM ([VName], Stms (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
(SubExpRes -> Adj -> ADM ()) -> Result -> [Adj] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ SubExpRes -> Adj -> ADM ()
onResult (Int -> Result -> Result
forall a. Int -> [a] -> [a]
takeLast ([Adj] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Adj]
res_adjs) Result
res) [Adj]
res_adjs
Stms SOACS -> ADM ()
diffStms (Stms SOACS -> ADM ()) -> ADM (Stms SOACS) -> ADM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Stms SOACS -> ADM (Stms SOACS)
preprocess Stms SOACS
stms
(VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM VName
lookupAdjVal [VName]
get_adjs_for
pure $ Body () stms' $ res <> varsRes adjs
diffLambda :: [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
diffLambda :: [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
diffLambda [Adj]
res_adjs [VName]
get_adjs_for (Lambda [LParam SOACS]
params [Type]
_ Body SOACS
body) =
Scope SOACS -> ADM (Lambda SOACS) -> ADM (Lambda SOACS)
forall a. Scope SOACS -> ADM a -> ADM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([LParam SOACS] -> Scope SOACS
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [LParam SOACS]
params) (ADM (Lambda SOACS) -> ADM (Lambda SOACS))
-> ADM (Lambda SOACS) -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ do
Body () stms res <- [Adj] -> [VName] -> Body SOACS -> ADM (Body SOACS)
diffBody [Adj]
res_adjs [VName]
get_adjs_for Body SOACS
body
let body' = BodyDec SOACS -> Stms SOACS -> Result -> Body SOACS
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms SOACS
stms (Result -> Body SOACS) -> Result -> Body SOACS
forall a b. (a -> b) -> a -> b
$ Int -> Result -> Result
forall a. Int -> [a] -> [a]
takeLast ([VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
get_adjs_for) Result
res
ts' <- mapM lookupType get_adjs_for
pure $ Lambda params ts' body'
revVJP :: (MonadFreshNames m) => Scope SOACS -> Lambda SOACS -> m (Lambda SOACS)
revVJP :: forall (m :: * -> *).
MonadFreshNames m =>
Scope SOACS -> Lambda SOACS -> m (Lambda SOACS)
revVJP Scope SOACS
scope (Lambda [LParam SOACS]
params [Type]
ts Body SOACS
body) =
ADM (Lambda SOACS) -> m (Lambda SOACS)
forall (m :: * -> *) a. MonadFreshNames m => ADM a -> m a
runADM (ADM (Lambda SOACS) -> m (Lambda SOACS))
-> (ADM (Lambda SOACS) -> ADM (Lambda SOACS))
-> ADM (Lambda SOACS)
-> m (Lambda SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope SOACS -> ADM (Lambda SOACS) -> ADM (Lambda SOACS)
forall a. Scope SOACS -> ADM a -> ADM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Scope SOACS
scope Scope SOACS -> Scope SOACS -> Scope SOACS
forall a. Semigroup a => a -> a -> a
<> [Param Type] -> Scope SOACS
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [Param Type]
[LParam SOACS]
params) (ADM (Lambda SOACS) -> m (Lambda SOACS))
-> ADM (Lambda SOACS) -> m (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ do
params_adj <- [(SubExp, Type)]
-> ((SubExp, Type) -> ADM (Param Type)) -> ADM [Param Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([SubExp] -> [Type] -> [(SubExp, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult Body SOACS
body)) [Type]
ts) (((SubExp, Type) -> ADM (Param Type)) -> ADM [Param Type])
-> ((SubExp, Type) -> ADM (Param Type)) -> ADM [Param Type]
forall a b. (a -> b) -> a -> b
$ \(SubExp
se, Type
t) ->
Attrs -> VName -> Type -> Param Type
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty (VName -> Type -> Param Type)
-> ADM VName -> ADM (Type -> Param Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ADM VName -> (VName -> ADM VName) -> Maybe VName -> ADM VName
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (String -> ADM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"const_adj") VName -> ADM VName
adjVName (SubExp -> Maybe VName
subExpVar SubExp
se) ADM (Type -> Param Type) -> ADM Type -> ADM (Param Type)
forall a b. ADM (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> ADM Type
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
t
body' <-
localScope (scopeOfLParams params_adj) $
diffBody
(map adjFromParam params_adj)
(map paramName params)
body
pure $ Lambda (params ++ params_adj) (ts <> map paramType params) body'