Skip to content

Commit 7ec3faf

Browse files
committed
Merge branch 'refactor_evgen' of github.com:MadGraphTeam/MadGraph7 into refactor_evgen
2 parents e3f59d9 + ae28f88 commit 7ec3faf

5 files changed

Lines changed: 111 additions & 84 deletions

File tree

madspace/include/madspace/madcode/function.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ class Function {
5353
ValueVec _locals;
5454
std::unordered_map<std::string, Value> _globals;
5555
std::vector<InstructionCall> _instructions;
56+
57+
friend Function sort_breadth_first(const Function& function);
5658
};
5759

5860
std::ostream& operator<<(std::ostream& out, const Value& value);

madspace/include/madspace/madcode/optimizer.h

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,27 @@ class InstructionDependencies {
1010
public:
1111
InstructionDependencies(const Function& function);
1212
bool depends(std::size_t test_index, std::size_t dependency_index) {
13-
return matrix[test_index * size + dependency_index];
13+
return _matrix[test_index * _size + dependency_index];
1414
}
15+
const std::vector<int>& ranks() const { return _ranks; }
1516

1617
private:
17-
std::size_t size;
18-
std::vector<bool> matrix;
19-
std::vector<int> ranks;
18+
std::size_t _size;
19+
std::vector<bool> _matrix;
20+
std::vector<int> _ranks;
2021
};
2122

2223
class LastUseOfLocals {
2324
public:
2425
LastUseOfLocals(const Function& function);
2526
std::vector<std::size_t>& local_indices(std::size_t index) {
26-
return last_used.at(index);
27+
return _last_used.at(index);
2728
}
2829

2930
private:
30-
std::vector<std::vector<std::size_t>> last_used;
31+
std::vector<std::vector<std::size_t>> _last_used;
3132
};
3233

34+
Function sort_breadth_first(const Function& function);
35+
3336
} // namespace madspace

madspace/src/gpu/runtime.cu

Lines changed: 69 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -973,20 +973,20 @@ public:
973973
_stream_count(stream_count), _sync_matrix(stream_count * stream_count, true) {}
974974

975975
bool is_in_sync_with(std::size_t this_stream, std::size_t other_stream) const {
976-
return _sync_matrix.at(this_stream * stream_count + other_stream);
976+
return _sync_matrix.at(this_stream * _stream_count + other_stream);
977977
}
978978
void desynchronize(std::size_t this_stream) {
979979
for (std::size_t other_stream = 0; other_stream < _stream_count;
980-
++_other_stream) {
980+
++other_stream) {
981981
if (this_stream != other_stream) {
982-
_sync_matrix.at(other_stream * stream_count + this_stream) = false;
982+
_sync_matrix.at(other_stream * _stream_count + this_stream) = false;
983983
}
984984
}
985985
}
986986
void synchronize(std::size_t this_stream, std::size_t other_stream) {
987987
for (std::size_t i = 0; i < _stream_count; ++i) {
988988
if (is_in_sync_with(other_stream, i)) {
989-
_sync_matrix.at(this_stream * stream_count + i) = true;
989+
_sync_matrix.at(this_stream * _stream_count + i) = true;
990990
}
991991
}
992992
}
@@ -995,14 +995,15 @@ public:
995995
private:
996996
std::size_t _stream_count;
997997
std::vector<bool> _sync_matrix;
998-
}
998+
};
999999

10001000
} // namespace
10011001

10021002
GpuRuntime::GpuRuntime(const Function& function, ContextPtr context) :
10031003
_context(context),
1004-
_input_count(function.inputs().size()) _gpublas_handle(
1005-
context.thread_pool(),
1004+
_input_count(function.inputs().size()),
1005+
_gpublas_handle(
1006+
context->thread_pool(),
10061007
[]() {
10071008
gpublasHandle_t handle;
10081009
check_error(gpublasCreate(&handle));
@@ -1011,7 +1012,7 @@ GpuRuntime::GpuRuntime(const Function& function, ContextPtr context) :
10111012
[](gpublasHandle_t handle) { check_error(gpublasDestroy(handle)); }
10121013
),
10131014
_gpurand_generator(
1014-
context.thread_pool(),
1015+
context->thread_pool(),
10151016
[]() {
10161017
gpurandGenerator_t handle;
10171018
check_error(gpurandCreateGenerator(&handle, GPURAND_RNG_PSEUDO_DEFAULT));
@@ -1035,52 +1036,38 @@ GpuRuntime::GpuRuntime(const Function& function, ContextPtr context) :
10351036
std::size_t stream_count = 0, event_count = 0, backward_event_count = 0;
10361037
for (auto& instr : function.instructions()) {
10371038
if (instr.stream_index >= stream_count) {
1038-
stream_count = instr.stream_index;
1039+
stream_count = instr.stream_index + 1;
10391040
}
10401041
}
10411042
SyncTracker sync_tracker(stream_count);
10421043
std::vector<int> local_source_streams(function.locals().size(), -1);
10431044
SizeVec last_stream_instrs(stream_count);
1044-
nested_vector2<std::size_t> local_consumer_streams(function.locals().size());
10451045
nested_vector2<std::size_t> backward_wait_events(function.instructions().size());
10461046
std::vector<int> backward_record_events(function.instructions().size(), -1);
10471047

1048-
for (auto& instr : function.instructions()) {
1049-
if (instr.stream_index >= stream_count) {
1050-
stream_count = instr.stream_index + 1;
1051-
}
1052-
}
1053-
1054-
auto update_sync = [&](std::size_t local_index,
1055-
std::size_t stream_index,
1056-
SizeVec& wait_events,
1057-
auto get_event) {
1058-
int source_stream = local_source_streams.at(local_index);
1059-
auto& consumer_streams = local_consumer_streams.at(local_index);
1060-
if (std::find(consumer_streams.begin(), consumer_streams.end(), stream_index) ==
1061-
consumer_streams.end()) {
1062-
consumer_streams.push_back(stream_index);
1063-
}
1064-
if (!sync_tracker.is_in_sync_with(stream_index, source_stream)) {
1065-
wait_events.push_back(get_event(source_stream));
1066-
sync_tracker.synchronize(stream_index, source_stream);
1067-
}
1068-
} auto get_event_backward = [&](std::size_t source_stream) -> int {
1069-
int& event = backward_record_events.at(last_stream_instrs.at(source_stream));
1070-
if (event == -1) {
1071-
event = backward_event_count;
1072-
++backward_event_count;
1073-
}
1074-
return event;
1075-
};
1048+
auto update_sync_backward =
1049+
[&](std::size_t local_index, std::size_t stream_index, SizeVec& wait_events) {
1050+
int source_stream = local_source_streams.at(local_index);
1051+
if (source_stream == -1) {
1052+
return;
1053+
}
1054+
if (!sync_tracker.is_in_sync_with(stream_index, source_stream)) {
1055+
int& event =
1056+
backward_record_events.at(last_stream_instrs.at(source_stream));
1057+
if (event == -1) {
1058+
event = backward_event_count;
1059+
++backward_event_count;
1060+
}
1061+
wait_events.push_back(event);
1062+
sync_tracker.synchronize(stream_index, source_stream);
1063+
}
1064+
};
10761065

10771066
for (std::size_t instr_index = 0; auto [instr, bw_wait_events] :
10781067
zip(std::views::reverse(function.instructions()),
10791068
std::views::reverse(backward_wait_events))) {
10801069
for (auto& out : instr.outputs) {
1081-
update_sync(
1082-
out.local_index, instr.stream_index, bw_wait_events, get_event_backward
1083-
);
1070+
update_sync_backward(out.local_index, instr.stream_index, bw_wait_events);
10841071
}
10851072
for (auto& in : instr.inputs) {
10861073
local_source_streams.at(in.local_index) = instr.stream_index;
@@ -1090,22 +1077,35 @@ GpuRuntime::GpuRuntime(const Function& function, ContextPtr context) :
10901077
++instr_index;
10911078
}
10921079
for (auto& in : function.inputs()) {
1093-
update_sync(in.local_index, 0, _wait_events, get_event_forward);
1080+
update_sync_backward(in.local_index, 0, _backward_wait_events);
10941081
}
10951082

10961083
sync_tracker.reset();
1097-
local_source_streams.clear();
1098-
last_stream_instrs.clear();
1099-
local_consumer_streams.clear();
1100-
1101-
auto get_event_forward = [&](std::size_t source_stream) -> int {
1102-
int& event =
1103-
_instructions.at(last_stream_instrs.at(source_stream)).record_event;
1104-
if (event == -1) {
1105-
event = event_count;
1106-
++event_count;
1084+
std::fill(local_source_streams.begin(), local_source_streams.end(), -1);
1085+
nested_vector2<std::size_t> local_consumer_streams(function.locals().size());
1086+
1087+
auto update_sync = [&](std::size_t local_index,
1088+
std::size_t stream_index,
1089+
SizeVec& wait_events) {
1090+
int source_stream = local_source_streams.at(local_index);
1091+
if (source_stream == -1) {
1092+
return;
1093+
}
1094+
auto& consumer_streams = local_consumer_streams.at(local_index);
1095+
if (std::find(consumer_streams.begin(), consumer_streams.end(), stream_index) ==
1096+
consumer_streams.end()) {
1097+
consumer_streams.push_back(stream_index);
1098+
}
1099+
if (!sync_tracker.is_in_sync_with(stream_index, source_stream)) {
1100+
int& event =
1101+
_instructions.at(last_stream_instrs.at(source_stream)).record_event;
1102+
if (event == -1) {
1103+
event = event_count;
1104+
++event_count;
1105+
}
1106+
wait_events.push_back(event);
1107+
sync_tracker.synchronize(stream_index, source_stream);
11071108
}
1108-
return event;
11091109
};
11101110

11111111
std::vector<bool> is_input(function.locals().size());
@@ -1141,9 +1141,7 @@ GpuRuntime::GpuRuntime(const Function& function, ContextPtr context) :
11411141
if (in.type.batch_size != BatchSize::one) {
11421142
batch_size_index = in.local_index;
11431143
}
1144-
update_sync(
1145-
in.local_index, instr.stream_index, bw_wait_events, get_event_forward
1146-
);
1144+
update_sync(in.local_index, instr.stream_index, bw_wait_events);
11471145
}
11481146
SizeVec output_indices;
11491147
std::vector<DataType> output_dtypes;
@@ -1161,6 +1159,7 @@ GpuRuntime::GpuRuntime(const Function& function, ContextPtr context) :
11611159
}
11621160

11631161
sync_tracker.desynchronize(instr.stream_index);
1162+
last_stream_instrs.at(instr.stream_index) = _instructions.size();
11641163
_instructions.push_back({
11651164
instr.instruction->opcode(),
11661165
input_indices,
@@ -1176,16 +1175,18 @@ GpuRuntime::GpuRuntime(const Function& function, ContextPtr context) :
11761175
wait_events,
11771176
-1,
11781177
bw_wait_events,
1179-
bw_record_events,
1178+
bw_record_event,
11801179
});
11811180

1182-
auto locals_to_free = last_use.local_indices(instr_index);
1181+
SizeVec locals_to_free = last_use.local_indices(instr_index);
11831182
free_queue.insert(
11841183
free_queue.end(), locals_to_free.begin(), locals_to_free.end()
11851184
);
11861185
free_queue.erase(
11871186
std::remove_if(
1188-
free_queue.begin(), free_queue.end(), [&](std::size_t local_index) {
1187+
free_queue.begin(),
1188+
free_queue.end(),
1189+
[&](std::size_t local_index) {
11891190
for (std::size_t consumer_stream :
11901191
local_consumer_streams.at(local_index)) {
11911192
if (!sync_tracker.is_in_sync_with(
@@ -1213,7 +1214,8 @@ GpuRuntime::GpuRuntime(const Function& function, ContextPtr context) :
12131214
);
12141215
return true;
12151216
}
1216-
)
1217+
),
1218+
free_queue.end()
12171219
);
12181220

12191221
++instr_index;
@@ -1252,11 +1254,11 @@ GpuRuntime::GpuRuntime(const Function& function, ContextPtr context) :
12521254

12531255
for (auto& out : function.outputs()) {
12541256
_output_indices.push_back(out.local_index);
1255-
update_sync(out.local_index, 0, _wait_events, get_event_forward);
1257+
update_sync(out.local_index, 0, _wait_events);
12561258
}
12571259

12581260
_streams = ThreadResource<std::vector<gpuStream_t>>(
1259-
context.thread_pool(),
1261+
context->thread_pool(),
12601262
[stream_count]() {
12611263
std::vector<gpuStream_t> streams(stream_count);
12621264
for (auto& item : streams) {
@@ -1272,7 +1274,7 @@ GpuRuntime::GpuRuntime(const Function& function, ContextPtr context) :
12721274
);
12731275
std::size_t max_event_count = std::max(event_count, backward_event_count);
12741276
_events = ThreadResource<std::vector<gpuEvent_t>>(
1275-
context.thread_pool(),
1277+
context->thread_pool(),
12761278
[max_event_count]() {
12771279
std::vector<gpuEvent_t> events(max_event_count);
12781280
for (auto& item : events) {
@@ -1388,7 +1390,7 @@ std::tuple<TensorVec, TensorVec, std::vector<bool>> GpuRuntime::run_with_grad(
13881390
for (auto index : _output_indices) {
13891391
outputs.push_back(locals[index]);
13901392
}
1391-
check_error(gpuStreamSynchronize(main_stream);
1393+
check_error(gpuStreamSynchronize(main_stream));
13921394
return {outputs, locals, eval_grad};
13931395
}
13941396

@@ -1411,7 +1413,7 @@ GpuRuntime::run_backward(
14111413
gpuStream_t main_stream = streams.at(0);
14121414
for (auto [instr, instr_eval_grad] :
14131415
zip(std::views::reverse(_instructions), std::views::reverse(eval_grad))) {
1414-
/*gpuStream_t stream = streams.at(instr.backward_stream);
1416+
/*gpuStream_t stream = streams.at(instr.stream);
14151417
for (auto event : instr.backward_wait_events) {
14161418
check_error(gpuStreamWaitEvent(stream, events.at(event)));
14171419
}*/
@@ -1432,8 +1434,8 @@ GpuRuntime::run_backward(
14321434
#include "runtime_backward_mixin.h"
14331435
}
14341436
}
1435-
/*if (instr.backward_record_event) {
1436-
check_error(gpuEventRecord(instr.backward_record_event, stream));
1437+
/*if (instr.backward_record_event != -1) {
1438+
check_error(gpuEventRecord(events.at(instr.backward_record_event), stream));
14371439
}*/
14381440
}
14391441
/*for (auto event : _backward_wait_events) {

madspace/src/madcode/optimizer.cpp

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
using namespace madspace;
1111

1212
InstructionDependencies::InstructionDependencies(const Function& function) :
13-
size(function.instructions().size()), matrix(size * size) {
13+
_size(function.instructions().size()), _matrix(_size * _size) {
1414
std::vector<int> local_source(function.locals().size(), -1);
1515
int index = 0;
1616
for (auto& instr : function.instructions()) {
@@ -20,26 +20,26 @@ InstructionDependencies::InstructionDependencies(const Function& function) :
2020
if (source_index == -1) {
2121
continue;
2222
}
23-
matrix.at(index * size + source_index) = true;
24-
for (int i = 0; i < size; ++i) {
25-
matrix.at(index * size + i) =
26-
matrix.at(index * size + i) | matrix.at(source_index * size + i);
23+
_matrix.at(index * _size + source_index) = true;
24+
for (int i = 0; i < _size; ++i) {
25+
_matrix.at(index * _size + i) =
26+
_matrix.at(index * _size + i) | _matrix.at(source_index * _size + i);
2727
}
28-
int source_rank = ranks.at(source_index);
28+
int source_rank = _ranks.at(source_index);
2929
if (rank < source_rank) {
3030
rank = source_rank;
3131
}
3232
}
3333
for (auto& output : instr.outputs) {
3434
local_source.at(output.local_index) = index;
3535
}
36-
ranks.push_back(rank + 1);
36+
_ranks.push_back(rank + 1);
3737
++index;
3838
}
3939
}
4040

4141
LastUseOfLocals::LastUseOfLocals(const Function& function) :
42-
last_used(function.instructions().size()) {
42+
_last_used(function.instructions().size()) {
4343
std::vector<bool> seen_locals;
4444
for (auto& local : function.locals()) {
4545
seen_locals.push_back(
@@ -50,7 +50,7 @@ LastUseOfLocals::LastUseOfLocals(const Function& function) :
5050
seen_locals.at(output.local_index) = true;
5151
}
5252
auto instr = function.instructions().rbegin();
53-
auto indices = last_used.begin();
53+
auto indices = _last_used.begin();
5454
for (; instr != function.instructions().rend(); ++instr, ++indices) {
5555
for (auto& input : instr->inputs) {
5656
auto index = input.local_index;
@@ -60,5 +60,22 @@ LastUseOfLocals::LastUseOfLocals(const Function& function) :
6060
}
6161
}
6262
}
63-
std::reverse(last_used.begin(), last_used.end());
63+
std::reverse(_last_used.begin(), _last_used.end());
64+
}
65+
66+
Function madspace::sort_breadth_first(const Function& function) {
67+
Function func_out = function;
68+
InstructionDependencies dependencies(function);
69+
auto order = dependencies.ranks();
70+
std::vector<std::size_t> instruction_perm(function.instructions().size());
71+
std::iota(instruction_perm.begin(), instruction_perm.end(), 0);
72+
std::stable_sort(
73+
instruction_perm.begin(), instruction_perm.end(),
74+
[&](std::size_t i, std::size_t j) { return order.at(i) < order.at(j); }
75+
);
76+
func_out._instructions.clear();
77+
for (std::size_t index : instruction_perm) {
78+
func_out._instructions.push_back(function._instructions.at(index));
79+
}
80+
return func_out;
6481
}

0 commit comments

Comments
 (0)