Skip to content

Commit fa33a3c

Browse files
authored
Merge pull request #363 from wbernoudy/fix/argsort-revert
Fix ArgSortNode::revert
2 parents 2053b11 + 1bc0cf6 commit fa33a3c

2 files changed

Lines changed: 40 additions & 5 deletions

File tree

dwave/optimization/src/nodes/sorting.cpp

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,9 @@ struct ArgSortNodeData : public ArrayNodeStateData {
4343
ArgSortNodeData(ArgSortNodeDataHelper_&& helper)
4444
: ArrayNodeStateData(std::move(helper.indices)), order(std::move(helper.order)) {}
4545

46-
/// First is the value in the original array, second is the index of the value
46+
/// Pairs are <value in the original array, index of the value>
4747
std::set<std::pair<double, ssize_t>> order;
48+
std::vector<Update> predecessor_updates;
4849
};
4950

5051
ArgSortNode::ArgSortNode(ArrayNode* arr_ptr)
@@ -73,16 +74,24 @@ bool ArgSortNode::integral() const { return arr_ptr_->integral(); }
7374
std::pair<double, double> ArgSortNode::minmax(
7475
optional_cache_type<std::pair<double, double>> cache) const {
7576
return memoize(cache, [&]() {
76-
return std::make_pair(0.0, static_cast<double>(arr_ptr_->sizeinfo().max.value_or(
77-
std::numeric_limits<ssize_t>::max()) - 1));
77+
return std::make_pair(0.0,
78+
static_cast<double>(arr_ptr_->sizeinfo().max.value_or(
79+
std::numeric_limits<ssize_t>::max()) -
80+
1));
7881
});
7982
}
8083

8184
void ArgSortNode::propagate(State& state) const {
8285
auto node_data = data_ptr<ArgSortNodeData>(state);
8386

87+
auto pred_diff = arr_ptr_->diff(state);
88+
89+
// Save a copy of the predecessor's updates so we can use them in case we
90+
// need to revert the changes to the ordering
91+
node_data->predecessor_updates.assign(pred_diff.begin(), pred_diff.end());
92+
8493
// Make the modifications to the std::set based on the updates.
85-
for (const Update& update : arr_ptr_->diff(state)) {
94+
for (const Update& update : pred_diff) {
8695
if (!update.placed()) {
8796
node_data->order.erase(std::make_pair(update.old, update.index));
8897
}
@@ -100,7 +109,22 @@ void ArgSortNode::propagate(State& state) const {
100109
std::views::transform([](const std::pair<double, ssize_t>& p) { return p.second; }));
101110
}
102111

103-
void ArgSortNode::revert(State& state) const { data_ptr<ArgSortNodeData>(state)->revert(); }
112+
void ArgSortNode::revert(State& state) const {
113+
auto node_data = data_ptr<ArgSortNodeData>(state);
114+
115+
// Revert the changes to `order` by going over the predecessor's previous updates in reverse
116+
for (const Update& update : node_data->predecessor_updates | std::views::reverse) {
117+
if (!update.placed()) {
118+
node_data->order.insert(std::make_pair(update.old, update.index));
119+
}
120+
if (!update.removed()) {
121+
node_data->order.erase(std::make_pair(update.value, update.index));
122+
}
123+
}
124+
125+
node_data->predecessor_updates.clear();
126+
node_data->revert();
127+
}
104128

105129
std::span<const ssize_t> ArgSortNode::shape(const State& state) const {
106130
return arr_ptr_->shape(state);

tests/cpp/nodes/test_sorting.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,17 @@ TEST_CASE("ArgSortNode") {
137137
CHECK_THAT(argsort_ptr->view(state), RangeEquals({0, 1}));
138138
}
139139
}
140+
141+
AND_WHEN("We revert and propagate again") {
142+
graph.revert(state);
143+
144+
set_ptr->assign(state, std::vector<double>{4, 8, 7, 2});
145+
graph.propagate(state);
146+
147+
THEN("The argsort's state is correct") {
148+
CHECK_THAT(argsort_ptr->view(state), RangeEquals({3, 0, 2, 1}));
149+
}
150+
}
140151
}
141152
}
142153
}

0 commit comments

Comments
 (0)