{-# LANGUAGE TypeFamilies #-}

-- | VJP transformation for Map SOACs.  This is a pretty complicated
-- case due to the possibility of free variables.
module Futhark.AD.Rev.Map (vjpMap) where

import Control.Monad
import Data.Bifunctor (first)
import Futhark.AD.Rev.Monad
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.SOACS
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Util (splitAt3)

-- | A classification of a free variable based on its adjoint.  The
-- 'VName' stored is *not* the adjoint, but the primal variable.
data AdjVar
  = -- | Adjoint is already an accumulator.
    FreeAcc VName
  | -- | Currently has no adjoint, but should be given one, and is an
    -- array with this shape and element type.
    FreeArr VName Shape PrimType
  | -- | Does not need an accumulator adjoint (might still be an array).
    FreeNonAcc VName

classifyAdjVars :: [VName] -> ADM [AdjVar]
classifyAdjVars :: [VName] -> ADM [AdjVar]
classifyAdjVars = (VName -> ADM AdjVar) -> [VName] -> ADM [AdjVar]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM AdjVar
f
  where
    f :: VName -> ADM AdjVar
f VName
v = do
      v_adj <- VName -> ADM VName
lookupAdjVal VName
v
      v_adj_t <- lookupType v_adj
      case v_adj_t of
        Array PrimType
pt Shape
shape NoUniqueness
_ ->
          AdjVar -> ADM AdjVar
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AdjVar -> ADM AdjVar) -> AdjVar -> ADM AdjVar
forall a b. (a -> b) -> a -> b
$ VName -> Shape -> PrimType -> AdjVar
FreeArr VName
v Shape
shape PrimType
pt
        Acc {} ->
          AdjVar -> ADM AdjVar
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AdjVar -> ADM AdjVar) -> AdjVar -> ADM AdjVar
forall a b. (a -> b) -> a -> b
$ VName -> AdjVar
FreeAcc VName
v
        Type
_ ->
          AdjVar -> ADM AdjVar
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AdjVar -> ADM AdjVar) -> AdjVar -> ADM AdjVar
forall a b. (a -> b) -> a -> b
$ VName -> AdjVar
FreeNonAcc VName
v

partitionAdjVars :: [AdjVar] -> ([(VName, (Shape, PrimType))], [VName], [VName])
partitionAdjVars :: [AdjVar] -> ([(VName, (Shape, PrimType))], [VName], [VName])
partitionAdjVars [] = ([], [], [])
partitionAdjVars (AdjVar
fv : [AdjVar]
fvs) =
  case AdjVar
fv of
    FreeArr VName
v Shape
shape PrimType
t -> ((VName
v, (Shape
shape, PrimType
t)) (VName, (Shape, PrimType))
-> [(VName, (Shape, PrimType))] -> [(VName, (Shape, PrimType))]
forall a. a -> [a] -> [a]
: [(VName, (Shape, PrimType))]
xs, [VName]
ys, [VName]
zs)
    FreeAcc VName
v -> ([(VName, (Shape, PrimType))]
xs, VName
v VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
ys, [VName]
zs)
    FreeNonAcc VName
v -> ([(VName, (Shape, PrimType))]
xs, [VName]
ys, VName
v VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
zs)
  where
    ([(VName, (Shape, PrimType))]
xs, [VName]
ys, [VName]
zs) = [AdjVar] -> ([(VName, (Shape, PrimType))], [VName], [VName])
partitionAdjVars [AdjVar]
fvs

buildRenamedBody ::
  (MonadBuilder m) =>
  m (Result, a) ->
  m (Body (Rep m), a)
buildRenamedBody :: forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildRenamedBody m (Result, a)
m = do
  (body, x) <- m (Result, a) -> m (Body (Rep m), a)
forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildBody m (Result, a)
m
  body' <- renameBody body
  pure (body', x)

withAcc ::
  [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))] ->
  ([VName] -> ADM Result) ->
  ADM [VName]
withAcc :: [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
-> ([VName] -> ADM Result) -> ADM [VName]
withAcc [] [VName] -> ADM Result
m =
  (SubExpRes -> ADM VName) -> Result -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"withacc_res" (Exp SOACS -> ADM VName)
-> (SubExpRes -> Exp SOACS) -> SubExpRes -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS)
-> (SubExpRes -> BasicOp) -> SubExpRes -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp (SubExp -> BasicOp)
-> (SubExpRes -> SubExp) -> SubExpRes -> BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) (Result -> ADM [VName]) -> ADM Result -> ADM [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [VName] -> ADM Result
m []
withAcc [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
inputs [VName] -> ADM Result
m = do
  (cert_params, acc_params) <- ([(Param Type, Param Type)] -> ([Param Type], [Param Type]))
-> ADM [(Param Type, Param Type)]
-> ADM ([Param Type], [Param Type])
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(Param Type, Param Type)] -> ([Param Type], [Param Type])
forall a b. [(a, b)] -> ([a], [b])
unzip (ADM [(Param Type, Param Type)]
 -> ADM ([Param Type], [Param Type]))
-> ADM [(Param Type, Param Type)]
-> ADM ([Param Type], [Param Type])
forall a b. (a -> b) -> a -> b
$
    [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
-> ((Shape, [VName], Maybe (Lambda SOACS, [SubExp]))
    -> ADM (Param Type, Param Type))
-> ADM [(Param Type, Param Type)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
inputs (((Shape, [VName], Maybe (Lambda SOACS, [SubExp]))
  -> ADM (Param Type, Param Type))
 -> ADM [(Param Type, Param Type)])
-> ((Shape, [VName], Maybe (Lambda SOACS, [SubExp]))
    -> ADM (Param Type, Param Type))
-> ADM [(Param Type, Param Type)]
forall a b. (a -> b) -> a -> b
$ \(Shape
shape, [VName]
arrs, Maybe (Lambda SOACS, [SubExp])
_) -> do
      cert_param <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"acc_cert_p" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit
      ts <- mapM (fmap (stripArray (shapeRank shape)) . lookupType) arrs
      acc_param <- newParam "acc_p" $ Acc (paramName cert_param) shape ts NoUniqueness
      pure (cert_param, acc_param)
  acc_lam <-
    subAD $ mkLambda (cert_params ++ acc_params) $ m $ map paramName acc_params
  letTupExp "withhacc_res" $ WithAcc inputs acc_lam

-- | Perform VJP on a Map.  The 'Adj' list is the adjoints of the
-- result of the map.
vjpMap :: VjpOps -> [Adj] -> StmAux () -> SubExp -> Lambda SOACS -> [VName] -> ADM ()
vjpMap :: VjpOps
-> [Adj]
-> StmAux ()
-> SubExp
-> Lambda SOACS
-> [VName]
-> ADM ()
vjpMap VjpOps
ops [Adj]
res_adjs StmAux ()
_ SubExp
w Lambda SOACS
map_lam [VName]
as
  | Just [[(InBounds, SubExp, SubExp)]]
res_ivs <- (Adj -> Maybe [(InBounds, SubExp, SubExp)])
-> [Adj] -> Maybe [[(InBounds, SubExp, SubExp)]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Adj -> Maybe [(InBounds, SubExp, SubExp)]
isSparse [Adj]
res_adjs = ADM () -> ADM ()
forall a. ADM a -> ADM a
returnSweepCode (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ do
      -- Since at most only a constant number of adjoint are nonzero
      -- (length res_ivs), there is no need for the return sweep code to
      -- contain a Map at all.

      free <- (VName -> ADM Bool) -> [VName] -> ADM [VName]
forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM VName -> ADM Bool
isActive ([VName] -> ADM [VName]) -> [VName] -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn Lambda SOACS
map_lam Names -> Names -> Names
`namesSubtract` [VName] -> Names
namesFromList [VName]
as
      free_ts <- mapM lookupType free
      let adjs_for = (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam) [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
free
          adjs_ts = (Param Type -> Type) -> [Param Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam) [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
free_ts

      let oneHot Int
res_i Adj
adj_v = (Int -> Type -> Adj) -> [Int] -> [Type] -> [Adj]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Type -> Adj
f [Int
0 :: Int ..] ([Type] -> [Adj]) -> [Type] -> [Adj]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
map_lam
            where
              f :: Int -> Type -> Adj
f Int
j Type
t
                | Int
res_i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
j = Adj
adj_v
                | Bool
otherwise = Shape -> PrimType -> Adj
AdjZero (Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t) (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t)
          -- Values for the out-of-bounds case does not matter, as we will
          -- be writing to an out-of-bounds index anyway, which is ignored.
          ooBounds SubExp
adj_i = ADM (Body SOACS, [SubExp] -> [Adj])
-> ADM (Body SOACS, [SubExp] -> [Adj])
forall a. ADM a -> ADM a
subAD (ADM (Body SOACS, [SubExp] -> [Adj])
 -> ADM (Body SOACS, [SubExp] -> [Adj]))
-> (ADM (Result, [SubExp] -> [Adj])
    -> ADM (Body SOACS, [SubExp] -> [Adj]))
-> ADM (Result, [SubExp] -> [Adj])
-> ADM (Body SOACS, [SubExp] -> [Adj])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ADM (Result, [SubExp] -> [Adj])
-> ADM (Body (Rep ADM), [SubExp] -> [Adj])
ADM (Result, [SubExp] -> [Adj])
-> ADM (Body SOACS, [SubExp] -> [Adj])
forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildRenamedBody (ADM (Result, [SubExp] -> [Adj])
 -> ADM (Body SOACS, [SubExp] -> [Adj]))
-> ADM (Result, [SubExp] -> [Adj])
-> ADM (Body SOACS, [SubExp] -> [Adj])
forall a b. (a -> b) -> a -> b
$ do
            [(VName, Type)] -> ((VName, Type) -> ADM ()) -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [Type] -> [(VName, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
as [Type]
adjs_ts) (((VName, Type) -> ADM ()) -> ADM ())
-> ((VName, Type) -> ADM ()) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \(VName
a, Type
t) -> do
              scratch <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"oo_scratch" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => Type -> m (Exp (Rep m))
eBlank Type
t
              updateAdjIndex a (OutOfBounds, adj_i) scratch
            -- We must make sure that all free variables have the same
            -- representation in the oo-branch as in the ib-branch.
            -- In practice we do this by manifesting the adjoint.
            -- This is probably efficient, since the adjoint of a free
            -- variable is probably either a scalar or an accumulator.
            [VName] -> (VName -> ADM ()) -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [VName]
free ((VName -> ADM ()) -> ADM ()) -> (VName -> ADM ()) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \VName
v -> VName -> VName -> ADM ()
insAdj VName
v (VName -> ADM ()) -> ADM VName -> ADM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Adj -> ADM VName
adjVal (Adj -> ADM VName) -> ADM Adj -> ADM VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> ADM Adj
lookupAdj VName
v
            ([SubExp] -> Result)
-> ([SubExp], [SubExp] -> [Adj]) -> (Result, [SubExp] -> [Adj])
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first [SubExp] -> Result
subExpsRes (([SubExp], [SubExp] -> [Adj]) -> (Result, [SubExp] -> [Adj]))
-> ([Adj] -> ([SubExp], [SubExp] -> [Adj]))
-> [Adj]
-> (Result, [SubExp] -> [Adj])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Adj] -> ([SubExp], [SubExp] -> [Adj])
adjsReps ([Adj] -> (Result, [SubExp] -> [Adj]))
-> ADM [Adj] -> ADM (Result, [SubExp] -> [Adj])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> ADM Adj) -> [VName] -> ADM [Adj]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM Adj
lookupAdj ([VName]
as [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
free)
          inBounds Int
res_i SubExp
adj_i SubExp
adj_v = ADM (Body SOACS, [SubExp] -> [Adj])
-> ADM (Body SOACS, [SubExp] -> [Adj])
forall a. ADM a -> ADM a
subAD (ADM (Body SOACS, [SubExp] -> [Adj])
 -> ADM (Body SOACS, [SubExp] -> [Adj]))
-> (ADM (Result, [SubExp] -> [Adj])
    -> ADM (Body SOACS, [SubExp] -> [Adj]))
-> ADM (Result, [SubExp] -> [Adj])
-> ADM (Body SOACS, [SubExp] -> [Adj])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ADM (Result, [SubExp] -> [Adj])
-> ADM (Body (Rep ADM), [SubExp] -> [Adj])
ADM (Result, [SubExp] -> [Adj])
-> ADM (Body SOACS, [SubExp] -> [Adj])
forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildRenamedBody (ADM (Result, [SubExp] -> [Adj])
 -> ADM (Body SOACS, [SubExp] -> [Adj]))
-> ADM (Result, [SubExp] -> [Adj])
-> ADM (Body SOACS, [SubExp] -> [Adj])
forall a b. (a -> b) -> a -> b
$ do
            [(Param Type, VName)] -> ((Param Type, VName) -> ADM ()) -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam) [VName]
as) (((Param Type, VName) -> ADM ()) -> ADM ())
-> ((Param Type, VName) -> ADM ()) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \(Param Type
p, VName
a) -> do
              a_t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
a
              letBindNames [paramName p] . BasicOp . Index a $
                fullSlice a_t [DimFix adj_i]
            adj_elems <-
              (Result -> [SubExp]) -> ADM Result -> ADM [SubExp]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp) (ADM Result -> ADM [SubExp])
-> (Lambda SOACS -> ADM Result) -> Lambda SOACS -> ADM [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body (Rep ADM) -> ADM Result
Body SOACS -> ADM Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (Body SOACS -> ADM Result)
-> (Lambda SOACS -> Body SOACS) -> Lambda SOACS -> ADM Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody
                (Lambda SOACS -> ADM [SubExp])
-> ADM (Lambda SOACS) -> ADM [SubExp]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VjpOps -> [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda VjpOps
ops (Int -> Adj -> [Adj]
oneHot Int
res_i (SubExp -> Adj
AdjVal SubExp
adj_v)) [VName]
adjs_for Lambda SOACS
map_lam
            let (as_adj_elems, free_adj_elems) = splitAt (length as) adj_elems
            forM_ (zip as as_adj_elems) $ \(VName
a, SubExp
a_adj_elem) ->
              VName -> (InBounds, SubExp) -> SubExp -> ADM ()
updateAdjIndex VName
a (InBounds
AssumeBounds, SubExp
adj_i) SubExp
a_adj_elem
            forM_ (zip free free_adj_elems) $ \(VName
v, SubExp
adj_se) -> do
              adj_se_v <- String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"adj_v" (BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
adj_se)
              insAdj v adj_se_v
            first subExpsRes . adjsReps <$> mapM lookupAdj (as <> free)

          -- Generate an iteration of the map function for every
          -- position.  This is a bit inefficient - probably we could do
          -- some deduplication.
          forPos Int
res_i (InBounds
check, SubExp
adj_i, SubExp
adj_v) = do
            adjs <-
              case InBounds
check of
                CheckBounds Maybe SubExp
b -> do
                  (obbranch, mkadjs) <- SubExp -> ADM (Body SOACS, [SubExp] -> [Adj])
ooBounds SubExp
adj_i
                  (ibbranch, _) <- inBounds res_i adj_i adj_v
                  fmap mkadjs . letTupExp' "map_adj_elem"
                    =<< eIf
                      (maybe (eDimInBounds (eSubExp w) (eSubExp adj_i)) eSubExp b)
                      (pure ibbranch)
                      (pure obbranch)
                InBounds
AssumeBounds -> do
                  (body, mkadjs) <- Int -> SubExp -> SubExp -> ADM (Body SOACS, [SubExp] -> [Adj])
inBounds Int
res_i SubExp
adj_i SubExp
adj_v
                  mkadjs . map resSubExp <$> bodyBind body
                InBounds
OutOfBounds ->
                  (VName -> ADM Adj) -> [VName] -> ADM [Adj]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM Adj
lookupAdj [VName]
as

            zipWithM setAdj (as <> free) adjs

          -- Generate an iteration of the map function for every result.
          forRes Int
res_i = ((InBounds, SubExp, SubExp) -> ADM [()])
-> [(InBounds, SubExp, SubExp)] -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Int -> (InBounds, SubExp, SubExp) -> ADM [()]
forPos Int
res_i)

      zipWithM_ forRes [0 ..] res_ivs
  where
    isSparse :: Adj -> Maybe [(InBounds, SubExp, SubExp)]
isSparse (AdjSparse (Sparse Shape
shape PrimType
_ [(InBounds, SubExp, SubExp)]
ivs)) = do
      Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp
w]
      [(InBounds, SubExp, SubExp)] -> Maybe [(InBounds, SubExp, SubExp)]
forall a. a -> Maybe a
Just [(InBounds, SubExp, SubExp)]
ivs
    isSparse Adj
_ =
      Maybe [(InBounds, SubExp, SubExp)]
forall a. Maybe a
Nothing
-- See Note [Adjoints of accumulators] for how we deal with
-- accumulators - it's a bit tricky here.
vjpMap VjpOps
ops [Adj]
pat_adj StmAux ()
aux SubExp
w Lambda SOACS
map_lam [VName]
as = ADM () -> ADM ()
forall a. ADM a -> ADM a
returnSweepCode (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ do
  pat_adj_vals <- [(Adj, Type)] -> ((Adj, Type) -> ADM VName) -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Adj] -> [Type] -> [(Adj, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Adj]
pat_adj (Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
map_lam)) (((Adj, Type) -> ADM VName) -> ADM [VName])
-> ((Adj, Type) -> ADM VName) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ \(Adj
adj, Type
t) ->
    case Type
t of
      Acc {} -> String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"acc_adj_rep" (Exp SOACS -> ADM VName)
-> (VName -> Exp SOACS) -> VName -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> (VName -> BasicOp) -> VName -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) (SubExp -> BasicOp) -> (VName -> SubExp) -> VName -> BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> ADM VName) -> ADM VName -> ADM VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Adj -> ADM VName
adjVal Adj
adj
      Type
_ -> Adj -> ADM VName
adjVal Adj
adj
  pat_adj_params <-
    mapM (newParam "map_adj_p" . rowType <=< lookupType) pat_adj_vals

  map_lam' <- renameLambda map_lam
  free <- filterM isActive $ namesToList $ freeIn map_lam'

  accAdjoints free $ \[VName]
free_with_adjs Names
free_without_adjs -> do
    free_adjs <- (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM VName
lookupAdjVal [VName]
free_with_adjs
    free_adjs_ts <- mapM lookupType free_adjs
    free_adjs_params <- mapM (newParam "free_adj_p") free_adjs_ts
    let lam_rev_params =
          Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam' [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type]
pat_adj_params [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type]
free_adjs_params
        adjs_for = (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam') [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
free
    lam_rev <-
      mkLambda lam_rev_params . subAD . noAdjsFor free_without_adjs $ do
        zipWithM_ insAdj free_with_adjs $ map paramName free_adjs_params
        bodyBind . lambdaBody
          =<< vjpLambda ops (map adjFromParam pat_adj_params) adjs_for map_lam'

    (param_contribs, free_contribs) <-
      fmap (splitAt (length (lambdaParams map_lam'))) $
        auxing aux . letTupExp "map_adjs" . Op $
          Screma w (as ++ pat_adj_vals ++ free_adjs) (mapSOAC lam_rev)

    -- Crucial that we handle the free contribs first in case 'free'
    -- and 'as' intersect.
    zipWithM_ freeContrib free free_contribs
    let param_ts = (Param Type -> Type) -> [Param Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam')
    forM_ (zip3 param_ts as param_contribs) $ \(Type
param_t, VName
a, VName
param_contrib) ->
      case Type
param_t of
        Acc {} -> VName -> VName -> ADM ()
freeContrib VName
a VName
param_contrib
        Type
_ -> VName -> VName -> ADM ()
updateAdj VName
a VName
param_contrib
  where
    addIdxParams :: Int -> Lambda rep -> m (Lambda rep)
addIdxParams Int
n Lambda rep
lam = do
      idxs <- Int -> m (Param (TypeBase shape u)) -> m [Param (TypeBase shape u)]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (m (Param (TypeBase shape u)) -> m [Param (TypeBase shape u)])
-> m (Param (TypeBase shape u)) -> m [Param (TypeBase shape u)]
forall a b. (a -> b) -> a -> b
$ String -> TypeBase shape u -> m (Param (TypeBase shape u))
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"idx" (TypeBase shape u -> m (Param (TypeBase shape u)))
-> TypeBase shape u -> m (Param (TypeBase shape u))
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase shape u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
      pure $ lam {lambdaParams = idxs ++ lambdaParams lam}

    accAddLambda :: Int -> Type -> ADM (Lambda SOACS)
accAddLambda Int
n Type
t = Int -> Lambda SOACS -> ADM (Lambda SOACS)
forall {rep} {shape} {u} {m :: * -> *}.
(LParamInfo rep ~ TypeBase shape u, MonadFreshNames m) =>
Int -> Lambda rep -> m (Lambda rep)
addIdxParams Int
n (Lambda SOACS -> ADM (Lambda SOACS))
-> ADM (Lambda SOACS) -> ADM (Lambda SOACS)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Type -> ADM (Lambda SOACS)
addLambda Type
t

    withAccInput :: (VName, (a, PrimType))
-> ADM (a, [VName], Maybe (Lambda SOACS, [SubExp]))
withAccInput (VName
v, (a
shape, PrimType
pt)) = do
      v_adj <- VName -> ADM VName
lookupAdjVal VName
v
      add_lam <- accAddLambda (shapeRank shape) $ Prim pt
      zero <- letSubExp "zero" $ zeroExp $ Prim pt
      pure (shape, [v_adj], Just (add_lam, [zero]))

    accAdjoints :: [VName] -> ([VName] -> Names -> ADM ()) -> ADM ()
accAdjoints [VName]
free [VName] -> Names -> ADM ()
m = do
      (arr_free, acc_free, nonacc_free) <-
        [AdjVar] -> ([(VName, (Shape, PrimType))], [VName], [VName])
partitionAdjVars ([AdjVar] -> ([(VName, (Shape, PrimType))], [VName], [VName]))
-> ADM [AdjVar]
-> ADM ([(VName, (Shape, PrimType))], [VName], [VName])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [VName] -> ADM [AdjVar]
classifyAdjVars [VName]
free
      arr_free' <- mapM withAccInput arr_free
      -- We only consider those input arrays that are also not free in
      -- the lambda.
      let as_nonfree = (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
free) [VName]
as
      (arr_adjs, acc_adjs, rest_adjs) <-
        fmap (splitAt3 (length arr_free) (length acc_free)) . withAcc arr_free' $ \[VName]
accs -> do
          (VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj (((VName, (Shape, PrimType)) -> VName)
-> [(VName, (Shape, PrimType))] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, (Shape, PrimType)) -> VName
forall a b. (a, b) -> a
fst [(VName, (Shape, PrimType))]
arr_free) [VName]
accs
          () <- [VName] -> Names -> ADM ()
m ([VName]
acc_free [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ ((VName, (Shape, PrimType)) -> VName)
-> [(VName, (Shape, PrimType))] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, (Shape, PrimType)) -> VName
forall a b. (a, b) -> a
fst [(VName, (Shape, PrimType))]
arr_free) ([VName] -> Names
namesFromList [VName]
nonacc_free)
          acc_free_adj <- mapM lookupAdjVal acc_free
          arr_free_adj <- mapM (lookupAdjVal . fst) arr_free
          nonacc_free_adj <- mapM lookupAdjVal nonacc_free
          as_nonfree_adj <- mapM lookupAdjVal as_nonfree
          pure $ varsRes $ arr_free_adj <> acc_free_adj <> nonacc_free_adj <> as_nonfree_adj
      zipWithM_ insAdj acc_free acc_adjs
      zipWithM_ insAdj (map fst arr_free) arr_adjs
      let (nonacc_adjs, as_nonfree_adjs) = splitAt (length nonacc_free) rest_adjs
      zipWithM_ insAdj nonacc_free nonacc_adjs
      zipWithM_ insAdj as_nonfree as_nonfree_adjs

    freeContrib :: VName -> VName -> ADM ()
freeContrib VName
v VName
contribs = do
      contribs_t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
contribs
      case rowType contribs_t of
        Acc {} -> ADM () -> ADM ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> ADM ()
insAdj VName
v VName
contribs
        Type
t -> do
          lam <- Type -> ADM (Lambda SOACS)
addLambda Type
t
          zero <- letSubExp "zero" $ zeroExp t
          reduce <- reduceSOAC [Reduce Commutative lam [zero]]
          contrib_sum <-
            letExp (baseString v <> "_contrib_sum") . Op $
              Screma w [contribs] reduce
          void $ updateAdj v contrib_sum