module Futhark.Internalise.Lambdas
( InternaliseLambda,
internaliseFoldLambda,
internalisePartitionLambda,
)
where
import Data.Maybe (listToMaybe)
import Futhark.IR.SOACS as I
import Futhark.Internalise.AccurateSizes
import Futhark.Internalise.Monad
import Language.Futhark as E
type InternaliseLambda =
E.Exp -> [I.Type] -> InternaliseM ([I.LParam SOACS], I.Body SOACS, [I.Type])
internaliseFoldLambda ::
InternaliseLambda ->
E.Exp ->
[I.Type] ->
[I.Type] ->
InternaliseM (I.Lambda SOACS)
internaliseFoldLambda :: InternaliseLambda
-> Exp -> [Type] -> [Type] -> InternaliseM (Lambda SOACS)
internaliseFoldLambda InternaliseLambda
internaliseLambda Exp
lam [Type]
acctypes [Type]
arrtypes = do
let rowtypes :: [Type]
rowtypes = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall u.
TypeBase (ShapeBase SubExp) u -> TypeBase (ShapeBase SubExp) u
I.rowType [Type]
arrtypes
(params, body, rettype) <- InternaliseLambda
internaliseLambda Exp
lam ([Type] -> InternaliseM ([LParam SOACS], Body SOACS, [Type]))
-> [Type] -> InternaliseM ([LParam SOACS], Body SOACS, [Type])
forall a b. (a -> b) -> a -> b
$ [Type]
acctypes [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
rowtypes
let rettype' =
[ Type
t Type -> ShapeBase SubExp -> Type
forall newshape oldshape u.
ArrayShape newshape =>
TypeBase oldshape u -> newshape -> TypeBase newshape u
`I.setArrayShape` Type -> ShapeBase SubExp
forall shape u. ArrayShape shape => TypeBase shape u -> shape
I.arrayShape Type
shape
| (Type
t, Type
shape) <- [Type] -> [Type] -> [(Type, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
rettype [Type]
acctypes
]
mkLambda params $
ensureResultShape
(ErrorMsg [ErrorString "shape of result does not match shape of initial value"])
rettype'
=<< bodyBind body
internalisePartitionLambda ::
InternaliseLambda ->
Int ->
E.Exp ->
[I.SubExp] ->
InternaliseM (I.Lambda SOACS)
internalisePartitionLambda :: InternaliseLambda
-> Int -> Exp -> [SubExp] -> InternaliseM (Lambda SOACS)
internalisePartitionLambda InternaliseLambda
internaliseLambda Int
k Exp
lam [SubExp]
args = do
argtypes <- (SubExp -> InternaliseM Type) -> [SubExp] -> InternaliseM [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> InternaliseM Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
I.subExpType [SubExp]
args
let rowtypes = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall u.
TypeBase (ShapeBase SubExp) u -> TypeBase (ShapeBase SubExp) u
I.rowType [Type]
argtypes
(params, body, _) <- internaliseLambda lam rowtypes
body' <-
localScope (scopeOfLParams params) $
lambdaWithIncrement body
pure $ I.Lambda params rettype body'
where
rettype :: [TypeBase shape u]
rettype = Int -> TypeBase shape u -> [TypeBase shape u]
forall a. Int -> a -> [a]
replicate (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2) (TypeBase shape u -> [TypeBase shape u])
-> TypeBase shape u -> [TypeBase shape u]
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase shape u
forall shape u. PrimType -> TypeBase shape u
I.Prim PrimType
int64
result :: Int -> [SubExp]
result Int
i =
(Int64 -> SubExp) -> [Int64] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant ([Int64] -> [SubExp]) -> [Int64] -> [SubExp]
forall a b. (a -> b) -> a -> b
$
Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i
Int64 -> [Int64] -> [Int64]
forall a. a -> [a] -> [a]
: (Int -> Int64 -> [Int64]
forall a. Int -> a -> [a]
replicate Int
i Int64
0 [Int64] -> [Int64] -> [Int64]
forall a. [a] -> [a] -> [a]
++ [Int64
1 :: Int64] [Int64] -> [Int64] -> [Int64]
forall a. [a] -> [a] -> [a]
++ Int -> Int64 -> [Int64]
forall a. Int -> a -> [a]
replicate (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
i) Int64
0)
mkResult :: SubExp -> Int -> f [SubExp]
mkResult SubExp
_ Int
i | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
k = [SubExp] -> f [SubExp]
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([SubExp] -> f [SubExp]) -> [SubExp] -> f [SubExp]
forall a b. (a -> b) -> a -> b
$ Int -> [SubExp]
result Int
i
mkResult SubExp
eq_class Int
i = do
is_i <-
String -> Exp (Rep f) -> f SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"is_i" (Exp (Rep f) -> f SubExp) -> Exp (Rep f) -> f SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Rep f)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep f)) -> BasicOp -> Exp (Rep f)
forall a b. (a -> b) -> a -> b
$
CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (PrimType -> CmpOp
CmpEq PrimType
int64) SubExp
eq_class (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
IntType -> Integer -> SubExp
intConst IntType
Int64 (Integer -> SubExp) -> Integer -> SubExp
forall a b. (a -> b) -> a -> b
$
Int -> Integer
forall a. Integral a => a -> Integer
toInteger Int
i
letTupExp' "part_res"
=<< eIf
(eSubExp is_i)
(pure $ resultBody $ result i)
(resultBody <$> mkResult eq_class (i + 1))
lambdaWithIncrement :: I.Body SOACS -> InternaliseM (I.Body SOACS)
lambdaWithIncrement :: Body SOACS -> InternaliseM (Body SOACS)
lambdaWithIncrement Body SOACS
lam_body = Builder SOACS Result -> InternaliseM (Body SOACS)
forall rep (m :: * -> *) somerep.
(Buildable rep, MonadFreshNames m, HasScope somerep m,
SameScope somerep rep) =>
Builder rep Result -> m (Body rep)
runBodyBuilder (Builder SOACS Result -> InternaliseM (Body SOACS))
-> Builder SOACS Result -> InternaliseM (Body SOACS)
forall a b. (a -> b) -> a -> b
$ do
eq_class <-
SubExp -> (SubExpRes -> SubExp) -> Maybe SubExpRes -> SubExp
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) SubExpRes -> SubExp
resSubExp (Maybe SubExpRes -> SubExp)
-> (Result -> Maybe SubExpRes) -> Result -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Result -> Maybe SubExpRes
forall a. [a] -> Maybe a
listToMaybe (Result -> SubExp)
-> Builder SOACS Result
-> BuilderT SOACS (State VNameSource) SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body (Rep (BuilderT SOACS (State VNameSource)))
-> Builder SOACS Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind Body (Rep (BuilderT SOACS (State VNameSource)))
Body SOACS
lam_body
subExpsRes <$> mkResult eq_class 0