Skip to content

Commit ae28f88

Browse files
committed
implement breadth-first sorting of compute graphs
1 parent ab7bef7 commit ae28f88

4 files changed

Lines changed: 42 additions & 17 deletions

File tree

madspace/include/madspace/madcode/function.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ class Function {
5353
ValueVec _locals;
5454
std::unordered_map<std::string, Value> _globals;
5555
std::vector<InstructionCall> _instructions;
56+
57+
friend Function sort_breadth_first(const Function& function);
5658
};
5759

5860
std::ostream& operator<<(std::ostream& out, const Value& value);

madspace/include/madspace/madcode/optimizer.h

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,27 @@ class InstructionDependencies {
1010
public:
1111
InstructionDependencies(const Function& function);
1212
bool depends(std::size_t test_index, std::size_t dependency_index) {
13-
return matrix[test_index * size + dependency_index];
13+
return _matrix[test_index * _size + dependency_index];
1414
}
15+
const std::vector<int>& ranks() const { return _ranks; }
1516

1617
private:
17-
std::size_t size;
18-
std::vector<bool> matrix;
19-
std::vector<int> ranks;
18+
std::size_t _size;
19+
std::vector<bool> _matrix;
20+
std::vector<int> _ranks;
2021
};
2122

2223
class LastUseOfLocals {
2324
public:
2425
LastUseOfLocals(const Function& function);
2526
std::vector<std::size_t>& local_indices(std::size_t index) {
26-
return last_used.at(index);
27+
return _last_used.at(index);
2728
}
2829

2930
private:
30-
std::vector<std::vector<std::size_t>> last_used;
31+
std::vector<std::vector<std::size_t>> _last_used;
3132
};
3233

34+
Function sort_breadth_first(const Function& function);
35+
3336
} // namespace madspace

madspace/src/madcode/optimizer.cpp

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
using namespace madspace;
1111

1212
InstructionDependencies::InstructionDependencies(const Function& function) :
13-
size(function.instructions().size()), matrix(size * size) {
13+
_size(function.instructions().size()), _matrix(_size * _size) {
1414
std::vector<int> local_source(function.locals().size(), -1);
1515
int index = 0;
1616
for (auto& instr : function.instructions()) {
@@ -20,26 +20,26 @@ InstructionDependencies::InstructionDependencies(const Function& function) :
2020
if (source_index == -1) {
2121
continue;
2222
}
23-
matrix.at(index * size + source_index) = true;
24-
for (int i = 0; i < size; ++i) {
25-
matrix.at(index * size + i) =
26-
matrix.at(index * size + i) | matrix.at(source_index * size + i);
23+
_matrix.at(index * _size + source_index) = true;
24+
for (int i = 0; i < _size; ++i) {
25+
_matrix.at(index * _size + i) =
26+
_matrix.at(index * _size + i) | _matrix.at(source_index * _size + i);
2727
}
28-
int source_rank = ranks.at(source_index);
28+
int source_rank = _ranks.at(source_index);
2929
if (rank < source_rank) {
3030
rank = source_rank;
3131
}
3232
}
3333
for (auto& output : instr.outputs) {
3434
local_source.at(output.local_index) = index;
3535
}
36-
ranks.push_back(rank + 1);
36+
_ranks.push_back(rank + 1);
3737
++index;
3838
}
3939
}
4040

4141
LastUseOfLocals::LastUseOfLocals(const Function& function) :
42-
last_used(function.instructions().size()) {
42+
_last_used(function.instructions().size()) {
4343
std::vector<bool> seen_locals;
4444
for (auto& local : function.locals()) {
4545
seen_locals.push_back(
@@ -50,7 +50,7 @@ LastUseOfLocals::LastUseOfLocals(const Function& function) :
5050
seen_locals.at(output.local_index) = true;
5151
}
5252
auto instr = function.instructions().rbegin();
53-
auto indices = last_used.begin();
53+
auto indices = _last_used.begin();
5454
for (; instr != function.instructions().rend(); ++instr, ++indices) {
5555
for (auto& input : instr->inputs) {
5656
auto index = input.local_index;
@@ -60,5 +60,22 @@ LastUseOfLocals::LastUseOfLocals(const Function& function) :
6060
}
6161
}
6262
}
63-
std::reverse(last_used.begin(), last_used.end());
63+
std::reverse(_last_used.begin(), _last_used.end());
64+
}
65+
66+
Function madspace::sort_breadth_first(const Function& function) {
67+
Function func_out = function;
68+
InstructionDependencies dependencies(function);
69+
auto order = dependencies.ranks();
70+
std::vector<std::size_t> instruction_perm(function.instructions().size());
71+
std::iota(instruction_perm.begin(), instruction_perm.end(), 0);
72+
std::stable_sort(
73+
instruction_perm.begin(), instruction_perm.end(),
74+
[&](std::size_t i, std::size_t j) { return order.at(i) < order.at(j); }
75+
);
76+
func_out._instructions.clear();
77+
for (std::size_t index : instruction_perm) {
78+
func_out._instructions.push_back(function._instructions.at(index));
79+
}
80+
return func_out;
6481
}

madspace/src/phasespace/integrand.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -651,13 +651,16 @@ ValueVec MultiChannelIntegrand::build_function_impl(
651651
common_args.max_weight = args.at(1);
652652
}
653653
std::vector<Integrand::ChannelResult> results;
654-
for (auto [integrand, chan_size] : zip(_integrands, all_batch_sizes)) {
654+
for (std::size_t index = 0; auto [integrand, chan_size] : zip(_integrands, all_batch_sizes)) {
655+
fb.set_current_stream(index + 1);
655656
auto channel_args = common_args;
656657
channel_args.r = fb.random(chan_size, integrand->_random_dim);
657658
channel_args.batch_size = chan_size;
658659
channel_args.has_permutations = integrand->_mapping.channel_count() > 1;
659660
results.push_back(integrand->build_channel_part(fb, channel_args));
661+
++index;
660662
}
663+
fb.set_current_stream(0);
661664

662665
Integrand::ChannelResult common_results;
663666
for (std::size_t i = 0; i < common_results.values.size(); ++i) {

0 commit comments

Comments
 (0)