Skip to content

Commit 156481a

Browse files
committed
Reformat NumberNode mutate methods
Removed `asserts()` that axis-wise bounds were satisfied to `NumberNode::propagate()`.
1 parent 0ec4bbb commit 156481a

2 files changed

Lines changed: 41 additions & 22 deletions

File tree

dwave/optimization/include/dwave-optimization/nodes/numbers.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,9 @@ class NumberNode : public ArrayOutputMixin<ArrayNode>, public DecisionNode {
128128
return initialize_state(state, std::move(values));
129129
}
130130

131+
/// @copydoc Node::propagate()
132+
void propagate(State& state) const override;
133+
131134
// NumberNode methods *****************************************************
132135

133136
// In the given state, swap the value of index i with the value of index j.

dwave/optimization/src/nodes/numbers.cpp

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,15 @@ void NumberNode::initialize_state(State& state) const {
348348
}
349349
}
350350

351+
void NumberNode::propagate(State& state) const {
352+
// Should only propagate states that obey the axis-wise bounds.
353+
assert(satisfies_axis_wise_bounds(bound_axes_info_, bound_axis_sums(state)));
354+
// Technically vestigial but will keep it for forms sake.
355+
for (const auto& sv : successors()) {
356+
sv->update(state, sv.index);
357+
}
358+
}
359+
351360
void NumberNode::commit(State& state) const noexcept {
352361
auto node_data = data_ptr<NumberNodeStateData>(state);
353362
// Manually store a copy of bound_axes_sums.
@@ -372,16 +381,15 @@ void NumberNode::exchange(State& state, ssize_t i, ssize_t j) const {
372381
// assert() that i and j are valid indices occurs in ptr->exchange().
373382
// State change occurs IFF (i != j) and (buffer[i] != buffer[j]).
374383
if (ptr->exchange(i, j)) {
375-
// No need to update slice sums as they will be unchanged.
376-
if (!bound_axis_ops_all_equals_) {
377-
// If exchange occurred, update the bound axis sums.
384+
// If change occurred and axis-wise bounds exist, update bound axis sums.
385+
// Nothing to update if all axis bound operators are Equals.
386+
if (!bound_axis_ops_all_equals_ && bound_axes_info_.size() > 0) {
378387
const double difference = ptr->get(i) - ptr->get(j);
379388
// Index i changed from (what is now) ptr->get(j) to ptr->get(i)
380389
update_bound_axis_slice_sums(state, i, difference);
381390
// Index j changed from (what is now) ptr->get(i) to ptr->get(j)
382391
update_bound_axis_slice_sums(state, j, -difference);
383392
}
384-
assert(satisfies_axis_wise_bounds(bound_axes_info_, ptr->bound_axes_sums));
385393
}
386394
}
387395

@@ -429,9 +437,10 @@ void NumberNode::clip_and_set_value(State& state, ssize_t index, double value) c
429437
// assert() that i is a valid index occurs in ptr->set().
430438
// State change occurs IFF `value` != buffer[index].
431439
if (ptr->set(index, value)) {
432-
// If change occurred, update bound axis sums by difference.
433-
update_bound_axis_slice_sums(state, index, value - diff(state).back().old);
434-
assert(satisfies_axis_wise_bounds(bound_axes_info_, ptr->bound_axes_sums));
440+
// If change occurred and axis-wise bounds exist, update bound axis sums.
441+
if (bound_axes_info_.size() > 0) {
442+
update_bound_axis_slice_sums(state, index, value - diff(state).back().old);
443+
}
435444
}
436445
}
437446

@@ -569,9 +578,9 @@ NumberNode::NumberNode(std::span<const ssize_t> shape, std::vector<double> lower
569578

570579
void NumberNode::update_bound_axis_slice_sums(State& state, const ssize_t index,
571580
const double value_change) const {
572-
assert(value_change != 0); // Should not call when no change occurs.
573581
const auto& bound_axes_info = bound_axes_info_;
574-
if (bound_axes_info.size() == 0) return; // No axis-wise bounds to satisfy.
582+
assert(value_change != 0); // Should not call when no change occurs.
583+
assert(bound_axes_info.size() != 0); // Should only call where applicable.
575584

576585
// Get multidimensional indices for `index` so we can identify the slices
577586
// `index` lies on per bound axis.
@@ -698,9 +707,10 @@ void IntegerNode::set_value(State& state, ssize_t index, double value) const {
698707
// assert() that i is a valid index occurs in ptr->set().
699708
// State change occurs IFF `value` != buffer[index].
700709
if (ptr->set(index, value)) {
701-
// If change occurred, update bound axis sums by difference.
702-
update_bound_axis_slice_sums(state, index, value - diff(state).back().old);
703-
assert(satisfies_axis_wise_bounds(bound_axes_info_, ptr->bound_axes_sums));
710+
// If change occurred and axis-wise bounds exist, update bound axis sums.
711+
if (bound_axes_info_.size() > 0) {
712+
update_bound_axis_slice_sums(state, index, value - diff(state).back().old);
713+
}
704714
}
705715
}
706716

@@ -808,10 +818,12 @@ void BinaryNode::flip(State& state, ssize_t i) const {
808818
// assert() that i is a valid index occurs in ptr->set().
809819
// State change occurs IFF `value` != buffer[i].
810820
if (ptr->set(i, !ptr->get(i))) {
811-
// If value changed from 0 -> 1, update the bound axis sums by 1.
812-
// If value changed from 1 -> 0, update the bound axis sums by -1.
813-
update_bound_axis_slice_sums(state, i, (ptr->get(i) == 1) ? 1 : -1);
814-
assert(satisfies_axis_wise_bounds(bound_axes_info_, ptr->bound_axes_sums));
821+
// If change occurred and axis-wise bounds exist, update bound axis sums.
822+
if (bound_axes_info_.size() > 0) {
823+
// If value changed from 0 -> 1, update by 1.
824+
// If value changed from 1 -> 0, update by -1.
825+
update_bound_axis_slice_sums(state, i, (ptr->get(i) == 1) ? 1 : -1);
826+
}
815827
}
816828
}
817829

@@ -822,9 +834,11 @@ void BinaryNode::set(State& state, ssize_t i) const {
822834
// assert() that i is a valid index occurs in ptr->set().
823835
// State change occurs IFF `value` != buffer[i].
824836
if (ptr->set(i, 1.0)) {
825-
// If value changed from 0 -> 1, update the bound axis sums by 1.
826-
update_bound_axis_slice_sums(state, i, 1.0);
827-
assert(satisfies_axis_wise_bounds(bound_axes_info_, ptr->bound_axes_sums));
837+
// If change occurred and axis-wise bounds exist, update bound axis sums.
838+
if (bound_axes_info_.size() > 0) {
839+
// If value changed from 0 -> 1, update by 1.
840+
update_bound_axis_slice_sums(state, i, 1.0);
841+
}
828842
}
829843
}
830844

@@ -835,9 +849,11 @@ void BinaryNode::unset(State& state, ssize_t i) const {
835849
// assert() that i is a valid index occurs in ptr->set().
836850
// State change occurs IFF `value` != buffer[i].
837851
if (ptr->set(i, 0.0)) {
838-
// If value changed from 1 -> 0, update the bound axis sums by -1.
839-
update_bound_axis_slice_sums(state, i, -1.0);
840-
assert(satisfies_axis_wise_bounds(bound_axes_info_, ptr->bound_axes_sums));
852+
// If change occurred and axis-wise bounds exist, update bound axis sums.
853+
if (bound_axes_info_.size() > 0) {
854+
// If value changed from 1 -> 0, update by -1.
855+
update_bound_axis_slice_sums(state, i, -1.0);
856+
}
841857
}
842858
}
843859

0 commit comments

Comments
 (0)