forked from PSORLab/SourceCodeMcCormick.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathkernel_write.jl
More file actions
2043 lines (1886 loc) · 100 KB
/
kernel_write.jl
File metadata and controls
2043 lines (1886 loc) · 100 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
include(joinpath(@__DIR__, "math_kernels.jl"))
include(joinpath(@__DIR__, "string_math_kernels.jl"))
# The kernel-generating function, analogous to fgen.
kgen(num::Num; constants::Vector{Num}=Num[], overwrite::Bool=false, splitting::Symbol=:default, affine_quadratic::Bool=true) = kgen(num, setdiff(pull_vars(num), constants), [:all], constants, overwrite, splitting, affine_quadratic)
kgen(num::Num, gradlist::Vector{Num}; constants::Vector{Num}=Num[], overwrite::Bool=false, splitting::Symbol=:default, affine_quadratic::Bool=true) = kgen(num, gradlist, [:all], constants, overwrite, splitting, affine_quadratic)
kgen(num::Num, raw_outputs::Vector{Symbol}; constants::Vector{Num}=Num[], overwrite::Bool=false, splitting::Symbol=:default, affine_quadratic::Bool=true) = kgen(num, setdiff(pull_vars(num), constants), raw_outputs, constants, overwrite, splitting, affine_quadratic)
kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}; constants::Vector{Num}=Num[], overwrite::Bool=false, splitting::Symbol=:default, affine_quadratic::Bool=true) = kgen(num, gradlist, raw_outputs, constants, overwrite, splitting, affine_quadratic)
function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, constants::Vector{Num}, overwrite::Bool, splitting::Symbol, affine_quadratic::Bool)
# Create a hash of the expression and check if the function already exists
expr_hash = string(hash(string(num)*string(gradlist)), base=62)
if (overwrite==false) && (isfile(joinpath(@__DIR__, "storage", "f_"*expr_hash*".jl")))
try func_name = eval(Meta.parse("f_"*expr_hash))
return func_name
catch
include(joinpath(@__DIR__, "storage", "f_"*expr_hash*".jl"))
func_name = eval(Meta.parse("f_"*expr_hash))
@eval $(func_name)(CUDA.zeros(Float64, 2, 4+2*length($gradlist)), [CUDA.zeros(Float64, 2,3) for i = 1:length($gradlist)]...)
return func_name
end
end
# If we reach this comment, we're going to be creating/modifying the kernel file.
# Create/open the file in write mode.
file = open(joinpath(@__DIR__, "storage", "f_"*expr_hash*".jl"), "w")
write(file, "# Generated at $(Dates.now())\n\n")
write(file, "# Kernel(s) generated for the expression: $(string(num))\n\n")
close(file)
# Parse the list of requested outputs (Not currently useful, since all outputs are generated by default)
func_outputs = Symbol[]
for output in raw_outputs
output == :cv && push!(func_outputs, :cv)
output == :cc && push!(func_outputs, :cc)
output == :lo && push!(func_outputs, :lo)
output == :hi && push!(func_outputs, :hi)
output == :MC && push!(func_outputs, :cv, :cc, :lo, :hi)
output == :cvgrad && push!(func_outputs, :cvgrad)
output == :ccgrad && push!(func_outputs, :ccgrad)
output == :grad && push!(func_outputs, :cvgrad, :ccgrad)
output == :all && push!(func_outputs, :cv, :cc, :lo, :hi, :cvgrad, :ccgrad)
if ~(output in [:cv, :cc, :lo, :hi, :MC, :cvgrad, :ccgrad, :grad, :all])
error("Output list contains an invalid output symbol: :$output. Acceptable symbols
include [:cv, :cc, :lo, :hi, :MC, :cvgrad, :ccgrad, :grad, :all]")
end
end
if isempty(func_outputs)
error("No outputs specified.")
end
# Check the number of independent variables in the expression. Note that
# even if gradlist has 32+ elements, if the expression itself has fewer
# than 32 independent variables, it can be handled in a single kernel.
# Note also that constants are still required as inputs, so we can't
# exclude them from the list of participating variables
indep_vars = get_name.(pull_vars(num))
# Deal with quadratic functions differently, if `affine_quadratic` is true.
# This uses SCIP's method of handling nonconvex quadratic terms, which is explained
# in section 2.4.3.2 of:
# Vigerske, S. and Gleixner, A. "SCIP: global optimization of mixed-integer
# nonlinear programs in a branch-and-cut framework". Optimization Methods
# and Software, 33:3, 563-593 (2018). DOI: 10.1080/10556788.2017.1335312
# This method is also used in EAGO's `affine_relax_quadratic!` function.
if affine_quadratic==true && is_quadratic(num) # NOTE: When switching to MOI variables, this will be easy to detect
func_name = kgen_affine_quadratic(expr_hash, num, gradlist, func_outputs, constants)
return func_name
end
# Now we need to split the input expression, if it's too complicated
# to handle quickly in a single kernel. Complexity is determined by
# the complexity() function, which (slightly conservatively) estimates
# the number of lines the kernel will require. The two options for
# splitting based on complexity are:
# Default: Split at 1500 or greater, unless the total remaining is under 2k
# (Also, split if >31 variables needed)
# Conservative: Slower compilation, faster execution. Split at 10k or greater,
# unless total remaining is under 13k. (Also, split if >31
# variables needed)
# Begin by factoring and substituting the expression
factored = perform_substitutions(factor(num, split_div=true))
# Get the number of lines and variables associated with each factor
n_lines = complexity(factored)
n_vars = var_counts(factored)
# Prepare input and output information
kernel_nums = Int[]
inputs = Vector{String}[]
outputs = String[]
# Pick the split points based on the "splitting" input
if splitting==:low # Low, formerly none
split_point = Inf
max_size = Inf
elseif splitting==:default # Default, formerly low
split_point = 10000
max_size = 12000
elseif splitting==:high # Formerly default
split_point = 1500
max_size = 2000
elseif splitting==:max # Extremely small
split_point = 500
max_size = 750
else
error("Splitting must be one of: {:low, :default, :high, :max}")
end
# Pull out sparsity information in the factorization
sparsity = detect_sparsity(factored, gradlist)
# Decide if the kernel needs to be split
if (n_vars[end] < 31) && ((n_lines[end] <= max_size) || (findfirst(x -> x > split_point, n_lines)==length(n_lines)))
# Complexity is fairly low; only a single kernel needed
create_kernel!(expr_hash, 1, num, get_name.(gradlist), func_outputs, constants, factored, sparsity)
push!(kernel_nums, 1)
push!(inputs, string.(indep_vars))
push!(outputs, "OUT")
else
# Complexity is not low enough; need multiple kernels
complete = false
kernel_count = 1
# structure_list = String[] # Experimental
while !complete
# Determine which line to break at
line_ID = findfirst(x -> x > split_point, n_lines)
vars_ID = findfirst(x -> (x == 30) || (x == 31), n_vars)
if isnothing(vars_ID)
new_ID = line_ID
elseif isnothing(line_ID)
new_ID = vars_ID
else
new_ID = min(line_ID, vars_ID)
end
# =============================================================================
# =============================================================================
# EXPERIMENTAL:
# We want to make the element at `new_ID` into a kernel, unless the structure
# is exactly the same as a kernel we made previously.
# new_term = extract(factored, new_ID)
# new_structure, order = structure(new_term)
# @show new_structure
# # @show structure_list
# if new_structure in structure_list
# # We've already made this structure. Identify which kernel that was and use that
# # one instead of making a new kernel
# kernel_ID = findfirst(x -> x==new_structure, structure_list)
# push!(kernel_nums, kernel_ID)
# push!(inputs, string.(get_name.(order)))
# push!(outputs, string(factored[new_ID].lhs))
# println("Was in the list")
# else
# # Send the element at `new_ID` to create_kernel!()
# create_kernel!(expr_hash, kernel_count, new_term, get_name.(gradlist), func_outputs, constants)
# push!(structure_list, new_structure)
# push!(kernel_nums, kernel_count)
# push!(inputs, string.(get_name.(pull_vars(extract(factored, new_ID)))))
# push!(outputs, string(factored[new_ID].lhs))
# kernel_count += 1
# end
# =============================================================================
# =============================================================================
#### Start of alternative to experimental section
# Send the element at `new_ID` to create_kernel!()
create_kernel!(expr_hash, kernel_count, extract(factored, new_ID), get_name.(gradlist), func_outputs, constants, factored, sparsity)
push!(kernel_nums, kernel_count)
push!(inputs, string.(get_name.(pull_vars(extract(factored, new_ID)))))
push!(outputs, string(factored[new_ID].lhs))
kernel_count += 1
#### End of alternative to experimental section
# Eliminate this part of the factored list, since we've already calculated
# it from this kernel
factored[new_ID] = factored[new_ID].lhs ~ factored[new_ID].lhs
# Re-calculate line and variable counts
n_lines = complexity(factored)
n_vars = var_counts(factored)
# If the total number of lines (not including the final line) is below the max size
# and the number of variables is below 32, we can make the final kernel and be done
if (n_vars[end] < 32) && (all(n_lines[1:end-1] .<= max_size))
create_kernel!(expr_hash, kernel_count, extract(factored), get_name.(gradlist), func_outputs, constants, factored, sparsity)
push!(kernel_nums, kernel_count)
push!(inputs, string.(get_name.(pull_vars(extract(factored)))))
push!(outputs, "OUT")
complete = true
end
end
end
# Include all the kernels that were just generated
include(joinpath(@__DIR__, "storage", "f_"*expr_hash*".jl"))
# We can assume that the newly created kernels have sufficiently
# high register usage that they all have a max number of 256 threads.
# All that we need is to figure out the maximum number of blocks
# on the user's machine, and we have some decent parameters with
# which to call @cuda. This command checks the number of Streaming
# Multiprocessors (SMs) in the user's GPU.
blocks = Int32(CUDA.attribute(CUDA.device(), CUDA.DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT))
# Now we need to write the main CPU function that calls all of
# the generated kernels and append it to the file.
file = open(joinpath(@__DIR__, "storage", "f_"*expr_hash*".jl"), "a")
write(file, outro(expr_hash, kernel_nums, inputs, outputs, blocks, get_name.(gradlist)))
close(file)
# Compile the function and kernels
include(joinpath(@__DIR__, "storage", "f_"*expr_hash*".jl"))
func_name = eval(Meta.parse("f_"*expr_hash))
@eval $(func_name)(CUDA.zeros(Float64, 2, 4+2*length($gradlist)), [CUDA.zeros(Float64, 2, 3) for i = 1:length($gradlist)]...)
return func_name
end
# This is a quick function to detect if a Num object is quadratic. This function
# should be unnecessary when using MOI variables, since the object will already
# indicate if the term is quadratic.
function is_quadratic(num::Num)
expr = num.val
# The function is quadratic if it's composed of one simple term like:
# [const], [var], [coeff]*[var], [var]*[var], [coeff]*[var]*[var]
# or if it's an `Add` type, where each term is one of the above.
if typeof(expr) <: Real
# If the expr has some Real-valued type (Float64, Int64, etc.), treat
# it like it's quadratic
nothing
elseif base_term(expr)
# Easy enough to treat this as quadratic, since the normal process doesn't make sense
nothing
elseif exprtype(expr)==MUL
if length(expr.dict)==1
# One element in the dict. Check that the term is a base-level variable and that
# its exponent is 1 or 2
for (key, val) in expr.dict
if !base_term(key)
# Inner term is not a base-level variable, so it's not quadratic
return false
end
if !(val in [1, 2])
# Inner term's exponent is not 1 or 2, so it's not quadratic
return false
end
end
elseif length(expr.dict)==2
# Two elements in the dict. Check that both are base-level variables and that
# their exponents are exactly 1
for (key, val) in expr.dict
if !base_term(key)
# At least one term isn't a base-level variable, so it's not quadratic
return false
end
if !isone(val)
# At least one term has a non-one exponent, so it's not quadratic
return false
end
end
else
# More than 2 elements in the dict? Not quadratic.
return false
end
elseif exprtype(expr)==ADD
# The coefficient is irrelevant; we only need to explore the dictionary.
# The keys are important, but the values are irrelevant.
for key in keys(expr.dict)
if base_term(key)
# SYM is fine
nothing
elseif exprtype(key)==MUL
# It must be exactly [var]*[var], or it's not quadratic
if length(key.dict) != 2
return false
end
for (subkey, subval) in key.dict
if !base_term(subkey)
# Both terms must be base-level variables
return false
end
if !isone(subval)
# Both terms must have exponents of 1
return false
end
end
elseif exprtype(key)==POW
# It must be exactly [var]^2, or it's not quadratic
if !base_term(key.base)
return false
end
if key.exp != 2
return false
end
else
# If it's not one of the above, it's not quadratic
return false
end
end
else
return false # If it's not one of the above, it's not quadratic
end
return true
end
# A special version of kgen that only applies to quadratic functions. Instead of
# doing McCormick relaxations, this returns either affine bounds or secant line
# bounds, depending on where on the quadratic function the point of interest is.
function kgen_affine_quadratic(expr_hash::String, num::Num, gradlist::Vector{Num}, func_outputs::Vector{Symbol}, constants::Vector{Num})
# Since it's quadratic, we can construct the kernel according to
# `affine_relax_quadratic!` in EAGO.
# Extract the variables that are participating in this expression
expr = num.val
vars = get_name.(pull_vars(expr))
varlist = string.(get_name.(gradlist))
# Open the file again
file = open(joinpath(@__DIR__, "storage", "f_"*expr_hash*".jl"), "a")
# Put in the preamble.
if isempty(vars)
write(file, preamble_string(expr_hash, ["OUT";], 1, 1, length(gradlist)))
else
write(file, preamble_string(expr_hash, ["OUT"; string.(vars)], 1, 1, length(gradlist)))
end
# Depending on the format of the expression, compose the kernel differently
if typeof(expr) <: Real
# It's something like a Float64. We initialize with the value of the expr
write(file, SCMC_quadaff_initialize(expr))
# And that's pretty much it... Wrap up afterwards as normal
elseif base_term(expr)
# If it's a SYM type, we initialize with zero
write(file, SCMC_quadaff_initialize(0.0))
# We can now add to the temporary subgradient storages
# for this variable with a coefficient of 1.0
write(file, SCMC_quadaff_affine(string(get_name(expr)), 1.0, varlist))
elseif exprtype(expr)==MUL
# If it's a MUL type, we initialize with zero
write(file, SCMC_quadaff_initialize(0.0))
# For this to be quadratic, there's either 1 term with an exponent
# of 1 or 2, or there's 2 terms with exponents of 1. If it's 1 term
# with an exponent of 1, it's essentially affine
if length(expr.dict)==1
for (key, val) in expr.dict
if isone(val)
write(file, SCMC_quadaff_affine(string(get_name(key)), expr.coeff, varlist))
else # Must be something squared
write(file, SCMC_quadaff_squared(string(get_name(key)), expr.coeff, varlist))
end
end
else # There must be two elements in the dictionary
binary_vars = string.(get_name.(keys(expr.dict)))
binary_vars = binary_vars[sort_vars(binary_vars)]
write(file, SCMC_quadaff_binary(binary_vars..., expr.coeff, varlist))
end
elseif exprtype(expr)==ADD
# Final option. If it's ADD type, we initialize with the addition coefficient
write(file, SCMC_quadaff_initialize(expr.coeff))
# Now we must go through every term in ADD's dictionary and add it as appropriate
for (key, val) in expr.dict
if base_term(key)
write(file, SCMC_quadaff_affine(string(get_name(key)), val, varlist))
elseif exprtype(key)==MUL
binary_vars = string.(get_name.(keys(key.dict)))
binary_vars = binary_vars[sort_vars(binary_vars)]
write(file, SCMC_quadaff_binary(binary_vars..., val, varlist))
elseif exprtype(key)==POW
write(file, SCMC_quadaff_squared(string(get_name(key.base)), val, varlist))
else
error("How did you get here? Why is $key in a quadratic term?")
end
end
else
error("How did you get here? Somehow this is marked as quadratic: $num")
end
# Wrap up the kernel by calculating bounds and relaxations
# (NOTE: Technically, calculating the intercepts of the relaxations,
# which is what the kernel is doing normally, is more useful. ParBB
# will need to convert from relaxation values to intercepts anyway.
# EAGO already does this and bypasses the need to calculate relaxations.
# But, for compatibility with McCormick-style relaxations in ParBB,
# it's easier to simply calculate what ParBB is expecting.)
if isempty(varlist)
write(file, postamble_quadaff(String[], String[]))
elseif isempty(vars)
write(file, postamble_quadaff(String[], varlist))
else
write(file, postamble_quadaff(string.(vars), varlist))
end
close(file)
# Include this kernel so SCMC knows what it is
include(joinpath(@__DIR__, "storage", "f_"*expr_hash*".jl"))
# Add onto the file the "main" CPU function that calls the kernel
blocks = Int32(CUDA.attribute(CUDA.device(), CUDA.DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT))
file = open(joinpath(@__DIR__, "storage", "f_"*expr_hash*".jl"), "a")
if isempty(gradlist)
write(file, outro(expr_hash, [1], [String[]], ["OUT"], blocks, Symbol[]))
elseif isempty(vars)
write(file, outro(expr_hash, [1], [String[]], ["OUT"], blocks, get_name.(gradlist)))
else
write(file, outro(expr_hash, [1], [string.(vars)], ["OUT"], blocks, get_name.(gradlist)))
end
close(file)
# Include the file again to get the final kernel
include(joinpath(@__DIR__, "storage", "f_"*expr_hash*".jl"))
func_name = eval(Meta.parse("f_"*expr_hash))
@eval $(func_name)(CUDA.zeros(Float64, 2, 4+2*length($gradlist)), [CUDA.zeros(Float64, 2,4) for i = 1:length($gradlist)]...)
return func_name
end
# Note: Currently doesn't consider specific outputs or constants.
# Outputs can be modified by passing that through to the string
# kernels, which will then replace, e.g., "OUT" with "#OUT".
# Not sure yet how I want to handle constants.
# This function takes information about the file name, kernel ID, and
# the expression that a SINGLE kernel is being created for, and creates
# that kernel in the specified file.
create_kernel!(expr_hash::String, kernel_ID::Int, num::Num, gradlist::Vector{Symbol}, func_outputs::Vector{Symbol}, constants::Vector{Num}, orig_factored::Vector{Equation}, orig_sparsity::Vector{Vector{Int}}) = create_kernel!(expr_hash, kernel_ID, num.val, gradlist, func_outputs, constants, orig_factored, orig_sparsity)
function create_kernel!(expr_hash::String, kernel_ID::Int, num::BasicSymbolic{Real}, gradlist::Vector{Symbol}, func_outputs::Vector{Symbol}, constants::Vector{Num}, orig_factored::Vector{Equation}, orig_sparsity::Vector{Vector{Int}})
# This function will create a kernel for `num`, with the name:
# "f_" * expr_hash * "_$n". This name will be pushed to `kernels`,
# and a vector of the required inputs variables will be pushed to
# `inputs`. Other inputs are needed to determine portions of the
# kernel itself.
# Start by factorizing the input expression
factorized = factor(num, split_div=true)
# Perform substitutions if possible
factorized = perform_substitutions(factorized)
# Collect all the LHS terms and participating variables
LHS = string.(getfield.(factorized, :lhs))
vars = get_name.(pull_vars(num))
# Put the factorized expression into directed acyclic graph form
edgelist, varids = eqn_edges(factorized) #varids includes all aux variables also
g = SimpleDiGraph(edgelist)
# Perform a topological sort to get the order in which we should
# perform calculations (i.e., the final entry in "varorder" is the
# full original expression)
varorder = varids[topological_sort(g)]
# Now, we need sparsity information for all the variables. We can pull sparsity information
# normally, unless it's a temporary variable, in which case we have to refer to the original
# factorization and sparsity.
string_gradlist = string.(gradlist)
sparsity = Vector{Vector{Int}}(undef, length(varorder))
for i in eachindex(varorder)
if varorder[i] in string_gradlist
# Mark sparsity if the variable is already in gradlist
sparsity[i] = [findfirst(==(varorder[i]), string_gradlist)]
else
# Find out what index we're on
idx = findfirst(x -> isequal(string(x.lhs), varorder[i]), factorized)
if isnothing(idx)
sparsity[i] = orig_sparsity[findfirst(x -> isequal(string(x.lhs), varorder[i]), orig_factored)]
else
# Extract all the variables for this index
vars = pull_vars(extract(factorized, idx))
# For each variable in the expanded expression, add in sparsity information
curr_sparsity = Int[]
for var in vars
ID = findfirst(==(string(get_name(var))), string_gradlist)
if isnothing(ID)
# If we didn't find the variable, we need to scan the original factorization,
# and then pull sparsity info from the original sparsity list
ID = findfirst(x -> isequal(string(x.lhs), string(var)), orig_factored)
push!(curr_sparsity, orig_sparsity[ID]...)
else
# If we do find the variable, we can add this variable directly into the sparsity
push!(curr_sparsity, ID)
end
end
# Add a sorted, unique list to the sparsity tracker
sparsity[i] = sort(unique(curr_sparsity))
end
end
end
# Check if we need temporary variables at all. We don't need
# temporary variables if we only have addition, or if we have
# the addition of single-rule terms, since we can just keep adding
# new information to the existing output space. E.g.:
# x + y + z : No temporary variables needed
# 3x + x*y + z^2 : No temporary variables needed
# x*y + y*z + x*y*z : Temporary variable needed because x*y*z is two rules
# need_temps = depth_check(factorized)
# NOTE: Alternatively, saving data to global GPU memory will definitely
# be slower than saving it to a temporary variable and copying it
# to global memory only at the end. Though, that would mean more
# local storage for each thread, which would limit the number of
# threads per SM that could be used (and affect occupancy). So,
# in some cases it may be better to store data in temporary variables,
# and in other cases it might be better to store directly to the
# final output location. This may require some testing, and then
# perhaps a flag that overrides "need_temps" and the subtraction
# of 1 from the temp count later on.
# (Disabling entirely for now)
# Glossary:
# varids: ALL variables including base variables and aux variables
# varorder: A topologically sorted list of varids
# vars: Variables that participate in the original expression
# (i.e., NOT including ones produced through factorization)
# g.fadjlist[i]: Contains indices of varids that depend on varids[i]
# g.badjlist[i]: Contains indices of varids that are needed to compute varids[i]
# Calculate the number of temporary variables needed.
temp_endlist = []
maxtemp = 0
# if need_temps #(Skip for now)
for i in eachindex(varorder) # Loop through every participating variable
if (varorder[i] in string.(get_name.(vars)))
# Skip the variable if it's an input
continue
end
# Find which index varorder[i] is in `varids`
ID = findfirst(x -> occursin(varorder[i], x), varids)
tempID = 0
# If we are not already keeping track of temporary variables,
# this becomes the first one
if isempty(temp_endlist)
# Keep track of what varids depend on this temporary variable
push!(temp_endlist, copy(g.fadjlist[ID]))
tempID = 1
else
# Check if this variable's expression uses addition. If so,
# check if either of the RHS variables appeared earlier
# in varorder.
factorized_ID = findfirst(x -> isequal(string(x.lhs), varorder[i]), factorized)
if exprtype(factorized[factorized_ID].rhs) == ADD
# Check through the temp_endlist to see if any temporary variables
# ONLY point to this ID (i.e., they aren't used elsewhere). If so,
# we can re-use that temporary variable and overwrite the results
# with addition.
for j in eachindex(temp_endlist)
if (length(temp_endlist[j])==1) && (temp_endlist[j][1]==ID)
temp_endlist[j] = copy(g.fadjlist[ID])
tempID = j
break
end
end
end
# Check if there are any temporary variables we can override
for j in eachindex(temp_endlist)
if isempty(temp_endlist[j]) # Then we can override this one
temp_endlist[j] = copy(g.fadjlist[ID])
tempID = j
break
end
end
if tempID==0 # Then we haven't found one we can override
push!(temp_endlist, copy(g.fadjlist[ID]))
tempID = length(temp_endlist)
end
end
# Now that we're done with this variable, look over other
# temporary variables to see if they're no longer needed
for j in eachindex(temp_endlist)
if ID in temp_endlist[j]
filter!(x -> x!=ID, temp_endlist[j])
end
end
if tempID > maxtemp
maxtemp = tempID
end
end
# end # Skipping outer loop for now
# [Deprecating, using temporary variables to decrease global memory accesses]
# # We have one more temporary variable than we need, since at
# # least one result could have been stored in the final output
# # of this kernel.
# maxtemp -= 1
# At this point, we should be ready to write the kernel. Open
# the file in "append" mode so that other info that was written
# won't be impacted.
file = open(joinpath(@__DIR__, "storage", "f_"*expr_hash*".jl"), "a")
# Put in the preamble.
write(file, preamble_string(expr_hash, ["OUT"; string.(get_name.(vars))], kernel_ID, maxtemp, length(gradlist)))
# Loop through the topological list to add calculations in order
temp_endlist = []
outvar = ""
name_tracker = copy(varids)
for i in eachindex(varorder) # Order in which variables are calculated
# Skip calculation if the variable is one of the inputs
if (varorder[i] in string.(get_name.(vars)))
continue
end
# Determine the corresponding ID of the variable in varids
ID = findfirst(x -> occursin(varorder[i], x), varids)
# Get the inputs for this operation by checking the name
# tracker (there might be a better way to do this... ah well)
factorized_ID = findfirst(x -> isequal(string(x.lhs), varorder[i]), factorized)
participants = get_name.(pull_vars(factorized[factorized_ID].rhs))
inputs = []
for p in string.(participants)
# Find the corresponding element in varids
varids_ID = findfirst(x -> isequal(x, p), varids)
# Push the name_tracker name to the input list
push!(inputs, name_tracker[varids_ID])
end
# [Deprecating; I'll use temporary variables the whole way and then set
# the output variable at the end for final copying]
# # If this is the final variable, it'll be called "OUT". No need
# # for temp variables
# if i==length(varorder)
# name_tracker[ID] = "OUT"
# else
# Determine which tempID to use/override. temp_endlist keeps
# track of where variables will be used in the future (stored
# as g.fadjlist), with elements removed as they are used. If
# there is an empty row in temp_endlist, we can re-use that
# tempID. If there isn't an empty row, we add a new row.
tempID = 0
# If we are not already keeping track of temporary variables,
# this becomes the first one
if isempty(temp_endlist)
# Keep track of what varids depend on this temporary variable
push!(temp_endlist, copy(g.fadjlist[ID]))
tempID = 1
else
# Check if this variable's expression uses addition. If so,
# check if we can reuse a temporary variable.
factorized_ID = findfirst(x -> isequal(string(x.lhs), varorder[i]), factorized)
if exprtype(factorized[factorized_ID].rhs) == ADD
# Check through the temp_endlist to see if any temporary variables
# ONLY point to this ID (i.e., they aren't used elsewhere). If so,
# we can re-use that temporary variable and overwrite the results
# with addition.
for j in eachindex(temp_endlist)
if (length(temp_endlist[j])==1) && (temp_endlist[j][1]==ID)
temp_endlist[j] = copy(g.fadjlist[ID])
tempID = j
break
end
end
end
# Check if there are any temporary variables we can override
for j in eachindex(temp_endlist)
if isempty(temp_endlist[j]) # Then we can override this one
temp_endlist[j] = copy(g.fadjlist[ID])
tempID = j
break
end
end
if tempID==0 # Then we haven't found one we can override
push!(temp_endlist, copy(g.fadjlist[ID]))
tempID = length(temp_endlist)
end
end
# When we refer to this variable in the future, we need to know what tempID
# the variable is using
name_tracker[ID] = "temp$(tempID)"
# Now we can append this temporary variable to the list of inputs
# for the correct operation
inputs = [name_tracker[ID]; inputs]
# Now we can pass the equation's RHS and the inputs to call the correct
# writer function
if length(inputs)>2 && inputs[1]==inputs[2]
# Special case. We're adding inputs[3] to inputs[2], so we only want
# to pass the sparsity information of inputs[3] (rather than the
# sparsity information of inputs[2] and inputs[3])
corrected_i = findfirst(x->x==inputs[3], varorder)
write_operation(file, factorized[factorized_ID].rhs, inputs, string.(gradlist), sparsity[corrected_i])
elseif length(inputs)>2 && inputs[1]==inputs[3]
# Special case. We're adding inputs[2] to inputs[3], so we only want
# to pass the sparsity information of inputs[2] (rather than the
# sparsity information of inputs[2] and inputs[3])
corrected_i = findfirst(x->x==inputs[2], varorder)
write_operation(file, factorized[factorized_ID].rhs, inputs, string.(gradlist), sparsity[corrected_i])
else
write_operation(file, factorized[factorized_ID].rhs, inputs, string.(gradlist), sparsity[i])
end
# Now that we're done with this variable, eliminate this variable
# from the lists of temporary variables' requirements
for j in eachindex(temp_endlist)
if ID in temp_endlist[j]
filter!(x -> x!=ID, temp_endlist[j])
end
end
# Keep track of the name of the output variable for setting the output
if i==length(varorder)
outvar = name_tracker[ID]
end
end
# Now that the function is complete, we can close out the function
write(file, postamble(outvar))
close(file)
# And here's the function, if we want to print it out...
# @show joinpath(@__DIR__, "storage", "f_"*expr_hash*".jl")
return nothing
end
# This function detects subexpressions in the factored list of equations
# that have known/improved relaxations, and replaces them with their
# improved forms. All LHS forms are replaced by the RHS equivalent.
# Some reformulations must be checked before others, since structures
# may change and hide possible reformulations. E.g., if we have the
# expression 2^(log(x*y)), the ideal reformulation is: (x*y)^log(2).
# But, based on the ordering of factors, "log(x*y)" will appear earlier
# than "2^(log(x*y)), and "log(x*y)" reformulates to "log(x) + log(y)".
# This reformulation would disable the ideal reformulation above,
# since we'd then have 2^(log(x) + log(y)). Larger-scale reformulations
# like this example are marked by "HIGH PRIORITY" and are searched for
# first, before any other reformulations.
# Included forms:
# 1) exp(x1)*exp(x2) = exp(x1+x2) [EAGO paper/Misener2014]
# 2) log(x1^a) = a*log(x1) [EAGO paper/Misener2014]
# 3) log(a^x1) = x1*log(a) [EAGO paper]
# 4) log(x1*x2) = log(x1) + log(x2) [EAGO paper]
# 5) (x1^a)^b = x1^(a*b) [EAGO paper]
# 6) a^(log(x1)) = x1^(log(a)) [EAGO paper] (HIGH PRIORITY)
# 7) log(inv(x1)) = -log(x1) [EAGO paper]
# 8) CONST1*CONST2*x1 = (CONST1*CONST2)*x1
# 9) 1 / (1 + exp(-x)) = Sigmoid(x)
# 10) sin(x) = cos(x - pi/2)
#
# Forms that aren't relevant yet:
# 1) (a^x1)^b = (a^b)^x1 [EAGO paper] (Can't do powers besides integers)
function perform_substitutions(old_factored::Vector{Equation})
factored = deepcopy(old_factored)
# Register any terms we want to substitute
@eval @register_symbolic SCMC_sigmoid(x)
scan_flag = true
while scan_flag
scan_flag = false
# Higher priority reformulations
for index0 in eachindex(factored)
# 6) a^(log(x1)) = x1^(log(a)) [EAGO paper]
if exprtype(factored[index0].rhs)==POW
# We only apply this rule if the base of the power is Real-valued
if typeof(factored[index0].rhs.base) <: Real
index1 = findfirst(x -> isequal(x.lhs, factored[index0].rhs.exp), factored)
if !isnothing(index1) && exprtype(factored[index1].rhs)==TERM
if factored[index1].rhs.f==log
# We also only want to apply this rule if the argument of
# log() is a variable
if !(typeof(arguments(factored[index1].rhs)[]) <: Real)
# We have:
# aux1 = log(x1)
# aux2 = CONST^aux1
#
# We will replace with:
# aux1 = log(x1)
# aux2 = x1^[log(CONST)]
scan_flag = true
# Identify the log variable and base constant value
log_var = arguments(factored[index1].rhs)[]
base_val = factored[index0].rhs.base
# Modify index0
ID_1 = findfirst(x -> isequal(x.rhs, log_var^log(base_val)), factored)
if isnothing(ID_1)
@eval $factored[$index0] = $factored[$index0].lhs ~ $log_var^log($base_val)
else
for i in eachindex(factored)
@eval $factored[$i] = $factored[$i].lhs ~ substitute($factored[$i].rhs, Dict($factored[$index0].lhs => $factored[$ID_1].lhs))
end
deleteat!(factored, index0)
end
break
end
end
end
end
end
end
# Lower priority reformulations
if scan_flag==false
for index0 in eachindex(factored)
# 1) exp(x1)*exp(x2) = exp(x1+x2) [EAGO paper/Misener2014]
if exprtype(factored[index0].rhs)==MUL
# We need to check all the multiplication arguments. We create
# a list of all the args using a recursive function.
args = pull_mult(factored, index0)
# Count the number of args that are "exp"
exp_count = 0
exp_args = []
if length(args) > 1
for arg in args
if !(typeof(arg) <: Real) && exprtype(arg)==TERM
if arg.f==exp
exp_count += 1
push!(exp_args, arguments(arg)[])
end
end
end
end
# If we have more than 1 "exp", we'll apply this rule.
# Otherwise, we don't do anything.
if exp_count > 1
# We're going to create one big term and factor it, since
# the arrangement of multiplications in the original factorization
# might not line up with the changes
scan_flag = true
# Create the initial exp()
new_expr = exp(sum(exp_args))
# Multiply by all the non-exp terms
for arg in args
if exprtype(arg)!=TERM || arg.f!=exp
new_expr *= arg
end
end
# Create a factorization of this new expr
new_factorization = factor(new_expr, split_div=true)
# Scan through the new factorization to see if we can merge elements
# with the original factored list
done = false
last_match = nothing
while !done
done = true
for i in eachindex(new_factorization)
ID_1 = findfirst(x -> isequal(x.rhs, new_factorization[i].rhs), factored)
if !isnothing(ID_1)
# Match was found in the main factorization. Remove this element
# from the new factorization and replace all references to it with
# the auxiliary variable of ID_1
for j in eachindex(new_factorization)
new_factorization[j] = new_factorization[j].lhs ~ substitute(new_factorization[j].rhs, Dict(new_factorization[i].lhs => factored[ID_1].lhs))
end
last_match = factored[ID_1].lhs
deleteat!(new_factorization, i)
done = false
break
end
end
end
# Add remaining elements of the factorization to the location of index0
if length(new_factorization) >= 1
for j = length(new_factorization)-1:-1:1
insert!(factored, index0, new_factorization[j])
end
@eval $factored[$index0+length($new_factorization)-1] = $factored[$index0+length($new_factorization)-1].lhs ~ $new_factorization[end].rhs
else
# Or, if no elements exist, replace all instances of this new auxiliary variable with the last known match
for i in eachindex(factored)
@eval $factored[$i] = $factored[$i].lhs ~ substitute($factored[$i].rhs, Dict($factored[$index0].lhs => $last_match))
end
deleteat!(factored, index0)
end
break
end
end
# 2) log(x1^a) = a*log(x1) [EAGO paper/Misener2014]
# Equivalent in this code to:
# 3) log(a^x1) = x1*log(a) [EAGO paper]
if exprtype(factored[index0].rhs)==TERM
if factored[index0].rhs.f==log
index1 = findfirst(x -> isequal(x.lhs, arguments(factored[index0].rhs)[]), factored)
if !isnothing(index1) && exprtype(factored[index1].rhs)==POW
# If the argument of log() is POW type, we can separate based
# on the `base` and `exp`
# start:
# aux1 = log(x^a)
#
# Convert to:
# aux2 = log(x)
# aux1 = a*aux2
scan_flag = true
factor_base = factored[index1].rhs.base
factor_exp = factored[index1].rhs.exp
newsym = gensym(:aux)
newsym = Symbol(string(newsym)[3:5] * string(newsym)[7:end])
newvar = genvar(newsym)
if typeof(factor_base) <: Real
# Check whether the new term exists, and then either add it or only refer to that term
ID_1 = findfirst(x -> isequal(x.rhs, factor_exp*log(factor_base)), factored)
if isnothing(ID_1)
@eval $factored[$index0] = $factored[$index0].lhs ~ $factor_exp*log($factor_base)
else
for i in eachindex(factored)
@eval $factored[$i] = $factored[$i].lhs ~ substitute($factored[$i].rhs, Dict($factored[$index0].lhs => $factored[$ID_1].lhs))
end
deleteat!(factored, index0)
end
else
# Check whether terms already exist, and then either add them or only refer to them
ID_1 = findfirst(x -> isequal(x.rhs, log(factor_base)), factored)
if isnothing(ID_1)
insert!(factored, index0, Equation(Symbolics.value(newvar), log(factor_base)))
@eval $factored[$index0+1] = $factored[$index0+1].lhs ~ $factor_exp*$factored[$index0].lhs
else
# The log term exists... does factor_exp*log_term also exist?
ID_2 = findfirst(x -> isequal(x.rhs, factor_exp*factored[ID_1].lhs), factored)
if isnothing(ID_2)
@eval $factored[$index0] = $factored[$index0].lhs ~ $factor_exp*$factored[$ID_1].lhs
else
# Both the inner and outer terms already exist. All terms that refer to factored[index0]
# should instead refer to the term that already exists
for i in eachindex(factored)
@eval $factored[$i] = $factored[$i].lhs ~ substitute($factored[$i].rhs, Dict($factored[$index0].lhs => $factored[$ID_2].lhs))
end
deleteat!(factored, index0)
end
end
end
break
end
end
end
# 4) log(x1*x2) = log(x1) + log(x2) [EAGO paper]
if exprtype(factored[index0].rhs)==TERM
if factored[index0].rhs.f==log
index1 = findfirst(x -> isequal(x.lhs, arguments(factored[index0].rhs)[]), factored)
if !isnothing(index1) && exprtype(factored[index1].rhs)==MUL
# If the argument of log() is MUL type, we can create two new
# log() terms and add them together
# start:
# aux1 = log(x*y)
#
# Convert to:
# aux2 = log(x)
# aux3 = log(y)
# aux1 = aux2+aux3
#
# Or:
# aux1 = log(CONST*x)
# to:
# aux2 = log(x)
# aux1 = log(CONST) + aux2
scan_flag = true
args = [keys(factored[index1].rhs.dict)...]
coeff = factored[index1].rhs.coeff
if length(args)==1
# If there's only one arg, it's the CONST*x case
# Check whether log(x) already exists
ID_1 = findfirst(x -> isequal(x.rhs, log(args[1])), factored)
if isnothing(ID_1)
# If we don't already have log(x), create a new auxiliary variable
# and update the old expression
newsym = gensym(:aux)
newsym = Symbol(string(newsym)[3:5] * string(newsym)[7:end])
newvar = genvar(newsym)
insert!(factored, index0, Equation(Symbolics.value(newvar), log(args[1])))
@eval $factored[$index0+1] = $factored[$index0+1].lhs ~ log($coeff) + $factored[$index0].lhs
else
# If log(x) already exists, check if log(COEFF) + log(x) exists
ID_2 = findfirst(x -> isequal(x.rhs, log(coeff) + factored[ID_1].lhs), factored)
if isnothing(ID_2)
@eval $factored[$index0] = $factored[$index0].lhs ~ log($coeff) + $factored[$ID_1].lhs
else
for i in eachindex(factored)