-
Notifications
You must be signed in to change notification settings - Fork 132
Expand file tree
/
Copy pathSmart.hs
More file actions
1394 lines (1147 loc) · 53.6 KB
/
Smart.hs
File metadata and controls
1394 lines (1147 loc) · 53.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE PolyKinds #-}
{-# OPTIONS_HADDOCK hide #-}
-- |
-- Module : Data.Array.Accelerate.Smart
-- Copyright : [2008..2020] The Accelerate Team
-- License : BSD3
--
-- Maintainer : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability : experimental
-- Portability : non-portable (GHC extensions)
--
-- This modules defines the AST of the user-visible embedded language using more
-- convenient higher-order abstract syntax (instead of de Bruijn indices).
-- Moreover, it defines smart constructors to construct programs.
--
module Data.Array.Accelerate.Smart (
-- * HOAS AST
-- ** Array computations
Acc(..), SmartAcc(..), PreSmartAcc(..),
Level, Direction(..), Message(..),
-- ** Scalar expressions
Exp(..), SmartExp(..), PreSmartExp(..),
Stencil(..),
Boundary(..), PreBoundary(..),
PrimBool,
PrimMaybe,
-- ** Extracting type information
HasArraysR(..),
HasTypeR(..),
-- ** Smart constructors for literals
constant, undef,
-- ** Smart destructors for shapes
indexHead, indexTail,
-- ** Smart constructors for constants
mkMinBound, mkMaxBound, mkPi,
mkSin, mkCos, mkTan,
mkAsin, mkAcos, mkAtan,
mkSinh, mkCosh, mkTanh,
mkAsinh, mkAcosh, mkAtanh,
mkExpFloating, mkSqrt, mkLog,
mkFPow, mkLogBase,
mkTruncate, mkRound, mkFloor, mkCeiling,
mkAtan2,
-- ** Smart constructors for primitive functions
mkAdd, mkSub, mkMul, mkNeg, mkAbs, mkSig, mkQuot, mkRem, mkQuotRem, mkIDiv, mkMod, mkDivMod,
mkBAnd, mkBOr, mkBXor, mkBNot, mkBShiftL, mkBShiftR, mkBRotateL, mkBRotateR, mkPopCount, mkCountLeadingZeros, mkCountTrailingZeros,
mkFDiv, mkRecip, mkLt, mkGt, mkLtEq, mkGtEq, mkEq, mkNEq, mkMax, mkMin,
mkLAnd, mkLOr, mkLNot, mkIsNaN, mkIsInfinite,
-- ** Smart constructors for type coercion functions
mkFromIntegral, mkToFloating, mkBitcast, mkCoerce, Coerce(..),
-- ** Smart constructors for vector operations
mkVectorCreate,
mkVectorIndex,
mkVectorWrite,
-- ** Auxiliary functions
($$), ($$$), ($$$$), ($$$$$),
ApplyAcc(..),
unAcc, unAccFunction, mkExp, unExp, unExpFunction, unExpBinaryFunction, unPair, mkPairToTuple,
-- ** Miscellaneous
formatPreAccOp,
formatPreExpOp,
) where
import Data.Proxy
import Data.Array.Accelerate.AST.Idx
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Representation.Elt
import Data.Array.Accelerate.Representation.Shape
import Data.Array.Accelerate.Representation.Slice
import Data.Array.Accelerate.Representation.Stencil hiding ( StencilR, stencilR )
import Data.Array.Accelerate.Representation.Tag
import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.Representation.Vec
import Data.Array.Accelerate.Sugar.Array ( Arrays )
import Data.Array.Accelerate.Sugar.Elt
import Data.Array.Accelerate.Sugar.Vec
import Data.Array.Accelerate.Sugar.Foreign
import Data.Array.Accelerate.Sugar.Shape ( (:.)(..) )
import Data.Array.Accelerate.Type
import qualified Data.Array.Accelerate.Representation.Stencil as R
import qualified Data.Array.Accelerate.Sugar.Array as Sugar
import qualified Data.Array.Accelerate.Sugar.Shape as Sugar
import Data.Array.Accelerate.AST ( Direction(..), Message(..)
, PrimBool, PrimMaybe
, PrimFun(..), primFunType
, PrimConst(..), primConstType )
import Data.Primitive.Vec
import Data.Kind
import Data.Text.Lazy.Builder
import Formatting
import GHC.TypeLits
-- Array computations
-- ------------------
-- | Accelerate is an /embedded language/ that distinguishes between vanilla
-- arrays (e.g. in Haskell memory on the CPU) and embedded arrays (e.g. in
-- device memory on a GPU), as well as the computations on both of these. Since
-- Accelerate is an embedded language, programs written in Accelerate are not
-- compiled by the Haskell compiler (GHC). Rather, each Accelerate backend is
-- a /runtime compiler/ which generates and executes parallel SIMD code of the
-- target language at application /runtime/.
--
-- The type constructor 'Acc' represents embedded collective array operations.
-- A term of type @Acc a@ is an Accelerate program which, once executed, will
-- produce a value of type 'a' (an 'Array' or a tuple of 'Arrays'). Collective
-- operations of type @Acc a@ comprise many /scalar expressions/, wrapped in
-- type constructor 'Exp', which will be executed in parallel. Although
-- collective operations comprise many scalar operations executed in parallel,
-- scalar operations /cannot/ initiate new collective operations: this
-- stratification between scalar operations in 'Exp' and array operations in
-- 'Acc' helps statically exclude /nested data parallelism/, which is difficult
-- to execute efficiently on constrained hardware such as GPUs.
--
-- [/A simple example/]
--
-- As a simple example, to compute a vector dot product we can write:
--
-- > dotp :: Num a => Vector a -> Vector a -> Acc (Scalar a)
-- > dotp xs ys =
-- > let
-- > xs' = use xs
-- > ys' = use ys
-- > in
-- > fold (+) 0 (zipWith (*) xs' ys')
--
-- The function @dotp@ consumes two one-dimensional arrays ('Vector's) of
-- values, and produces a single ('Scalar') result as output. As the return type
-- is wrapped in the type 'Acc', we see that it is an embedded Accelerate
-- computation - it will be evaluated in the /object/ language of dynamically
-- generated parallel code, rather than the /meta/ language of vanilla Haskell.
--
-- As the arguments to @dotp@ are plain Haskell arrays, to make these available
-- to Accelerate computations they must be embedded with the
-- 'Data.Array.Accelerate.Language.use' function.
--
-- An Accelerate backend is used to evaluate the embedded computation and return
-- the result back to vanilla Haskell. Calling the 'run' function of a backend
-- will generate code for the target architecture, compile, and execute it. For
-- example, the following backends are available:
--
-- * <http://hackage.haskell.org/package/accelerate-llvm-native accelerate-llvm-native>: for execution on multicore CPUs
-- * <http://hackage.haskell.org/package/accelerate-llvm-ptx accelerate-llvm-ptx>: for execution on NVIDIA CUDA-capable GPUs
--
-- See also 'Exp', which encapsulates embedded /scalar/ computations.
--
-- [/Avoiding nested parallelism/]
--
-- As mentioned above, embedded scalar computations of type 'Exp' can not
-- initiate further collective operations.
--
-- Suppose we wanted to extend our above @dotp@ function to matrix-vector
-- multiplication. First, let's rewrite our @dotp@ function to take 'Acc' arrays
-- as input (which is typically what we want):
--
-- > dotp :: Num a => Acc (Vector a) -> Acc (Vector a) -> Acc (Scalar a)
-- > dotp xs ys = fold (+) 0 $ zipWith (*) xs ys
--
-- We might then be inclined to lift our dot-product program to the following
-- (incorrect) matrix-vector product, by applying @dotp@ to each row of the
-- input matrix:
--
-- > mvm_ndp :: Num a => Acc (Matrix a) -> Acc (Vector a) -> Acc (Vector a)
-- > mvm_ndp mat vec =
-- > let I2 rows cols = shape mat
-- > in generate (I1 rows)
-- > (\(I1 row) -> the $ dotp vec (slice mat (I2 row All_)))
--
-- Here, we use 'Data.Array.Accelerate.generate' to create a one-dimensional
-- vector by applying at each index a function to 'Data.Array.Accelerate.slice'
-- out the corresponding @row@ of the matrix to pass to the @dotp@ function.
-- However, since both 'Data.Array.Accelerate.generate' and
-- 'Data.Array.Accelerate.slice' are data-parallel operations, and moreover that
-- 'Data.Array.Accelerate.slice' /depends on/ the argument @row@ given to it by
-- the 'Data.Array.Accelerate.generate' function, this definition requires
-- nested data-parallelism, and is thus not permitted. The clue that this
-- definition is invalid is that in order to create a program which will be
-- accepted by the type checker, we must use the function
-- 'Data.Array.Accelerate.the' to retrieve the result of the @dotp@ operation,
-- effectively concealing that @dotp@ is a collective array computation in order
-- to match the type expected by 'Data.Array.Accelerate.generate', which is that
-- of scalar expressions. Additionally, since we have fooled the type-checker,
-- this problem will only be discovered at program runtime.
--
-- In order to avoid this problem, we can make use of the fact that operations
-- in Accelerate are /rank polymorphic/. The 'Data.Array.Accelerate.fold'
-- operation reduces along the innermost dimension of an array of arbitrary
-- rank, reducing the rank (dimensionality) of the array by one. Thus, we can
-- 'Data.Array.Accelerate.replicate' the input vector to as many @rows@ there
-- are in the input matrix, and perform the dot-product of the vector with every
-- row simultaneously:
--
-- > mvm :: Num a => Acc (Matrix a) -> Acc (Vector a) -> Acc (Vector a)
-- > mvm mat vec =
-- > let I2 rows cols = shape mat
-- > vec' = replicate (I2 rows All_) vec
-- > in
-- > fold (+) 0 $ zipWith (*) mat vec'
--
-- Note that the intermediate, replicated array @vec'@ is never actually created
-- in memory; it will be fused directly into the operation which consumes it. We
-- discuss fusion next.
--
-- [/Fusion/]
--
-- Array computations of type 'Acc' will be subject to /array fusion/;
-- Accelerate will combine individual 'Acc' computations into a single
-- computation, which reduces the number of traversals over the input data and
-- thus improves performance. As such, it is often useful to have some intuition
-- on when fusion should occur.
--
-- The main idea is to first partition array operations into two categories:
--
-- 1. Element-wise operations, such as 'Data.Array.Accelerate.map',
-- 'Data.Array.Accelerate.generate', and
-- 'Data.Array.Accelerate.backpermute'. Each element of these operations
-- can be computed independently of all others.
--
-- 2. Collective operations such as 'Data.Array.Accelerate.fold',
-- 'Data.Array.Accelerate.scanl', and 'Data.Array.Accelerate.stencil'. To
-- compute each output element of these operations requires reading
-- multiple elements from the input array(s).
--
-- Element-wise operations fuse together whenever the consumer operation uses
-- a single element of the input array. Element-wise operations can both fuse
-- their inputs into themselves, as well be fused into later operations. Both
-- these examples should fuse into a single loop:
--
-- <<images/fusion_example_1.png>>
--
-- <<images/fusion_example_2.png>>
--
-- If the consumer operation uses more than one element of the input array
-- (typically, via 'Data.Array.Accelerate.generate' indexing an array multiple
-- times), then the input array will be completely evaluated first; no fusion
-- occurs in this case, because fusing the first operation into the second
-- implies duplicating work.
--
-- On the other hand, collective operations can fuse their input arrays into
-- themselves, but on output always evaluate to an array; collective operations
-- will not be fused into a later step. For example:
--
-- <<images/fusion_example_3.png>>
--
-- Here the element-wise sequence ('Data.Array.Accelerate.use'
-- + 'Data.Array.Accelerate.generate' + 'Data.Array.Accelerate.zipWith') will
-- fuse into a single operation, which then fuses into the collective
-- 'Data.Array.Accelerate.fold' operation. At this point in the program the
-- 'Data.Array.Accelerate.fold' must now be evaluated. In the final step the
-- 'Data.Array.Accelerate.map' reads in the array produced by
-- 'Data.Array.Accelerate.fold'. As there is no fusion between the
-- 'Data.Array.Accelerate.fold' and 'Data.Array.Accelerate.map' steps, this
-- program consists of two "loops"; one for the 'Data.Array.Accelerate.use'
-- + 'Data.Array.Accelerate.generate' + 'Data.Array.Accelerate.zipWith'
-- + 'Data.Array.Accelerate.fold' step, and one for the final
-- 'Data.Array.Accelerate.map' step.
--
-- You can see how many operations will be executed in the fused program by
-- 'Show'-ing the 'Acc' program, or by using the debugging option @-ddump-dot@
-- to save the program as a graphviz DOT file.
--
-- As a special note, the operations 'Data.Array.Accelerate.unzip' and
-- 'Data.Array.Accelerate.reshape', when applied to a real array, are executed
-- in constant time, so in this situation these operations will not be fused.
--
-- [/Tips/]
--
-- * Since 'Acc' represents embedded computations that will only be executed
-- when evaluated by a backend, we can programatically generate these
-- computations using the meta language Haskell; for example, unrolling loops
-- or embedding input values into the generated code.
--
-- * It is usually best to keep all intermediate computations in 'Acc', and
-- only 'run' the computation at the very end to produce the final result.
-- This enables optimisations between intermediate results (e.g. array
-- fusion) and, if the target architecture has a separate memory space, as is
-- the case of GPUs, to prevent excessive data transfers.
--
newtype Acc a = Acc (SmartAcc (Sugar.ArraysR a))
newtype SmartAcc a = SmartAcc (PreSmartAcc SmartAcc SmartExp a)
-- The level of lambda-bound variables. The root has level 0; then it
-- increases with each bound variable — i.e., it is the same as the size of
-- the environment at the defining occurrence.
--
type Level = Int
-- | Array-valued collective computations without a recursive knot
--
data PreSmartAcc acc exp as where
-- Needed for conversion to de Bruijn form
Atag :: ArraysR as
-> Level -- environment size at defining occurrence
-> PreSmartAcc acc exp as
Pipe :: ArraysR as
-> ArraysR bs
-> ArraysR cs
-> (SmartAcc as -> acc bs)
-> (SmartAcc bs -> acc cs)
-> acc as
-> PreSmartAcc acc exp cs
Aforeign :: Foreign asm
=> ArraysR bs
-> asm (as -> bs)
-> (SmartAcc as -> SmartAcc bs)
-> acc as
-> PreSmartAcc acc exp bs
Acond :: exp PrimBool
-> acc as
-> acc as
-> PreSmartAcc acc exp as
Awhile :: ArraysR arrs
-> (SmartAcc arrs -> acc (Scalar PrimBool))
-> (SmartAcc arrs -> acc arrs)
-> acc arrs
-> PreSmartAcc acc exp arrs
Anil :: PreSmartAcc acc exp ()
Apair :: acc arrs1
-> acc arrs2
-> PreSmartAcc acc exp (arrs1, arrs2)
Aprj :: PairIdx (arrs1, arrs2) arrs
-> acc (arrs1, arrs2)
-> PreSmartAcc acc exp arrs
Atrace :: Message arrs1
-> acc arrs1
-> acc arrs2
-> PreSmartAcc acc exp arrs2
Use :: ArrayR (Array sh e)
-> Array sh e
-> PreSmartAcc acc exp (Array sh e)
Unit :: TypeR e
-> exp e
-> PreSmartAcc acc exp (Scalar e)
Generate :: ArrayR (Array sh e)
-> exp sh
-> (SmartExp sh -> exp e)
-> PreSmartAcc acc exp (Array sh e)
Reshape :: ShapeR sh
-> exp sh
-> acc (Array sh' e)
-> PreSmartAcc acc exp (Array sh e)
Replicate :: SliceIndex slix sl co sh
-> exp slix
-> acc (Array sl e)
-> PreSmartAcc acc exp (Array sh e)
Slice :: SliceIndex slix sl co sh
-> acc (Array sh e)
-> exp slix
-> PreSmartAcc acc exp (Array sl e)
Map :: TypeR e
-> TypeR e'
-> (SmartExp e -> exp e')
-> acc (Array sh e)
-> PreSmartAcc acc exp (Array sh e')
ZipWith :: TypeR e1
-> TypeR e2
-> TypeR e3
-> (SmartExp e1 -> SmartExp e2 -> exp e3)
-> acc (Array sh e1)
-> acc (Array sh e2)
-> PreSmartAcc acc exp (Array sh e3)
Fold :: TypeR e
-> (SmartExp e -> SmartExp e -> exp e)
-> Maybe (exp e)
-> acc (Array (sh, Int) e)
-> PreSmartAcc acc exp (Array sh e)
FoldSeg :: IntegralType i
-> TypeR e
-> (SmartExp e -> SmartExp e -> exp e)
-> Maybe (exp e)
-> acc (Array (sh, Int) e)
-> acc (Segments i)
-> PreSmartAcc acc exp (Array (sh, Int) e)
Scan :: Direction
-> TypeR e
-> (SmartExp e -> SmartExp e -> exp e)
-> Maybe (exp e)
-> acc (Array (sh, Int) e)
-> PreSmartAcc acc exp (Array (sh, Int) e)
Scan' :: Direction
-> TypeR e
-> (SmartExp e -> SmartExp e -> exp e)
-> exp e
-> acc (Array (sh, Int) e)
-> PreSmartAcc acc exp (Array (sh, Int) e, Array sh e)
Permute :: ArrayR (Array sh e)
-> (SmartExp e -> SmartExp e -> exp e)
-> acc (Array sh' e)
-> (SmartExp sh -> exp (PrimMaybe sh'))
-> acc (Array sh e)
-> PreSmartAcc acc exp (Array sh' e)
Backpermute :: ShapeR sh'
-> exp sh'
-> (SmartExp sh' -> exp sh)
-> acc (Array sh e)
-> PreSmartAcc acc exp (Array sh' e)
Stencil :: R.StencilR sh a stencil
-> TypeR b
-> (SmartExp stencil -> exp b)
-> PreBoundary acc exp (Array sh a)
-> acc (Array sh a)
-> PreSmartAcc acc exp (Array sh b)
Stencil2 :: R.StencilR sh a stencil1
-> R.StencilR sh b stencil2
-> TypeR c
-> (SmartExp stencil1 -> SmartExp stencil2 -> exp c)
-> PreBoundary acc exp (Array sh a)
-> acc (Array sh a)
-> PreBoundary acc exp (Array sh b)
-> acc (Array sh b)
-> PreSmartAcc acc exp (Array sh c)
-- Embedded expressions of the surface language
-- --------------------------------------------
-- HOAS expressions mirror the constructors of 'AST.OpenExp', but with the 'Tag'
-- constructor instead of variables in the form of de Bruijn indices.
--
-- | The type 'Exp' represents embedded scalar expressions. The collective
-- operations of Accelerate 'Acc' consist of many scalar expressions executed in
-- data-parallel.
--
-- Note that scalar expressions can not initiate new collective operations:
-- doing so introduces /nested data parallelism/, which is difficult to execute
-- efficiently on constrained hardware such as GPUs, and is thus currently
-- unsupported.
--
newtype Exp t = Exp (SmartExp (EltR t))
newtype SmartExp t = SmartExp (PreSmartExp SmartAcc SmartExp t)
-- | Scalar expressions to parametrise collective array operations, themselves parameterised over
-- the type of collective array operations.
--
data PreSmartExp acc exp t where
-- Needed for conversion to de Bruijn form
Tag :: TypeR t
-> Level -- environment size at defining occurrence
-> PreSmartExp acc exp t
-- Needed for embedded pattern matching
Match :: TagR t
-> exp t
-> PreSmartExp acc exp t
-- All the same constructors as 'AST.Exp', plus projection
Const :: ScalarType t
-> t
-> PreSmartExp acc exp t
Nil :: PreSmartExp acc exp ()
Pair :: exp t1
-> exp t2
-> PreSmartExp acc exp (t1, t2)
Prj :: PairIdx (t1, t2) t
-> exp (t1, t2)
-> PreSmartExp acc exp t
VecPack :: KnownNat n
=> VecR n s tup
-> exp tup
-> PreSmartExp acc exp (Vec n s)
VecUnpack :: KnownNat n
=> VecR n s tup
-> exp (Vec n s)
-> PreSmartExp acc exp tup
ToIndex :: ShapeR sh
-> exp sh
-> exp sh
-> PreSmartExp acc exp Int
FromIndex :: ShapeR sh
-> exp sh
-> exp Int
-> PreSmartExp acc exp sh
Case :: exp a
-> [(TagR a, exp b)]
-> PreSmartExp acc exp b
Cond :: exp PrimBool
-> exp t
-> exp t
-> PreSmartExp acc exp t
While :: TypeR t
-> (SmartExp t -> exp PrimBool)
-> (SmartExp t -> exp t)
-> exp t
-> PreSmartExp acc exp t
PrimConst :: PrimConst t
-> PreSmartExp acc exp t
PrimApp :: PrimFun (a -> r)
-> exp a
-> PreSmartExp acc exp r
Index :: TypeR t
-> acc (Array sh t)
-> exp sh
-> PreSmartExp acc exp t
LinearIndex :: TypeR t
-> acc (Array sh t)
-> exp Int
-> PreSmartExp acc exp t
Shape :: ShapeR sh
-> acc (Array sh e)
-> PreSmartExp acc exp sh
ShapeSize :: ShapeR sh
-> exp sh
-> PreSmartExp acc exp Int
Foreign :: Foreign asm
=> TypeR y
-> asm (x -> y)
-> (SmartExp x -> SmartExp y) -- RCE: Using SmartExp instead of exp to aid in sharing recovery.
-> exp x
-> PreSmartExp acc exp y
Undef :: ScalarType t
-> PreSmartExp acc exp t
Coerce :: BitSizeEq a b
=> ScalarType a
-> ScalarType b
-> exp a
-> PreSmartExp acc exp b
-- Smart constructors for stencils
-- -------------------------------
-- | Boundary condition specification for stencil operations
--
data Boundary t where
Boundary :: PreBoundary SmartAcc SmartExp (Array (EltR sh) (EltR e))
-> Boundary (Sugar.Array sh e)
data PreBoundary acc exp t where
Clamp :: PreBoundary acc exp t
Mirror :: PreBoundary acc exp t
Wrap :: PreBoundary acc exp t
Constant :: e
-> PreBoundary acc exp (Array sh e)
Function :: (SmartExp sh -> exp e)
-> PreBoundary acc exp (Array sh e)
-- Stencil reification
-- -------------------
--
-- In the AST representation, we turn the stencil type from nested tuples
-- of Accelerate expressions into an Accelerate expression whose type is
-- a tuple nested in the same manner. This enables us to represent the
-- stencil function as a unary function (which also only needs one de
-- Bruijn index). The various positions in the stencil are accessed via
-- tuple indices (i.e., projections).
--
class Stencil sh e stencil where
type StencilR sh stencil :: Type
stencilR :: R.StencilR (EltR sh) (EltR e) (StencilR sh stencil)
stencilPrj :: SmartExp (StencilR sh stencil) -> stencil
-- DIM1
instance Elt e => Stencil Sugar.DIM1 e (Exp e, Exp e, Exp e) where
type StencilR Sugar.DIM1 (Exp e, Exp e, Exp e)
= EltR (e, e, e)
stencilR = StencilRunit3 @(EltR e) $ eltR @e
stencilPrj s = (Exp $ prj2 s,
Exp $ prj1 s,
Exp $ prj0 s)
instance Elt e => Stencil Sugar.DIM1 e (Exp e, Exp e, Exp e, Exp e, Exp e) where
type StencilR Sugar.DIM1 (Exp e, Exp e, Exp e, Exp e, Exp e)
= EltR (e, e, e, e, e)
stencilR = StencilRunit5 $ eltR @e
stencilPrj s = (Exp $ prj4 s,
Exp $ prj3 s,
Exp $ prj2 s,
Exp $ prj1 s,
Exp $ prj0 s)
instance Elt e => Stencil Sugar.DIM1 e (Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e) where
type StencilR Sugar.DIM1 (Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e)
= EltR (e, e, e, e, e, e, e)
stencilR = StencilRunit7 $ eltR @e
stencilPrj s = (Exp $ prj6 s,
Exp $ prj5 s,
Exp $ prj4 s,
Exp $ prj3 s,
Exp $ prj2 s,
Exp $ prj1 s,
Exp $ prj0 s)
instance Elt e => Stencil Sugar.DIM1 e (Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e)
where
type StencilR Sugar.DIM1 (Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e, Exp e)
= EltR (e, e, e, e, e, e, e, e, e)
stencilR = StencilRunit9 $ eltR @e
stencilPrj s = (Exp $ prj8 s,
Exp $ prj7 s,
Exp $ prj6 s,
Exp $ prj5 s,
Exp $ prj4 s,
Exp $ prj3 s,
Exp $ prj2 s,
Exp $ prj1 s,
Exp $ prj0 s)
-- DIM(n+1)
instance (Stencil (sh:.Int) a row2,
Stencil (sh:.Int) a row1,
Stencil (sh:.Int) a row0) => Stencil (sh:.Int:.Int) a (row2, row1, row0) where
type StencilR (sh:.Int:.Int) (row2, row1, row0)
= Tup3 (StencilR (sh:.Int) row2) (StencilR (sh:.Int) row1) (StencilR (sh:.Int) row0)
stencilR = StencilRtup3 (stencilR @(sh:.Int) @a @row2) (stencilR @(sh:.Int) @a @row1) (stencilR @(sh:.Int) @a @row0)
stencilPrj s = (stencilPrj @(sh:.Int) @a $ prj2 s,
stencilPrj @(sh:.Int) @a $ prj1 s,
stencilPrj @(sh:.Int) @a $ prj0 s)
instance (Stencil (sh:.Int) a row4,
Stencil (sh:.Int) a row3,
Stencil (sh:.Int) a row2,
Stencil (sh:.Int) a row1,
Stencil (sh:.Int) a row0) => Stencil (sh:.Int:.Int) a (row4, row3, row2, row1, row0) where
type StencilR (sh:.Int:.Int) (row4, row3, row2, row1, row0)
= Tup5 (StencilR (sh:.Int) row4) (StencilR (sh:.Int) row3) (StencilR (sh:.Int) row2)
(StencilR (sh:.Int) row1) (StencilR (sh:.Int) row0)
stencilR = StencilRtup5 (stencilR @(sh:.Int) @a @row4) (stencilR @(sh:.Int) @a @row3)
(stencilR @(sh:.Int) @a @row2) (stencilR @(sh:.Int) @a @row1) (stencilR @(sh:.Int) @a @row0)
stencilPrj s = (stencilPrj @(sh:.Int) @a $ prj4 s,
stencilPrj @(sh:.Int) @a $ prj3 s,
stencilPrj @(sh:.Int) @a $ prj2 s,
stencilPrj @(sh:.Int) @a $ prj1 s,
stencilPrj @(sh:.Int) @a $ prj0 s)
instance (Stencil (sh:.Int) a row6,
Stencil (sh:.Int) a row5,
Stencil (sh:.Int) a row4,
Stencil (sh:.Int) a row3,
Stencil (sh:.Int) a row2,
Stencil (sh:.Int) a row1,
Stencil (sh:.Int) a row0)
=> Stencil (sh:.Int:.Int) a (row6, row5, row4, row3, row2, row1, row0) where
type StencilR (sh:.Int:.Int) (row6, row5, row4, row3, row2, row1, row0)
= Tup7 (StencilR (sh:.Int) row6) (StencilR (sh:.Int) row5) (StencilR (sh:.Int) row4)
(StencilR (sh:.Int) row3) (StencilR (sh:.Int) row2) (StencilR (sh:.Int) row1)
(StencilR (sh:.Int) row0)
stencilR = StencilRtup7 (stencilR @(sh:.Int) @a @row6)
(stencilR @(sh:.Int) @a @row5) (stencilR @(sh:.Int) @a @row4) (stencilR @(sh:.Int) @a @row3)
(stencilR @(sh:.Int) @a @row2) (stencilR @(sh:.Int) @a @row1) (stencilR @(sh:.Int) @a @row0)
stencilPrj s = (stencilPrj @(sh:.Int) @a $ prj6 s,
stencilPrj @(sh:.Int) @a $ prj5 s,
stencilPrj @(sh:.Int) @a $ prj4 s,
stencilPrj @(sh:.Int) @a $ prj3 s,
stencilPrj @(sh:.Int) @a $ prj2 s,
stencilPrj @(sh:.Int) @a $ prj1 s,
stencilPrj @(sh:.Int) @a $ prj0 s)
instance (Stencil (sh:.Int) a row8,
Stencil (sh:.Int) a row7,
Stencil (sh:.Int) a row6,
Stencil (sh:.Int) a row5,
Stencil (sh:.Int) a row4,
Stencil (sh:.Int) a row3,
Stencil (sh:.Int) a row2,
Stencil (sh:.Int) a row1,
Stencil (sh:.Int) a row0)
=> Stencil (sh:.Int:.Int) a (row8, row7, row6, row5, row4, row3, row2, row1, row0) where
type StencilR (sh:.Int:.Int) (row8, row7, row6, row5, row4, row3, row2, row1, row0)
= Tup9 (StencilR (sh:.Int) row8) (StencilR (sh:.Int) row7) (StencilR (sh:.Int) row6)
(StencilR (sh:.Int) row5) (StencilR (sh:.Int) row4) (StencilR (sh:.Int) row3)
(StencilR (sh:.Int) row2) (StencilR (sh:.Int) row1) (StencilR (sh:.Int) row0)
stencilR = StencilRtup9
(stencilR @(sh:.Int) @a @row8) (stencilR @(sh:.Int) @a @row7) (stencilR @(sh:.Int) @a @row6)
(stencilR @(sh:.Int) @a @row5) (stencilR @(sh:.Int) @a @row4) (stencilR @(sh:.Int) @a @row3)
(stencilR @(sh:.Int) @a @row2) (stencilR @(sh:.Int) @a @row1) (stencilR @(sh:.Int) @a @row0)
stencilPrj s = (stencilPrj @(sh:.Int) @a $ prj8 s,
stencilPrj @(sh:.Int) @a $ prj7 s,
stencilPrj @(sh:.Int) @a $ prj6 s,
stencilPrj @(sh:.Int) @a $ prj5 s,
stencilPrj @(sh:.Int) @a $ prj4 s,
stencilPrj @(sh:.Int) @a $ prj3 s,
stencilPrj @(sh:.Int) @a $ prj2 s,
stencilPrj @(sh:.Int) @a $ prj1 s,
stencilPrj @(sh:.Int) @a $ prj0 s)
prjTail :: SmartExp (t, a) -> SmartExp t
prjTail = SmartExp . Prj PairIdxLeft
prj0 :: SmartExp (t, a) -> SmartExp a
prj0 = SmartExp . Prj PairIdxRight
prj1 :: SmartExp ((t, a), s0) -> SmartExp a
prj1 = prj0 . prjTail
prj2 :: SmartExp (((t, a), s1), s0) -> SmartExp a
prj2 = prj1 . prjTail
prj3 :: SmartExp ((((t, a), s2), s1), s0) -> SmartExp a
prj3 = prj2 . prjTail
prj4 :: SmartExp (((((t, a), s3), s2), s1), s0) -> SmartExp a
prj4 = prj3 . prjTail
prj5 :: SmartExp ((((((t, a), s4), s3), s2), s1), s0) -> SmartExp a
prj5 = prj4 . prjTail
prj6 :: SmartExp (((((((t, a), s5), s4), s3), s2), s1), s0) -> SmartExp a
prj6 = prj5 . prjTail
prj7 :: SmartExp ((((((((t, a), s6), s5), s4), s3), s2), s1), s0) -> SmartExp a
prj7 = prj6 . prjTail
prj8 :: SmartExp (((((((((t, a), s7), s6), s5), s4), s3), s2), s1), s0) -> SmartExp a
prj8 = prj7 . prjTail
-- Extracting type information
-- ---------------------------
class HasArraysR f where
arraysR :: f a -> ArraysR a
instance HasArraysR SmartAcc where
arraysR (SmartAcc e) = arraysR e
arrayR :: HasArraysR f => f (Array sh e) -> ArrayR (Array sh e)
arrayR acc = case arraysR acc of
TupRsingle repr -> repr
instance HasArraysR acc => HasArraysR (PreSmartAcc acc exp) where
arraysR = \case
Atag repr _ -> repr
Pipe _ _ repr _ _ _ -> repr
Aforeign repr _ _ _ -> repr
Acond _ a _ -> arraysR a
Awhile _ _ _ a -> arraysR a
Anil -> TupRunit
Apair a1 a2 -> arraysR a1 `TupRpair` arraysR a2
Aprj idx a | TupRpair t1 t2 <- arraysR a
-> case idx of
PairIdxLeft -> t1
PairIdxRight -> t2
Aprj _ _ -> error "Ejector seat? You're joking!"
Atrace _ _ a -> arraysR a
Use repr _ -> TupRsingle repr
Unit tp _ -> TupRsingle $ ArrayR ShapeRz $ tp
Generate repr _ _ -> TupRsingle repr
Reshape shr _ a -> let ArrayR _ tp = arrayR a
in TupRsingle $ ArrayR shr tp
Replicate si _ a -> let ArrayR _ tp = arrayR a
in TupRsingle $ ArrayR (sliceDomainR si) tp
Slice si a _ -> let ArrayR _ tp = arrayR a
in TupRsingle $ ArrayR (sliceShapeR si) tp
Map _ tp _ a -> let ArrayR shr _ = arrayR a
in TupRsingle $ ArrayR shr tp
ZipWith _ _ tp _ a _ -> let ArrayR shr _ = arrayR a
in TupRsingle $ ArrayR shr tp
Fold _ _ _ a -> let ArrayR (ShapeRsnoc shr) tp = arrayR a
in TupRsingle (ArrayR shr tp)
FoldSeg _ _ _ _ a _ -> arraysR a
Scan _ _ _ _ a -> arraysR a
Scan' _ _ _ _ a -> let repr@(ArrayR (ShapeRsnoc shr) tp) = arrayR a
in TupRsingle repr `TupRpair` TupRsingle (ArrayR shr tp)
Permute _ _ a _ _ -> arraysR a
Backpermute shr _ _ a -> let ArrayR _ tp = arrayR a
in TupRsingle (ArrayR shr tp)
Stencil s tp _ _ _ -> TupRsingle $ ArrayR (stencilShapeR s) tp
Stencil2 s _ tp _ _ _ _ _ -> TupRsingle $ ArrayR (stencilShapeR s) tp
class HasTypeR f where
typeR :: HasCallStack => f t -> TypeR t
instance HasTypeR SmartExp where
typeR (SmartExp e) = typeR e
instance HasTypeR exp => HasTypeR (PreSmartExp acc exp) where
typeR = \case
Tag tp _ -> tp
Match _ e -> typeR e
Const tp _ -> TupRsingle tp
Nil -> TupRunit
Pair e1 e2 -> typeR e1 `TupRpair` typeR e2
Prj idx e
| TupRpair t1 t2 <- typeR e -> case idx of
PairIdxLeft -> t1
PairIdxRight -> t2
Prj _ _ -> error "I never joke about my work"
VecPack vecR _ -> TupRsingle $ VectorScalarType $ vecRvector vecR
VecUnpack vecR _ -> vecRtuple vecR
ToIndex _ _ _ -> TupRsingle scalarTypeInt
FromIndex shr _ _ -> shapeType shr
Case _ ((_,c):_) -> typeR c
Case{} -> internalError "encountered empty case"
Cond _ e _ -> typeR e
While t _ _ _ -> t
PrimConst c -> TupRsingle $ primConstType c
PrimApp f _ -> snd $ primFunType f
Index tp _ _ -> tp
LinearIndex tp _ _ -> tp
Shape shr _ -> shapeType shr
ShapeSize _ _ -> TupRsingle scalarTypeInt
Foreign tp _ _ _ -> tp
Undef tp -> TupRsingle tp
Coerce _ tp _ -> TupRsingle tp
-- Smart constructors
-- ------------------
-- | Scalar expression inlet: make a Haskell value available for processing in
-- an Accelerate scalar expression.
--
-- Note that this embeds the value directly into the expression. Depending on
-- the backend used to execute the computation, this might not always be
-- desirable. For example, a backend that does external code generation may
-- embed this constant directly into the generated code, which means new code
-- will need to be generated and compiled every time the value changes. In such
-- cases, consider instead lifting scalar values into (singleton) arrays so that
-- they can be passed as an input to the computation and thus the value can
-- change without the need to generate fresh code.
--
constant :: forall e. (HasCallStack, Elt e) => e -> Exp e
constant = Exp . go (eltR @e) . fromElt
where
go :: HasCallStack => TypeR t -> t -> SmartExp t
go TupRunit () = SmartExp $ Nil
go (TupRsingle tp) c = SmartExp $ Const tp c
go (TupRpair t1 t2) (c1, c2) = SmartExp $ go t1 c1 `Pair` go t2 c2
-- | 'undef' can be used anywhere a constant is expected, and indicates that the
-- consumer of the value can receive an unspecified bit pattern.
--
-- This is useful because a store of an undefined value can be assumed to not
-- have any effect; we can assume that the value is overwritten with bits that
-- happen to match what was already there. However, a store /to/ an undefined
-- location could clobber arbitrary memory, therefore, its use in such a context
-- would introduce undefined /behaviour/.
--
-- There are (at least) two cases where you may want to use this:
--
-- 1. The 'Data.Array.Accelerate.Language.permute' function requires an array
-- of default values, into which the new values are combined. However, if
-- you are sure the default values are not used, and will (eventually) be
-- completely overwritten, then 'Data.Array.Accelerate.Prelude.fill'ing an
-- array with this value will give you a new uninitialised array.
--
-- 2. In the definition of sum data types. See for example
-- "Data.Array.Accelerate.Data.Maybe" and
-- "Data.Array.Accelerate.Data.Either".
--
-- @since 1.2.0.0
--
undef :: forall e. Elt e => Exp e
undef = Exp $ go $ eltR @e
where
go :: TypeR t -> SmartExp t
go TupRunit = SmartExp $ Nil
go (TupRsingle t) = SmartExp $ Undef t
go (TupRpair t1 t2) = SmartExp $ go t1 `Pair` go t2
-- | Get the innermost dimension of a shape.
--
-- The innermost dimension (right-most component of the shape) is the index of
-- the array which varies most rapidly, and corresponds to elements of the array
-- which are adjacent in memory.
--
-- Another way to think of this is, for example when writing nested loops over
-- an array in C, this index corresponds to the index iterated over by the
-- innermost nested loop.
--
indexHead :: (Elt sh, Elt a) => Exp (sh :. a) -> Exp a
indexHead (Exp x) = mkExp $ Prj PairIdxRight x
-- | Get all but the innermost element of a shape
--
indexTail :: (Elt sh, Elt a) => Exp (sh :. a) -> Exp sh
indexTail (Exp x) = mkExp $ Prj PairIdxLeft x
-- Smart constructor for constants
--
mkMinBound :: (Elt t, IsBounded (EltR t)) => Exp t
mkMinBound = mkExp $ PrimConst (PrimMinBound boundedType)
mkMaxBound :: (Elt t, IsBounded (EltR t)) => Exp t
mkMaxBound = mkExp $ PrimConst (PrimMaxBound boundedType)
mkPi :: (Elt r, IsFloating (EltR r)) => Exp r
mkPi = mkExp $ PrimConst (PrimPi floatingType)
-- Smart constructors for primitive applications
--
-- Operators from Floating
mkSin :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t
mkSin = mkPrimUnary $ PrimSin floatingType
mkCos :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t
mkCos = mkPrimUnary $ PrimCos floatingType
mkTan :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t
mkTan = mkPrimUnary $ PrimTan floatingType
mkAsin :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t
mkAsin = mkPrimUnary $ PrimAsin floatingType
mkAcos :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t
mkAcos = mkPrimUnary $ PrimAcos floatingType
mkAtan :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t
mkAtan = mkPrimUnary $ PrimAtan floatingType
mkSinh :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t
mkSinh = mkPrimUnary $ PrimSinh floatingType
mkCosh :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t
mkCosh = mkPrimUnary $ PrimCosh floatingType
mkTanh :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t
mkTanh = mkPrimUnary $ PrimTanh floatingType
mkAsinh :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t
mkAsinh = mkPrimUnary $ PrimAsinh floatingType