-
Notifications
You must be signed in to change notification settings - Fork 132
Expand file tree
/
Copy pathAST.hs
More file actions
1480 lines (1294 loc) · 64.6 KB
/
AST.hs
File metadata and controls
1480 lines (1294 loc) · 64.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_HADDOCK hide #-}
-- |
-- Module : Data.Array.Accelerate.AST
-- Copyright : [2008..2020] The Accelerate Team
-- License : BSD3
--
-- Maintainer : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability : experimental
-- Portability : non-portable (GHC extensions)
--
-- /Scalar versus collective operations/
--
-- The embedded array processing language is a two-level language. It
-- combines a language of scalar expressions and functions with a language of
-- collective array operations. Scalar expressions are used to compute
-- arguments for collective operations and scalar functions are used to
-- parametrise higher-order, collective array operations. The two-level
-- structure, in particular, ensures that collective operations cannot be
-- parametrised with collective operations; hence, we are following a flat
-- data-parallel model. The collective operations manipulate
-- multi-dimensional arrays whose shape is explicitly tracked in their types.
-- In fact, collective operations cannot produce any values other than
-- multi-dimensional arrays; when they yield a scalar, this is in the form of
-- a 0-dimensional, singleton array. Similarly, scalar expression can -as
-- their name indicates- only produce tuples of scalar, but not arrays.
--
-- There are, however, two expression forms that take arrays as arguments. As
-- a result scalar and array expressions are recursively dependent. As we
-- cannot and don't want to compute arrays in the middle of scalar
-- computations, array computations will always be hoisted out of scalar
-- expressions. So that this is always possible, these array expressions may
-- not contain any free scalar variables. To express that condition in the
-- type structure, we use separate environments for scalar and array variables.
--
-- /Programs/
--
-- Collective array programs comprise closed expressions of array operations.
-- There is no explicit sharing in the initial AST form, but sharing is
-- introduced subsequently by common subexpression elimination and floating
-- of array computations.
--
-- /Functions/
--
-- The array expression language is first-order and only provides limited
-- control structures to ensure that it can be efficiently executed on
-- compute-acceleration hardware, such as GPUs. To restrict functions to
-- first-order, we separate function abstraction from the main expression
-- type. Functions are represented using de Bruijn indices.
--
-- /Parametric and ad-hoc polymorphism/
--
-- The array language features paramatric polymophism (e.g., pairing and
-- projections) as well as ad-hoc polymorphism (e.g., arithmetic operations).
-- All ad-hoc polymorphic constructs include reified dictionaries (c.f.,
-- module 'Types'). Reified dictionaries also ensure that constants
-- (constructor 'Const') are representable on compute acceleration hardware.
--
-- The AST contains both reified dictionaries and type class constraints.
-- Type classes are used for array-related functionality that is uniformly
-- available for all supported types. In contrast, reified dictionaries are
-- used for functionality that is only available for certain types, such as
-- arithmetic operations.
--
module Data.Array.Accelerate.AST (
-- * Internal AST
-- ** Array computations
Afun, PreAfun, OpenAfun, PreOpenAfun(..),
Acc, OpenAcc(..), PreOpenAcc(..), Direction(..), Message(..),
ALeftHandSide, ArrayVar, ArrayVars,
-- ** Scalar expressions
ELeftHandSide, ExpVar, ExpVars, expVars,
Fun, OpenFun(..),
Exp, OpenExp(..),
Boundary(..),
PrimConst(..),
PrimFun(..),
PrimBool,
PrimMaybe,
-- ** Extracting type information
HasArraysR(..), arrayR,
expType,
primConstType,
primFunType,
-- ** Normal-form
NFDataAcc,
rnfOpenAfun, rnfPreOpenAfun,
rnfOpenAcc, rnfPreOpenAcc,
rnfALeftHandSide,
rnfArrayVar,
rnfOpenFun,
rnfOpenExp,
rnfELeftHandSide,
rnfExpVar,
rnfBoundary,
rnfConst,
rnfPrimConst,
rnfPrimFun,
-- ** Template Haskell
LiftAcc,
liftPreOpenAfun,
liftPreOpenAcc,
liftALeftHandSide,
liftArrayVar,
liftOpenFun,
liftOpenExp,
liftELeftHandSide,
liftExpVar,
liftBoundary,
liftPrimConst,
liftPrimFun,
liftMessage,
-- ** Miscellaneous
formatPreAccOp,
formatExpOp,
) where
import Data.Array.Accelerate.AST.Idx
import Data.Array.Accelerate.AST.LeftHandSide
import Data.Array.Accelerate.AST.Var
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Representation.Elt
import Data.Array.Accelerate.Representation.Shape
import Data.Array.Accelerate.Representation.Slice
import Data.Array.Accelerate.Representation.Stencil
import Data.Array.Accelerate.Representation.Tag
import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.Representation.Vec
import Data.Array.Accelerate.Sugar.Foreign
import Data.Array.Accelerate.Type
import Data.Primitive.Vec
import Data.Primitive.Types
import Control.DeepSeq
import Data.Kind
import Data.Maybe
import Data.Text ( Text )
import Data.Text.Lazy.Builder
import Formatting
import Language.Haskell.TH.Extra ( CodeQ )
import qualified Language.Haskell.TH.Extra as TH
import qualified Language.Haskell.TH.Syntax as TH
import GHC.TypeLits
-- Array expressions
-- -----------------
-- | Function abstraction over parametrised array computations
--
data PreOpenAfun acc aenv t where
Abody :: acc aenv t -> PreOpenAfun acc aenv t
Alam :: ALeftHandSide a aenv aenv' -> PreOpenAfun acc aenv' t -> PreOpenAfun acc aenv (a -> t)
-- Function abstraction over vanilla open array computations
--
type OpenAfun = PreOpenAfun OpenAcc
-- | Parametrised array-computation function without free array variables
--
type PreAfun acc = PreOpenAfun acc ()
-- | Vanilla array-computation function without free array variables
--
type Afun = OpenAfun ()
-- Vanilla open array computations
--
newtype OpenAcc aenv t = OpenAcc (PreOpenAcc OpenAcc aenv t)
-- | Closed array expression aka an array program
--
type Acc = OpenAcc ()
-- Types for array binders
--
type ALeftHandSide = LeftHandSide ArrayR
type ArrayVar = Var ArrayR
type ArrayVars aenv = Vars ArrayR aenv
-- Bool is not a primitive type
type PrimBool = TAG
type PrimMaybe a = (TAG, ((), a))
-- Trace messages
data Message a where
Message :: (a -> String) -- embedded show
-> Maybe (CodeQ (a -> String)) -- lifted version of show, for TH
-> Text
-> Message a
-- | Collective array computations parametrised over array variables
-- represented with de Bruijn indices.
--
-- * Scalar functions and expressions embedded in well-formed array
-- computations cannot contain free scalar variable indices. The latter
-- cannot be bound in array computations, and hence, cannot appear in any
-- well-formed program.
--
-- * The let-form is used to represent the sharing discovered by common
-- subexpression elimination as well as to control evaluation order. (We
-- need to hoist array expressions out of scalar expressions---they occur
-- in scalar indexing and in determining an arrays shape.)
--
-- The data type is parameterised over the representation types (not the
-- surface type).
--
-- We use a non-recursive variant parametrised over the recursive closure,
-- to facilitate attribute calculation in the backend.
--
data PreOpenAcc (acc :: Type -> Type -> Type) aenv a where
-- Local non-recursive binding to represent sharing and demand
-- explicitly. Note this is an eager binding!
--
Alet :: ALeftHandSide bndArrs aenv aenv'
-> acc aenv bndArrs -- bound expression
-> acc aenv' bodyArrs -- the bound expression scope
-> PreOpenAcc acc aenv bodyArrs
-- Variable bound by a 'Let', represented by a de Bruijn index
--
Avar :: ArrayVar aenv (Array sh e)
-> PreOpenAcc acc aenv (Array sh e)
-- Tuples of arrays
--
Apair :: acc aenv as
-> acc aenv bs
-> PreOpenAcc acc aenv (as, bs)
Anil :: PreOpenAcc acc aenv ()
-- Array-function application.
--
-- The array function is not closed at the core level because we need access
-- to free variables introduced by 'run1' style evaluators. See Issue#95.
--
Apply :: ArraysR arrs2
-> PreOpenAfun acc aenv (arrs1 -> arrs2)
-> acc aenv arrs1
-> PreOpenAcc acc aenv arrs2
-- Apply a backend-specific foreign function to an array, with a pure
-- Accelerate version for use with other backends. The functions must be
-- closed.
--
Aforeign :: Foreign asm
=> ArraysR bs
-> asm (as -> bs) -- The foreign function for a given backend
-> PreAfun acc (as -> bs) -- Fallback implementation(s)
-> acc aenv as -- Arguments to the function
-> PreOpenAcc acc aenv bs
-- If-then-else for array-level computations
--
Acond :: Exp aenv PrimBool
-> acc aenv arrs
-> acc aenv arrs
-> PreOpenAcc acc aenv arrs
-- Value-recursion for array-level computations
--
Awhile :: PreOpenAfun acc aenv (arrs -> Scalar PrimBool) -- continue iteration while true
-> PreOpenAfun acc aenv (arrs -> arrs) -- function to iterate
-> acc aenv arrs -- initial value
-> PreOpenAcc acc aenv arrs
Atrace :: Message arrs1
-> acc aenv arrs1
-> acc aenv arrs2
-> PreOpenAcc acc aenv arrs2
-- Array inlet. Triggers (possibly) asynchronous host->device transfer if
-- necessary.
--
Use :: ArrayR (Array sh e)
-> Array sh e
-> PreOpenAcc acc aenv (Array sh e)
-- Capture a scalar (or a tuple of scalars) in a singleton array
--
Unit :: TypeR e
-> Exp aenv e
-> PreOpenAcc acc aenv (Scalar e)
-- Change the shape of an array without altering its contents.
-- Precondition (this may not be checked!):
--
-- > dim == size dim'
--
Reshape :: ShapeR sh
-> Exp aenv sh -- new shape
-> acc aenv (Array sh' e) -- array to be reshaped
-> PreOpenAcc acc aenv (Array sh e)
-- Construct a new array by applying a function to each index.
--
Generate :: ArrayR (Array sh e)
-> Exp aenv sh -- output shape
-> Fun aenv (sh -> e) -- representation function
-> PreOpenAcc acc aenv (Array sh e)
-- Hybrid map/backpermute, where we separate the index and value
-- transformations.
--
Transform :: ArrayR (Array sh' b)
-> Exp aenv sh' -- dimension of the result
-> Fun aenv (sh' -> sh) -- index permutation function
-> Fun aenv (a -> b) -- function to apply at each element
-> acc aenv (Array sh a) -- source array
-> PreOpenAcc acc aenv (Array sh' b)
-- Replicate an array across one or more dimensions as given by the first
-- argument
--
Replicate :: SliceIndex slix sl co sh -- slice type specification
-> Exp aenv slix -- slice value specification
-> acc aenv (Array sl e) -- data to be replicated
-> PreOpenAcc acc aenv (Array sh e)
-- Index a sub-array out of an array; i.e., the dimensions not indexed
-- are returned whole
--
Slice :: SliceIndex slix sl co sh -- slice type specification
-> acc aenv (Array sh e) -- array to be indexed
-> Exp aenv slix -- slice value specification
-> PreOpenAcc acc aenv (Array sl e)
-- Apply the given unary function to all elements of the given array
--
Map :: TypeR e'
-> Fun aenv (e -> e')
-> acc aenv (Array sh e)
-> PreOpenAcc acc aenv (Array sh e')
-- Apply a given binary function pairwise to all elements of the given
-- arrays. The length of the result is the length of the shorter of the
-- two argument arrays.
--
ZipWith :: TypeR e3
-> Fun aenv (e1 -> e2 -> e3)
-> acc aenv (Array sh e1)
-> acc aenv (Array sh e2)
-> PreOpenAcc acc aenv (Array sh e3)
-- Fold along the innermost dimension of an array with a given
-- /associative/ function.
--
Fold :: Fun aenv (e -> e -> e) -- combination function
-> Maybe (Exp aenv e) -- default value
-> acc aenv (Array (sh, Int) e) -- folded array
-> PreOpenAcc acc aenv (Array sh e)
-- Segmented fold along the innermost dimension of an array with a given
-- /associative/ function
--
FoldSeg :: IntegralType i
-> Fun aenv (e -> e -> e) -- combination function
-> Maybe (Exp aenv e) -- default value
-> acc aenv (Array (sh, Int) e) -- folded array
-> acc aenv (Segments i) -- segment descriptor
-> PreOpenAcc acc aenv (Array (sh, Int) e)
-- Haskell-style scan of a linear array with a given
-- /associative/ function and optionally an initial element
-- (which does not need to be the neutral of the associative operations)
-- If no initial value is given, this is a scan1
--
Scan :: Direction
-> Fun aenv (e -> e -> e) -- combination function
-> Maybe (Exp aenv e) -- initial value
-> acc aenv (Array (sh, Int) e)
-> PreOpenAcc acc aenv (Array (sh, Int) e)
-- Like 'Scan', but produces a rightmost (in case of a left-to-right scan)
-- fold value and an array with the same length as the input array (the
-- fold value would be the rightmost element in a Haskell-style scan)
--
Scan' :: Direction
-> Fun aenv (e -> e -> e) -- combination function
-> Exp aenv e -- initial value
-> acc aenv (Array (sh, Int) e)
-> PreOpenAcc acc aenv (Array (sh, Int) e, Array sh e)
-- Generalised forward permutation is characterised by a permutation function
-- that determines for each element of the source array where it should go in
-- the output. The permutation can be between arrays of varying shape and
-- dimensionality.
--
-- Other characteristics of the permutation function 'f':
--
-- 1. 'f' is a (morally) partial function: only the elements of the domain
-- for which the function evaluates to a 'Just' value are mapped in the
-- result. Other elements are dropped.
--
-- 2. 'f' is not surjective: positions in the target array need not be
-- picked up by the permutation function, so the target array must first
-- be initialised from an array of default values.
--
-- 3. 'f' is not injective: distinct elements of the domain may map to the
-- same position in the target array. In this case the combination
-- function is used to combine elements, which needs to be /associative/
-- and /commutative/.
--
Permute :: Fun aenv (e -> e -> e) -- combination function
-> acc aenv (Array sh' e) -- default values
-> Fun aenv (sh -> PrimMaybe sh') -- permutation function
-> acc aenv (Array sh e) -- source array
-> PreOpenAcc acc aenv (Array sh' e)
-- Generalised multi-dimensional backwards permutation; the permutation can
-- be between arrays of varying shape; the permutation function must be total
--
Backpermute :: ShapeR sh'
-> Exp aenv sh' -- dimensions of the result
-> Fun aenv (sh' -> sh) -- permutation function
-> acc aenv (Array sh e) -- source array
-> PreOpenAcc acc aenv (Array sh' e)
-- Map a stencil over an array. In contrast to 'map', the domain of
-- a stencil function is an entire /neighbourhood/ of each array element.
--
Stencil :: StencilR sh e stencil
-> TypeR e'
-> Fun aenv (stencil -> e') -- stencil function
-> Boundary aenv (Array sh e) -- boundary condition
-> acc aenv (Array sh e) -- source array
-> PreOpenAcc acc aenv (Array sh e')
-- Map a binary stencil over an array.
--
Stencil2 :: StencilR sh a stencil1
-> StencilR sh b stencil2
-> TypeR c
-> Fun aenv (stencil1 -> stencil2 -> c) -- stencil function
-> Boundary aenv (Array sh a) -- boundary condition #1
-> acc aenv (Array sh a) -- source array #1
-> Boundary aenv (Array sh b) -- boundary condition #2
-> acc aenv (Array sh b) -- source array #2
-> PreOpenAcc acc aenv (Array sh c)
data Direction = LeftToRight | RightToLeft
deriving Eq
-- | Vanilla boundary condition specification for stencil operations
--
data Boundary aenv t where
-- Clamp coordinates to the extent of the array
Clamp :: Boundary aenv t
-- Mirror coordinates beyond the array extent
Mirror :: Boundary aenv t
-- Wrap coordinates around on each dimension
Wrap :: Boundary aenv t
-- Use a constant value for outlying coordinates
Constant :: e
-> Boundary aenv (Array sh e)
-- Apply the given function to outlying coordinates
Function :: Fun aenv (sh -> e)
-> Boundary aenv (Array sh e)
-- Embedded expressions
-- --------------------
-- | Vanilla open function abstraction
--
data OpenFun env aenv t where
Body :: OpenExp env aenv t -> OpenFun env aenv t
Lam :: ELeftHandSide a env env' -> OpenFun env' aenv t -> OpenFun env aenv (a -> t)
-- | Vanilla function without free scalar variables
--
type Fun = OpenFun ()
-- | Vanilla expression without free scalar variables
--
type Exp = OpenExp ()
-- Types for scalar bindings
--
type ELeftHandSide = LeftHandSide ScalarType
type ExpVar = Var ScalarType
type ExpVars env = Vars ScalarType env
expVars :: ExpVars env t -> OpenExp env aenv t
expVars TupRunit = Nil
expVars (TupRsingle var) = Evar var
expVars (TupRpair v1 v2) = expVars v1 `Pair` expVars v2
-- | Vanilla open expressions using de Bruijn indices for variables ranging
-- over tuples of scalars and arrays of tuples. All code, except Cond, is
-- evaluated eagerly. N-tuples are represented as nested pairs.
--
-- The data type is parametrised over the representation type (not the
-- surface types).
--
data OpenExp env aenv t where
-- Local binding of a scalar expression
Let :: ELeftHandSide bnd_t env env'
-> OpenExp env aenv bnd_t
-> OpenExp env' aenv body_t
-> OpenExp env aenv body_t
-- Variable index, ranging only over tuples or scalars
Evar :: ExpVar env t
-> OpenExp env aenv t
-- Apply a backend-specific foreign function
Foreign :: Foreign asm
=> TypeR y
-> asm (x -> y) -- foreign function
-> Fun () (x -> y) -- alternate implementation (for other backends)
-> OpenExp env aenv x
-> OpenExp env aenv y
-- Tuples
Pair :: OpenExp env aenv t1
-> OpenExp env aenv t2
-> OpenExp env aenv (t1, t2)
Nil :: OpenExp env aenv ()
-- SIMD vectors
VecPack :: KnownNat n
=> VecR n s tup
-> OpenExp env aenv tup
-> OpenExp env aenv (Vec n s)
VecUnpack :: KnownNat n
=> VecR n s tup
-> OpenExp env aenv (Vec n s)
-> OpenExp env aenv tup
-- Array indices & shapes
IndexSlice :: SliceIndex slix sl co sh
-> OpenExp env aenv slix
-> OpenExp env aenv sh
-> OpenExp env aenv sl
IndexFull :: SliceIndex slix sl co sh
-> OpenExp env aenv slix
-> OpenExp env aenv sl
-> OpenExp env aenv sh
-- Shape and index conversion
ToIndex :: ShapeR sh
-> OpenExp env aenv sh -- shape of the array
-> OpenExp env aenv sh -- index into the array
-> OpenExp env aenv Int
FromIndex :: ShapeR sh
-> OpenExp env aenv sh -- shape of the array
-> OpenExp env aenv Int -- index into linear representation
-> OpenExp env aenv sh
-- Case statement
Case :: OpenExp env aenv TAG
-> [(TAG, OpenExp env aenv b)] -- list of equations
-> Maybe (OpenExp env aenv b) -- default case
-> OpenExp env aenv b
-- Conditional expression (non-strict in 2nd and 3rd argument)
Cond :: OpenExp env aenv PrimBool
-> OpenExp env aenv t
-> OpenExp env aenv t
-> OpenExp env aenv t
-- Value recursion
While :: OpenFun env aenv (a -> PrimBool) -- continue while true
-> OpenFun env aenv (a -> a) -- function to iterate
-> OpenExp env aenv a -- initial value
-> OpenExp env aenv a
-- Constant values
Const :: ScalarType t
-> t
-> OpenExp env aenv t
PrimConst :: PrimConst t
-> OpenExp env aenv t
-- Primitive scalar operations
PrimApp :: PrimFun (a -> r)
-> OpenExp env aenv a
-> OpenExp env aenv r
-- Project a single scalar from an array.
-- The array expression can not contain any free scalar variables.
Index :: ArrayVar aenv (Array dim t)
-> OpenExp env aenv dim
-> OpenExp env aenv t
LinearIndex :: ArrayVar aenv (Array dim t)
-> OpenExp env aenv Int
-> OpenExp env aenv t
-- Array shape.
-- The array expression can not contain any free scalar variables.
Shape :: ArrayVar aenv (Array dim e)
-> OpenExp env aenv dim
-- Number of elements of an array given its shape
ShapeSize :: ShapeR dim
-> OpenExp env aenv dim
-> OpenExp env aenv Int
-- Unsafe operations (may fail or result in undefined behaviour)
-- An unspecified bit pattern
Undef :: ScalarType t
-> OpenExp env aenv t
-- Reinterpret the bits of a value as a different type
Coerce :: BitSizeEq a b
=> ScalarType a
-> ScalarType b
-> OpenExp env aenv a
-> OpenExp env aenv b
-- |Primitive constant values
--
data PrimConst ty where
-- constants from Bounded
PrimMinBound :: BoundedType a -> PrimConst a
PrimMaxBound :: BoundedType a -> PrimConst a
-- constant from Floating
PrimPi :: FloatingType a -> PrimConst a
-- constant for empty Vec
PrimVectorCreate :: (KnownNat n, Prim a) => VectorType (Vec n a) -> PrimConst (Vec n a)
-- |Primitive scalar operations
--
data PrimFun sig where
-- operators from Num
PrimAdd :: NumType a -> PrimFun ((a, a) -> a)
PrimSub :: NumType a -> PrimFun ((a, a) -> a)
PrimMul :: NumType a -> PrimFun ((a, a) -> a)
PrimNeg :: NumType a -> PrimFun (a -> a)
PrimAbs :: NumType a -> PrimFun (a -> a)
PrimSig :: NumType a -> PrimFun (a -> a)
-- operators from Integral
PrimQuot :: IntegralType a -> PrimFun ((a, a) -> a)
PrimRem :: IntegralType a -> PrimFun ((a, a) -> a)
PrimQuotRem :: IntegralType a -> PrimFun ((a, a) -> (a, a))
PrimIDiv :: IntegralType a -> PrimFun ((a, a) -> a)
PrimMod :: IntegralType a -> PrimFun ((a, a) -> a)
PrimDivMod :: IntegralType a -> PrimFun ((a, a) -> (a, a))
-- operators from Bits & FiniteBits
PrimBAnd :: IntegralType a -> PrimFun ((a, a) -> a)
PrimBOr :: IntegralType a -> PrimFun ((a, a) -> a)
PrimBXor :: IntegralType a -> PrimFun ((a, a) -> a)
PrimBNot :: IntegralType a -> PrimFun (a -> a)
PrimBShiftL :: IntegralType a -> PrimFun ((a, Int) -> a)
PrimBShiftR :: IntegralType a -> PrimFun ((a, Int) -> a)
PrimBRotateL :: IntegralType a -> PrimFun ((a, Int) -> a)
PrimBRotateR :: IntegralType a -> PrimFun ((a, Int) -> a)
PrimPopCount :: IntegralType a -> PrimFun (a -> Int)
PrimCountLeadingZeros :: IntegralType a -> PrimFun (a -> Int)
PrimCountTrailingZeros :: IntegralType a -> PrimFun (a -> Int)
-- operators from Fractional and Floating
PrimFDiv :: FloatingType a -> PrimFun ((a, a) -> a)
PrimRecip :: FloatingType a -> PrimFun (a -> a)
PrimSin :: FloatingType a -> PrimFun (a -> a)
PrimCos :: FloatingType a -> PrimFun (a -> a)
PrimTan :: FloatingType a -> PrimFun (a -> a)
PrimAsin :: FloatingType a -> PrimFun (a -> a)
PrimAcos :: FloatingType a -> PrimFun (a -> a)
PrimAtan :: FloatingType a -> PrimFun (a -> a)
PrimSinh :: FloatingType a -> PrimFun (a -> a)
PrimCosh :: FloatingType a -> PrimFun (a -> a)
PrimTanh :: FloatingType a -> PrimFun (a -> a)
PrimAsinh :: FloatingType a -> PrimFun (a -> a)
PrimAcosh :: FloatingType a -> PrimFun (a -> a)
PrimAtanh :: FloatingType a -> PrimFun (a -> a)
PrimExpFloating :: FloatingType a -> PrimFun (a -> a)
PrimSqrt :: FloatingType a -> PrimFun (a -> a)
PrimLog :: FloatingType a -> PrimFun (a -> a)
PrimFPow :: FloatingType a -> PrimFun ((a, a) -> a)
PrimLogBase :: FloatingType a -> PrimFun ((a, a) -> a)
-- FIXME: add missing operations from RealFrac & RealFloat
-- operators from RealFrac
PrimTruncate :: FloatingType a -> IntegralType b -> PrimFun (a -> b)
PrimRound :: FloatingType a -> IntegralType b -> PrimFun (a -> b)
PrimFloor :: FloatingType a -> IntegralType b -> PrimFun (a -> b)
PrimCeiling :: FloatingType a -> IntegralType b -> PrimFun (a -> b)
-- PrimProperFraction :: FloatingType a -> IntegralType b -> PrimFun (a -> (b, a))
-- operators from RealFloat
PrimAtan2 :: FloatingType a -> PrimFun ((a, a) -> a)
PrimIsNaN :: FloatingType a -> PrimFun (a -> PrimBool)
PrimIsInfinite :: FloatingType a -> PrimFun (a -> PrimBool)
-- relational and equality operators
PrimLt :: SingleType a -> PrimFun ((a, a) -> PrimBool)
PrimGt :: SingleType a -> PrimFun ((a, a) -> PrimBool)
PrimLtEq :: SingleType a -> PrimFun ((a, a) -> PrimBool)
PrimGtEq :: SingleType a -> PrimFun ((a, a) -> PrimBool)
PrimEq :: SingleType a -> PrimFun ((a, a) -> PrimBool)
PrimNEq :: SingleType a -> PrimFun ((a, a) -> PrimBool)
PrimMax :: SingleType a -> PrimFun ((a, a) -> a)
PrimMin :: SingleType a -> PrimFun ((a, a) -> a)
-- logical operators
--
-- Note that these operators are strict in both arguments. That is, the
-- second argument of PrimLAnd is always evaluated even when the first
-- argument is false.
--
-- We define (surface level) (&&) and (||) using if-then-else to enable
-- short-circuiting, while (&&!) and (||!) are strict versions of these
-- operators, which are defined using PrimLAnd and PrimLOr.
--
PrimLAnd :: PrimFun ((PrimBool, PrimBool) -> PrimBool)
PrimLOr :: PrimFun ((PrimBool, PrimBool) -> PrimBool)
PrimLNot :: PrimFun (PrimBool -> PrimBool)
-- local array operators
PrimVectorIndex :: (KnownNat n, Prim a) => VectorType (Vec n a) -> IntegralType i -> PrimFun ((Vec n a, i) -> a)
PrimVectorWrite :: (KnownNat n, Prim a) => VectorType (Vec n a) -> IntegralType i -> PrimFun ((Vec n a, (i, a)) -> Vec n a)
-- general conversion between types
PrimFromIntegral :: IntegralType a -> NumType b -> PrimFun (a -> b)
PrimToFloating :: NumType a -> FloatingType b -> PrimFun (a -> b)
-- Type utilities
-- --------------
class HasArraysR f where
arraysR :: f aenv a -> ArraysR a
instance HasArraysR OpenAcc where
arraysR (OpenAcc a) = arraysR a
arrayR :: HasArraysR f => f aenv (Array sh e) -> ArrayR (Array sh e)
arrayR a = case arraysR a of
TupRsingle aR -> aR
instance HasArraysR acc => HasArraysR (PreOpenAcc acc) where
arraysR (Alet _ _ body) = arraysR body
arraysR (Avar (Var aR _)) = TupRsingle aR
arraysR (Apair as bs) = TupRpair (arraysR as) (arraysR bs)
arraysR Anil = TupRunit
arraysR (Atrace _ _ bs) = arraysR bs
arraysR (Apply aR _ _) = aR
arraysR (Aforeign r _ _ _) = r
arraysR (Acond _ a _) = arraysR a
arraysR (Awhile _ (Alam lhs _) _) = lhsToTupR lhs
arraysR Awhile{} = error "I want my, I want my MTV!"
arraysR (Use aR _) = TupRsingle aR
arraysR (Unit tR _) = arraysRarray ShapeRz tR
arraysR (Reshape sh _ a) = let ArrayR _ tR = arrayR a
in arraysRarray sh tR
arraysR (Generate aR _ _) = TupRsingle aR
arraysR (Transform aR _ _ _ _) = TupRsingle aR
arraysR (Replicate slice _ a) = let ArrayR _ tR = arrayR a
in arraysRarray (sliceDomainR slice) tR
arraysR (Slice slice a _) = let ArrayR _ tR = arrayR a
in arraysRarray (sliceShapeR slice) tR
arraysR (Map tR _ a) = let ArrayR sh _ = arrayR a
in arraysRarray sh tR
arraysR (ZipWith tR _ a _) = let ArrayR sh _ = arrayR a
in arraysRarray sh tR
arraysR (Fold _ _ a) = let ArrayR (ShapeRsnoc sh) tR = arrayR a
in arraysRarray sh tR
arraysR (FoldSeg _ _ _ a _) = arraysR a
arraysR (Scan _ _ _ a) = arraysR a
arraysR (Scan' _ _ _ a) = let aR@(ArrayR (ShapeRsnoc sh) tR) = arrayR a
in TupRsingle aR `TupRpair` TupRsingle (ArrayR sh tR)
arraysR (Permute _ a _ _) = arraysR a
arraysR (Backpermute sh _ _ a) = let ArrayR _ tR = arrayR a
in arraysRarray sh tR
arraysR (Stencil _ tR _ _ a) = let ArrayR sh _ = arrayR a
in arraysRarray sh tR
arraysR (Stencil2 _ _ tR _ _ a _ _) = let ArrayR sh _ = arrayR a
in arraysRarray sh tR
expType :: HasCallStack => OpenExp aenv env t -> TypeR t
expType = \case
Let _ _ body -> expType body
Evar (Var tR _) -> TupRsingle tR
Foreign tR _ _ _ -> tR
Pair e1 e2 -> TupRpair (expType e1) (expType e2)
Nil -> TupRunit
VecPack vecR _ -> TupRsingle $ VectorScalarType $ vecRvector vecR
VecUnpack vecR _ -> vecRtuple vecR
IndexSlice si _ _ -> shapeType $ sliceShapeR si
IndexFull si _ _ -> shapeType $ sliceDomainR si
ToIndex{} -> TupRsingle scalarTypeInt
FromIndex shr _ _ -> shapeType shr
Case _ ((_,e):_) _ -> expType e
Case _ [] (Just e) -> expType e
Case{} -> internalError "empty case encountered"
Cond _ e _ -> expType e
While _ (Lam lhs _) _ -> lhsToTupR lhs
While{} -> error "What's the matter, you're running in the shadows"
Const tR _ -> TupRsingle tR
PrimConst c -> TupRsingle $ primConstType c
PrimApp f _ -> snd $ primFunType f
Index (Var repr _) _ -> arrayRtype repr
LinearIndex (Var repr _) _ -> arrayRtype repr
Shape (Var repr _) -> shapeType $ arrayRshape repr
ShapeSize{} -> TupRsingle scalarTypeInt
Undef tR -> TupRsingle tR
Coerce _ tR _ -> TupRsingle tR
primConstType :: PrimConst a -> ScalarType a
primConstType = \case
PrimMinBound t -> bounded t
PrimMaxBound t -> bounded t
PrimPi t -> floating t
PrimVectorCreate t -> vector t
where
bounded :: BoundedType a -> ScalarType a
bounded (IntegralBoundedType t) = SingleScalarType $ NumSingleType $ IntegralNumType t
floating :: FloatingType t -> ScalarType t
floating = SingleScalarType . NumSingleType . FloatingNumType
vector :: forall n a. (KnownNat n) => VectorType (Vec n a) -> ScalarType (Vec n a)
vector = VectorScalarType
primFunType :: PrimFun (a -> b) -> (TypeR a, TypeR b)
primFunType = \case
-- Num
PrimAdd t -> binary' $ num t
PrimSub t -> binary' $ num t
PrimMul t -> binary' $ num t
PrimNeg t -> unary' $ num t
PrimAbs t -> unary' $ num t
PrimSig t -> unary' $ num t
-- Integral
PrimQuot t -> binary' $ integral t
PrimRem t -> binary' $ integral t
PrimQuotRem t -> unary' $ integral t `TupRpair` integral t
PrimIDiv t -> binary' $ integral t
PrimMod t -> binary' $ integral t
PrimDivMod t -> unary' $ integral t `TupRpair` integral t
-- Bits & FiniteBits
PrimBAnd t -> binary' $ integral t
PrimBOr t -> binary' $ integral t
PrimBXor t -> binary' $ integral t
PrimBNot t -> unary' $ integral t
PrimBShiftL t -> (integral t `TupRpair` tint, integral t)
PrimBShiftR t -> (integral t `TupRpair` tint, integral t)
PrimBRotateL t -> (integral t `TupRpair` tint, integral t)
PrimBRotateR t -> (integral t `TupRpair` tint, integral t)
PrimPopCount t -> unary (integral t) tint
PrimCountLeadingZeros t -> unary (integral t) tint
PrimCountTrailingZeros t -> unary (integral t) tint
-- Fractional, Floating
PrimFDiv t -> binary' $ floating t
PrimRecip t -> unary' $ floating t
PrimSin t -> unary' $ floating t
PrimCos t -> unary' $ floating t
PrimTan t -> unary' $ floating t
PrimAsin t -> unary' $ floating t
PrimAcos t -> unary' $ floating t
PrimAtan t -> unary' $ floating t
PrimSinh t -> unary' $ floating t
PrimCosh t -> unary' $ floating t
PrimTanh t -> unary' $ floating t
PrimAsinh t -> unary' $ floating t
PrimAcosh t -> unary' $ floating t
PrimAtanh t -> unary' $ floating t
PrimExpFloating t -> unary' $ floating t
PrimSqrt t -> unary' $ floating t
PrimLog t -> unary' $ floating t
PrimFPow t -> binary' $ floating t
PrimLogBase t -> binary' $ floating t
-- RealFrac
PrimTruncate a b -> unary (floating a) (integral b)
PrimRound a b -> unary (floating a) (integral b)
PrimFloor a b -> unary (floating a) (integral b)
PrimCeiling a b -> unary (floating a) (integral b)
-- RealFloat
PrimAtan2 t -> binary' $ floating t
PrimIsNaN t -> unary (floating t) tbool
PrimIsInfinite t -> unary (floating t) tbool
-- Relational and equality
PrimLt t -> compare' t
PrimGt t -> compare' t
PrimLtEq t -> compare' t
PrimGtEq t -> compare' t
PrimEq t -> compare' t
PrimNEq t -> compare' t
PrimMax t -> binary' $ single t
PrimMin t -> binary' $ single t
-- Logical
PrimLAnd -> binary' tbool
PrimLOr -> binary' tbool
PrimLNot -> unary' tbool
-- Local Vector operations
PrimVectorIndex v'@(VectorType _ a) i' ->
let v = singleVector v'
i = integral i'
in (v `TupRpair` i, single a)
-- general conversion between types
PrimFromIntegral a b -> unary (integral a) (num b)
PrimToFloating a b -> unary (num a) (floating b)
where
unary a b = (a, b)
unary' a = unary a a
binary a b = (a `TupRpair` a, b)
binary' a = binary a a
compare' a = binary (single a) tbool
single = TupRsingle . SingleScalarType
singleVector = TupRsingle . VectorScalarType
num = TupRsingle . SingleScalarType . NumSingleType
integral = num . IntegralNumType
floating = num . FloatingNumType
tbool = TupRsingle scalarTypeWord8
tint = TupRsingle scalarTypeInt
-- Normal form data
-- ================
instance NFData (OpenAfun aenv f) where
rnf = rnfOpenAfun
instance NFData (OpenAcc aenv t) where
rnf = rnfOpenAcc
instance NFData (OpenExp env aenv t) where
rnf = rnfOpenExp
instance NFData (OpenFun env aenv t) where
rnf = rnfOpenFun
type NFDataAcc acc = forall aenv t. acc aenv t -> ()
rnfOpenAfun :: OpenAfun aenv t -> ()
rnfOpenAfun = rnfPreOpenAfun rnfOpenAcc
rnfPreOpenAfun :: NFDataAcc acc -> PreOpenAfun acc aenv t -> ()
rnfPreOpenAfun rnfA (Abody b) = rnfA b
rnfPreOpenAfun rnfA (Alam lhs f) = rnfALeftHandSide lhs `seq` rnfPreOpenAfun rnfA f
rnfOpenAcc :: OpenAcc aenv t -> ()
rnfOpenAcc (OpenAcc pacc) = rnfPreOpenAcc rnfOpenAcc pacc
rnfPreOpenAcc :: forall acc aenv t. HasArraysR acc => NFDataAcc acc -> PreOpenAcc acc aenv t -> ()
rnfPreOpenAcc rnfA pacc =
let
rnfAF :: PreOpenAfun acc aenv' t' -> ()
rnfAF = rnfPreOpenAfun rnfA
rnfE :: OpenExp env' aenv' t' -> ()