-
Notifications
You must be signed in to change notification settings - Fork 73
Expand file tree
/
Copy pathdeterminestrategy.jl
More file actions
1735 lines (1693 loc) · 56.6 KB
/
determinestrategy.jl
File metadata and controls
1735 lines (1693 loc) · 56.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
function check_linear_parents(ls::LoopSet, op::Operation, s::Symbol)
(s ∈ loopdependencies(op)) || return true
if isload(op) # TODO: handle loading from ranges.
return false
elseif !iscompute(op)
return true
end
Base.sym_in(
instruction(op).instr,
(
:vadd_nsw,
:vadd_nuw,
:vadd_nw,
:vsub_nsw,
:vsub_nuw,
:vsub_nw,
:(+),
:vadd,
:add_fast,
:(-),
:vsub,
:sub_fast,
),
) || return false
for opp ∈ parents(op)
check_linear_parents(ls, opp, s) || return false
end
true
end
function findparent(ls::LoopSet, s::Symbol)#opdict isn't filled when reconstructing
id = findfirst(Base.Fix2(===, s) ∘ name, operations(ls))
id === nothing && throw("$s not found")
operations(ls)[id]
end
function unitstride(ls::LoopSet, op::Operation, s::Symbol)
inds = getindices(op)
li = op.ref.loopedindex
# The first index is allowed to be indexed by `s`
fi = first(inds)
if ((fi === DISCONTIGUOUS) | (fi === CONSTANTZEROINDEX)) ||
(first(getstrides(op)) ≠ 1) ||
!unitstep(getloop(ls, s))
return false
# elseif !first(li)
# # We must check if this
# parent = findparent(ls, fi)
# indexappearences(parent, s) > 1 && return false
end
if length(li) > 0 && !first(li)
parent = findparent(ls, first(inds))
check_linear_parents(ls, parent, s) || return false
end
for i ∈ 2:length(inds)
if li[i]
s === inds[i] && return false
else
parent = findparent(ls, inds[i])
s ∈ loopdependencies(parent) && return false
end
end
true
end
function cannot_shuffle(op::Operation, u₁::Symbol, u₂::Symbol, contigind::Symbol, indices) # assumes isvectorized and !unitstride
!((
!rejectcurly(op) && (
(
(contigind === CONSTANTZEROINDEX) &&
((length(indices) > 1) && (indices[2] === u₁) || (indices[2] === u₂))
) || ((u₁ === contigind) | (u₂ === contigind))
)
))
end
function cost(
ls::LoopSet,
op::Operation,
(u₁, u₂)::Tuple{Symbol,Symbol},
vloopsym::Symbol,
Wshift::Int,
size_T::Int = op.elementbytes,
)
isconstant(op) && return 0.0, 0, 1.0#Float64(length(loopdependencies(op)) > 0)
isloopvalue(op) && return 0.0, 0, 0.0
instr = instruction(op)
if length(parents(op)) == 1
if instr == Instruction(:-) ||
instr === Instruction(:sub_fast) ||
instr == Instruction(:+) ||
instr == Instruction(:add_fast)
return 0.0, 0, 0.0
end
elseif iscompute(op) && (
Base.sym_in(
instruction(op).instr,
(:vadd_nsw, :vsub_nsw, :(+), :(-), :add_fast, :sub_fast),
) && all(opp -> (isloopvalue(opp)), parents(op))
)# || (reg_count(ls) == 32) && (instruction(op).instr === :ifelse))
# all(opp -> (isloopvalue(opp) | isconstant(opp)), parents(op))
return 0.0, 0, 0.0
end
opisvectorized = isvectorized(op)
srt, sl, srp = opisvectorized ? vector_cost(instr, Wshift, size_T) : scalar_cost(instr)
if accesses_memory(op)
# either vbroadcast/reductionstore, vmov(a/u)pd, or gather/scatter
if opisvectorized
if !unitstride(ls, op, vloopsym)# || !isdense(op) # need gather/scatter
indices = getindices(op)
contigind = first(indices)
shifter = max(2, Wshift)
# rejectinterleave false means omop
# cannot shuffle false means reject curly
# either false means shuffle
dont_shuffle =
(Wshift > 3) ||
(rejectinterleave(op) && (cannot_shuffle(op, u₁, u₂, contigind, indices)))
if dont_shuffle
# offset = 0.0 # gather/scatter, alignment doesn't matter
r = 1 << shifter
srt = srt * r# + offset
sl *= r
else#if rejectinterleave(op) # means omop
if isload(op) & (length(loopdependencies(op)) > 1)# vmov(a/u)pd
srt += 0.5reg_size(ls) / cache_lnsze(ls)
end
srt += shifter # shifter == number of shuffles
sl += shifter
end
elseif isload(op) & (length(loopdependencies(op)) > 1)# vmov(a/u)pd
# penalize vectorized loads with more than 1 loopdep
# heuristic; more than 1 loopdep means that many loads will not be aligned
# Roughly corresponds to double-counting loads crossing cacheline boundaries
# TODO: apparently the new ARM A64FX CPU (with 512 bit vectors) is NOT penalized for unaligned loads
# would be nice to add a check for this CPU, to see if such a penalty is still appropriate.
# Also, once more SVE (scalable vector extension) CPUs are released, would be nice to know if
# this feature is common to all of them.
srt += 0.5reg_size(ls) / cache_lnsze(ls)
# srt += 0.25reg_size(ls) / cache_lnsze(ls)
end
elseif isstore(op)# && isvectorized(first(parents(op))) # broadcast or reductionstore; if store we want to penalize reduction
srt *= 3
sl *= 3
end
end
srt, sl, Float64(srp + 1)
end
# Base._return_type()
function biggest_type_size(ls::LoopSet)
maximum(elsize, operations(ls))
end
function hasintersection(a, b)
for aᵢ ∈ a, bᵢ ∈ b
aᵢ === bᵢ && return true
end
false
end
const num_iterations = cld
function set_vector_width!(ls::LoopSet, vloopsym::Symbol)
W = ls.vector_width
if !iszero(W)
ls.vector_width = min(W, VectorizationBase.nextpow2(length(ls, vloopsym)))
end
nothing
end
function lsvecwidthshift(ls::LoopSet, vloopsym::Symbol, size_T = nothing)
W = ls.vector_width
lvec = length(ls, vloopsym)
W = if iszero(W)
bytes = size_T === nothing ? biggest_type_size(ls) : size_T
reg_size(ls) ÷ bytes
else
min(W, VectorizationBase.nextpow2(lvec))
end
W, VectorizationBase.intlog2(W)
end
# evaluates cost of evaluating loop in given order
function evaluate_cost_unroll(
ls::LoopSet,
order::Vector{Symbol},
vloopsym::Symbol,
max_cost::Float64 = typemax(Float64),
sld::Vector{Vector{Symbol}} = store_load_deps(operations(ls)),
)
included_vars = fill!(resize!(ls.included_vars, length(operations(ls))), false)
nested_loop_syms = Symbol[]#Set{Symbol}()
total_cost = 0.0
iter = 1.0
size_T = biggest_type_size(ls)
W, Wshift = lsvecwidthshift(ls, vloopsym, size_T)
# Need to check if fusion is possible
for itersym ∈ order
cacheunrolled!(ls, itersym, Symbol(""), vloopsym)
# Add to set of defined symbles
push!(nested_loop_syms, itersym)
looplength = length(ls, itersym)
liter = itersym === vloopsym ? num_iterations(looplength, W) : looplength
iter *= liter
# check which vars we can define at this level of loop nest
for (id, op) ∈ enumerate(operations(ls))
# won't define if already defined...
# id = identifier(op)
included_vars[id] && continue
# it must also be a subset of defined symbols
loopdependencies(op) ⊆ nested_loop_syms || continue
# hasintersection(reduceddependencies(op), nested_loop_syms) && return Inf
(isassigned(sld, id) && any(s -> (s ∉ sld[id]), nested_loop_syms)) && return Inf
included_vars[id] = true
# TODO: use actual unrolls here?
c = first(cost(ls, op, (Symbol(""), Symbol("")), vloopsym, Wshift, size_T))
total_cost += iter * c
0.9total_cost > max_cost && return total_cost # abort if more expensive; we only want to know the cheapest
end
end
0.9total_cost + stride_penalty(ls, order) # 0.999 to place finger on scale in its favor
end
# only covers vectorized ops; everything else considered lifted?
function depchain_cost!(
ls::LoopSet,
skip::Vector{Bool},
op::Operation,
unrolled::Symbol,
vloopsym::Symbol,
Wshift::Int,
size_T::Int,
rt::Float64 = 0.0,
sl::Int = 0,
)
skip[identifier(op)] = true
# depth first search
for opp ∈ parents(op)
skip[identifier(opp)] && continue
rt, sl = depchain_cost!(ls, skip, opp, unrolled, vloopsym, Wshift, size_T, rt, sl)
end
# Basically assuming memory and compute don't conflict, but everything else does
# Ie, ignoring the fact that integer and floating point operations likely don't either
if iscompute(op)
rtᵢ, slᵢ = cost(ls, op, (unrolled, Symbol("")), vloopsym, Wshift, size_T)
rt += rtᵢ
sl += slᵢ
end
rt, sl
end
function parentsnotreduction(op::Operation)
for opp ∈ parents(op)
isreduction(opp) && return false
end
return true
end
# function roundpow2(i::Integer)
# u = VectorizationBase.nextpow2(i)
# l = u >>> 1
# ud = u - i
# ld = i - l
# ud > ld ? l : u
# end
# function roundpow2(x::Float64)
# 1 << round(Int, log2(x))
# end
function unroll_no_reductions(ls, order, vloopsym)
size_T = biggest_type_size(ls)
W, Wshift = lsvecwidthshift(ls, vloopsym, size_T)
compute_rt = load_rt = store_rt = 0.0
unrolled = last(order)
i = 0
while reject_reorder(ls, unrolled, false)
i += 1
unrolled = order[end-i]
end
if unrolled === vloopsym && length(order) > 1
unrolled_candidate = order[end-1]
if !reject_reorder(ls, unrolled_candidate, false)
unrolled = unrolled_candidate
end
end
# latency not a concern, because no depchains
compute_l = 0.0
rpp = 0 # register pressure proportional to unrolling
rpc = 0 # register pressure independent of unroll factor
for op ∈ operations(ls)
isu₁unrolled(op) || continue
rt, sl, rpop = cost(ls, op, (unrolled, Symbol("")), vloopsym, Wshift, size_T)
if iscompute(op)
compute_rt += rt
compute_l += sl
rpc += max(zero(rpop), rpop - one(rpop)) # constant loads for special functions reused with unrolling
elseif isload(op)
load_rt += rt
rpp += rpop # loads are proportional to unrolling
elseif isstore(op)
store_rt += rt
end
end
# heuristic guess
# roundpow2(min(4, round(Int, (compute_rt + load_rt + 1) / compute_rt)))
memory_rt = load_rt + store_rt
u = if compute_rt ≤ 1
4
elseif compute_rt > memory_rt
# if compute_rt > 40
# max(VectorizationBase.nextpow2( min( 4, round(Int, compute_rt / memory_rt) ) ), 1)
# else
clamp(round(Int, compute_l / compute_rt), 1, Core.ifelse(compute_rt > 80, 2, 4))
# end
elseif iszero(load_rt)
iszero(store_rt) ? 4 : max(1, min(4, round(Int, 2compute_rt / store_rt)))
else
max(1, min(4, round(Int, 1.75compute_rt / load_rt)))
end
# u = min(u, max(1, (reg_count(ls) ÷ max(1,round(Int,rp)))))
# commented out here is to decide to align loops
# if memory_rt > compute_rt && isone(u) && (length(order) > 1) && (last(order) === vloopsym) && length(getloop(ls, last(order))) > 8W
# ls.align_loops[] = findfirst(operations(ls)) do op
# isstore(op) && isu₁unrolled(op)
# end
# end
# if unrolled === vloopsym
# u = demote_unroll_factor(ls, u, vloopsym)
# end
remaining_reg = max(8, (reg_count(ls) - round(Int, rpc))) # spilling a few consts isn't so bad
if compute_l ≥ 4compute_rt ≥ 4rpp
# motivation for skipping division by loads here: https://github.com/microhh/stencilbuilder/blob/master/julia/stencil_julia_4th.jl
# Some values:
# (load_rt, store_rt, compute_rt, compute_l, u, rpc, rpp) = (52.0, 3.0, 92.0, 736.0, 4, 0.0, 52.0)
# This is fastest when `u = 4`, but `reg_constraint` was restricting it to 1. ## later benchmarks were faster with u = 2?
# Obviously, this limitation on number of registers didn't seem so important in practice.
# So, heuristically I check if compute latency dominates the problem, in which case unrolling could be expected to benefit us.
# Ideally, we'd count the number of loads that actually have to be live at a given time. But this heuristic is hopefully okay for now.
reg_constraint = max(1, remaining_reg)
else
reg_constraint = max(1, remaining_reg ÷ max(1, round(Int, rpp)))
end
maybe_demote_unroll(ls, clamp(u, 1, reg_constraint), unrolled, vloopsym), unrolled
# rt = max(compute_rt, load_rt + store_rt)
# # (iszero(rt) ? 4 : max(1, roundpow2( min( 4, round(Int, 16 / rt) ) ))), unrolled
# (iszero(rt) ? 4 : max(1, VectorizationBase.nextpow2( min( 4, round(Int, 8 / rt) ) ))), unrolled
end
function determine_unroll_factor(
ls::LoopSet,
order::Vector{Symbol},
unrolled::Symbol,
vloopsym::Symbol,
)
cacheunrolled!(ls, unrolled, Symbol(""), vloopsym)
size_T = biggest_type_size(ls)
W, Wshift = lsvecwidthshift(ls, vloopsym, size_T)
# So if num_reductions > 0, we set the unroll factor to be high enough so that the CPU can be kept busy
# if there are, U = max(1, round(Int, max(latency) * throughput / num_reductions)) = max(1, round(Int, latency / (recip_throughput * num_reductions)))
# We also make sure register pressure is not too high.
latency = 1.0
# compute_recip_throughput_u = 0.0
compute_recip_throughput = 0.0
visited_nodes = fill(false, length(operations(ls)))
load_recip_throughput = 0.0
store_recip_throughput = 0.0
for op ∈ operations(ls)
if isreduction(op)
rt, sl = depchain_cost!(ls, visited_nodes, op, unrolled, vloopsym, Wshift, size_T)
if isouterreduction(ls, op) ≠ -1 || unrolled ∉ reduceddependencies(op)
latency = max(sl, latency)
end
# if unrolled ∈ loopdependencies(op)
# compute_recip_throughput_u += rt
# else
compute_recip_throughput += rt
# end
elseif isload(op)
load_recip_throughput +=
first(cost(ls, op, (unrolled, Symbol("")), vloopsym, Wshift, size_T))
elseif isstore(op)
store_recip_throughput +=
first(cost(ls, op, (unrolled, Symbol("")), vloopsym, Wshift, size_T))
end
end
recip_throughput =
max(compute_recip_throughput, load_recip_throughput, store_recip_throughput)
recip_throughput, latency
end
function count_reductions(ls::LoopSet)
num_reductions = 0
for op ∈ operations(ls)
if isreduction(op) & iscompute(op) && parentsnotreduction(op)
num_reductions += 1
end
end
num_reductions
end
demote_unroll_factor(ls::LoopSet, UF, loop::Symbol) =
demote_unroll_factor(ls, UF, getloop(ls, loop))
function demote_unroll_factor(ls::LoopSet, UF, loop::Loop)
W = ls.vector_width
if !iszero(W) && isstaticloop(loop)
UFW = maybedemotesize(UF * W, length(loop))
UF = cld(UFW, W)
end
UF
end
function determine_unroll_factor(ls::LoopSet, order::Vector{Symbol}, vloopsym::Symbol)
num_reductions = count_reductions(ls)
# The strategy is to use an unroll factor of 1, unless there appears to be loop carried dependencies (ie, num_reductions > 0)
# The assumption here is that unrolling provides no real benefit, unless it is needed to enable OOO execution by breaking up these dependency chains
loopindexesbit = ls.loopindexesbit
if iszero(length(loopindexesbit)) || ((!loopindexesbit[getloopid(ls, vloopsym)]))
if iszero(num_reductions)
return unroll_no_reductions(ls, order, vloopsym)
else
return determine_unroll_factor(ls, order, vloopsym, num_reductions)
end
elseif iszero(num_reductions) # handle `BitArray` loops w/out reductions
return 8 ÷ ls.vector_width, vloopsym
else # handle `BitArray` loops with reductions
rttemp, ltemp = determine_unroll_factor(ls, order, vloopsym, vloopsym)
UF = min(8, VectorizationBase.nextpow2(max(1, round(Int, ltemp / (rttemp)))))
UFfactor = 8 ÷ ls.vector_width
cld(UF, UFfactor) * UFfactor, vloopsym
# UF2 = cld(UF, UFfactor)*UFfactor, vloopsym
# maybe_demote_unroll(ls, UF2, vloopsym, vloopsym), vloopsym
end
end
# function scale_unrolled()
# end
function determine_unroll_factor(
ls::LoopSet,
order::Vector{Symbol},
vloopsym::Symbol,
num_reductions::Int,
)
innermost_loop = last(order)
rt = Inf
rtcomp = Inf
latency = Inf
best_unrolled = Symbol("")
for unrolled ∈ order
reject_reorder(ls, unrolled, false) && continue
rttemp, ltemp = determine_unroll_factor(ls, order, unrolled, vloopsym)
rtcomptemp =
rttemp + (0.01 * ((vloopsym === unrolled) + (unrolled === innermost_loop) - latency))
if rtcomptemp < rtcomp
rt = rttemp
rtcomp = rtcomptemp
latency = ltemp
best_unrolled = unrolled
end
end
# min(8, roundpow2(max(1, round(Int, latency / (rt * num_reductions) ) ))), best_unrolled
lrtratio = latency / rt
if lrtratio ≥ 7.0
UF = 8
else
UF = VectorizationBase.nextpow2(round(Int, clamp(lrtratio, 1.0, 4.0), RoundUp))
end
UF = maybe_demote_unroll(ls, UF, best_unrolled, vloopsym)
UF, best_unrolled
end
function maybe_demote_unroll(ls::LoopSet, UF::Int, unrollsym::Symbol, vloopsym::Symbol)::Int
if unrollsym === vloopsym
return demote_unroll_factor(ls, UF, vloopsym)
else
ul = getloop(ls, unrollsym)
isstaticloop(ul) ? min(length(ul), UF) : UF
end
end
@inline function unroll_cost(X, u₁, u₂, u₁L, u₂L)
u₂factor = (num_iterations(u₂L, u₂) / u₂L)
u₁factor = (num_iterations(u₁L, u₁) / u₁L)
# X[1]*u₂factor*u₁factor + X[4] + X[2] * u₂factor + X[3] * u₁factor
X[1] + X[2] * u₂factor + X[3] * u₁factor + X[4] * u₁factor * u₂factor
end
function solve_unroll_iter(X, R, u₁L, u₂L, u₁range, u₂range)
R₁, R₂, R₃, R₄ = R[1], R[2], R[3], R[4]
RR = R₄
u₁best, u₂best = 0, 0
bestcost = Inf
for u₁temp ∈ u₁range
for u₂temp ∈ u₂range
RR ≥ u₁temp * u₂temp * R₁ + u₁temp * R₂ + u₂temp * R₃ || continue
tempcost = unroll_cost(X, u₁temp, u₂temp, u₁L, u₂L)
if tempcost ≤ bestcost
bestcost = tempcost
u₁best, u₂best = u₁temp, u₂temp
end
end
end
u₁best, u₂best, bestcost
end
function solve_unroll_lagrange(
X,
R,
u₁L,
u₂L,
u₁step::Int,
u₂step::Int,
atleast31registers::Bool,
)
X₁, X₂, X₃, X₄ = X[1], X[2], X[3], X[4]
# If we don't have opmask registers, masks probably occupy a vector register (e.g., on CPUs with AVX but not AVX512)
R₁, R₂, R₃, R₄ = R[1], R[2], R[3], R[4]
iszero(R₃) || return solve_unroll_iter(X, R, u₁L, u₂L, u₁step:u₁step:10, u₂step:u₂step:10)
RR = R₄
a = R₂^2 * X₃ - R₁ * X₄ * R₂ - R₁ * X₂ * RR
b = R₁ * X₄ * RR - R₁ * X₄ * RR - 2X₃ * RR * R₂
c = X₃ * RR^2
discriminant = b^2 - 4a * c
discriminant < 0 && return -1, -1, Inf
u₁float = max((sqrt(discriminant) + b) / (-2a), float(u₁step)) # must be at least 1
u₂float = (RR - u₁float * R₂) / (u₁float * R₁)
u₁float_finite = isfinite(u₁float)
u₂float_finite = isfinite(u₂float)
if !(u₁float_finite & u₂float_finite) # brute force
u₁high = Core.ifelse(iszero(X₃), u₁step, Core.ifelse(atleast31registers, 8, 6))
u₂high = Core.ifelse(iszero(X₂), u₂step, Core.ifelse(atleast31registers, 8, 6))
return solve_unroll_iter(X, R, u₁L, u₂L, u₁step:u₁step:u₁high, u₂step:u₂step:u₂high)
end
u₁low = floor(Int, u₁float)
u₂low = max(u₂step, floor(Int, 0.8u₂float)) # must be at least 1
u₁high = solve_unroll_constT(R, u₂low) + u₁step
u₂high = solve_unroll_constU(R, u₁low) + u₂step
if u₁low ≥ u₁high
u₁low = solve_unroll_constT(R, u₂high)
end
if u₂low ≥ u₂high
u₂low = solve_unroll_constU(R, u₁high)
end
maxunroll = atleast31registers ? (((X₂ > 0) & (X₃ > 0)) ? 10 : 8) : 6
u₁low = (clamp(u₁low, u₁step, maxunroll) ÷ u₁step) * u₁step
u₂low = (clamp(u₂low, u₂step, maxunroll) ÷ u₂step) * u₂step
u₁high = clamp(u₁high, 1, maxunroll)
u₂high = clamp(u₂high, 1, maxunroll)
solve_unroll_iter(
X,
R,
u₁L,
u₂L,
reverse(u₁low:u₁step:u₁high),
reverse(u₂low:u₂step:u₂high),
)
end
function solve_unroll_constU(R::AbstractVector, u₁::Int)
denom = u₁ * R[1] + R[3]
iszero(denom) && return 8
floor(Int, (R[4] - u₁ * R[2]) / denom)
end
function solve_unroll_constT(R::AbstractVector, u₂::Int)
denom = u₂ * R[1] + R[2]
iszero(denom) && return 8
floor(Int, (R[4] - u₂ * R[3]) / denom)
end
# function solve_unroll_constT(ls::LoopSet, u₂::Int)
# R = @view ls.reg_pres[:,1]
# denom = u₂ * R[1] + R[2]
# iszero(denom) && return 8
# floor(Int, (dynamic_register_count() - R[3] - R[4] - u₂*R[5]) / (u₂ * R[1] + R[2]))
# end
# Tiling here is about alleviating register pressure for the UxT
function solve_unroll(
X,
R,
u₁max,
u₂max,
u₁L,
u₂L,
u₁step,
u₂step,
atleast31registers::Bool,
)
# iszero(first(R)) && return -1,-1,Inf #solve_smalltilesize(X, R, u₁max, u₂max)
u₁, u₂, cost = solve_unroll_lagrange(X, R, u₁L, u₂L, u₁step, u₂step, atleast31registers)
# u₂ -= u₂ & 1
# u₁ = min(u₁, u₂)
u₁_too_large = u₁ > u₁max
u₂_too_large = u₂ > u₂max
if u₁_too_large
u₁ = u₁max
if u₂_too_large
u₂ = u₂max
else # u₁ too large, resolve u₂
u₂ = min(u₂max, max(1, solve_unroll_constU(R, u₁)))
end
cost = unroll_cost(X, u₁, u₂, u₁L, u₂L)
elseif u₂_too_large
u₂ = u₂max
u₁ = min(u₁max, max(1, solve_unroll_constT(R, u₂)))
cost = unroll_cost(X, u₁, u₂, u₁L, u₂L)
end
u₁, u₂, cost
end
function maybedemotesize(U::Int, N::Int)
num_iterations(N, num_iterations(N, U))
end
function maybedemotesize(u₂::Int, N::Int, U::Int, Uloop::Loop, maxu₂base::Int)
u₂ > 1 || return 1
u₂ == N && return u₂
u₂ = maybedemotesize(u₂, N)
if !(isstaticloop(Uloop) && length(Uloop) == U)
if N % u₂ != 0
u₂ = min(u₂, maxu₂base)
end
end
u₂
end
function solve_unroll(
ls::LoopSet,
u₁loopsym::Symbol,
u₂loopsym::Symbol,
cost_vec::AbstractVector{Float64},
reg_pressure::AbstractVector{Float64},
W::Int,
vloopsym::Symbol,
rounduᵢ::Int,
)
(u₁step, u₂step) = if rounduᵢ == 1 # max is to safeguard against some weird arch I've never heard of.
(clamp(cache_lnsze(ls) ÷ reg_size(ls), 1, 4), 1)
elseif rounduᵢ == 2
(1, clamp(cache_lnsze(ls) ÷ reg_size(ls), 1, 4))
elseif rounduᵢ == -1
(8 ÷ ls.vector_width, 1)
elseif rounduᵢ == -2
(1, 8 ÷ ls.vector_width)
else
(1, 1)
end
u₁loop = getloop(ls, u₁loopsym)
u₂loop = getloop(ls, u₂loopsym)
solve_unroll(
u₁loopsym,
u₂loopsym,
cost_vec,
reg_pressure,
W,
vloopsym,
u₁loop,
u₂loop,
u₁step,
u₂step,
reg_count(ls) ≥ 31,
)
end
function solve_unroll(
u₁loopsym::Symbol,
u₂loopsym::Symbol,
cost_vec::AbstractVector{Float64},
reg_pressure::AbstractVector{Float64},
W::Int,
vloopsym::Symbol,
u₁loop::Loop,
u₂loop::Loop,
u₁step::Int,
u₂step::Int,
atleast31registers::Bool,
)
maxu₂base = maxu₁base = atleast31registers ? 10 : 6#8
maxu₂ = maxu₂base#8
maxu₁ = maxu₁base#8
u₁L = length(u₁loop)
u₂L = length(u₂loop)
if isstaticloop(u₂loop)
if u₂loopsym !== vloopsym && u₂L ≤ 4
if isstaticloop(u₁loop)
u₁ = max(solve_unroll_constT(reg_pressure, u₂L), 1)
u₁ = maybedemotesize(u₁, u₁loopsym === vloopsym ? cld(u₁L, W) : u₁L)
else
u₁ = clamp(solve_unroll_constT(reg_pressure, u₂L), 1, 8)
end
return u₁, u₂L, unroll_cost(cost_vec, u₁, u₂L, u₁L, u₂L)
end
u₂Ltemp = u₂loopsym === vloopsym ? cld(u₂L, W) : u₂L
maxu₂ = min(4maxu₂, u₂Ltemp)
end
if isstaticloop(u₁loop)
if u₁loopsym !== vloopsym && u₁L ≤ 4
if isstaticloop(u₂loop)
u₂ = max(solve_unroll_constU(reg_pressure, u₁L), 1)
u₂ = maybedemotesize(u₂, u₂loopsym === vloopsym ? cld(u₂L, W) : u₂L)
else
u₂ = clamp(solve_unroll_constU(reg_pressure, u₁L), 1, 8)
end
return u₁L, u₂, unroll_cost(cost_vec, u₁L, u₂, u₁L, u₂L)
end
u₁Ltemp = u₁loopsym === vloopsym ? cld(u₁L, W) : u₁L
maxu₁ = min(4maxu₁, u₁Ltemp)
end
if u₁loopsym === vloopsym
u₁Lf = u₁L / W
else
u₁Lf = Float64(u₁L)
end
if u₂loopsym === vloopsym
u₂Lf = u₂L / W
else
u₂Lf = Float64(u₂L)
end
u₁, u₂, cost = solve_unroll(
cost_vec,
reg_pressure,
maxu₁,
maxu₂,
u₁Lf,
u₂Lf,
u₁step,
u₂step,
atleast31registers,
)
# heuristic to more evenly divide small numbers of iterations
if isstaticloop(u₂loop)
u₂ = maybedemotesize(u₂, length(u₂loop), u₁, u₁loop, maxu₂base)
end
if isstaticloop(u₁loop)
u₁ = maybedemotesize(u₁, length(u₁loop), u₂, u₂loop, maxu₁base)
end
u₁, u₂, cost
end
function set_upstream_family!(adal::Vector{T}, op::Operation, val::T) where {T}
adal[identifier(op)] == val && return # must already have been set
adal[identifier(op)] = val
for opp ∈ parents(op)
set_upstream_family!(adal, opp, val)
end
end
function loopdepindices(ls::LoopSet, op::Operation)
loopdeps = loopdependencies(op.ref)
isdiscontig = first(loopdeps) === DISCONTIGUOUS
# isdiscontig = isdiscontiguous(op.ref)
loopedindex = op.ref.loopedindex
if !isdiscontig && all(loopedindex) && !(any(==(CONSTANTZEROINDEX), loopdeps))
return loopdeps
end
loopdepsret = Symbol[]
for i ∈ eachindex(loopedindex)
if loopedindex[i]
loopdeps[i+isdiscontig] === CONSTANTZEROINDEX ||
push!(loopdepsret, loopdeps[i+isdiscontig])
else
oploopdeps = loopdependencies(findparent(ls, loopdeps[i+isdiscontig]))
for ld ∈ oploopdeps
(ld ∉ loopdepsret) && push!(loopdepsret, ld)
end
end
end
loopdepsret
end
function stride_penalty(ls::LoopSet, op::Operation, order::Vector{Symbol}, loopfreqs)
loopdeps = loopdepindices(ls, op)
opstrides = Vector{Int}(undef, length(loopdeps))
# very minor stride assumption here, because we don't really want to base optimization decisions on it...
opstrides[1] =
1.0 +
(first(loopdependencies(op.ref)) === DISCONTIGUOUS) +
(first(loopdependencies(op.ref)) === CONSTANTZEROINDEX)
l = Float64(length(getloop(ls, first(loopdeps))))
for i ∈ 2:length(loopdeps)
looplength = length(getloop(ls, loopdeps[i-1]))
opstrides[i] = opstrides[i-1] * looplength
l *= looplength
# opstrides[i] = opstrides[i-1] * length(loops[i-1])
end
penalty = 0.0
for i ∈ eachindex(order)
id = findfirst(Base.Fix2(===, order[i]), loopdeps)
if !(id === nothing)
penalty += loopfreqs[i] * opstrides[id]
end
end
penalty * l
end
function stride_penalty(ls::LoopSet, order::Vector{Symbol})
stridepenaltydict = Dict{Symbol,Vector{Float64}}()
loopfreqs = Vector{Int}(undef, length(order))
loopfreqs[1] = 1
for i ∈ 2:length(order)
loopfreqs[i] = loopfreqs[i-1] * length(getloop(ls, order[i]))
end
for op ∈ operations(ls)
if accesses_memory(op)
v = get!(() -> Float64[], stridepenaltydict, op.ref.ref.array)
push!(v, stride_penalty(ls, op, order, loopfreqs))
end
end
if iszero(length(values(stridepenaltydict)))
0.0
else # 1 / 1024 = 0.0009765625
10.0sum(maximum, values(stridepenaltydict)) *
Base.power_by_squaring(0.0009765625, length(order))
end
end
function isuniqueinindices(ls::LoopSet, op::Operation, opp::Operation, i::Int)
ld = loopdependencies(opp)
inds = getindicesonly(op)
li = op.ref.loopedindex
for j ∈ eachindex(inds)
i == j && continue
if li[j]
inds[j] ∈ ld && return false
else
opp = findparent(ls, inds[i+(first(inds)===DISCONTIGUOUS)])
any(in(ld), loopdependencies(opp)) && return false
end
end
true
end
function isoptranslation(ls::LoopSet, op::Operation, unrollsyms::UnrollSymbols)
@unpack u₁loopsym, u₂loopsym, vloopsym = unrollsyms
(vloopsym == u₁loopsym || vloopsym == u₂loopsym) && return 0, 0x00
(isu₁unrolled(op) && isu₂unrolled(op)) || return 0, 0x00
u₁step = step(getloop(ls, u₁loopsym))
u₂step = step(getloop(ls, u₂loopsym))
(isknown(u₁step) & isknown(u₂step)) || return 0, 0x00
abs(gethint(u₁step)) == abs(gethint(u₂step)) || return 0, 0x00
istranslation = 0
inds = getindices(op)
li = op.ref.loopedindex
for i ∈ eachindex(li)
if !li[i]
opp = findparent(ls, inds[i+(first(inds)===DISCONTIGUOUS)])
if isu₁unrolled(opp) & isu₂unrolled(opp)
if Base.sym_in(instruction(opp).instr, (:vadd_nsw, :(+)))
isuniqueinindices(ls, op, opp, i) || return 0, 0x00
return i, 0x03 # 00000011 - both positive
elseif Base.sym_in(instruction(opp).instr, (:vsub_nsw, :(-)))
isuniqueinindices(ls, op, opp, i) || return 0, 0x00
oppp1 = parents(opp)[1]
if isu₁unrolled(oppp1)
return i, 0x01 # 00000001 - u₁ positive, u₂ negative
else#isu₂unrolled(oppp1)
return i, 0x02 # 00000010 - u₂ positive, u₁ negative
end
end
end
end
end
0, 0x00
end
function maxnegativeoffset(ls::LoopSet, op::Operation, u::Symbol)
mno::Int = typemin(Int)
id = 0
isknown(step(getloop(ls, u))) || return mno, id
omop = offsetloadcollection(ls)
collectionid, opind = omop.opidcollectionmap[identifier(op)]
collectionid == 0 && return mno, id
@unpack opids = omop
opidcol = opids[collectionid]
opid = identifier(op)
opoffs = getoffsets(op)
opstrd = getstrides(op)
ops = operations(ls)
opindices = getindicesonly(op)
for oppid ∈ opidcol
opid == oppid && continue
opp = ops[oppid]
oppoffs = getoffsets(opp)
oppstrd = getstrides(opp)
mnonew::Int = typemin(Int)
for i ∈ eachindex(opindices)
strd = opstrd[i]
strd == oppstrd[i] == 1 || continue
if opindices[i] === u
mnonew = (opoffs[i] % Int) - (oppoffs[i] % Int)
# mnonew_t, mnonew_rem = divrem((opoffs[i] % Int) - (oppoffs[i] % Int), strd % Int)
# mnonew_rem == 0 || continue
# mnonew = mnonew_t
elseif opoffs[i] != oppoffs[i]
mnonew = 1
break
end
end
if mno < mnonew < 0
mno = mnonew
id = identifier(opp)
end
end
mno, id
end
function maxnegativeoffset(ls::LoopSet, op::Operation, unrollsyms::UnrollSymbols)
@unpack u₁loopsym, u₂loopsym, vloopsym = unrollsyms
mno = typemin(Int)
i = 0
if u₁loopsym !== vloopsym
mnou₁ = first(maxnegativeoffset(ls, op, u₁loopsym))
if mnou₁ > mno
i = 1
mno = mnou₁
end
end
if u₂loopsym !== vloopsym
mnou₂ = first(maxnegativeoffset(ls, op, u₂loopsym))
if mnou₂ > mno
i = 2
mno = mnou₂
end
end
mno, i
end
function load_elimination_cost_factor!(
cost_vec,
reg_pressure,
choose_to_inline,
ls::LoopSet,
op::Operation,
iters,
unrollsyms::UnrollSymbols,
Wshift,
size_T,
)
@unpack u₁loopsym, u₂loopsym, vloopsym = unrollsyms
if !iszero(first(isoptranslation(ls, op, unrollsyms)))
rt, lat, rp = cost(ls, op, (u₁loopsym, u₂loopsym), vloopsym, Wshift, size_T)
# rt = Core.ifelse(isvectorized(op), 0.5rt, rt)
# rto = rt
rt *= iters
# rt *= factor1; rp *= factor2;
choose_to_inline[] = true
# for loop ∈ ls.loops
# # If another loop is short, assume that LLVM will unroll it, in which case
# # we want to be a little more conservative in terms of register pressure.
# #FIXME: heuristic hack to get some desired behavior.
# if isstaticloop(loop) && length(loop) ≤ 4
# itersym = loop.itersymbol
# if itersym !== u₁loopsym && itersym !== u₂loopsym
# return (0.25, dynamic_register_count() == 32 ? 2.0 : 1.0)
# # return (0.25, 1.0)
# return true
# end
# end
# end
# u₁c, u₂c = child_dependent_u₁u₂(op)
# rp = max(zero(rp), rp - one(rp))
# # (0.25, dynamic_register_count() == 32 ? 1.2 : 1.0)
# (0.25, 1.0)
# cost_vec[1] -= rt
# cost_vec[1] -= 0.5625 * iters
# cost_vec[1] -= 0.5625 * iters / 2
# reg_pressure[1] += 0.25rp
reg_pressure[1] += 0.25rp
cost_vec[2] += rt
reg_pressure[2] += rp
cost_vec[3] += rt
# currently only place `reg_pressure[3]` is updated
reg_pressure[3] += rp
true
else
(1.0, 1.0)
false
end
end
function loadintostore(ls::LoopSet, op::Operation)
isload(op) || return false # leads to bad behavior more than it helps
for opp ∈ operations(ls)
isstore(opp) && opp.ref == op.ref && return true
end
false
end
function store_load_deps!(deps::Vector{Symbol}, op::Operation, compref = op.ref)
for opp ∈ parents(op)
for ld ∈ loopdependencies(opp)
(ld ∈ deps) || push!(deps, ld)
end
for ld ∈ reduceddependencies(opp)
(ld ∈ deps) || push!(deps, ld)
end
if isload(opp)
(opp.ref == compref) && return true
else
store_load_deps!(deps, opp, compref) && return true
end
end
false
end
function store_load_deps(op::Operation)::Union{Nothing,Vector{Symbol}}
isstore(op) || return nothing
deps = copy(loopdependencies(op))
store_load_deps!(deps, op) ? deps : nothing
end
function store_load_deps(ops::Vector{Operation})
sld = Vector{Vector{Symbol}}(undef, length(ops))
for i ∈ eachindex(sld)
sldᵢ = store_load_deps(ops[i])
sldᵢ ≢ nothing && (sld[i] = sldᵢ)
end
sld
end
function add_constant_offset_load_elmination_cost!(
X,
R,
choose_to_inline,