{-# LANGUAGE QuasiQuotes #-}

-- | Code generation for public API types.
module Futhark.CodeGen.Backends.GenericC.Types
  ( generateAPITypes,
    valueTypeToCType,
    opaqueToCType,
  )
where

import Control.Monad
import Control.Monad.Reader (asks)
import Control.Monad.State (gets, modify)
import Data.List qualified as L
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Text qualified as T
import Futhark.CodeGen.Backends.GenericC.Monad
import Futhark.CodeGen.Backends.GenericC.Pretty
import Futhark.CodeGen.ImpCode
import Futhark.Manifest qualified as Manifest
import Futhark.Util (chunks, mapAccumLM, zEncodeText)
import Language.C.Quote.OpenCL qualified as C
import Language.C.Syntax qualified as C

opaqueToCType :: Name -> CompilerM op s C.Type
opaqueToCType :: forall op s. Name -> CompilerM op s Type
opaqueToCType Name
desc = do
  name <- Text -> CompilerM op s Text
forall op s. Text -> CompilerM op s Text
publicName (Text -> CompilerM op s Text) -> Text -> CompilerM op s Text
forall a b. (a -> b) -> a -> b
$ Name -> Text
opaqueName Name
desc
  pure [C.cty|struct $id:name|]

valueTypeToCType :: Publicness -> ValueType -> CompilerM op s C.Type
valueTypeToCType :: forall op s. Publicness -> ValueType -> CompilerM op s Type
valueTypeToCType Publicness
_ (ValueType Signedness
signed (Rank Int
0) PrimType
pt) =
  Type -> CompilerM op s Type
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type -> CompilerM op s Type) -> Type -> CompilerM op s Type
forall a b. (a -> b) -> a -> b
$ Signedness -> PrimType -> Type
primAPIType Signedness
signed PrimType
pt
valueTypeToCType Publicness
pub (ValueType Signedness
signed (Rank Int
rank) PrimType
pt) = do
  name <- Text -> CompilerM op s Text
forall op s. Text -> CompilerM op s Text
publicName (Text -> CompilerM op s Text) -> Text -> CompilerM op s Text
forall a b. (a -> b) -> a -> b
$ PrimType -> Signedness -> Int -> Text
arrayName PrimType
pt Signedness
signed Int
rank
  let add = (Publicness -> Publicness -> Publicness)
-> (Signedness, PrimType, Int)
-> Publicness
-> Map (Signedness, PrimType, Int) Publicness
-> Map (Signedness, PrimType, Int) Publicness
forall k a. Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
M.insertWith Publicness -> Publicness -> Publicness
forall a. Ord a => a -> a -> a
max (Signedness
signed, PrimType
pt, Int
rank) Publicness
pub
  modify $ \CompilerState s
s -> CompilerState s
s {compArrayTypes = add $ compArrayTypes s}
  pure [C.cty|struct $id:name|]

prepareNewMem ::
  (C.ToExp arr, C.ToExp dim) =>
  arr ->
  Space ->
  [dim] ->
  PrimType ->
  CompilerM op s ()
prepareNewMem :: forall arr dim op s.
(ToExp arr, ToExp dim) =>
arr -> Space -> [dim] -> PrimType -> CompilerM op s ()
prepareNewMem arr
arr Space
space [dim]
shape PrimType
pt = do
  let rank :: Int
rank = [dim] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [dim]
shape
      arr_size :: Exp
arr_size = [Exp] -> Exp
cproduct [[C.cexp|$exp:k|] | dim
k <- [dim]
shape]
  Exp -> Space -> CompilerM op s ()
forall a op s. ToExp a => a -> Space -> CompilerM op s ()
resetMem [C.cexp|$exp:arr->mem|] Space
space
  Exp -> Exp -> Space -> Stm -> CompilerM op s ()
forall a b op s.
(ToExp a, ToExp b) =>
a -> b -> Space -> Stm -> CompilerM op s ()
allocMem
    [C.cexp|$exp:arr->mem|]
    [C.cexp|$exp:arr_size * $int:(primByteSize pt::Int)|]
    Space
space
    [C.cstm|err = 1;|]
  [(Int, dim)]
-> ((Int, dim) -> CompilerM op s ()) -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Int] -> [dim] -> [(Int, dim)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] [dim]
shape) (((Int, dim) -> CompilerM op s ()) -> CompilerM op s ())
-> ((Int, dim) -> CompilerM op s ()) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \(Int
i, dim
dim_s) ->
    Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$exp:arr->shape[$int:i] = $exp:dim_s;|]

arrayLibraryFunctions ::
  Publicness ->
  Space ->
  PrimType ->
  Signedness ->
  Int ->
  CompilerM op s Manifest.ArrayOps
arrayLibraryFunctions :: forall op s.
Publicness
-> Space
-> PrimType
-> Signedness
-> Int
-> CompilerM op s ArrayOps
arrayLibraryFunctions Publicness
pub Space
space PrimType
pt Signedness
signed Int
rank = do
  let pt' :: Type
pt' = Signedness -> PrimType -> Type
primAPIType Signedness
signed PrimType
pt
      name :: Text
name = PrimType -> Signedness -> Int -> Text
arrayName PrimType
pt Signedness
signed Int
rank
      arr_name :: Text
arr_name = Text
"futhark_" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
name
      array_type :: Type
array_type = [C.cty|struct $id:arr_name|]

  new_array <- Text -> CompilerM op s Text
forall op s. Text -> CompilerM op s Text
publicName (Text -> CompilerM op s Text) -> Text -> CompilerM op s Text
forall a b. (a -> b) -> a -> b
$ Text
"new_" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
name
  new_raw_array <- publicName $ "new_raw_" <> name
  free_array <- publicName $ "free_" <> name
  values_array <- publicName $ "values_" <> name
  values_raw_array <- publicName $ "values_raw_" <> name
  shape_array <- publicName $ "shape_" <> name
  index_array <- publicName $ "index_" <> name

  let shape_names = [Text
"dim" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Int -> Text
forall a. Pretty a => a -> Text
prettyText Int
i | Int
i <- [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]]
      shape_params = [[C.cparam|typename int64_t $id:k|] | Text
k <- [Text]
shape_names]
      shape = [[C.cexp|$id:k|] | Text
k <- [Text]
shape_names]
      index_names = [Text
"i" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Int -> Text
forall a. Pretty a => a -> Text
prettyText Int
i | Int
i <- [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]]
      index_params = [[C.cparam|typename int64_t $id:k|] | Text
k <- [Text]
index_names]
      arr_size = [Exp] -> Exp
cproduct [Exp]
shape
      arr_size_array = [Exp] -> Exp
cproduct [[C.cexp|arr->shape[$int:i]|] | Int
i <- [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]]

  copy <- asks $ opsCopy . envOperations

  memty <- rawMemCType space

  new_body <- collect $ do
    prepareNewMem [C.cexp|arr|] space shape pt
    copy
      CopyNoBarrier
      [C.cexp|arr->mem.mem|]
      [C.cexp|0|]
      space
      [C.cexp|(const unsigned char*)data|]
      [C.cexp|0|]
      DefaultSpace
      [C.cexp|((size_t)$exp:arr_size) * $int:(primByteSize pt::Int)|]

  new_raw_body <- collect $ do
    resetMem [C.cexp|arr->mem|] space
    stm [C.cstm|arr->mem.mem = data;|]
    forM_ [0 .. rank - 1] $ \Int
i ->
      let dim_s :: [Char]
dim_s = [Char]
"dim" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
i
       in Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|arr->shape[$int:i] = $id:dim_s;|]

  free_body <- collect $ unRefMem [C.cexp|arr->mem|] space

  values_body <-
    collect $
      copy
        CopyNoBarrier
        [C.cexp|(unsigned char*)data|]
        [C.cexp|0|]
        DefaultSpace
        [C.cexp|arr->mem.mem|]
        [C.cexp|0|]
        space
        [C.cexp|((size_t)$exp:arr_size_array) * $int:(primByteSize pt::Int)|]

  let arr_strides = do
        r <- [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
        pure $ cproduct [[C.cexp|arr->shape[$int:i]|] | i <- [r + 1 .. rank - 1]]
      index_exp =
        [Exp] -> Exp
cproduct
          [ [C.cexp|$int:(primByteSize pt::Int)|],
            [Exp] -> Exp
csum ((Text -> Exp -> Exp) -> [Text] -> [Exp] -> [Exp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Text
x Exp
y -> [C.cexp|$id:x * $exp:y|]) [Text]
index_names [Exp]
arr_strides)
          ]
      in_bounds =
        [Exp] -> Exp
allTrue
          [ [C.cexp|$id:p >= 0 && $id:p < arr->shape[$int:i]|]
            | (Text
p, Int
i) <- [Text] -> [Int] -> [(Text, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Text]
index_names [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
          ]
  index_body <-
    collect $
      copy
        CopyNoBarrier
        [C.cexp|(unsigned char*)out|]
        [C.cexp|0|]
        DefaultSpace
        [C.cexp|arr->mem.mem|]
        index_exp
        space
        [C.cexp|$int:(primByteSize pt::Int)|]

  ctx_ty <- contextType
  ops <- asks envOperations

  let proto = case Publicness
pub of
        Publicness
Public -> HeaderSection -> Definition -> CompilerM op s ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
headerDecl (Name -> HeaderSection
ArrayDecl (Text -> Name
nameFromText Text
name))
        Publicness
Private -> Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
libDecl

  proto
    [C.cedecl|struct $id:arr_name;|]
  proto
    [C.cedecl|$ty:array_type* $id:new_array($ty:ctx_ty *ctx, const $ty:pt' *data, $params:shape_params);|]
  proto
    [C.cedecl|$ty:array_type* $id:new_raw_array($ty:ctx_ty *ctx, $ty:memty data, $params:shape_params);|]
  proto
    [C.cedecl|int $id:free_array($ty:ctx_ty *ctx, $ty:array_type *arr);|]
  proto
    [C.cedecl|int $id:values_array($ty:ctx_ty *ctx, $ty:array_type *arr, $ty:pt' *data);|]
  proto
    [C.cedecl|int $id:index_array($ty:ctx_ty *ctx, $ty:pt' *out, $ty:array_type *arr,
                                  $params:index_params);|]
  proto
    [C.cedecl|$ty:memty $id:values_raw_array($ty:ctx_ty *ctx, $ty:array_type *arr);|]
  proto
    [C.cedecl|const typename int64_t* $id:shape_array($ty:ctx_ty *ctx, $ty:array_type *arr);|]

  mapM_
    libDecl
    [C.cunit|
          $ty:array_type* $id:new_array($ty:ctx_ty *ctx, const $ty:pt' *data, $params:shape_params) {
            int err = 0;
            $ty:array_type* bad = NULL;
            $ty:array_type *arr = ($ty:array_type*) malloc(sizeof($ty:array_type));
            if (arr == NULL) {
              return bad;
            }
            $items:(criticalSection ops new_body)
            if (err != 0) {
              free(arr);
              return bad;
            }
            return arr;
          }

          $ty:array_type* $id:new_raw_array($ty:ctx_ty *ctx, $ty:memty data, $params:shape_params) {
            int err = 0;
            $ty:array_type* bad = NULL;
            $ty:array_type *arr = ($ty:array_type*) malloc(sizeof($ty:array_type));
            if (arr == NULL) {
              return bad;
            }
            $items:(criticalSection ops new_raw_body)
            return arr;
          }

          int $id:free_array($ty:ctx_ty *ctx, $ty:array_type *arr) {
            $items:(criticalSection ops free_body)
            free(arr);
            return 0;
          }

          int $id:values_array($ty:ctx_ty *ctx, $ty:array_type *arr, $ty:pt' *data) {
            int err = 0;
            $items:(criticalSection ops values_body)
            return err;
          }

          int $id:index_array($ty:ctx_ty *ctx, $ty:pt' *out, $ty:array_type *arr,
                              $params:index_params) {
            int err = 0;
            if ($exp:in_bounds) {
              $items:(criticalSection ops index_body)
            } else {
              err = 1;
              set_error(ctx, strdup("Index out of bounds."));
            }
            return err;
          }

          $ty:memty $id:values_raw_array($ty:ctx_ty *ctx, $ty:array_type *arr) {
            (void)ctx;
            return arr->mem.mem;
          }

          const typename int64_t* $id:shape_array($ty:ctx_ty *ctx, $ty:array_type *arr) {
            (void)ctx;
            return arr->shape;
          }
          |]

  pure $
    Manifest.ArrayOps
      { Manifest.arrayFree = free_array,
        Manifest.arrayShape = shape_array,
        Manifest.arrayValues = values_array,
        Manifest.arrayNew = new_array,
        Manifest.arrayNewRaw = new_raw_array,
        Manifest.arrayValuesRaw = values_raw_array,
        Manifest.arrayIndex = index_array
      }

lookupOpaqueType :: Name -> OpaqueTypes -> OpaqueType
lookupOpaqueType :: Name -> OpaqueTypes -> OpaqueType
lookupOpaqueType Name
v (OpaqueTypes [(Name, OpaqueType)]
types) =
  case Name -> [(Name, OpaqueType)] -> Maybe OpaqueType
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Name
v [(Name, OpaqueType)]
types of
    Just OpaqueType
t -> OpaqueType
t
    Maybe OpaqueType
Nothing -> [Char] -> OpaqueType
forall a. HasCallStack => [Char] -> a
error ([Char] -> OpaqueType) -> [Char] -> OpaqueType
forall a b. (a -> b) -> a -> b
$ [Char]
"Unknown opaque type: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Name -> [Char]
forall a. Show a => a -> [Char]
show Name
v

opaquePayload :: OpaqueTypes -> OpaqueType -> [ValueType]
opaquePayload :: OpaqueTypes -> OpaqueType -> [ValueType]
opaquePayload OpaqueTypes
_ (OpaqueType [ValueType]
ts) = [ValueType]
ts
opaquePayload OpaqueTypes
_ (OpaqueSum [ValueType]
ts [(Name, [(EntryPointType, [Int])])]
_) = [ValueType]
ts
opaquePayload OpaqueTypes
_ (OpaqueArray Int
_ Name
_ [ValueType]
ts) = [ValueType]
ts
opaquePayload OpaqueTypes
types (OpaqueRecord [(Name, EntryPointType)]
fs) = ((Name, EntryPointType) -> [ValueType])
-> [(Name, EntryPointType)] -> [ValueType]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Name, EntryPointType) -> [ValueType]
forall {a}. (a, EntryPointType) -> [ValueType]
f [(Name, EntryPointType)]
fs
  where
    f :: (a, EntryPointType) -> [ValueType]
f (a
_, TypeOpaque Name
s) = OpaqueTypes -> OpaqueType -> [ValueType]
opaquePayload OpaqueTypes
types (OpaqueType -> [ValueType]) -> OpaqueType -> [ValueType]
forall a b. (a -> b) -> a -> b
$ Name -> OpaqueTypes -> OpaqueType
lookupOpaqueType Name
s OpaqueTypes
types
    f (a
_, TypeTransparent ValueType
v) = [ValueType
v]
opaquePayload OpaqueTypes
types (OpaqueRecordArray Int
_ Name
_ [(Name, EntryPointType)]
fs) = ((Name, EntryPointType) -> [ValueType])
-> [(Name, EntryPointType)] -> [ValueType]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Name, EntryPointType) -> [ValueType]
forall {a}. (a, EntryPointType) -> [ValueType]
f [(Name, EntryPointType)]
fs
  where
    f :: (a, EntryPointType) -> [ValueType]
f (a
_, TypeOpaque Name
s) = OpaqueTypes -> OpaqueType -> [ValueType]
opaquePayload OpaqueTypes
types (OpaqueType -> [ValueType]) -> OpaqueType -> [ValueType]
forall a b. (a -> b) -> a -> b
$ Name -> OpaqueTypes -> OpaqueType
lookupOpaqueType Name
s OpaqueTypes
types
    f (a
_, TypeTransparent ValueType
v) = [ValueType
v]

entryPointTypeToCType :: Publicness -> EntryPointType -> CompilerM op s C.Type
entryPointTypeToCType :: forall op s. Publicness -> EntryPointType -> CompilerM op s Type
entryPointTypeToCType Publicness
_ (TypeOpaque Name
desc) = Name -> CompilerM op s Type
forall op s. Name -> CompilerM op s Type
opaqueToCType Name
desc
entryPointTypeToCType Publicness
pub (TypeTransparent ValueType
vt) = Publicness -> ValueType -> CompilerM op s Type
forall op s. Publicness -> ValueType -> CompilerM op s Type
valueTypeToCType Publicness
pub ValueType
vt

entryTypeName :: EntryPointType -> Manifest.TypeName
entryTypeName :: EntryPointType -> Text
entryTypeName (TypeOpaque Name
desc) = Name -> Text
nameToText Name
desc
entryTypeName (TypeTransparent ValueType
vt) = ValueType -> Text
forall a. Pretty a => a -> Text
prettyText ValueType
vt

-- | Figure out which of the members of an opaque type corresponds to
-- which fields.
recordFieldPayloads :: OpaqueTypes -> [EntryPointType] -> [a] -> [[a]]
recordFieldPayloads :: forall a. OpaqueTypes -> [EntryPointType] -> [a] -> [[a]]
recordFieldPayloads OpaqueTypes
types = [Int] -> [a] -> [[a]]
forall a. [Int] -> [a] -> [[a]]
chunks ([Int] -> [a] -> [[a]])
-> ([EntryPointType] -> [Int]) -> [EntryPointType] -> [a] -> [[a]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (EntryPointType -> Int) -> [EntryPointType] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map EntryPointType -> Int
typeLength
  where
    typeLength :: EntryPointType -> Int
typeLength (TypeTransparent ValueType
_) = Int
1
    typeLength (TypeOpaque Name
desc) =
      [ValueType] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([ValueType] -> Int) -> [ValueType] -> Int
forall a b. (a -> b) -> a -> b
$ OpaqueTypes -> OpaqueType -> [ValueType]
opaquePayload OpaqueTypes
types (OpaqueType -> [ValueType]) -> OpaqueType -> [ValueType]
forall a b. (a -> b) -> a -> b
$ Name -> OpaqueTypes -> OpaqueType
lookupOpaqueType Name
desc OpaqueTypes
types

projectField ::
  Operations op s ->
  EntryPointType ->
  [(Int, ValueType)] ->
  CompilerM op s (C.Type, [C.BlockItem])
projectField :: forall op s.
Operations op s
-> EntryPointType
-> [(Int, ValueType)]
-> CompilerM op s (Type, [BlockItem])
projectField Operations op s
_ (TypeTransparent (ValueType Signedness
sign (Rank Int
0) PrimType
pt)) [(Int
i, ValueType
_)] = do
  (Type, [BlockItem]) -> CompilerM op s (Type, [BlockItem])
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( Signedness -> PrimType -> Type
primAPIType Signedness
sign PrimType
pt,
      [C.citems|v = obj->$id:(tupleField i);|]
    )
projectField Operations op s
ops (TypeTransparent ValueType
vt) [(Int
i, ValueType
_)] = do
  ct <- Publicness -> ValueType -> CompilerM op s Type
forall op s. Publicness -> ValueType -> CompilerM op s Type
valueTypeToCType Publicness
Public ValueType
vt
  pure
    ( [C.cty|$ty:ct *|],
      criticalSection
        ops
        [C.citems|v = malloc(sizeof($ty:ct));
                  memcpy(v, obj->$id:(tupleField i), sizeof($ty:ct));
                  (void)(*(v->mem.references))++;|]
    )
projectField Operations op s
_ (TypeTransparent ValueType
_) [(Int, ValueType)]
rep =
  [Char] -> CompilerM op s (Type, [BlockItem])
forall a. HasCallStack => [Char] -> a
error ([Char] -> CompilerM op s (Type, [BlockItem]))
-> [Char] -> CompilerM op s (Type, [BlockItem])
forall a b. (a -> b) -> a -> b
$ [Char]
"projectField: invalid representation of transparent type: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [(Int, ValueType)] -> [Char]
forall a. Show a => a -> [Char]
show [(Int, ValueType)]
rep
projectField Operations op s
ops (TypeOpaque Name
f_desc) [(Int, ValueType)]
components = do
  ct <- Name -> CompilerM op s Type
forall op s. Name -> CompilerM op s Type
opaqueToCType Name
f_desc
  let setField Int
j (Int
i, ValueType Signedness
_ (Rank Int
r) PrimType
_) =
        if Int
r Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
          then [C.citems|v->$id:(tupleField j) = obj->$id:(tupleField i);|]
          else
            [C.citems|v->$id:(tupleField j) = malloc(sizeof(*v->$id:(tupleField j)));
                      *v->$id:(tupleField j) = *obj->$id:(tupleField i);
                      (void)(*(v->$id:(tupleField j)->mem.references))++;|]
  pure
    ( [C.cty|$ty:ct *|],
      criticalSection
        ops
        [C.citems|v = malloc(sizeof($ty:ct));
                  $items:(concat (zipWith setField [0..] components))|]
    )

recordProjectFunctions ::
  OpaqueTypes ->
  Name ->
  [(Name, EntryPointType)] ->
  [ValueType] ->
  CompilerM op s [Manifest.RecordField]
recordProjectFunctions :: forall op s.
OpaqueTypes
-> Name
-> [(Name, EntryPointType)]
-> [ValueType]
-> CompilerM op s [RecordField]
recordProjectFunctions OpaqueTypes
types Name
desc [(Name, EntryPointType)]
fs [ValueType]
vds = do
  opaque_type <- Name -> CompilerM op s Type
forall op s. Name -> CompilerM op s Type
opaqueToCType Name
desc
  ctx_ty <- contextType
  ops <- asks envOperations
  let onField ((Name
f, EntryPointType
et), [(Int, ValueType)]
elems) = do
        let f' :: Text
f' =
              if Text -> Bool
isValidCName (Text -> Bool) -> Text -> Bool
forall a b. (a -> b) -> a -> b
$ Name -> Text
opaqueName Name
desc Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"_" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Name -> Text
nameToText Name
f
                then Name -> Text
nameToText Name
f
                else Text -> Text
zEncodeText (Name -> Text
nameToText Name
f)
        project <- Text -> CompilerM op s Text
forall op s. Text -> CompilerM op s Text
publicName (Text -> CompilerM op s Text) -> Text -> CompilerM op s Text
forall a b. (a -> b) -> a -> b
$ Text
"project_" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Name -> Text
opaqueName Name
desc Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"_" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
f'
        (et_ty, project_items) <- projectField ops et elems
        headerDecl
          (OpaqueDecl desc)
          [C.cedecl|int $id:project($ty:ctx_ty *ctx, $ty:et_ty *out, const $ty:opaque_type *obj);|]
        libDecl
          [C.cedecl|int $id:project($ty:ctx_ty *ctx, $ty:et_ty *out, const $ty:opaque_type *obj) {
                      (void)ctx;
                      $ty:et_ty v;
                      $items:project_items
                      *out = v;
                      return 0;
                    }|]
        pure $ Manifest.RecordField (nameToText f) (entryTypeName et) project

  mapM onField . zip fs . recordFieldPayloads types (map snd fs) $
    zip [0 ..] vds

setFieldField :: (C.ToExp a) => Int -> a -> ValueType -> C.Stm
setFieldField :: forall a. ToExp a => Int -> a -> ValueType -> Stm
setFieldField Int
i a
e (ValueType Signedness
_ (Rank Int
r) PrimType
_)
  | Int
r Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 =
      [C.cstm|v->$id:(tupleField i) = $exp:e;|]
  | Bool
otherwise =
      [C.cstm|{v->$id:(tupleField i) = malloc(sizeof(*$exp:e));
               *v->$id:(tupleField i) = *$exp:e;
               (void)(*(v->$id:(tupleField i)->mem.references))++;}|]

recordNewSetFields ::
  OpaqueTypes ->
  [(Name, EntryPointType)] ->
  [ValueType] ->
  CompilerM op s ([C.Id], [C.Param], [C.BlockItem])
recordNewSetFields :: forall op s.
OpaqueTypes
-> [(Name, EntryPointType)]
-> [ValueType]
-> CompilerM op s ([Id], [Param], [BlockItem])
recordNewSetFields OpaqueTypes
types [(Name, EntryPointType)]
fs =
  ((Int, [(Id, Param, BlockItem)]) -> ([Id], [Param], [BlockItem]))
-> CompilerM op s (Int, [(Id, Param, BlockItem)])
-> CompilerM op s ([Id], [Param], [BlockItem])
forall a b. (a -> b) -> CompilerM op s a -> CompilerM op s b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([(Id, Param, BlockItem)] -> ([Id], [Param], [BlockItem])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
L.unzip3 ([(Id, Param, BlockItem)] -> ([Id], [Param], [BlockItem]))
-> ((Int, [(Id, Param, BlockItem)]) -> [(Id, Param, BlockItem)])
-> (Int, [(Id, Param, BlockItem)])
-> ([Id], [Param], [BlockItem])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int, [(Id, Param, BlockItem)]) -> [(Id, Param, BlockItem)]
forall a b. (a, b) -> b
snd)
    (CompilerM op s (Int, [(Id, Param, BlockItem)])
 -> CompilerM op s ([Id], [Param], [BlockItem]))
-> ([ValueType] -> CompilerM op s (Int, [(Id, Param, BlockItem)]))
-> [ValueType]
-> CompilerM op s ([Id], [Param], [BlockItem])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int
 -> ((Name, EntryPointType), [ValueType])
 -> CompilerM op s (Int, (Id, Param, BlockItem)))
-> Int
-> [((Name, EntryPointType), [ValueType])]
-> CompilerM op s (Int, [(Id, Param, BlockItem)])
forall (m :: * -> *) (t :: * -> *) acc x y.
(Monad m, Traversable t) =>
(acc -> x -> m (acc, y)) -> acc -> t x -> m (acc, t y)
mapAccumLM Int
-> ((Name, EntryPointType), [ValueType])
-> CompilerM op s (Int, (Id, Param, BlockItem))
forall {a} {op} {s}.
(ToIdent a, Semigroup a, IsString a) =>
Int
-> ((a, EntryPointType), [ValueType])
-> CompilerM op s (Int, (Id, Param, BlockItem))
onField Int
0
    ([((Name, EntryPointType), [ValueType])]
 -> CompilerM op s (Int, [(Id, Param, BlockItem)]))
-> ([ValueType] -> [((Name, EntryPointType), [ValueType])])
-> [ValueType]
-> CompilerM op s (Int, [(Id, Param, BlockItem)])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Name, EntryPointType)]
-> [[ValueType]] -> [((Name, EntryPointType), [ValueType])]
forall a b. [a] -> [b] -> [(a, b)]
zip [(Name, EntryPointType)]
fs
    ([[ValueType]] -> [((Name, EntryPointType), [ValueType])])
-> ([ValueType] -> [[ValueType]])
-> [ValueType]
-> [((Name, EntryPointType), [ValueType])]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OpaqueTypes -> [EntryPointType] -> [ValueType] -> [[ValueType]]
forall a. OpaqueTypes -> [EntryPointType] -> [a] -> [[a]]
recordFieldPayloads OpaqueTypes
types (((Name, EntryPointType) -> EntryPointType)
-> [(Name, EntryPointType)] -> [EntryPointType]
forall a b. (a -> b) -> [a] -> [b]
map (Name, EntryPointType) -> EntryPointType
forall a b. (a, b) -> b
snd [(Name, EntryPointType)]
fs)
  where
    onField :: Int
-> ((a, EntryPointType), [ValueType])
-> CompilerM op s (Int, (Id, Param, BlockItem))
onField Int
offset ((a
f, EntryPointType
et), [ValueType]
f_vts) = do
      let param_name :: Id
param_name = a -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent (a
"f_" a -> a -> a
forall a. Semigroup a => a -> a -> a
<> a
f) SrcLoc
forall a. Monoid a => a
mempty
      case EntryPointType
et of
        TypeTransparent (ValueType Signedness
sign (Rank Int
0) PrimType
pt) -> do
          let ct :: Type
ct = Signedness -> PrimType -> Type
primAPIType Signedness
sign PrimType
pt
          (Int, (Id, Param, BlockItem))
-> CompilerM op s (Int, (Id, Param, BlockItem))
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
            ( Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1,
              ( Id
param_name,
                [C.cparam|const $ty:ct $id:param_name|],
                [C.citem|v->$id:(tupleField offset) = $id:param_name;|]
              )
            )
        TypeTransparent ValueType
vt -> do
          ct <- Publicness -> ValueType -> CompilerM op s Type
forall op s. Publicness -> ValueType -> CompilerM op s Type
valueTypeToCType Publicness
Public ValueType
vt
          pure
            ( offset + 1,
              ( param_name,
                [C.cparam|const $ty:ct* $id:param_name|],
                [C.citem|{v->$id:(tupleField offset) = malloc(sizeof($ty:ct));
                          *v->$id:(tupleField offset) = *$id:param_name;
                          (void)(*(v->$id:(tupleField offset)->mem.references))++;}|]
              )
            )
        TypeOpaque Name
f_desc -> do
          ct <- Name -> CompilerM op s Type
forall op s. Name -> CompilerM op s Type
opaqueToCType Name
f_desc
          let param_fields = do
                i <- [Int
0 ..]
                pure [C.cexp|$id:param_name->$id:(tupleField i)|]
          pure
            ( offset + length f_vts,
              ( param_name,
                [C.cparam|const $ty:ct* $id:param_name|],
                [C.citem|{$stms:(zipWith3 setFieldField [offset ..] param_fields f_vts)}|]
              )
            )

recordNewFunctions ::
  OpaqueTypes ->
  Name ->
  [(Name, EntryPointType)] ->
  [ValueType] ->
  CompilerM op s Manifest.CFuncName
recordNewFunctions :: forall op s.
OpaqueTypes
-> Name
-> [(Name, EntryPointType)]
-> [ValueType]
-> CompilerM op s Text
recordNewFunctions OpaqueTypes
types Name
desc [(Name, EntryPointType)]
fs [ValueType]
vds = do
  opaque_type <- Name -> CompilerM op s Type
forall op s. Name -> CompilerM op s Type
opaqueToCType Name
desc
  ctx_ty <- contextType
  ops <- asks envOperations
  new <- publicName $ "new_" <> opaqueName desc

  (_, params, new_stms) <- recordNewSetFields types fs vds

  headerDecl
    (OpaqueDecl desc)
    [C.cedecl|int $id:new($ty:ctx_ty *ctx, $ty:opaque_type** out, $params:params);|]
  libDecl
    [C.cedecl|int $id:new($ty:ctx_ty *ctx, $ty:opaque_type** out, $params:params) {
                $ty:opaque_type* v = malloc(sizeof($ty:opaque_type));
                $items:(criticalSection ops new_stms)
                *out = v;
                return FUTHARK_SUCCESS;
              }|]
  pure new

-- Because records and arrays-of-records are very similar in their
-- actual representation, we can reuse most of the code. Only indexing
-- requires something special.

recordArrayProjectFunctions ::
  OpaqueTypes ->
  Name ->
  [(Name, EntryPointType)] ->
  [ValueType] ->
  CompilerM op s [Manifest.RecordField]
recordArrayProjectFunctions :: forall op s.
OpaqueTypes
-> Name
-> [(Name, EntryPointType)]
-> [ValueType]
-> CompilerM op s [RecordField]
recordArrayProjectFunctions = OpaqueTypes
-> Name
-> [(Name, EntryPointType)]
-> [ValueType]
-> CompilerM op s [RecordField]
forall op s.
OpaqueTypes
-> Name
-> [(Name, EntryPointType)]
-> [ValueType]
-> CompilerM op s [RecordField]
recordProjectFunctions

recordArrayZipFunctions ::
  OpaqueTypes ->
  Name ->
  [(Name, EntryPointType)] ->
  [ValueType] ->
  Int ->
  CompilerM op s Manifest.CFuncName
recordArrayZipFunctions :: forall op s.
OpaqueTypes
-> Name
-> [(Name, EntryPointType)]
-> [ValueType]
-> Int
-> CompilerM op s Text
recordArrayZipFunctions OpaqueTypes
types Name
desc [(Name, EntryPointType)]
fs [ValueType]
vds Int
rank = do
  opaque_type <- Name -> CompilerM op s Type
forall op s. Name -> CompilerM op s Type
opaqueToCType Name
desc
  ctx_ty <- contextType
  ops <- asks envOperations
  new <- publicName $ "zip_" <> opaqueName desc

  (param_names, params, new_stms) <- recordNewSetFields types fs vds

  headerDecl
    (OpaqueDecl desc)
    [C.cedecl|int $id:new($ty:ctx_ty *ctx, $ty:opaque_type** out, $params:params);|]
  libDecl
    [C.cedecl|int $id:new($ty:ctx_ty *ctx, $ty:opaque_type** out, $params:params) {
                if (!$exp:(sameShape param_names)) {
                  set_error(ctx, strdup("Cannot zip arrays with different shapes."));
                  return 1;
                }
                $ty:opaque_type* v = malloc(sizeof($ty:opaque_type));
                $items:(criticalSection ops new_stms)
                *out = v;
                return FUTHARK_SUCCESS;
              }|]
  pure new
  where
    valueShape :: EntryPointType -> a -> [Exp]
valueShape TypeTransparent {} a
p =
      [[C.cexp|$id:p->shape[$int:i]|] | Int
i <- [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]]
    -- We know that the opaque value must contain arrays.
    valueShape TypeOpaque {} a
p =
      [[C.cexp|$id:p->$id:(tupleField 0)->shape[$int:i]|] | Int
i <- [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]]
    sameShape :: [b] -> Exp
sameShape [b]
param_names =
      [Exp] -> Exp
allTrue ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$ ([Exp] -> Exp) -> [[Exp]] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map [Exp] -> Exp
allEqual ([[Exp]] -> [Exp]) -> [[Exp]] -> [Exp]
forall a b. (a -> b) -> a -> b
$ [[Exp]] -> [[Exp]]
forall a. [[a]] -> [[a]]
L.transpose ([[Exp]] -> [[Exp]]) -> [[Exp]] -> [[Exp]]
forall a b. (a -> b) -> a -> b
$ (EntryPointType -> b -> [Exp])
-> [EntryPointType] -> [b] -> [[Exp]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith EntryPointType -> b -> [Exp]
forall {a}. ToIdent a => EntryPointType -> a -> [Exp]
valueShape (((Name, EntryPointType) -> EntryPointType)
-> [(Name, EntryPointType)] -> [EntryPointType]
forall a b. (a -> b) -> [a] -> [b]
map (Name, EntryPointType) -> EntryPointType
forall a b. (a, b) -> b
snd [(Name, EntryPointType)]
fs) [b]
param_names

recordArrayIndexFunctions ::
  Space ->
  OpaqueTypes ->
  Name ->
  Int ->
  Name ->
  [ValueType] ->
  CompilerM op s Manifest.CFuncName
recordArrayIndexFunctions :: forall op s.
Space
-> OpaqueTypes
-> Name
-> Int
-> Name
-> [ValueType]
-> CompilerM op s Text
recordArrayIndexFunctions Space
space OpaqueTypes
_types Name
desc Int
rank Name
elemtype [ValueType]
vds = do
  index_f <- Text -> CompilerM op s Text
forall op s. Text -> CompilerM op s Text
publicName (Text -> CompilerM op s Text) -> Text -> CompilerM op s Text
forall a b. (a -> b) -> a -> b
$ Text
"index_" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Name -> Text
opaqueName Name
desc
  ctx_ty <- contextType
  array_ct <- opaqueToCType desc
  obj_ct <- opaqueToCType elemtype
  copy <- asks $ opsCopy . envOperations

  index_items <- collect $ zipWithM_ (setField copy) [0 ..] vds

  headerDecl
    (OpaqueDecl desc)
    [C.cedecl|int $id:index_f($ty:ctx_ty *ctx, $ty:obj_ct **out, $ty:array_ct *arr,
                              $params:index_params);|]
  libDecl
    [C.cedecl|int $id:index_f($ty:ctx_ty *ctx, $ty:obj_ct **out, $ty:array_ct *arr,
                              $params:index_params) {
                int err = 0;
                if ($exp:in_bounds) {
                  $ty:obj_ct* v = malloc(sizeof($ty:obj_ct));
                  $items:index_items
                  if (err == 0) {
                    *out = v;
                  }
                } else {
                  err = 1;
                  set_error(ctx, strdup("Index out of bounds."));
                }
                return err;
              }|]

  pure index_f
  where
    index_names :: [Text]
index_names = [Text
"i" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Int -> Text
forall a. Pretty a => a -> Text
prettyText Int
i | Int
i <- [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]]
    index_params :: [Param]
index_params = [[C.cparam|typename int64_t $id:k|] | Text
k <- [Text]
index_names]
    indexExp :: PrimType -> p -> p -> Exp
indexExp PrimType
pt p
r p
shape =
      [Exp] -> Exp
cproduct
        [ [C.cexp|$int:(primByteSize pt::Int)|],
          [Exp] -> Exp
csum ((Text -> Exp -> Exp) -> [Text] -> [Exp] -> [Exp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Text
x Exp
y -> [C.cexp|$id:x * $exp:y|]) [Text]
index_names [Exp]
strides)
        ]
      where
        strides :: [Exp]
strides = do
          d <- [p
0 .. p
r p -> p -> p
forall a. Num a => a -> a -> a
- p
1]
          pure $ cproduct [[C.cexp|$exp:shape[$int:i]|] | i <- [d + 1 .. r - 1]]

    in_bounds :: Exp
in_bounds =
      [Exp] -> Exp
allTrue
        [ [C.cexp|$id:p >= 0 && $id:p < arr->$id:(tupleField 0)->shape[$int:i]|]
          | (Text
p, Int
i) <- [Text] -> [Int] -> [(Text, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Text]
index_names [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
        ]

    setField :: (CopyBarrier
 -> Exp
 -> Exp
 -> Space
 -> Exp
 -> Exp
 -> Space
 -> Exp
 -> CompilerM op s b)
-> Int -> ValueType -> CompilerM op s b
setField CopyBarrier
-> Exp
-> Exp
-> Space
-> Exp
-> Exp
-> Space
-> Exp
-> CompilerM op s b
copy Int
j (ValueType Signedness
_ (Rank Int
r) PrimType
pt)
      | Int
r Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
rank =
          -- Easy case: just copy the scalar from the array into the
          -- variable.
          CopyBarrier
-> Exp
-> Exp
-> Space
-> Exp
-> Exp
-> Space
-> Exp
-> CompilerM op s b
copy
            CopyBarrier
CopyNoBarrier
            [C.cexp|(unsigned char*)&v->$id:(tupleField j)|]
            [C.cexp|0|]
            Space
DefaultSpace
            [C.cexp|arr->$id:(tupleField j)->mem.mem|]
            (PrimType -> Int -> Exp -> Exp
forall {p} {p}.
(ToExp p, Show p, Integral p) =>
PrimType -> p -> p -> Exp
indexExp PrimType
pt Int
rank [C.cexp|arr->$id:(tupleField j)->shape|])
            Space
space
            [C.cexp|$int:(primByteSize pt::Int)|]
      | Bool
otherwise = do
          -- Tricky case, where we first have to allocate memory.
          let shape :: [Exp]
shape = do
                i <- [Int
rank .. Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
                pure [C.cexp|arr->$id:(tupleField j)->shape[$int:i]|]
          Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|v->$id:(tupleField j) = malloc(sizeof(*v->$id:(tupleField j)));|]
          Exp -> Space -> [Exp] -> PrimType -> CompilerM op s ()
forall arr dim op s.
(ToExp arr, ToExp dim) =>
arr -> Space -> [dim] -> PrimType -> CompilerM op s ()
prepareNewMem [C.cexp|v->$id:(tupleField j)|] Space
space [Exp]
shape PrimType
pt
          -- Now we can copy into the freshly allocated memory.
          CopyBarrier
-> Exp
-> Exp
-> Space
-> Exp
-> Exp
-> Space
-> Exp
-> CompilerM op s b
copy
            CopyBarrier
CopyNoBarrier
            [C.cexp|v->$id:(tupleField j)->mem.mem|]
            [C.cexp|0|]
            Space
space
            [C.cexp|arr->$id:(tupleField j)->mem.mem|]
            (PrimType -> Int -> Exp -> Exp
forall {p} {p}.
(ToExp p, Show p, Integral p) =>
PrimType -> p -> p -> Exp
indexExp PrimType
pt Int
r [C.cexp|arr->$id:(tupleField j)->shape|])
            Space
space
            (Exp -> CompilerM op s b) -> Exp -> CompilerM op s b
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp
cproduct ([C.cexp|$int:(primByteSize pt::Int)|] Exp -> [Exp] -> [Exp]
forall a. a -> [a] -> [a]
: [Exp]
shape)

recordArrayShapeFunctions :: Name -> CompilerM op s Manifest.CFuncName
recordArrayShapeFunctions :: forall op s. Name -> CompilerM op s Text
recordArrayShapeFunctions Name
desc = do
  shape_f <- Text -> CompilerM op s Text
forall op s. Text -> CompilerM op s Text
publicName (Text -> CompilerM op s Text) -> Text -> CompilerM op s Text
forall a b. (a -> b) -> a -> b
$ Text
"shape_" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Name -> Text
opaqueName Name
desc
  ctx_ty <- contextType
  array_ct <- opaqueToCType desc

  -- We know that the opaque value consists of arrays of at least the
  -- expected rank, and which have the same outer shape, so we just
  -- return the shape of the first one.
  headerDecl
    (OpaqueDecl desc)
    [C.cedecl|const typename int64_t* $id:shape_f($ty:ctx_ty *ctx, $ty:array_ct *arr);|]
  libDecl
    [C.cedecl|const typename int64_t* $id:shape_f($ty:ctx_ty *ctx, $ty:array_ct *arr) {
                (void)ctx;
                return arr->$id:(tupleField 0)->shape;
              }|]

  pure shape_f

opaqueArrayIndexFunctions ::
  Space ->
  OpaqueTypes ->
  Name ->
  Int ->
  Name ->
  [ValueType] ->
  CompilerM op s Manifest.CFuncName
opaqueArrayIndexFunctions :: forall op s.
Space
-> OpaqueTypes
-> Name
-> Int
-> Name
-> [ValueType]
-> CompilerM op s Text
opaqueArrayIndexFunctions = Space
-> OpaqueTypes
-> Name
-> Int
-> Name
-> [ValueType]
-> CompilerM op s Text
forall op s.
Space
-> OpaqueTypes
-> Name
-> Int
-> Name
-> [ValueType]
-> CompilerM op s Text
recordArrayIndexFunctions

opaqueArrayShapeFunctions :: Name -> CompilerM op s Manifest.CFuncName
opaqueArrayShapeFunctions :: forall op s. Name -> CompilerM op s Text
opaqueArrayShapeFunctions = Name -> CompilerM op s Text
forall op s. Name -> CompilerM op s Text
recordArrayShapeFunctions

sumVariants ::
  Name ->
  [(Name, [(EntryPointType, [Int])])] ->
  [ValueType] ->
  CompilerM op s [Manifest.SumVariant]
sumVariants :: forall op s.
Name
-> [(Name, [(EntryPointType, [Int])])]
-> [ValueType]
-> CompilerM op s [SumVariant]
sumVariants Name
desc [(Name, [(EntryPointType, [Int])])]
variants [ValueType]
vds = do
  opaque_ty <- Name -> CompilerM op s Type
forall op s. Name -> CompilerM op s Type
opaqueToCType Name
desc
  ctx_ty <- contextType
  ops <- asks envOperations

  let onVariant a
i (Name
name, [(EntryPointType, [Int])]
payload) = do
        construct <- Text -> CompilerM op s Text
forall op s. Text -> CompilerM op s Text
publicName (Text -> CompilerM op s Text) -> Text -> CompilerM op s Text
forall a b. (a -> b) -> a -> b
$ Text
"new_" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Name -> Text
opaqueName Name
desc Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"_" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Name -> Text
nameToText Name
name
        destruct <- publicName $ "destruct_" <> opaqueName desc <> "_" <> nameToText name

        constructFunction ops ctx_ty opaque_ty i construct payload
        destructFunction ops ctx_ty opaque_ty i destruct payload

        pure $
          Manifest.SumVariant
            { Manifest.sumVariantName = nameToText name,
              Manifest.sumVariantPayload = map (entryTypeName . fst) payload,
              Manifest.sumVariantConstruct = construct,
              Manifest.sumVariantDestruct = destruct
            }

  zipWithM onVariant [0 :: Int ..] variants
  where
    constructFunction :: Operations op s
-> Type
-> Type
-> a
-> a
-> [(EntryPointType, [Int])]
-> CompilerM op s ()
constructFunction Operations op s
ops Type
ctx_ty Type
opaque_ty a
i a
fname [(EntryPointType, [Int])]
payload = do
      (params, new_stms) <- [(Param, BlockItem)] -> ([Param], [BlockItem])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Param, BlockItem)] -> ([Param], [BlockItem]))
-> CompilerM op s [(Param, BlockItem)]
-> CompilerM op s ([Param], [BlockItem])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Int
 -> (EntryPointType, [Int]) -> CompilerM op s (Param, BlockItem))
-> [Int]
-> [(EntryPointType, [Int])]
-> CompilerM op s [(Param, BlockItem)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Int -> (EntryPointType, [Int]) -> CompilerM op s (Param, BlockItem)
forall {op} {s}.
Int -> (EntryPointType, [Int]) -> CompilerM op s (Param, BlockItem)
constructPayload [Int
0 ..] [(EntryPointType, [Int])]
payload

      let used = ((EntryPointType, [Int]) -> [Int])
-> [(EntryPointType, [Int])] -> [Int]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (EntryPointType, [Int]) -> [Int]
forall a b. (a, b) -> b
snd [(EntryPointType, [Int])]
payload
      set_unused_stms <-
        mapM setUnused $ filter ((`notElem` used) . fst) (zip [0 ..] vds)

      headerDecl
        (OpaqueDecl desc)
        [C.cedecl|int $id:fname($ty:ctx_ty *ctx,
                                $ty:opaque_ty **out,
                                $params:params);|]

      libDecl
        [C.cedecl|int $id:fname($ty:ctx_ty *ctx,
                                $ty:opaque_ty **out,
                                $params:params) {
                    (void)ctx;
                    $ty:opaque_ty* v = malloc(sizeof($ty:opaque_ty));
                    v->$id:(tupleField 0) = $int:i;
                    { $items:(criticalSection ops new_stms) }
                    // Set other fields
                    { $items:set_unused_stms }
                    *out = v;
                    return FUTHARK_SUCCESS;
                  }|]

    -- We must initialise some of the fields that are unused in this
    -- variant; specifically the ones corresponding to arrays. This
    -- has the unfortunate effect that all arrays in the nonused
    -- constructor are set to have size 0.
    setUnused :: (Int, ValueType) -> CompilerM op s BlockItem
setUnused (Int
_, ValueType Signedness
_ (Rank Int
0) PrimType
_) =
      BlockItem -> CompilerM op s BlockItem
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.citem|{}|]
    setUnused (Int
i, ValueType Signedness
signed (Rank Int
rank) PrimType
pt) = do
      new_array <- Text -> CompilerM op s Text
forall op s. Text -> CompilerM op s Text
publicName (Text -> CompilerM op s Text) -> Text -> CompilerM op s Text
forall a b. (a -> b) -> a -> b
$ Text
"new_" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> PrimType -> Signedness -> Int -> Text
arrayName PrimType
pt Signedness
signed Int
rank
      let dims = Int -> Exp -> [Exp]
forall a. Int -> a -> [a]
replicate Int
rank [C.cexp|0|]
      pure [C.citem|v->$id:(tupleField i) = $id:new_array(ctx, NULL, $args:dims);|]

    constructPayload :: Int -> (EntryPointType, [Int]) -> CompilerM op s (Param, BlockItem)
constructPayload Int
j (EntryPointType
et, [Int]
is) = do
      let param_name :: [Char]
param_name = [Char]
"v" [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Int -> [Char]
forall a. Show a => a -> [Char]
show (Int
j :: Int)
      case EntryPointType
et of
        TypeTransparent (ValueType Signedness
sign (Rank Int
0) PrimType
pt) -> do
          let ct :: Type
ct = Signedness -> PrimType -> Type
primAPIType Signedness
sign PrimType
pt
              i :: Int
i = [Int] -> Int
forall a. HasCallStack => [a] -> a
head [Int]
is
          (Param, BlockItem) -> CompilerM op s (Param, BlockItem)
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
            ( [C.cparam|const $ty:ct $id:param_name|],
              [C.citem|v->$id:(tupleField i) = $id:param_name;|]
            )
        TypeTransparent ValueType
vt -> do
          ct <- Publicness -> ValueType -> CompilerM op s Type
forall op s. Publicness -> ValueType -> CompilerM op s Type
valueTypeToCType Publicness
Public ValueType
vt
          let i = [Int] -> Int
forall a. HasCallStack => [a] -> a
head [Int]
is
          pure
            ( [C.cparam|const $ty:ct* $id:param_name|],
              [C.citem|{v->$id:(tupleField i) = malloc(sizeof($ty:ct));
                        memcpy(v->$id:(tupleField i), $id:param_name, sizeof(const $ty:ct));
                        (void)(*(v->$id:(tupleField i)->mem.references))++;}|]
            )
        TypeOpaque Name
f_desc -> do
          ct <- Name -> CompilerM op s Type
forall op s. Name -> CompilerM op s Type
opaqueToCType Name
f_desc
          let param_fields = do
                i <- [Int
0 ..]
                pure [C.cexp|$id:param_name->$id:(tupleField i)|]
              vts = (Int -> ValueType) -> [Int] -> [ValueType]
forall a b. (a -> b) -> [a] -> [b]
map ([ValueType]
vds [ValueType] -> Int -> ValueType
forall a. HasCallStack => [a] -> Int -> a
!!) [Int]
is
          pure
            ( [C.cparam|const $ty:ct* $id:param_name|],
              [C.citem|{$stms:(zipWith3 setFieldField is param_fields vts)}|]
            )

    destructFunction :: Operations op s
-> Type
-> Type
-> a
-> a
-> [(EntryPointType, [Int])]
-> CompilerM op s ()
destructFunction Operations op s
ops Type
ctx_ty Type
opaque_ty a
i a
fname [(EntryPointType, [Int])]
payload = do
      (params, destruct_stms) <- [(Param, Stm)] -> ([Param], [Stm])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Param, Stm)] -> ([Param], [Stm]))
-> CompilerM op s [(Param, Stm)] -> CompilerM op s ([Param], [Stm])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Int -> (EntryPointType, [Int]) -> CompilerM op s (Param, Stm))
-> [Int]
-> [(EntryPointType, [Int])]
-> CompilerM op s [(Param, Stm)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (Operations op s
-> Int -> (EntryPointType, [Int]) -> CompilerM op s (Param, Stm)
forall {op} {s}.
Operations op s
-> Int -> (EntryPointType, [Int]) -> CompilerM op s (Param, Stm)
destructPayload Operations op s
ops) [Int
0 ..] [(EntryPointType, [Int])]
payload
      headerDecl
        (OpaqueDecl desc)
        [C.cedecl|int $id:fname($ty:ctx_ty *ctx,
                                $params:params,
                                const $ty:opaque_ty *obj);|]

      libDecl
        [C.cedecl|int $id:fname($ty:ctx_ty *ctx,
                                $params:params,
                                const $ty:opaque_ty *obj) {
                    (void)ctx;
                    assert(obj->$id:(tupleField 0) == $int:i);
                    $stms:destruct_stms
                    return FUTHARK_SUCCESS;
                  }|]

    destructPayload :: Operations op s
-> Int -> (EntryPointType, [Int]) -> CompilerM op s (Param, Stm)
destructPayload Operations op s
ops Int
j (EntryPointType
et, [Int]
is) = do
      let param_name :: [Char]
param_name = [Char]
"v" [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Int -> [Char]
forall a. Show a => a -> [Char]
show (Int
j :: Int)
      (ct, project_items) <- Operations op s
-> EntryPointType
-> [(Int, ValueType)]
-> CompilerM op s (Type, [BlockItem])
forall op s.
Operations op s
-> EntryPointType
-> [(Int, ValueType)]
-> CompilerM op s (Type, [BlockItem])
projectField Operations op s
ops EntryPointType
et ([(Int, ValueType)] -> CompilerM op s (Type, [BlockItem]))
-> [(Int, ValueType)] -> CompilerM op s (Type, [BlockItem])
forall a b. (a -> b) -> a -> b
$ [Int] -> [ValueType] -> [(Int, ValueType)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int]
is ([ValueType] -> [(Int, ValueType)])
-> [ValueType] -> [(Int, ValueType)]
forall a b. (a -> b) -> a -> b
$ (Int -> ValueType) -> [Int] -> [ValueType]
forall a b. (a -> b) -> [a] -> [b]
map ([ValueType]
vds [ValueType] -> Int -> ValueType
forall a. HasCallStack => [a] -> Int -> a
!!) [Int]
is
      pure
        ( [C.cparam|$ty:ct* $id:param_name|],
          [C.cstm|{$ty:ct v;
                   $items:project_items
                   *$id:param_name = v;
                  }|]
        )

sumVariantFunction :: Name -> CompilerM op s Manifest.CFuncName
sumVariantFunction :: forall op s. Name -> CompilerM op s Text
sumVariantFunction Name
desc = do
  opaque_ty <- Name -> CompilerM op s Type
forall op s. Name -> CompilerM op s Type
opaqueToCType Name
desc
  ctx_ty <- contextType
  variant <- publicName $ "variant_" <> opaqueName desc
  headerDecl
    (OpaqueDecl desc)
    [C.cedecl|int $id:variant($ty:ctx_ty *ctx, const $ty:opaque_ty* v);|]
  -- This depends on the assumption that the first value always
  -- encodes the variant.
  libDecl
    [C.cedecl|int $id:variant($ty:ctx_ty *ctx, const $ty:opaque_ty* v) {
                (void)ctx;
                return v->$id:(tupleField 0);
              }|]
  pure variant

opaqueExtraOps ::
  Space ->
  OpaqueTypes ->
  Name ->
  OpaqueType ->
  [ValueType] ->
  CompilerM op s (Maybe Manifest.OpaqueExtraOps)
opaqueExtraOps :: forall op s.
Space
-> OpaqueTypes
-> Name
-> OpaqueType
-> [ValueType]
-> CompilerM op s (Maybe OpaqueExtraOps)
opaqueExtraOps Space
_ OpaqueTypes
_ Name
_ (OpaqueType [ValueType]
_) [ValueType]
_ =
  Maybe OpaqueExtraOps -> CompilerM op s (Maybe OpaqueExtraOps)
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe OpaqueExtraOps
forall a. Maybe a
Nothing
opaqueExtraOps Space
_ OpaqueTypes
_types Name
desc (OpaqueSum [ValueType]
_ [(Name, [(EntryPointType, [Int])])]
cs) [ValueType]
vds =
  OpaqueExtraOps -> Maybe OpaqueExtraOps
forall a. a -> Maybe a
Just (OpaqueExtraOps -> Maybe OpaqueExtraOps)
-> (SumOps -> OpaqueExtraOps) -> SumOps -> Maybe OpaqueExtraOps
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SumOps -> OpaqueExtraOps
Manifest.OpaqueSum
    (SumOps -> Maybe OpaqueExtraOps)
-> CompilerM op s SumOps -> CompilerM op s (Maybe OpaqueExtraOps)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ( [SumVariant] -> Text -> SumOps
Manifest.SumOps
            ([SumVariant] -> Text -> SumOps)
-> CompilerM op s [SumVariant] -> CompilerM op s (Text -> SumOps)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Name
-> [(Name, [(EntryPointType, [Int])])]
-> [ValueType]
-> CompilerM op s [SumVariant]
forall op s.
Name
-> [(Name, [(EntryPointType, [Int])])]
-> [ValueType]
-> CompilerM op s [SumVariant]
sumVariants Name
desc [(Name, [(EntryPointType, [Int])])]
cs [ValueType]
vds
            CompilerM op s (Text -> SumOps)
-> CompilerM op s Text -> CompilerM op s SumOps
forall a b.
CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Name -> CompilerM op s Text
forall op s. Name -> CompilerM op s Text
sumVariantFunction Name
desc
        )
opaqueExtraOps Space
_ OpaqueTypes
types Name
desc (OpaqueRecord [(Name, EntryPointType)]
fs) [ValueType]
vds =
  OpaqueExtraOps -> Maybe OpaqueExtraOps
forall a. a -> Maybe a
Just (OpaqueExtraOps -> Maybe OpaqueExtraOps)
-> (RecordOps -> OpaqueExtraOps)
-> RecordOps
-> Maybe OpaqueExtraOps
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RecordOps -> OpaqueExtraOps
Manifest.OpaqueRecord
    (RecordOps -> Maybe OpaqueExtraOps)
-> CompilerM op s RecordOps
-> CompilerM op s (Maybe OpaqueExtraOps)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ( [RecordField] -> Text -> RecordOps
Manifest.RecordOps
            ([RecordField] -> Text -> RecordOps)
-> CompilerM op s [RecordField]
-> CompilerM op s (Text -> RecordOps)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> OpaqueTypes
-> Name
-> [(Name, EntryPointType)]
-> [ValueType]
-> CompilerM op s [RecordField]
forall op s.
OpaqueTypes
-> Name
-> [(Name, EntryPointType)]
-> [ValueType]
-> CompilerM op s [RecordField]
recordProjectFunctions OpaqueTypes
types Name
desc [(Name, EntryPointType)]
fs [ValueType]
vds
            CompilerM op s (Text -> RecordOps)
-> CompilerM op s Text -> CompilerM op s RecordOps
forall a b.
CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> OpaqueTypes
-> Name
-> [(Name, EntryPointType)]
-> [ValueType]
-> CompilerM op s Text
forall op s.
OpaqueTypes
-> Name
-> [(Name, EntryPointType)]
-> [ValueType]
-> CompilerM op s Text
recordNewFunctions OpaqueTypes
types Name
desc [(Name, EntryPointType)]
fs [ValueType]
vds
        )
opaqueExtraOps Space
space OpaqueTypes
types Name
desc (OpaqueRecordArray Int
rank Name
elemtype [(Name, EntryPointType)]
fs) [ValueType]
vds =
  OpaqueExtraOps -> Maybe OpaqueExtraOps
forall a. a -> Maybe a
Just (OpaqueExtraOps -> Maybe OpaqueExtraOps)
-> (RecordArrayOps -> OpaqueExtraOps)
-> RecordArrayOps
-> Maybe OpaqueExtraOps
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RecordArrayOps -> OpaqueExtraOps
Manifest.OpaqueRecordArray
    (RecordArrayOps -> Maybe OpaqueExtraOps)
-> CompilerM op s RecordArrayOps
-> CompilerM op s (Maybe OpaqueExtraOps)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ( Int
-> Text -> [RecordField] -> Text -> Text -> Text -> RecordArrayOps
Manifest.RecordArrayOps Int
rank (Name -> Text
nameToText Name
elemtype)
            ([RecordField] -> Text -> Text -> Text -> RecordArrayOps)
-> CompilerM op s [RecordField]
-> CompilerM op s (Text -> Text -> Text -> RecordArrayOps)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> OpaqueTypes
-> Name
-> [(Name, EntryPointType)]
-> [ValueType]
-> CompilerM op s [RecordField]
forall op s.
OpaqueTypes
-> Name
-> [(Name, EntryPointType)]
-> [ValueType]
-> CompilerM op s [RecordField]
recordArrayProjectFunctions OpaqueTypes
types Name
desc [(Name, EntryPointType)]
fs [ValueType]
vds
            CompilerM op s (Text -> Text -> Text -> RecordArrayOps)
-> CompilerM op s Text
-> CompilerM op s (Text -> Text -> RecordArrayOps)
forall a b.
CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> OpaqueTypes
-> Name
-> [(Name, EntryPointType)]
-> [ValueType]
-> Int
-> CompilerM op s Text
forall op s.
OpaqueTypes
-> Name
-> [(Name, EntryPointType)]
-> [ValueType]
-> Int
-> CompilerM op s Text
recordArrayZipFunctions OpaqueTypes
types Name
desc [(Name, EntryPointType)]
fs [ValueType]
vds Int
rank
            CompilerM op s (Text -> Text -> RecordArrayOps)
-> CompilerM op s Text -> CompilerM op s (Text -> RecordArrayOps)
forall a b.
CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Space
-> OpaqueTypes
-> Name
-> Int
-> Name
-> [ValueType]
-> CompilerM op s Text
forall op s.
Space
-> OpaqueTypes
-> Name
-> Int
-> Name
-> [ValueType]
-> CompilerM op s Text
recordArrayIndexFunctions Space
space OpaqueTypes
types Name
desc Int
rank Name
elemtype [ValueType]
vds
            CompilerM op s (Text -> RecordArrayOps)
-> CompilerM op s Text -> CompilerM op s RecordArrayOps
forall a b.
CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Name -> CompilerM op s Text
forall op s. Name -> CompilerM op s Text
recordArrayShapeFunctions Name
desc
        )
opaqueExtraOps Space
space OpaqueTypes
types Name
desc (OpaqueArray Int
rank Name
elemtype [ValueType]
_) [ValueType]
vds =
  OpaqueExtraOps -> Maybe OpaqueExtraOps
forall a. a -> Maybe a
Just (OpaqueExtraOps -> Maybe OpaqueExtraOps)
-> (OpaqueArrayOps -> OpaqueExtraOps)
-> OpaqueArrayOps
-> Maybe OpaqueExtraOps
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OpaqueArrayOps -> OpaqueExtraOps
Manifest.OpaqueArray
    (OpaqueArrayOps -> Maybe OpaqueExtraOps)
-> CompilerM op s OpaqueArrayOps
-> CompilerM op s (Maybe OpaqueExtraOps)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ( Int -> Text -> Text -> Text -> OpaqueArrayOps
Manifest.OpaqueArrayOps Int
rank (Name -> Text
nameToText Name
elemtype)
            (Text -> Text -> OpaqueArrayOps)
-> CompilerM op s Text -> CompilerM op s (Text -> OpaqueArrayOps)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Space
-> OpaqueTypes
-> Name
-> Int
-> Name
-> [ValueType]
-> CompilerM op s Text
forall op s.
Space
-> OpaqueTypes
-> Name
-> Int
-> Name
-> [ValueType]
-> CompilerM op s Text
opaqueArrayIndexFunctions Space
space OpaqueTypes
types Name
desc Int
rank Name
elemtype [ValueType]
vds
            CompilerM op s (Text -> OpaqueArrayOps)
-> CompilerM op s Text -> CompilerM op s OpaqueArrayOps
forall a b.
CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Name -> CompilerM op s Text
forall op s. Name -> CompilerM op s Text
opaqueArrayShapeFunctions Name
desc
        )

opaqueLibraryFunctions ::
  Space ->
  OpaqueTypes ->
  Name ->
  OpaqueType ->
  CompilerM op s (Manifest.OpaqueOps, Maybe Manifest.OpaqueExtraOps)
opaqueLibraryFunctions :: forall op s.
Space
-> OpaqueTypes
-> Name
-> OpaqueType
-> CompilerM op s (OpaqueOps, Maybe OpaqueExtraOps)
opaqueLibraryFunctions Space
space OpaqueTypes
types Name
desc OpaqueType
ot = do
  name <- Text -> CompilerM op s Text
forall op s. Text -> CompilerM op s Text
publicName (Text -> CompilerM op s Text) -> Text -> CompilerM op s Text
forall a b. (a -> b) -> a -> b
$ Name -> Text
opaqueName Name
desc
  free_opaque <- publicName $ "free_" <> opaqueName desc
  store_opaque <- publicName $ "store_" <> opaqueName desc
  restore_opaque <- publicName $ "restore_" <> opaqueName desc

  let opaque_type = [C.cty|struct $id:name|]

      freeComponent Int
i (ValueType Signedness
signed (Rank Int
rank) PrimType
pt) = Bool -> CompilerM op s () -> CompilerM op s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Int
rank Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) (CompilerM op s () -> CompilerM op s ())
-> CompilerM op s () -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ do
        let field :: [Char]
field = Int -> [Char]
tupleField Int
i
        free_array <- Text -> CompilerM op s Text
forall op s. Text -> CompilerM op s Text
publicName (Text -> CompilerM op s Text) -> Text -> CompilerM op s Text
forall a b. (a -> b) -> a -> b
$ Text
"free_" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> PrimType -> Signedness -> Int -> Text
arrayName PrimType
pt Signedness
signed Int
rank
        -- Protect against NULL here, because we also want to use this
        -- to free partially loaded opaques.
        stm
          [C.cstm|if (obj->$id:field != NULL && (tmp = $id:free_array(ctx, obj->$id:field)) != 0) {
                ret = tmp;
             }|]

      storeComponent Int
i (ValueType Signedness
sign (Rank Int
0) PrimType
pt) =
        let field :: [Char]
field = Int -> [Char]
tupleField Int
i
         in ( PrimType -> Int -> Exp -> Exp
storageSize PrimType
pt Int
0 [C.cexp|NULL|],
              Signedness -> PrimType -> Int -> Exp -> Exp -> [Stm]
storeValueHeader Signedness
sign PrimType
pt Int
0 [C.cexp|NULL|] [C.cexp|out|]
                [Stm] -> [Stm] -> [Stm]
forall a. [a] -> [a] -> [a]
++ [C.cstms|memcpy(out, &obj->$id:field, sizeof(obj->$id:field));
                            out += sizeof(obj->$id:field);|]
            )
      storeComponent Int
i (ValueType Signedness
sign (Rank Int
rank) PrimType
pt) =
        let arr_name :: Text
arr_name = PrimType -> Signedness -> Int -> Text
arrayName PrimType
pt Signedness
sign Int
rank
            field :: [Char]
field = Int -> [Char]
tupleField Int
i
            shape_array :: Text
shape_array = Text
"futhark_shape_" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
arr_name
            values_array :: Text
values_array = Text
"futhark_values_" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
arr_name
            shape' :: Exp
shape' = [C.cexp|$id:shape_array(ctx, obj->$id:field)|]
            num_elems :: Exp
num_elems = [Exp] -> Exp
cproduct [[C.cexp|$exp:shape'[$int:j]|] | Int
j <- [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]]
         in ( PrimType -> Int -> Exp -> Exp
storageSize PrimType
pt Int
rank Exp
shape',
              Signedness -> PrimType -> Int -> Exp -> Exp -> [Stm]
storeValueHeader Signedness
sign PrimType
pt Int
rank Exp
shape' [C.cexp|out|]
                [Stm] -> [Stm] -> [Stm]
forall a. [a] -> [a] -> [a]
++ [C.cstms|ret |= $id:values_array(ctx, obj->$id:field, (void*)out);
                            out += $exp:num_elems * sizeof($ty:(primStorageType pt));|]
            )

  ctx_ty <- contextType

  let vds = OpaqueTypes -> OpaqueType -> [ValueType]
opaquePayload OpaqueTypes
types OpaqueType
ot
  free_body <- collect $ zipWithM_ freeComponent [0 ..] vds

  store_body <- collect $ do
    let (sizes, stores) = unzip $ zipWith storeComponent [0 ..] vds
        size_vars = (Int -> [Char]) -> [Int] -> [[Char]]
forall a b. (a -> b) -> [a] -> [b]
map (([Char]
"size_" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++) ([Char] -> [Char]) -> (Int -> [Char]) -> Int -> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [Char]
forall a. Show a => a -> [Char]
show) [Int
0 .. [Exp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Exp]
sizes Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
        size_sum = [Exp] -> Exp
csum [[C.cexp|$id:size|] | [Char]
size <- [[Char]]
size_vars]
    forM_ (zip size_vars sizes) $ \([Char]
v, Exp
e) ->
      BlockItem -> CompilerM op s ()
forall op s. BlockItem -> CompilerM op s ()
item [C.citem|typename int64_t $id:v = $exp:e;|]
    stm [C.cstm|*n = $exp:size_sum;|]
    stm [C.cstm|if (p != NULL && *p == NULL) { *p = malloc(*n); }|]
    stm [C.cstm|if (p != NULL) { unsigned char *out = *p; $stms:(concat stores) }|]

  let restoreComponent Int
i (ValueType Signedness
sign (Rank Int
0) PrimType
pt) = do
        let field :: [Char]
field = Int -> [Char]
tupleField Int
i
            dataptr :: [Char]
dataptr = [Char]
"data_" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
i
        [Stm] -> CompilerM op s ()
forall op s. [Stm] -> CompilerM op s ()
stms ([Stm] -> CompilerM op s ()) -> [Stm] -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ Signedness -> PrimType -> Int -> Exp -> Exp -> [Stm]
loadValueHeader Signedness
sign PrimType
pt Int
0 [C.cexp|NULL|] [C.cexp|src|]
        BlockItem -> CompilerM op s ()
forall op s. BlockItem -> CompilerM op s ()
item [C.citem|const void* $id:dataptr = src;|]
        Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|src += sizeof(obj->$id:field);|]
        [Stm] -> CompilerM op s [Stm]
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cstms|memcpy(&obj->$id:field, $id:dataptr, sizeof(obj->$id:field));|]
      restoreComponent Int
i (ValueType Signedness
sign (Rank Int
rank) PrimType
pt) = do
        let field :: [Char]
field = Int -> [Char]
tupleField Int
i
            arr_name :: Text
arr_name = PrimType -> Signedness -> Int -> Text
arrayName PrimType
pt Signedness
sign Int
rank
            new_array :: Text
new_array = Text
"futhark_new_" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
arr_name
            dataptr :: Text
dataptr = Text
"data_" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Int -> Text
forall a. Pretty a => a -> Text
prettyText Int
i
            shapearr :: Text
shapearr = Text
"shape_" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Int -> Text
forall a. Pretty a => a -> Text
prettyText Int
i
            dims :: [Exp]
dims = [[C.cexp|$id:shapearr[$int:j]|] | Int
j <- [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]]
            num_elems :: Exp
num_elems = [Exp] -> Exp
cproduct [Exp]
dims
        BlockItem -> CompilerM op s ()
forall op s. BlockItem -> CompilerM op s ()
item [C.citem|typename int64_t $id:shapearr[$int:rank] = {0};|]
        [Stm] -> CompilerM op s ()
forall op s. [Stm] -> CompilerM op s ()
stms ([Stm] -> CompilerM op s ()) -> [Stm] -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ Signedness -> PrimType -> Int -> Exp -> Exp -> [Stm]
loadValueHeader Signedness
sign PrimType
pt Int
rank [C.cexp|$id:shapearr|] [C.cexp|src|]
        BlockItem -> CompilerM op s ()
forall op s. BlockItem -> CompilerM op s ()
item [C.citem|const void* $id:dataptr = src;|]
        Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|obj->$id:field = NULL;|]
        Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|src += $exp:num_elems * sizeof($ty:(primStorageType pt));|]
        [Stm] -> CompilerM op s [Stm]
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
          [C.cstms|
             obj->$id:field = $id:new_array(ctx, $id:dataptr, $args:dims);
             if (obj->$id:field == NULL) { err = 1; }|]

  load_body <- collect $ do
    loads <- concat <$> zipWithM restoreComponent [0 ..] (opaquePayload types ot)
    stm
      [C.cstm|if (err == 0) {
                $stms:loads
              }|]

  headerDecl
    (OpaqueTypeDecl desc)
    [C.cedecl|struct $id:name;|]
  headerDecl
    (OpaqueDecl desc)
    [C.cedecl|int $id:free_opaque($ty:ctx_ty *ctx, $ty:opaque_type *obj);|]
  headerDecl
    (OpaqueDecl desc)
    [C.cedecl|int $id:store_opaque($ty:ctx_ty *ctx, const $ty:opaque_type *obj, void **p, size_t *n);|]
  headerDecl
    (OpaqueDecl desc)
    [C.cedecl|$ty:opaque_type* $id:restore_opaque($ty:ctx_ty *ctx, const void *p);|]

  extra_ops <- opaqueExtraOps space types desc ot vds

  -- We do not need to enclose most bodies in a critical section,
  -- because when we operate on the components of the opaque, we are
  -- calling public API functions that do their own locking.  The
  -- exception is projection, where we fiddle with reference counts.
  mapM_
    libDecl
    [C.cunit|
          int $id:free_opaque($ty:ctx_ty *ctx, $ty:opaque_type *obj) {
            (void)ctx;
            int ret = 0, tmp;
            $items:free_body
            free(obj);
            return ret;
          }

          int $id:store_opaque($ty:ctx_ty *ctx,
                               const $ty:opaque_type *obj, void **p, size_t *n) {
            (void)ctx;
            int ret = 0;
            $items:store_body
            return ret;
          }

          $ty:opaque_type* $id:restore_opaque($ty:ctx_ty *ctx,
                                              const void *p) {
            (void)ctx;
            int err = 0;
            const unsigned char *src = p;
            $ty:opaque_type* obj = malloc(sizeof($ty:opaque_type));
            $items:load_body
            if (err != 0) {
              int ret = 0, tmp;
              $items:free_body
              free(obj);
              obj = NULL;
            }
            return obj;
          }
    |]

  pure
    ( Manifest.OpaqueOps
        { Manifest.opaqueFree = free_opaque,
          Manifest.opaqueStore = store_opaque,
          Manifest.opaqueRestore = restore_opaque
        },
      extra_ops
    )

generateArray ::
  Space ->
  ((Signedness, PrimType, Int), Publicness) ->
  CompilerM op s (Maybe (T.Text, Manifest.Type))
generateArray :: forall op s.
Space
-> ((Signedness, PrimType, Int), Publicness)
-> CompilerM op s (Maybe (Text, Type))
generateArray Space
space ((Signedness
signed, PrimType
pt, Int
rank), Publicness
pub) = do
  name <- Text -> CompilerM op s Text
forall op s. Text -> CompilerM op s Text
publicName (Text -> CompilerM op s Text) -> Text -> CompilerM op s Text
forall a b. (a -> b) -> a -> b
$ PrimType -> Signedness -> Int -> Text
arrayName PrimType
pt Signedness
signed Int
rank
  let memty = Space -> Type
fatMemType Space
space
  libDecl [C.cedecl|struct $id:name { $ty:memty mem; typename int64_t shape[$int:rank]; };|]
  ops <- arrayLibraryFunctions pub space pt signed rank
  let pt_name = Bool -> PrimType -> Text
prettySigned (Signedness
signed Signedness -> Signedness -> Bool
forall a. Eq a => a -> a -> Bool
== Signedness
Unsigned) PrimType
pt
      pretty_name = [Text] -> Text
forall a. Monoid a => [a] -> a
mconcat (Int -> Text -> [Text]
forall a. Int -> a -> [a]
replicate Int
rank Text
"[]") Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
pt_name
      arr_type = [C.cty|struct $id:name*|]
  case pub of
    Publicness
Public ->
      Maybe (Text, Type) -> CompilerM op s (Maybe (Text, Type))
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (Text, Type) -> CompilerM op s (Maybe (Text, Type)))
-> Maybe (Text, Type) -> CompilerM op s (Maybe (Text, Type))
forall a b. (a -> b) -> a -> b
$
        (Text, Type) -> Maybe (Text, Type)
forall a. a -> Maybe a
Just
          ( Text
pretty_name,
            Text -> Text -> Int -> ArrayOps -> Type
Manifest.TypeArray (Type -> Text
typeText Type
arr_type) Text
pt_name Int
rank ArrayOps
ops
          )
    Publicness
Private ->
      Maybe (Text, Type) -> CompilerM op s (Maybe (Text, Type))
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Text, Type)
forall a. Maybe a
Nothing

generateOpaque ::
  Space ->
  OpaqueTypes ->
  (Name, OpaqueType) ->
  CompilerM op s (T.Text, Manifest.Type)
generateOpaque :: forall op s.
Space
-> OpaqueTypes -> (Name, OpaqueType) -> CompilerM op s (Text, Type)
generateOpaque Space
space OpaqueTypes
types (Name
desc, OpaqueType
ot) = do
  name <- Text -> CompilerM op s Text
forall op s. Text -> CompilerM op s Text
publicName (Text -> CompilerM op s Text) -> Text -> CompilerM op s Text
forall a b. (a -> b) -> a -> b
$ Name -> Text
opaqueName Name
desc
  members <- zipWithM field (opaquePayload types ot) [(0 :: Int) ..]
  libDecl [C.cedecl|struct $id:name { $sdecls:members };|]
  (ops, extra_ops) <- opaqueLibraryFunctions space types desc ot
  let opaque_type = [C.cty|struct $id:name*|]
  pure
    ( nameToText desc,
      Manifest.TypeOpaque (typeText opaque_type) ops extra_ops
    )
  where
    field :: ValueType -> Int -> CompilerM op s FieldGroup
field vt :: ValueType
vt@(ValueType Signedness
_ (Rank Int
r) PrimType
_) Int
i = do
      ct <- Publicness -> ValueType -> CompilerM op s Type
forall op s. Publicness -> ValueType -> CompilerM op s Type
valueTypeToCType Publicness
Private ValueType
vt
      pure $
        if r == 0
          then [C.csdecl|$ty:ct $id:(tupleField i);|]
          else [C.csdecl|$ty:ct *$id:(tupleField i);|]

generateAPITypes :: Space -> OpaqueTypes -> CompilerM op s (M.Map T.Text Manifest.Type)
generateAPITypes :: forall op s. Space -> OpaqueTypes -> CompilerM op s (Map Text Type)
generateAPITypes Space
arr_space types :: OpaqueTypes
types@(OpaqueTypes [(Name, OpaqueType)]
opaques) = do
  ((Name, OpaqueType) -> CompilerM op s ())
-> [(Name, OpaqueType)] -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (OpaqueType -> CompilerM op s ()
forall {op} {s}. OpaqueType -> CompilerM op s ()
findNecessaryArrays (OpaqueType -> CompilerM op s ())
-> ((Name, OpaqueType) -> OpaqueType)
-> (Name, OpaqueType)
-> CompilerM op s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name, OpaqueType) -> OpaqueType
forall a b. (a, b) -> b
snd) [(Name, OpaqueType)]
opaques
  array_ts <- (((Signedness, PrimType, Int), Publicness)
 -> CompilerM op s (Maybe (Text, Type)))
-> [((Signedness, PrimType, Int), Publicness)]
-> CompilerM op s [Maybe (Text, Type)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Space
-> ((Signedness, PrimType, Int), Publicness)
-> CompilerM op s (Maybe (Text, Type))
forall op s.
Space
-> ((Signedness, PrimType, Int), Publicness)
-> CompilerM op s (Maybe (Text, Type))
generateArray Space
arr_space) ([((Signedness, PrimType, Int), Publicness)]
 -> CompilerM op s [Maybe (Text, Type)])
-> (Map (Signedness, PrimType, Int) Publicness
    -> [((Signedness, PrimType, Int), Publicness)])
-> Map (Signedness, PrimType, Int) Publicness
-> CompilerM op s [Maybe (Text, Type)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map (Signedness, PrimType, Int) Publicness
-> [((Signedness, PrimType, Int), Publicness)]
forall k a. Map k a -> [(k, a)]
M.toList (Map (Signedness, PrimType, Int) Publicness
 -> CompilerM op s [Maybe (Text, Type)])
-> CompilerM op s (Map (Signedness, PrimType, Int) Publicness)
-> CompilerM op s [Maybe (Text, Type)]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (CompilerState s -> Map (Signedness, PrimType, Int) Publicness)
-> CompilerM op s (Map (Signedness, PrimType, Int) Publicness)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets CompilerState s -> Map (Signedness, PrimType, Int) Publicness
forall s.
CompilerState s -> Map (Signedness, PrimType, Int) Publicness
compArrayTypes
  opaque_ts <- mapM (generateOpaque arr_space types) opaques
  pure $ M.fromList $ catMaybes array_ts <> opaque_ts
  where
    -- Ensure that array types will be generated before the opaque
    -- types that allow projection of them.  This is because the
    -- projection functions somewhat uglily directly poke around in
    -- the innards to increment reference counts.
    findNecessaryArrays :: OpaqueType -> CompilerM op s ()
findNecessaryArrays (OpaqueType [ValueType]
_) =
      () -> CompilerM op s ()
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    findNecessaryArrays (OpaqueArray {}) =
      () -> CompilerM op s ()
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    findNecessaryArrays (OpaqueRecordArray Int
_ Name
_ [(Name, EntryPointType)]
fs) =
      ((Name, EntryPointType) -> CompilerM op s Type)
-> [(Name, EntryPointType)] -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Publicness -> EntryPointType -> CompilerM op s Type
forall op s. Publicness -> EntryPointType -> CompilerM op s Type
entryPointTypeToCType Publicness
Public (EntryPointType -> CompilerM op s Type)
-> ((Name, EntryPointType) -> EntryPointType)
-> (Name, EntryPointType)
-> CompilerM op s Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name, EntryPointType) -> EntryPointType
forall a b. (a, b) -> b
snd) [(Name, EntryPointType)]
fs
    findNecessaryArrays (OpaqueSum [ValueType]
_ [(Name, [(EntryPointType, [Int])])]
variants) =
      ((Name, [(EntryPointType, [Int])]) -> CompilerM op s ())
-> [(Name, [(EntryPointType, [Int])])] -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (((EntryPointType, [Int]) -> CompilerM op s Type)
-> [(EntryPointType, [Int])] -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Publicness -> EntryPointType -> CompilerM op s Type
forall op s. Publicness -> EntryPointType -> CompilerM op s Type
entryPointTypeToCType Publicness
Public (EntryPointType -> CompilerM op s Type)
-> ((EntryPointType, [Int]) -> EntryPointType)
-> (EntryPointType, [Int])
-> CompilerM op s Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (EntryPointType, [Int]) -> EntryPointType
forall a b. (a, b) -> a
fst) ([(EntryPointType, [Int])] -> CompilerM op s ())
-> ((Name, [(EntryPointType, [Int])]) -> [(EntryPointType, [Int])])
-> (Name, [(EntryPointType, [Int])])
-> CompilerM op s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name, [(EntryPointType, [Int])]) -> [(EntryPointType, [Int])]
forall a b. (a, b) -> b
snd) [(Name, [(EntryPointType, [Int])])]
variants
    findNecessaryArrays (OpaqueRecord [(Name, EntryPointType)]
fs) =
      ((Name, EntryPointType) -> CompilerM op s Type)
-> [(Name, EntryPointType)] -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Publicness -> EntryPointType -> CompilerM op s Type
forall op s. Publicness -> EntryPointType -> CompilerM op s Type
entryPointTypeToCType Publicness
Public (EntryPointType -> CompilerM op s Type)
-> ((Name, EntryPointType) -> EntryPointType)
-> (Name, EntryPointType)
-> CompilerM op s Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name, EntryPointType) -> EntryPointType
forall a b. (a, b) -> b
snd) [(Name, EntryPointType)]
fs