@@ -43,8 +43,9 @@ struct ArgSortNodeData : public ArrayNodeStateData {
4343 ArgSortNodeData (ArgSortNodeDataHelper_&& helper)
4444 : ArrayNodeStateData(std::move(helper.indices)), order(std::move(helper.order)) {}
4545
46- // / First is the value in the original array, second is the index of the value
46+ // / Pairs are < value in the original array, index of the value>
4747 std::set<std::pair<double , ssize_t >> order;
48+ std::vector<Update> predecessor_updates;
4849};
4950
5051ArgSortNode::ArgSortNode (ArrayNode* arr_ptr)
@@ -73,16 +74,24 @@ bool ArgSortNode::integral() const { return arr_ptr_->integral(); }
7374std::pair<double , double > ArgSortNode::minmax (
7475 optional_cache_type<std::pair<double , double >> cache) const {
7576 return memoize (cache, [&]() {
76- return std::make_pair (0.0 , static_cast <double >(arr_ptr_->sizeinfo ().max .value_or (
77- std::numeric_limits<ssize_t >::max ()) - 1 ));
77+ return std::make_pair (0.0 ,
78+ static_cast <double >(arr_ptr_->sizeinfo ().max .value_or (
79+ std::numeric_limits<ssize_t >::max ()) -
80+ 1 ));
7881 });
7982}
8083
8184void ArgSortNode::propagate (State& state) const {
8285 auto node_data = data_ptr<ArgSortNodeData>(state);
8386
87+ auto pred_diff = arr_ptr_->diff (state);
88+
89+ // Save a copy of the predecessor's updates so we can use them in case we
90+ // need to revert the changes to the ordering
91+ node_data->predecessor_updates .assign (pred_diff.begin (), pred_diff.end ());
92+
8493 // Make the modifications to the std::set based on the updates.
85- for (const Update& update : arr_ptr_-> diff (state) ) {
94+ for (const Update& update : pred_diff ) {
8695 if (!update.placed ()) {
8796 node_data->order .erase (std::make_pair (update.old , update.index ));
8897 }
@@ -100,7 +109,22 @@ void ArgSortNode::propagate(State& state) const {
100109 std::views::transform ([](const std::pair<double , ssize_t >& p) { return p.second ; }));
101110}
102111
103- void ArgSortNode::revert (State& state) const { data_ptr<ArgSortNodeData>(state)->revert (); }
112+ void ArgSortNode::revert (State& state) const {
113+ auto node_data = data_ptr<ArgSortNodeData>(state);
114+
115+ // Revert the changes to `order` by going over the predecessor's previous updates in reverse
116+ for (const Update& update : node_data->predecessor_updates | std::views::reverse) {
117+ if (!update.placed ()) {
118+ node_data->order .insert (std::make_pair (update.old , update.index ));
119+ }
120+ if (!update.removed ()) {
121+ node_data->order .erase (std::make_pair (update.value , update.index ));
122+ }
123+ }
124+
125+ node_data->predecessor_updates .clear ();
126+ node_data->revert ();
127+ }
104128
105129std::span<const ssize_t > ArgSortNode::shape (const State& state) const {
106130 return arr_ptr_->shape (state);
0 commit comments