-
Notifications
You must be signed in to change notification settings - Fork 132
Expand file tree
/
Copy pathInterpreter.hs
More file actions
1882 lines (1586 loc) · 63.6 KB
/
Interpreter.hs
File metadata and controls
1882 lines (1586 loc) · 63.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fno-warn-name-shadowing #-}
{-# OPTIONS_HADDOCK prune #-}
-- |
-- Module : Data.Array.Accelerate.Interpreter
-- Description : Reference backend (interpreted)
-- Copyright : [2008..2020] The Accelerate Team
-- License : BSD3
--
-- Maintainer : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability : experimental
-- Portability : non-portable (GHC extensions)
--
-- This interpreter is meant to be a reference implementation of the
-- semantics of the embedded array language. The emphasis is on defining
-- the semantics clearly, not on performance.
--
module Data.Array.Accelerate.Interpreter (
Smart.Acc, Sugar.Arrays,
Afunction, AfunctionR,
-- * Interpret an array expression
run, run1, runN,
-- Internal (hidden)
evalPrim, evalPrimConst, evalCoerceScalar, atraceOp,
) where
import Data.Array.Accelerate.AST hiding ( Boundary(..) )
import Data.Array.Accelerate.AST.Environment
import Data.Array.Accelerate.AST.Var
import Data.Array.Accelerate.Array.Data
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Representation.Elt
import Data.Array.Accelerate.Representation.Shape
import Data.Array.Accelerate.Representation.Slice
import Data.Array.Accelerate.Representation.Stencil
import Data.Array.Accelerate.Representation.Tag
import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.Representation.Vec
import Data.Array.Accelerate.Trafo
import Data.Array.Accelerate.Trafo.Delayed ( DelayedOpenAfun, DelayedOpenAcc )
import Data.Array.Accelerate.Trafo.Sharing ( AfunctionR, AfunctionRepr(..), afunctionRepr )
import Data.Array.Accelerate.Type
import Data.Primitive.Vec
import qualified Data.Array.Accelerate.AST as AST
import qualified Data.Array.Accelerate.Debug.Internal.Flags as Debug
import qualified Data.Array.Accelerate.Debug.Internal.Graph as Debug
import qualified Data.Array.Accelerate.Debug.Internal.Stats as Debug
import qualified Data.Array.Accelerate.Debug.Internal.Timed as Debug
import qualified Data.Array.Accelerate.Smart as Smart
import qualified Data.Array.Accelerate.Sugar.Array as Sugar
import qualified Data.Array.Accelerate.Sugar.Elt as Sugar
import qualified Data.Array.Accelerate.Trafo.Delayed as AST
import GHC.TypeLits
import Control.DeepSeq
import Control.Exception
import Control.Monad
import Control.Monad.ST
import Data.Bits
import Data.Primitive.ByteArray
import Data.Primitive.Types
import Data.Text.Lazy.Builder
import Formatting
import System.IO
import System.IO.Unsafe ( unsafePerformIO )
import Unsafe.Coerce
import qualified Data.Text.IO as T
import Prelude hiding ( (!!), sum )
-- Program execution
-- -----------------
-- | Run a complete embedded array program using the reference interpreter.
--
run :: (HasCallStack, Sugar.Arrays a) => Smart.Acc a -> a
run a = unsafePerformIO execute
where
!acc = convertAcc a
execute = do
Debug.dumpGraph $!! acc
Debug.dumpSimplStats
res <- phase "execute" Debug.elapsed $ evaluate $ evalOpenAcc acc Empty
return $ Sugar.toArr $ snd res
-- | This is 'runN' specialised to an array program of one argument.
--
run1 :: (HasCallStack, Sugar.Arrays a, Sugar.Arrays b) => (Smart.Acc a -> Smart.Acc b) -> a -> b
run1 = runN
-- | Prepare and execute an embedded array program.
--
runN :: forall f. (HasCallStack, Afunction f) => f -> AfunctionR f
runN f = go
where
!acc = convertAfun f
!afun = unsafePerformIO $ do
Debug.dumpGraph $!! acc
Debug.dumpSimplStats
return acc
!go = eval (afunctionRepr @f) afun Empty
--
eval :: AfunctionRepr g (AfunctionR g) (ArraysFunctionR g)
-> DelayedOpenAfun aenv (ArraysFunctionR g)
-> Val aenv
-> AfunctionR g
eval (AfunctionReprLam reprF) (Alam lhs f) aenv = \a -> eval reprF f $ aenv `push` (lhs, Sugar.fromArr a)
eval AfunctionReprBody (Abody b) aenv = unsafePerformIO $ phase "execute" Debug.elapsed (Sugar.toArr . snd <$> evaluate (evalOpenAcc b aenv))
eval _ _aenv _ = error "Two men say they're Jesus; one of them must be wrong"
-- -- | Stream a lazily read list of input arrays through the given program,
-- -- collecting results as we go
-- --
-- streamOut :: Arrays a => Sugar.Seq [a] -> [a]
-- streamOut seq = let seq' = convertSeqWith config seq
-- in evalDelayedSeq defaultSeqConfig seq'
-- Debugging
-- ---------
phase :: Builder -> Format Builder (Double -> Double -> Builder) -> IO a -> IO a
phase n fmt go = Debug.timed Debug.dump_phases (now ("phase " <> n <> ": ") % fmt) go
-- Delayed Arrays
-- --------------
-- Note that in contrast to the representation used in the optimised AST, the
-- delayed array representation used here is _only_ for delayed arrays --- we do
-- not require an optional Manifest|Delayed data type to evaluate the program.
--
data Delayed a where
Delayed :: ArrayR (Array sh e)
-> sh
-> (sh -> e)
-> (Int -> e)
-> Delayed (Array sh e)
-- Array expression evaluation
-- ---------------------------
type WithReprs acc = (ArraysR acc, acc)
fromFunction' :: ArrayR (Array sh e) -> sh -> (sh -> e) -> WithReprs (Array sh e)
fromFunction' repr sh f = (TupRsingle repr, fromFunction repr sh f)
-- Evaluate an open array function
--
evalOpenAfun :: HasCallStack => DelayedOpenAfun aenv f -> Val aenv -> f
evalOpenAfun (Alam lhs f) aenv = \a -> evalOpenAfun f $ aenv `push` (lhs, a)
evalOpenAfun (Abody b) aenv = snd $ evalOpenAcc b aenv
-- The core interpreter for optimised array programs
--
evalOpenAcc
:: forall aenv a. HasCallStack
=> DelayedOpenAcc aenv a
-> Val aenv
-> WithReprs a
evalOpenAcc AST.Delayed{} _ = internalError "expected manifest array"
evalOpenAcc (AST.Manifest pacc) aenv =
let
manifest :: forall a'. HasCallStack => DelayedOpenAcc aenv a' -> WithReprs a'
manifest acc =
let (repr, a') = evalOpenAcc acc aenv
in rnfArraysR repr a' `seq` (repr, a')
delayed :: DelayedOpenAcc aenv (Array sh e) -> Delayed (Array sh e)
delayed AST.Delayed{..} = Delayed reprD (evalE extentD) (evalF indexD) (evalF linearIndexD)
delayed a' = Delayed aR (shape a) (indexArray aR a) (linearIndexArray (arrayRtype aR) a)
where
(TupRsingle aR, a) = manifest a'
evalE :: Exp aenv t -> t
evalE exp = evalExp exp aenv
evalF :: Fun aenv f -> f
evalF fun = evalFun fun aenv
evalB :: AST.Boundary aenv t -> Boundary t
evalB bnd = evalBoundary bnd aenv
dir :: Direction -> t -> t -> t
dir LeftToRight l _ = l
dir RightToLeft _ r = r
in
case pacc of
Avar (Var repr ix) -> (TupRsingle repr, prj ix aenv)
Alet lhs acc1 acc2 -> evalOpenAcc acc2 $ aenv `push` (lhs, snd $ manifest acc1)
Apair acc1 acc2 -> let (r1, a1) = manifest acc1
(r2, a2) = manifest acc2
in
(TupRpair r1 r2, (a1, a2))
Anil -> (TupRunit, ())
Atrace msg as bs -> unsafePerformIO $ manifest bs <$ atraceOp msg (snd $ manifest as)
Apply repr afun acc -> (repr, evalOpenAfun afun aenv $ snd $ manifest acc)
Aforeign repr _ afun acc -> (repr, evalOpenAfun afun Empty $ snd $ manifest acc)
Acond p acc1 acc2
| toBool (evalE p) -> manifest acc1
| otherwise -> manifest acc2
Awhile cond body acc -> (repr, go initial)
where
(repr, initial) = manifest acc
p = evalOpenAfun cond aenv
f = evalOpenAfun body aenv
go !x
| toBool (linearIndexArray (Sugar.eltR @Word8) (p x) 0) = go (f x)
| otherwise = x
Use repr arr -> (TupRsingle repr, arr)
Unit tp e -> unitOp tp (evalE e)
-- Collect s -> evalSeq defaultSeqConfig s aenv
-- Producers
-- ---------
Map tp f acc -> mapOp tp (evalF f) (delayed acc)
Generate repr sh f -> generateOp repr (evalE sh) (evalF f)
Transform repr sh p f acc -> transformOp repr (evalE sh) (evalF p) (evalF f) (delayed acc)
Backpermute shr sh p acc -> backpermuteOp shr (evalE sh) (evalF p) (delayed acc)
Reshape shr sh acc -> reshapeOp shr (evalE sh) (manifest acc)
ZipWith tp f acc1 acc2 -> zipWithOp tp (evalF f) (delayed acc1) (delayed acc2)
Replicate slice slix acc -> replicateOp slice (evalE slix) (manifest acc)
Slice slice acc slix -> sliceOp slice (manifest acc) (evalE slix)
-- Consumers
-- ---------
Fold f (Just z) acc -> foldOp (evalF f) (evalE z) (delayed acc)
Fold f Nothing acc -> fold1Op (evalF f) (delayed acc)
FoldSeg i f (Just z) acc seg -> foldSegOp i (evalF f) (evalE z) (delayed acc) (delayed seg)
FoldSeg i f Nothing acc seg -> fold1SegOp i (evalF f) (delayed acc) (delayed seg)
Scan d f (Just z) acc -> dir d scanlOp scanrOp (evalF f) (evalE z) (delayed acc)
Scan d f Nothing acc -> dir d scanl1Op scanr1Op (evalF f) (delayed acc)
Scan' d f z acc -> dir d scanl'Op scanr'Op (evalF f) (evalE z) (delayed acc)
Permute f def p acc -> permuteOp (evalF f) (manifest def) (evalF p) (delayed acc)
Stencil s tp sten b acc -> stencilOp s tp (evalF sten) (evalB b) (delayed acc)
Stencil2 s1 s2 tp sten b1 a1 b2 a2
-> stencil2Op s1 s2 tp (evalF sten) (evalB b1) (delayed a1) (evalB b2) (delayed a2)
-- Array primitives
-- ----------------
unitOp :: TypeR e -> e -> WithReprs (Scalar e)
unitOp tp e = fromFunction' (ArrayR ShapeRz tp) () (const e)
generateOp
:: ArrayR (Array sh e)
-> sh
-> (sh -> e)
-> WithReprs (Array sh e)
generateOp = fromFunction'
transformOp
:: ArrayR (Array sh' b)
-> sh'
-> (sh' -> sh)
-> (a -> b)
-> Delayed (Array sh a)
-> WithReprs (Array sh' b)
transformOp repr sh' p f (Delayed _ _ xs _)
= fromFunction' repr sh' (\ix -> f (xs $ p ix))
reshapeOp
:: HasCallStack
=> ShapeR sh
-> sh
-> WithReprs (Array sh' e)
-> WithReprs (Array sh e)
reshapeOp newShapeR newShape (TupRsingle (ArrayR shr tp), (Array sh adata))
= boundsCheck "shape mismatch" (size newShapeR newShape == size shr sh)
( TupRsingle (ArrayR newShapeR tp)
, Array newShape adata
)
replicateOp
:: SliceIndex slix sl co sh
-> slix
-> WithReprs (Array sl e)
-> WithReprs (Array sh e)
replicateOp slice slix (TupRsingle repr@(ArrayR _ tp), arr)
= fromFunction' repr' sh (\ix -> (repr, arr) ! pf ix)
where
repr' = ArrayR (sliceDomainR slice) tp
(sh, pf) = extend slice slix (shape arr)
extend :: SliceIndex slix sl co dim
-> slix
-> sl
-> (dim, dim -> sl)
extend SliceNil () ()
= ((), const ())
extend (SliceAll sliceIdx) (slx, ()) (sl, sz)
= let (dim', f') = extend sliceIdx slx sl
in ((dim', sz), \(ix, i) -> (f' ix, i))
extend (SliceFixed sliceIdx) (slx, sz) sl
= let (dim', f') = extend sliceIdx slx sl
in ((dim', sz), \(ix, _) -> f' ix)
sliceOp
:: SliceIndex slix sl co sh
-> WithReprs (Array sh e)
-> slix
-> WithReprs (Array sl e)
sliceOp slice (TupRsingle repr@(ArrayR _ tp), arr) slix
= fromFunction' repr' sh' (\ix -> (repr, arr) ! pf ix)
where
repr' = ArrayR (sliceShapeR slice) tp
(sh', pf) = restrict slice slix (shape arr)
restrict
:: HasCallStack
=> SliceIndex slix sl co sh
-> slix
-> sh
-> (sl, sl -> sh)
restrict SliceNil () ()
= ((), const ())
restrict (SliceAll sliceIdx) (slx, ()) (sl, sz)
= let (sl', f') = restrict sliceIdx slx sl
in ((sl', sz), \(ix, i) -> (f' ix, i))
restrict (SliceFixed sliceIdx) (slx, i) (sl, sz)
= let (sl', f') = restrict sliceIdx slx sl
in indexCheck i sz $ (sl', \ix -> (f' ix, i))
mapOp :: TypeR b
-> (a -> b)
-> Delayed (Array sh a)
-> WithReprs (Array sh b)
mapOp tp f (Delayed (ArrayR shr _) sh xs _)
= fromFunction' (ArrayR shr tp) sh (\ix -> f (xs ix))
zipWithOp
:: TypeR c
-> (a -> b -> c)
-> Delayed (Array sh a)
-> Delayed (Array sh b)
-> WithReprs (Array sh c)
zipWithOp tp f (Delayed (ArrayR shr _) shx xs _) (Delayed _ shy ys _)
= fromFunction' (ArrayR shr tp) (intersect shr shx shy) (\ix -> f (xs ix) (ys ix))
foldOp
:: (e -> e -> e)
-> e
-> Delayed (Array (sh, Int) e)
-> WithReprs (Array sh e)
foldOp f z (Delayed (ArrayR (ShapeRsnoc shr) tp) (sh, n) arr _)
= fromFunction' (ArrayR shr tp) sh (\ix -> iter (ShapeRsnoc ShapeRz) ((), n) (\((), i) -> arr (ix, i)) f z)
fold1Op
:: HasCallStack
=> (e -> e -> e)
-> Delayed (Array (sh, Int) e)
-> WithReprs (Array sh e)
fold1Op f (Delayed (ArrayR (ShapeRsnoc shr) tp) (sh, n) arr _)
= boundsCheck "empty array" (n > 0)
$ fromFunction' (ArrayR shr tp) sh (\ix -> iter1 (ShapeRsnoc ShapeRz) ((), n) (\((), i) -> arr (ix, i)) f)
foldSegOp
:: HasCallStack
=> IntegralType i
-> (e -> e -> e)
-> e
-> Delayed (Array (sh, Int) e)
-> Delayed (Segments i)
-> WithReprs (Array (sh, Int) e)
foldSegOp itp f z (Delayed repr (sh, _) arr _) (Delayed _ ((), n) _ seg)
| IntegralDict <- integralDict itp
= boundsCheck "empty segment descriptor" (n > 0)
$ fromFunction' repr (sh, n-1)
$ \(sz, ix) -> let start = fromIntegral $ seg ix
end = fromIntegral $ seg (ix+1)
in
boundsCheck "empty segment" (end >= start)
$ iter (ShapeRsnoc ShapeRz) ((), end-start) (\((), i) -> arr (sz, start+i)) f z
fold1SegOp
:: HasCallStack
=> IntegralType i
-> (e -> e -> e)
-> Delayed (Array (sh, Int) e)
-> Delayed (Segments i)
-> WithReprs (Array (sh, Int) e)
fold1SegOp itp f (Delayed repr (sh, _) arr _) (Delayed _ ((), n) _ seg)
| IntegralDict <- integralDict itp
= boundsCheck "empty segment descriptor" (n > 0)
$ fromFunction' repr (sh, n-1)
$ \(sz, ix) -> let start = fromIntegral $ seg ix
end = fromIntegral $ seg (ix+1)
in
boundsCheck "empty segment" (end > start)
$ iter1 (ShapeRsnoc ShapeRz) ((), end-start) (\((), i) -> arr (sz, start+i)) f
scanl1Op
:: forall sh e. HasCallStack
=> (e -> e -> e)
-> Delayed (Array (sh, Int) e)
-> WithReprs (Array (sh, Int) e)
scanl1Op f (Delayed (ArrayR shr tp) sh ain _)
= ( TupRsingle $ ArrayR shr tp
, adata `seq` Array sh adata
)
where
--
(adata, _) = runArrayData @e $ do
aout <- newArrayData tp (size shr sh)
let write (sz, 0) = writeArrayData tp aout (toIndex shr sh (sz, 0)) (ain (sz, 0))
write (sz, i) = do
x <- readArrayData tp aout (toIndex shr sh (sz, i-1))
let y = ain (sz, i)
writeArrayData tp aout (toIndex shr sh (sz, i)) (f x y)
iter shr sh write (>>) (return ())
return (aout, undefined)
scanlOp
:: forall sh e.
(e -> e -> e)
-> e
-> Delayed (Array (sh, Int) e)
-> WithReprs (Array (sh, Int) e)
scanlOp f z (Delayed (ArrayR shr tp) (sh, n) ain _)
= ( TupRsingle $ ArrayR shr tp
, adata `seq` Array sh' adata
)
where
sh' = (sh, n+1)
--
(adata, _) = runArrayData @e $ do
aout <- newArrayData tp (size shr sh')
let write (sz, 0) = writeArrayData tp aout (toIndex shr sh' (sz, 0)) z
write (sz, i) = do
x <- readArrayData tp aout (toIndex shr sh' (sz, i-1))
let y = ain (sz, i-1)
writeArrayData tp aout (toIndex shr sh' (sz, i)) (f x y)
iter shr sh' write (>>) (return ())
return (aout, undefined)
scanl'Op
:: forall sh e.
(e -> e -> e)
-> e
-> Delayed (Array (sh, Int) e)
-> WithReprs (Array (sh, Int) e, Array sh e)
scanl'Op f z (Delayed (ArrayR shr@(ShapeRsnoc shr') tp) (sh, n) ain _)
= ( TupRsingle (ArrayR shr tp) `TupRpair` TupRsingle (ArrayR shr' tp)
, aout `seq` asum `seq` ( Array (sh, n) aout, Array sh asum )
)
where
((aout, asum), _) = runArrayData @(e, e) $ do
aout <- newArrayData tp (size shr (sh, n))
asum <- newArrayData tp (size shr' sh)
let write (sz, 0)
| n == 0 = writeArrayData tp asum (toIndex shr' sh sz) z
| otherwise = writeArrayData tp aout (toIndex shr (sh, n) (sz, 0)) z
write (sz, i) = do
x <- readArrayData tp aout (toIndex shr (sh, n) (sz, i-1))
let y = ain (sz, i-1)
if i == n
then writeArrayData tp asum (toIndex shr' sh sz) (f x y)
else writeArrayData tp aout (toIndex shr (sh, n) (sz, i)) (f x y)
iter shr (sh, n+1) write (>>) (return ())
return ((aout, asum), undefined)
scanrOp
:: forall sh e.
(e -> e -> e)
-> e
-> Delayed (Array (sh, Int) e)
-> WithReprs (Array (sh, Int) e)
scanrOp f z (Delayed (ArrayR shr tp) (sz, n) ain _)
= ( TupRsingle (ArrayR shr tp)
, adata `seq` Array sh' adata
)
where
sh' = (sz, n+1)
--
(adata, _) = runArrayData @e $ do
aout <- newArrayData tp (size shr sh')
let write (sz, 0) = writeArrayData tp aout (toIndex shr sh' (sz, n)) z
write (sz, i) = do
let x = ain (sz, n-i)
y <- readArrayData tp aout (toIndex shr sh' (sz, n-i+1))
writeArrayData tp aout (toIndex shr sh' (sz, n-i)) (f x y)
iter shr sh' write (>>) (return ())
return (aout, undefined)
scanr1Op
:: forall sh e. HasCallStack
=> (e -> e -> e)
-> Delayed (Array (sh, Int) e)
-> WithReprs (Array (sh, Int) e)
scanr1Op f (Delayed (ArrayR shr tp) sh@(_, n) ain _)
= ( TupRsingle $ ArrayR shr tp
, adata `seq` Array sh adata
)
where
(adata, _) = runArrayData @e $ do
aout <- newArrayData tp (size shr sh)
let write (sz, 0) = writeArrayData tp aout (toIndex shr sh (sz, n-1)) (ain (sz, n-1))
write (sz, i) = do
let x = ain (sz, n-i-1)
y <- readArrayData tp aout (toIndex shr sh (sz, n-i))
writeArrayData tp aout (toIndex shr sh (sz, n-i-1)) (f x y)
iter shr sh write (>>) (return ())
return (aout, undefined)
scanr'Op
:: forall sh e.
(e -> e -> e)
-> e
-> Delayed (Array (sh, Int) e)
-> WithReprs (Array (sh, Int) e, Array sh e)
scanr'Op f z (Delayed (ArrayR shr@(ShapeRsnoc shr') tp) (sh, n) ain _)
= ( TupRsingle (ArrayR shr tp) `TupRpair` TupRsingle (ArrayR shr' tp)
, aout `seq` asum `seq` ( Array (sh, n) aout, Array sh asum )
)
where
((aout, asum), _) = runArrayData @(e, e) $ do
aout <- newArrayData tp (size shr (sh, n))
asum <- newArrayData tp (size shr' sh)
let write (sz, 0)
| n == 0 = writeArrayData tp asum (toIndex shr' sh sz) z
| otherwise = writeArrayData tp aout (toIndex shr (sh, n) (sz, n-1)) z
write (sz, i) = do
let x = ain (sz, n-i)
y <- readArrayData tp aout (toIndex shr (sh, n) (sz, n-i))
if i == n
then writeArrayData tp asum (toIndex shr' sh sz) (f x y)
else writeArrayData tp aout (toIndex shr (sh, n) (sz, n-i-1)) (f x y)
iter shr (sh, n+1) write (>>) (return ())
return ((aout, asum), undefined)
permuteOp
:: forall sh sh' e. HasCallStack
=> (e -> e -> e)
-> WithReprs (Array sh' e)
-> (sh -> PrimMaybe sh')
-> Delayed (Array sh e)
-> WithReprs (Array sh' e)
permuteOp f (TupRsingle (ArrayR shr' _), def@(Array _ adef)) p (Delayed (ArrayR shr tp) sh _ ain)
= (TupRsingle $ ArrayR shr' tp, adata `seq` Array sh' adata)
where
sh' = shape def
n' = size shr' sh'
--
(adata, _) = runArrayData @e $ do
aout <- newArrayData tp n'
let -- initialise array with default values
init i
| i >= n' = return ()
| otherwise = do
x <- readArrayData tp adef i
writeArrayData tp aout i x
init (i+1)
-- project each element onto the destination array and update
update src
= case p src of
(0,_) -> return ()
(1,((),dst)) -> do
let i = toIndex shr sh src
j = toIndex shr' sh' dst
x = ain i
--
y <- readArrayData tp aout j
writeArrayData tp aout j (f x y)
_ -> internalError "unexpected tag"
init 0
iter shr sh update (>>) (return ())
return (aout, undefined)
backpermuteOp
:: ShapeR sh'
-> sh'
-> (sh' -> sh)
-> Delayed (Array sh e)
-> WithReprs (Array sh' e)
backpermuteOp shr sh' p (Delayed (ArrayR _ tp) _ arr _)
= fromFunction' (ArrayR shr tp) sh' (\ix -> arr $ p ix)
stencilOp
:: HasCallStack
=> StencilR sh a stencil
-> TypeR b
-> (stencil -> b)
-> Boundary (Array sh a)
-> Delayed (Array sh a)
-> WithReprs (Array sh b)
stencilOp stencil tp f bnd arr@(Delayed _ sh _ _)
= fromFunction' (ArrayR shr tp) sh
$ f . stencilAccess stencil (bounded shr bnd arr)
where
shr = stencilShapeR stencil
stencil2Op
:: HasCallStack
=> StencilR sh a stencil1
-> StencilR sh b stencil2
-> TypeR c
-> (stencil1 -> stencil2 -> c)
-> Boundary (Array sh a)
-> Delayed (Array sh a)
-> Boundary (Array sh b)
-> Delayed (Array sh b)
-> WithReprs (Array sh c)
stencil2Op s1 s2 tp stencil bnd1 arr1@(Delayed _ sh1 _ _) bnd2 arr2@(Delayed _ sh2 _ _)
= fromFunction' (ArrayR shr tp) (intersect shr sh1 sh2) f
where
f ix = stencil (stencilAccess s1 (bounded shr bnd1 arr1) ix)
(stencilAccess s2 (bounded shr bnd2 arr2) ix)
shr = stencilShapeR s1
stencilAccess
:: StencilR sh e stencil
-> (sh -> e)
-> sh
-> stencil
stencilAccess stencil = goR (stencilShapeR stencil) stencil
where
-- Base cases, nothing interesting to do here since we know the lower
-- dimension is Z.
--
goR :: ShapeR sh -> StencilR sh e stencil -> (sh -> e) -> sh -> stencil
goR _ (StencilRunit3 _) rf ix =
let
(z, i) = ix
rf' d = rf (z, i+d)
in
((( ()
, rf' (-1))
, rf' 0 )
, rf' 1 )
goR _ (StencilRunit5 _) rf ix =
let (z, i) = ix
rf' d = rf (z, i+d)
in
((((( ()
, rf' (-2))
, rf' (-1))
, rf' 0 )
, rf' 1 )
, rf' 2 )
goR _ (StencilRunit7 _) rf ix =
let (z, i) = ix
rf' d = rf (z, i+d)
in
((((((( ()
, rf' (-3))
, rf' (-2))
, rf' (-1))
, rf' 0 )
, rf' 1 )
, rf' 2 )
, rf' 3 )
goR _ (StencilRunit9 _) rf ix =
let (z, i) = ix
rf' d = rf (z, i+d)
in
((((((((( ()
, rf' (-4))
, rf' (-3))
, rf' (-2))
, rf' (-1))
, rf' 0 )
, rf' 1 )
, rf' 2 )
, rf' 3 )
, rf' 4 )
-- Recursive cases. Note that because the stencil pattern is defined with
-- cons ordering, whereas shapes (and indices) are defined as a snoc-list,
-- when we recurse on the stencil structure we must manipulate the
-- _left-most_ index component.
--
goR (ShapeRsnoc shr) (StencilRtup3 s1 s2 s3) rf ix =
let (i, ix') = uncons shr ix
rf' d ds = rf (cons shr (i+d) ds)
in
((( ()
, goR shr s1 (rf' (-1)) ix')
, goR shr s2 (rf' 0) ix')
, goR shr s3 (rf' 1) ix')
goR (ShapeRsnoc shr) (StencilRtup5 s1 s2 s3 s4 s5) rf ix =
let (i, ix') = uncons shr ix
rf' d ds = rf (cons shr (i+d) ds)
in
((((( ()
, goR shr s1 (rf' (-2)) ix')
, goR shr s2 (rf' (-1)) ix')
, goR shr s3 (rf' 0) ix')
, goR shr s4 (rf' 1) ix')
, goR shr s5 (rf' 2) ix')
goR (ShapeRsnoc shr) (StencilRtup7 s1 s2 s3 s4 s5 s6 s7) rf ix =
let (i, ix') = uncons shr ix
rf' d ds = rf (cons shr (i+d) ds)
in
((((((( ()
, goR shr s1 (rf' (-3)) ix')
, goR shr s2 (rf' (-2)) ix')
, goR shr s3 (rf' (-1)) ix')
, goR shr s4 (rf' 0) ix')
, goR shr s5 (rf' 1) ix')
, goR shr s6 (rf' 2) ix')
, goR shr s7 (rf' 3) ix')
goR (ShapeRsnoc shr) (StencilRtup9 s1 s2 s3 s4 s5 s6 s7 s8 s9) rf ix =
let (i, ix') = uncons shr ix
rf' d ds = rf (cons shr (i+d) ds)
in
((((((((( ()
, goR shr s1 (rf' (-4)) ix')
, goR shr s2 (rf' (-3)) ix')
, goR shr s3 (rf' (-2)) ix')
, goR shr s4 (rf' (-1)) ix')
, goR shr s5 (rf' 0) ix')
, goR shr s6 (rf' 1) ix')
, goR shr s7 (rf' 2) ix')
, goR shr s8 (rf' 3) ix')
, goR shr s9 (rf' 4) ix')
-- Add a left-most component to an index
--
cons :: ShapeR sh -> Int -> sh -> (sh, Int)
cons ShapeRz ix () = ((), ix)
cons (ShapeRsnoc shr) ix (sh, sz) = (cons shr ix sh, sz)
-- Remove the left-most index of an index, and return the remainder
--
uncons :: ShapeR sh -> (sh, Int) -> (Int, sh)
uncons ShapeRz ((), v) = (v, ())
uncons (ShapeRsnoc shr) (v1, v2) = let (i, v1') = uncons shr v1
in (i, (v1', v2))
bounded
:: HasCallStack
=> ShapeR sh
-> Boundary (Array sh e)
-> Delayed (Array sh e)
-> sh
-> e
bounded shr bnd (Delayed _ sh f _) ix =
if inside shr sh ix
then f ix
else
case bnd of
Function g -> g ix
Constant v -> v
_ -> f (bound shr sh ix)
where
-- Whether the index (second argument) is inside the bounds of the given
-- shape (first argument).
--
inside :: ShapeR sh -> sh -> sh -> Bool
inside ShapeRz () () = True
inside (ShapeRsnoc shr) (sh, sz) (ih, iz) = iz >= 0 && iz < sz && inside shr sh ih
-- Return the index (second argument), updated to obey the given boundary
-- conditions when outside the bounds of the given shape (first argument)
--
bound :: HasCallStack => ShapeR sh -> sh -> sh -> sh
bound ShapeRz () () = ()
bound (ShapeRsnoc shr) (sh, sz) (ih, iz) = (bound shr sh ih, ih')
where
ih'
| iz < 0 = case bnd of
Clamp -> 0
Mirror -> -iz
Wrap -> sz + iz
_ -> internalError "unexpected boundary condition"
| iz >= sz = case bnd of
Clamp -> sz - 1
Mirror -> sz - (iz - sz + 2)
Wrap -> iz - sz
_ -> internalError "unexpected boundary condition"
| otherwise = iz
-- toSeqOp :: forall slix sl dim co e proxy. (Elt slix, Shape sl, Shape dim, Elt e)
-- => SliceIndex (EltRepr slix)
-- (EltRepr sl)
-- co
-- (EltRepr dim)
-- -> proxy slix
-- -> Array dim e
-- -> [Array sl e]
-- toSeqOp sliceIndex _ arr = map (sliceOp sliceIndex arr :: slix -> Array sl e)
-- (enumSlices sliceIndex (shape arr))
-- Stencil boundary conditions
-- ---------------------------
data Boundary t where
Clamp :: Boundary t
Mirror :: Boundary t
Wrap :: Boundary t
Constant :: t -> Boundary (Array sh t)
Function :: (sh -> e) -> Boundary (Array sh e)
evalBoundary :: HasCallStack => AST.Boundary aenv t -> Val aenv -> Boundary t
evalBoundary bnd aenv =
case bnd of
AST.Clamp -> Clamp
AST.Mirror -> Mirror
AST.Wrap -> Wrap
AST.Constant v -> Constant v
AST.Function f -> Function (evalFun f aenv)
atraceOp :: Message as -> as -> IO ()
atraceOp (Message show _ msg) as =
let str = show as
in do
if null str
then T.hPutStrLn stderr msg
else hprint stderr (stext % ": " % string % "\n") msg str
hFlush stderr
-- Scalar expression evaluation
-- ----------------------------
-- Evaluate a closed scalar expression
--
evalExp :: HasCallStack => Exp aenv t -> Val aenv -> t
evalExp e aenv = evalOpenExp e Empty aenv
-- Evaluate a closed scalar function
--
evalFun :: HasCallStack => Fun aenv t -> Val aenv -> t
evalFun f aenv = evalOpenFun f Empty aenv
-- Evaluate an open scalar function
--
evalOpenFun :: HasCallStack => OpenFun env aenv t -> Val env -> Val aenv -> t
evalOpenFun (Body e) env aenv = evalOpenExp e env aenv
evalOpenFun (Lam lhs f) env aenv =
\x -> evalOpenFun f (env `push` (lhs, x)) aenv
-- Evaluate an open scalar expression
--
-- NB: The implementation of 'Index' and 'Shape' demonstrate clearly why
-- array expressions must be hoisted out of scalar expressions before code
-- execution. If these operations are in the body of a function that gets
-- mapped over an array, the array argument would be evaluated many times
-- leading to a large amount of wasteful recomputation.
--
evalOpenExp
:: forall env aenv t. HasCallStack
=> OpenExp env aenv t
-> Val env
-> Val aenv
-> t
evalOpenExp pexp env aenv =
let
evalE :: OpenExp env aenv t' -> t'
evalE e = evalOpenExp e env aenv
evalF :: OpenFun env aenv f' -> f'
evalF f = evalOpenFun f env aenv
evalA :: ArrayVar aenv a -> WithReprs a
evalA (Var repr ix) = (TupRsingle repr, prj ix aenv)
in
case pexp of
Let lhs exp1 exp2 -> let !v1 = evalE exp1
env' = env `push` (lhs, v1)
in evalOpenExp exp2 env' aenv
Evar (Var _ ix) -> prj ix env
Const _ c -> c
Undef tp -> undefElt (TupRsingle tp)
PrimConst c -> evalPrimConst c
PrimApp f x -> evalPrim f (evalE x)
Nil -> ()
Pair e1 e2 -> let !x1 = evalE e1
!x2 = evalE e2
in (x1, x2)
VecPack vecR e -> pack vecR $! evalE e
VecUnpack vecR e -> unpack vecR $! evalE e
IndexSlice slice slix sh -> restrict slice (evalE slix)
(evalE sh)
where
restrict :: SliceIndex slix sl co sh -> slix -> sh -> sl
restrict SliceNil () () = ()
restrict (SliceAll sliceIdx) (slx, ()) (sl, sz) =
let sl' = restrict sliceIdx slx sl
in (sl', sz)
restrict (SliceFixed sliceIdx) (slx, _i) (sl, _sz) =
restrict sliceIdx slx sl
IndexFull slice slix sh -> extend slice (evalE slix)
(evalE sh)
where
extend :: SliceIndex slix sl co sh -> slix -> sl -> sh
extend SliceNil () () = ()
extend (SliceAll sliceIdx) (slx, ()) (sl, sz) =
let sh' = extend sliceIdx slx sl
in (sh', sz)
extend (SliceFixed sliceIdx) (slx, sz) sl =
let sh' = extend sliceIdx slx sl
in (sh', sz)
ToIndex shr sh ix -> toIndex shr (evalE sh) (evalE ix)
FromIndex shr sh ix -> fromIndex shr (evalE sh) (evalE ix)
Case e rhs def -> evalE (caseof (evalE e) rhs)
where
caseof :: TAG -> [(TAG, OpenExp env aenv t)] -> OpenExp env aenv t
caseof tag = go
where
go ((t,c):cs)
| tag == t = c
| otherwise = go cs
go []
| Just d <- def = d
| otherwise = internalError "unmatched case"
Cond c t e
| toBool (evalE c) -> evalE t
| otherwise -> evalE e
While cond body seed -> go (evalE seed)
where
f = evalF body
p = evalF cond
go !x
| toBool (p x) = go (f x)
| otherwise = x
Index acc ix -> let (TupRsingle repr, a) = evalA acc
in (repr, a) ! evalE ix
LinearIndex acc i -> let (TupRsingle repr, a) = evalA acc
ix = fromIndex (arrayRshape repr) (shape a) (evalE i)
in (repr, a) ! ix
Shape acc -> shape $ snd $ evalA acc
ShapeSize shr sh -> size shr (evalE sh)