{-# LANGUAGE TypeFamilies #-}

module Futhark.AD.Rev.Reduce
  ( diffReduce,
    diffMinMaxReduce,
    diffVecReduce,
    diffMulReduce,
  )
where

import Control.Monad
import Futhark.AD.Rev.Monad
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.SOACS
import Futhark.Tools
import Futhark.Transform.Rename

eReverse :: (MonadBuilder m) => VName -> m VName
eReverse :: forall (m :: * -> *). MonadBuilder m => VName -> m VName
eReverse VName
arr = do
  arr_t <- VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
  let w = Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
arr_t
  start <-
    letSubExp "rev_start" . BasicOp $
      BinOp (Sub Int64 OverflowUndef) w (intConst Int64 1)
  let stride = IntType -> Integer -> SubExp
intConst IntType
Int64 (-Integer
1)
      slice = Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_t [SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
start SubExp
w SubExp
stride]
  letExp (baseString arr <> "_rev") $ BasicOp $ Index arr slice

scanExc ::
  (MonadBuilder m, Rep m ~ SOACS) =>
  String ->
  Scan SOACS ->
  [VName] ->
  m [VName]
scanExc :: forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
String -> Scan SOACS -> [VName] -> m [VName]
scanExc String
desc Scan SOACS
scan [VName]
arrs = do
  w <- Int -> [Type] -> SubExp
forall u. Int -> [TypeBase Shape u] -> SubExp
arraysSize Int
0 ([Type] -> SubExp) -> m [Type] -> m SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> m Type) -> [VName] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
arrs
  form <- scanSOAC [scan]
  res_incl <- letTupExp (desc <> "_incl") $ Op $ Screma w arrs form

  iota <-
    letExp "iota" . BasicOp $
      Iota w (intConst Int64 0) (intConst Int64 1) Int64

  iparam <- newParam "iota_param" $ Prim int64

  lam <- mkLambda [iparam] $ do
    let first_elem =
          CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp
            (PrimType -> CmpOp
CmpEq PrimType
int64)
            (SubExp -> m (Exp (Rep m))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (VName -> SubExp
Var (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
iparam)))
            (SubExp -> m (Exp (Rep m))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0))
        prev = TPrimExp Int64 VName -> m (Exp (Rep m))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (TPrimExp Int64 VName -> m (Exp (Rep m)))
-> TPrimExp Int64 VName -> m (Exp (Rep m))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
iparam) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1
    fmap subExpsRes . letTupExp' "scan_ex_res"
      =<< eIf
        first_elem
        (resultBodyM $ scanNeutral scan)
        (eBody $ map (`eIndex` [prev]) res_incl)

  letTupExp desc $ Op $ Screma w [iota] (mapSOAC lam)

mkF :: Lambda SOACS -> ADM ([VName], Lambda SOACS)
mkF :: Lambda SOACS -> ADM ([VName], Lambda SOACS)
mkF Lambda SOACS
lam = do
  lam_l <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam
  lam_r <- renameLambda lam
  let q = [Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Type] -> Int) -> [Type] -> Int
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
lam
      (lps, aps) = splitAt q $ lambdaParams lam_l
      (ips, rps) = splitAt q $ lambdaParams lam_r
  lam' <- mkLambda (lps <> aps <> rps) $ do
    lam_l_res <- bodyBind $ lambdaBody lam_l
    forM_ (zip ips lam_l_res) $ \(Param Type
ip, SubExpRes Certs
cs SubExp
se) ->
      Certs -> ADM () -> ADM ()
forall a. Certs -> ADM a -> ADM a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
ip] (Exp (Rep ADM) -> ADM ()) -> Exp (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
    bodyBind $ lambdaBody lam_r
  pure (map paramName aps, lam')

diffReduce :: VjpOps -> [VName] -> SubExp -> [VName] -> Reduce SOACS -> ADM ()
diffReduce :: VjpOps -> [VName] -> SubExp -> [VName] -> Reduce SOACS -> ADM ()
diffReduce VjpOps
_ops [VName
adj] SubExp
w [VName
a] Reduce SOACS
red
  | Just [(BinOp
op, PrimType
_, VName
_, VName
_)] <- Lambda SOACS -> Maybe [(BinOp, PrimType, VName, VName)]
forall rep.
ASTRep rep =>
Lambda rep -> Maybe [(BinOp, PrimType, VName, VName)]
lamIsBinOp (Lambda SOACS -> Maybe [(BinOp, PrimType, VName, VName)])
-> Lambda SOACS -> Maybe [(BinOp, PrimType, VName, VName)]
forall a b. (a -> b) -> a -> b
$ Reduce SOACS -> Lambda SOACS
forall rep. Reduce rep -> Lambda rep
redLambda Reduce SOACS
red,
    BinOp -> Bool
isAdd BinOp
op = do
      adj_rep <-
        String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
adj String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_rep") (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$
          BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$
            Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
              VName -> SubExp
Var VName
adj
      void $ updateAdj a adj_rep
  where
    isAdd :: BinOp -> Bool
isAdd FAdd {} = Bool
True
    isAdd Add {} = Bool
True
    isAdd BinOp
_ = Bool
False
--
-- Differentiating a general single reduce:
--    let y = reduce \odot ne as
-- Forward sweep:
--    let ls = scan_exc \odot  ne as
--    let rs = scan_exc \odot' ne (reverse as)
-- Reverse sweep:
--    let as_c = map3 (f_bar y_bar) ls as (reverse rs)
-- where
--   x \odot' y = y \odot x
--   y_bar is the adjoint of the result y
--   f l_i a_i r_i = l_i \odot a_i \odot r_i
--   f_bar = the reverse diff of f with respect to a_i under the adjoint y_bar
-- The plan is to create
--   one scanomap SOAC which computes ls and rs
--   another map which computes as_c
--
diffReduce VjpOps
ops [VName]
pat_adj SubExp
w [VName]
as Reduce SOACS
red = do
  red' <- Reduce SOACS -> ADM (Reduce SOACS)
forall {f :: * -> *} {rep}.
(Rename (OpC rep rep), Rename (LetDec rep), Rename (ExpDec rep),
 Rename (BodyDec rep), Rename (FParamInfo rep),
 Rename (LParamInfo rep), Rename (RetType rep),
 Rename (BranchType rep), MonadFreshNames f) =>
Reduce rep -> f (Reduce rep)
renameRed Reduce SOACS
red
  flip_red <- renameRed =<< flipReduce red
  ls <- scanExc "ls" (redToScan red') as
  rs <-
    mapM eReverse
      =<< scanExc "ls" (redToScan flip_red)
      =<< mapM eReverse as

  (as_params, f) <- mkF $ redLambda red

  f_adj <- vjpLambda ops (map adjFromVar pat_adj) as_params f

  as_adj <- letTupExp "adjs" $ Op $ Screma w (ls ++ as ++ rs) (mapSOAC f_adj)

  zipWithM_ updateAdj as as_adj
  where
    renameRed :: Reduce rep -> f (Reduce rep)
renameRed (Reduce Commutativity
comm Lambda rep
lam [SubExp]
nes) =
      Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
comm (Lambda rep -> [SubExp] -> Reduce rep)
-> f (Lambda rep) -> f ([SubExp] -> Reduce rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Lambda rep -> f (Lambda rep)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda rep
lam f ([SubExp] -> Reduce rep) -> f [SubExp] -> f (Reduce rep)
forall a b. f (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [SubExp] -> f [SubExp]
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp]
nes

    redToScan :: Reduce SOACS -> Scan SOACS
    redToScan :: Reduce SOACS -> Scan SOACS
redToScan (Reduce Commutativity
_ Lambda SOACS
lam [SubExp]
nes) = Lambda SOACS -> [SubExp] -> Scan SOACS
forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda SOACS
lam [SubExp]
nes
    flipReduce :: Reduce rep -> m (Reduce rep)
flipReduce (Reduce Commutativity
comm Lambda rep
lam [SubExp]
nes) = do
      lam' <- Lambda rep -> m (Lambda rep)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda rep
lam {lambdaParams = flipParams $ lambdaParams lam}
      pure $ Reduce comm lam' nes
    flipParams :: [a] -> [a]
flipParams [a]
ps = ([a] -> [a] -> [a]) -> ([a], [a]) -> [a]
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (([a] -> [a] -> [a]) -> [a] -> [a] -> [a]
forall a b c. (a -> b -> c) -> b -> a -> c
flip [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
(++)) (([a], [a]) -> [a]) -> ([a], [a]) -> [a]
forall a b. (a -> b) -> a -> b
$ Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
splitAt ([a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
ps Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) [a]
ps

--
-- Special case of reduce with min/max:
--    let x = reduce minmax ne as
-- Forward trace (assuming w = length as):
--    let (x, x_ind) =
--      reduce (\ acc_v acc_i v i ->
--                 if (acc_v == v) then (acc_v, min acc_i i)
--                 else if (acc_v == minmax acc_v v)
--                      then (acc_v, acc_i)
--                      else (v, i))
--             (ne_min, -1)
--             (zip as (iota w))
-- Reverse trace:
--    num_elems = i64.bool (0 <= x_ind)
--    m_bar_repl = replicate num_elems m_bar
--    as_bar[x_ind:num_elems:1] += m_bar_repl
diffMinMaxReduce ::
  VjpOps -> VName -> StmAux () -> SubExp -> BinOp -> SubExp -> VName -> ADM () -> ADM ()
diffMinMaxReduce :: VjpOps
-> VName
-> StmAux ()
-> SubExp
-> BinOp
-> SubExp
-> VName
-> ADM ()
-> ADM ()
diffMinMaxReduce VjpOps
_ops VName
x StmAux ()
aux SubExp
w BinOp
minmax SubExp
ne VName
as ADM ()
m = do
  let t :: PrimType
t = BinOp -> PrimType
binOpType BinOp
minmax

  acc_v_p <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"acc_v" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
  acc_i_p <- newParam "acc_i" $ Prim int64
  v_p <- newParam "v" $ Prim t
  i_p <- newParam "i" $ Prim int64
  red_lam <-
    mkLambda [acc_v_p, acc_i_p, v_p, i_p] $
      fmap varsRes . letTupExp "idx_res"
        =<< eIf
          (eCmpOp (CmpEq t) (eParam acc_v_p) (eParam v_p))
          ( eBody
              [ eParam acc_v_p,
                eBinOp (SMin Int64) (eParam acc_i_p) (eParam i_p)
              ]
          )
          ( eBody
              [ eIf
                  ( eCmpOp
                      (CmpEq t)
                      (eParam acc_v_p)
                      (eBinOp minmax (eParam acc_v_p) (eParam v_p))
                  )
                  (eBody [eParam acc_v_p, eParam acc_i_p])
                  (eBody [eParam v_p, eParam i_p])
              ]
          )

  red_iota <-
    letExp "red_iota" $
      BasicOp $
        Iota w (intConst Int64 0) (intConst Int64 1) Int64
  form <- reduceSOAC [Reduce Commutative red_lam [ne, intConst Int64 (-1)]]
  x_ind <- newVName (baseString x <> "_ind")
  auxing aux $ letBindNames [x, x_ind] $ Op $ Screma w [as, red_iota] form

  m

  x_adj <- lookupAdjVal x
  in_bounds <-
    letSubExp "minmax_in_bounds" . BasicOp $
      CmpOp (CmpSlt Int64) (intConst Int64 0) w
  updateAdjIndex as (CheckBounds (Just in_bounds), Var x_ind) (Var x_adj)

--
-- Special case of vectorised reduce:
--    let x = reduce (map2 op) nes as
-- Idea:
--    rewrite to
--      let x = map2 (\as ne -> reduce op ne as) (transpose as) nes
--    and diff
diffVecReduce ::
  VjpOps -> Pat Type -> StmAux () -> SubExp -> Commutativity -> Lambda SOACS -> VName -> VName -> ADM () -> ADM ()
diffVecReduce :: VjpOps
-> Pat Type
-> StmAux ()
-> SubExp
-> Commutativity
-> Lambda SOACS
-> VName
-> VName
-> ADM ()
-> ADM ()
diffVecReduce VjpOps
ops Pat Type
x StmAux ()
aux SubExp
w Commutativity
iscomm Lambda SOACS
lam VName
ne VName
as ADM ()
m = do
  stms <- ADM () -> ADM (Stms (Rep ADM))
forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ (ADM () -> ADM (Stms (Rep ADM))) -> ADM () -> ADM (Stms (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
    rank <- Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (Type -> Int) -> ADM Type -> ADM Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
as
    let rear = [Int
1, Int
0] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
drop Int
2 [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]

    tran_as <- letExp "tran_as" $ BasicOp $ Rearrange as rear
    ts <- lookupType tran_as
    t_ne <- lookupType ne

    as_param <- newParam "as_param" $ rowType ts
    ne_param <- newParam "ne_param" $ rowType t_ne

    reduce_form <- reduceSOAC [Reduce iscomm lam [Var $ paramName ne_param]]

    map_lam <-
      mkLambda [as_param, ne_param] $
        fmap varsRes . letTupExp "idx_res" $
          Op $
            Screma w [paramName as_param] reduce_form
    addStm $ Let x aux $ Op $ Screma (arraySize 0 ts) [tran_as, ne] $ mapSOAC map_lam

  foldr (vjpStm ops) m stms

--
-- Special case of reduce with mul:
--    let x = reduce (*) ne as
-- Forward trace (assuming w = length as):
--    let (p, z) = map (\a -> if a == 0 then (1, 1) else (a, 0)) as
--    non_zero_prod = reduce (*) ne p
--    zr_count = reduce (+) 0 z
--    let x =
--      if 0 == zr_count
--      then non_zero_prod
--      else 0
-- Reverse trace:
--    as_bar = map2
--      (\a a_bar ->
--        if zr_count == 0
--        then a_bar + non_zero_prod/a * x_bar
--        else if zr_count == 1
--        then a_bar + (if a == 0 then non_zero_prod * x_bar else 0)
--        else as_bar
--      ) as as_bar
diffMulReduce ::
  VjpOps -> VName -> StmAux () -> SubExp -> BinOp -> SubExp -> VName -> ADM () -> ADM ()
diffMulReduce :: VjpOps
-> VName
-> StmAux ()
-> SubExp
-> BinOp
-> SubExp
-> VName
-> ADM ()
-> ADM ()
diffMulReduce VjpOps
_ops VName
x StmAux ()
aux SubExp
w BinOp
mul SubExp
ne VName
as ADM ()
m = do
  let t :: PrimType
t = BinOp -> PrimType
binOpType BinOp
mul
  let const_zero :: ADM (Exp (Rep ADM))
const_zero = SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp (Rep ADM))) -> SubExp -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t

  a_param <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"a" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
  map_lam <-
    mkLambda [a_param] $
      fmap varsRes . letTupExp "map_res"
        =<< eIf
          (eCmpOp (CmpEq t) (eParam a_param) const_zero)
          (eBody $ fmap eSubExp [Constant $ onePrimValue t, intConst Int64 1])
          (eBody [eParam a_param, eSubExp $ intConst Int64 0])

  ps <- newVName "ps"
  zs <- newVName "zs"
  auxing aux $
    letBindNames [ps, zs] $
      Op $
        Screma w [as] $
          mapSOAC map_lam

  red_lam_mul <- binOpLambda mul t
  red_lam_add <- binOpLambda (Add Int64 OverflowUndef) int64

  red_form_mul <- reduceSOAC $ pure $ Reduce Commutative red_lam_mul $ pure ne
  red_form_add <- reduceSOAC $ pure $ Reduce Commutative red_lam_add $ pure $ intConst Int64 0

  nz_prods <- newVName "non_zero_prod"
  zr_count <- newVName "zero_count"
  auxing aux $ letBindNames [nz_prods] $ Op $ Screma w [ps] red_form_mul
  auxing aux $ letBindNames [zr_count] $ Op $ Screma w [zs] red_form_add

  auxing aux $
    letBindNames [x]
      =<< eIf
        (toExp $ 0 .==. le64 zr_count)
        (eBody $ pure $ eSubExp $ Var nz_prods)
        (eBody $ pure const_zero)

  m

  x_adj <- lookupAdjVal x

  a_param_rev <- newParam "a" $ Prim t
  map_lam_rev <-
    mkLambda [a_param_rev] $
      fmap varsRes . letTupExp "adj_res"
        =<< eIf
          (toExp $ 0 .==. le64 zr_count)
          ( eBody $
              pure $
                eBinOp mul (eSubExp $ Var x_adj) $
                  eBinOp (getDiv t) (eSubExp $ Var nz_prods) $
                    eParam a_param_rev
          )
          ( eBody $
              pure $
                eIf
                  (toExp $ 1 .==. le64 zr_count)
                  ( eBody $
                      pure $
                        eIf
                          (eCmpOp (CmpEq t) (eParam a_param_rev) const_zero)
                          ( eBody $
                              pure $
                                eBinOp mul (eSubExp $ Var x_adj) $
                                  eSubExp $
                                    Var nz_prods
                          )
                          (eBody $ pure const_zero)
                  )
                  (eBody $ pure const_zero)
          )

  as_adjup <- letExp "adjs" $ Op $ Screma w [as] $ mapSOAC map_lam_rev

  updateAdj as as_adjup
  where
    getDiv :: PrimType -> BinOp
    getDiv :: PrimType -> BinOp
getDiv (IntType IntType
t) = IntType -> Safety -> BinOp
SDiv IntType
t Safety
Unsafe
    getDiv (FloatType FloatType
t) = FloatType -> BinOp
FDiv FloatType
t
    getDiv PrimType
_ = String -> BinOp
forall a. HasCallStack => String -> a
error String
"In getDiv, Reduce.hs: input not supported"