Skip to content

Commit 72c6d96

Browse files
committed
Address 2nd rnd. comments NumberNode axis-wise bounds
Added indicator variable that all bound axis operators are `==` to reduce redundancy in `NumberNode::exchange()` method.
1 parent b76e45a commit 72c6d96

6 files changed

Lines changed: 138 additions & 81 deletions

File tree

dwave/optimization/include/dwave-optimization/nodes/numbers.hpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@ namespace dwave::optimization {
2828
/// A contiguous block of numbers.
2929
class NumberNode : public ArrayOutputMixin<ArrayNode>, public DecisionNode {
3030
public:
31-
/// Struct for stateless axis-wise bound information. Given an `axis`, define
32-
/// constraints on the sum of the values in each slice along `axis`.
33-
/// Constraints can be defined for ALL slices along `axis` or PER slice along
34-
/// `axis`. Allowable operators are defined by `Operator`.
31+
/// Struct for stateless axis-wise bound information. Given an `axis`,
32+
/// define constraints on the sum of the values in each slice along `axis`.
33+
/// Constraints can be defined for ALL slices along `axis` or PER slice
34+
/// along `axis`. Allowable operators are defined by `Operator`.
3535
struct AxisBound {
3636
/// Allowable axis-wise bound operators.
3737
enum class Operator { Equal, LessEqual, GreaterEqual };
@@ -43,11 +43,11 @@ class NumberNode : public ArrayOutputMixin<ArrayNode>, public DecisionNode {
4343

4444
/// The bound axis
4545
ssize_t axis;
46-
/// Operator for ALL axis slices (vector has length one) or operators PER
47-
/// slice (length of vector is equal to the number of slices).
46+
/// Operator for ALL axis slices (vector has length one) or operators
47+
/// PER slice (length of vector is equal to the number of slices).
4848
std::vector<Operator> operators;
49-
/// Bound for ALL axis slices (vector has length one) or bounds PER slice
50-
/// (length of vector is equal to the number of slices).
49+
/// Bound for ALL axis slices (vector has length one) or bounds PER
50+
/// slice (length of vector is equal to the number of slices).
5151
std::vector<double> bounds;
5252

5353
/// Obtain the bound associated with a given slice along `axis`.
@@ -143,7 +143,7 @@ class NumberNode : public ArrayOutputMixin<ArrayNode>, public DecisionNode {
143143
/// Return the stateless axis-wise bound information i.e. bound_axes_info_.
144144
const std::vector<AxisBound>& axis_wise_bounds() const;
145145

146-
/// Return the state-dependent sum of the values within each hyperslice
146+
/// Return the state-dependent sum of the values within each slice
147147
/// along each bound axis. The returned vector is indexed by the
148148
/// bound axes in the same ordering that `axis_wise_bounds()` returns.
149149
const std::vector<std::vector<double>>& bound_axis_sums(State& state) const;
@@ -173,6 +173,8 @@ class NumberNode : public ArrayOutputMixin<ArrayNode>, public DecisionNode {
173173

174174
/// Stateless information on each bound axis.
175175
std::vector<AxisBound> bound_axes_info_;
176+
/// Indicator variable that all axis-wise bound operators are "==".
177+
bool bound_axis_ops_all_equals_;
176178
};
177179

178180
/// A contiguous block of integer numbers.

dwave/optimization/libcpp/nodes/numbers.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ cdef extern from "dwave-optimization/nodes/numbers.hpp" namespace "dwave::optimi
2929
LessEqual
3030
GreaterEqual
3131

32-
AxisBound(Py_ssize_t axis, vector[Operator] axis_opertors,
32+
AxisBound(Py_ssize_t axis, vector[Operator] axis_operators,
3333
vector[double] axis_bounds)
3434
Py_ssize_t axis
3535
vector[Operator] operators

dwave/optimization/model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,8 @@ def objective(self, value: ArraySymbol):
166166
def binary(self, shape: None | _ShapeLike = None,
167167
lower_bound: None | np.typing.ArrayLike = None,
168168
upper_bound: None | np.typing.ArrayLike = None,
169-
subject_to: None | list[tuple(int, str | list[str], float |
170-
list[float])] = None) -> BinaryVariable:
169+
subject_to: None | list[tuple[int, str | list[str], float |
170+
list[float]]] = None) -> BinaryVariable:
171171
r"""Create a binary symbol as a decision variable.
172172
173173
Args:
@@ -509,8 +509,8 @@ def integer(
509509
shape: None | _ShapeLike = None,
510510
lower_bound: None | numpy.typing.ArrayLike = None,
511511
upper_bound: None | numpy.typing.ArrayLike = None,
512-
subject_to: None | list[tuple(int, str | list[str], float |
513-
list[float])] = None) -> IntegerVariable:
512+
subject_to: None | list[tuple[int, str | list[str], float |
513+
list[float]]] = None) -> IntegerVariable:
514514
r"""Create an integer symbol as a decision variable.
515515
516516
Args:

dwave/optimization/src/nodes/numbers.cpp

Lines changed: 54 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ NumberNode::AxisBound::AxisBound(ssize_t bound_axis, std::vector<Operator> axis_
3939
throw std::invalid_argument("Axis-wise `operators` and `bounds` must have non-zero size.");
4040
}
4141

42-
// If `operators` and `bounds` are both defined PER hyperslice along
43-
// `axis`, they must have the same size.
42+
// If `operators` and `bounds` are both defined PER slice along `axis`,
43+
// they must have the same size.
4444
if ((num_operators > 1) && (num_bounds > 1) && (num_bounds != num_operators)) {
4545
throw std::invalid_argument(
4646
"Axis-wise `operators` and `bounds` should have same size if neither has size 1.");
@@ -70,10 +70,10 @@ struct NumberNodeStateData : public ArrayNodeStateData {
7070
: ArrayNodeStateData(std::move(input)),
7171
bound_axes_sums(std::move(bound_axes_sums)),
7272
prior_bound_axes_sums(this->bound_axes_sums) {}
73-
/// For each bound axis and for each hyperslice along said axis, we
74-
/// track the sum of the values within the hyperslice.
75-
/// bound_axes_sums[i][j] = "sum of the values within the jth
76-
/// hyperslice along the ith bound axis"
73+
/// For each bound axis and for each slice along said axis, we track the
74+
/// sum of the values within the slice.
75+
/// bound_axes_sums[i][j] = "sum of the values within the jth slice along
76+
/// the ith bound axis"
7777
/// Note that "ith bound axis" does not necessarily mean the ith axis.
7878
std::vector<std::vector<double>> bound_axes_sums;
7979
// Store a copy for NumberNode::revert() and commit()
@@ -94,7 +94,7 @@ double NumberNode::max() const { return max_; }
9494

9595
/// Given a NumberNode and an assingnment of it's variables (number_data),
9696
/// compute and return a vector containing the sum of the values within each
97-
/// hyperslice along each bound axis.
97+
/// slice along each bound axis.
9898
std::vector<std::vector<double>> get_bound_axes_sums(const NumberNode* node,
9999
const std::vector<double>& number_data) {
100100
std::span<const ssize_t> node_shape = node->shape();
@@ -105,13 +105,13 @@ std::vector<std::vector<double>> get_bound_axes_sums(const NumberNode* node,
105105
static_cast<ssize_t>(number_data.size()));
106106

107107
// For each bound axis, initialize the sum of the values contained in each
108-
// of it's hyperslice to 0. Define bound_axes_sums[i][j] = "sum of the
109-
// values within the jth hyperslice along the ith bound axis".
108+
// of it's slice to 0. Define bound_axes_sums[i][j] = "sum of the values
109+
// within the jth slice along the ith bound axis".
110110
std::vector<std::vector<double>> bound_axes_sums;
111111
bound_axes_sums.reserve(num_bound_axes);
112112
for (const NumberNode::AxisBound& axis_info : bound_axes_info) {
113113
assert(0 <= axis_info.axis && axis_info.axis < static_cast<ssize_t>(node_shape.size()));
114-
// Emplace an all zeros vector of size equal to the number of hyperslice
114+
// Emplace an all zeros vector of size equal to the number of slice
115115
// along the given bound axis (axis_info.axis).
116116
bound_axes_sums.emplace_back(node_shape[axis_info.axis], 0.0);
117117
}
@@ -120,7 +120,7 @@ std::vector<std::vector<double>> get_bound_axes_sums(const NumberNode* node,
120120
// NumberNode and iterate over it.
121121
for (BufferIterator<double, double, true> it(number_data.data(), node_shape, node->strides());
122122
it != std::default_sentinel; ++it) {
123-
// Increment the sum of the appropriate hyperslice along each bound axis.
123+
// Increment the sum of the appropriate slice along each bound axis.
124124
for (ssize_t bound_axis = 0; bound_axis < num_bound_axes; ++bound_axis) {
125125
const ssize_t axis = bound_axes_info[bound_axis].axis;
126126
assert(0 <= axis && axis < static_cast<ssize_t>(it.location().size()));
@@ -133,8 +133,8 @@ std::vector<std::vector<double>> get_bound_axes_sums(const NumberNode* node,
133133
return bound_axes_sums;
134134
}
135135

136-
/// Determine whether the sum of the values within each hyperslice along
137-
/// each bound axis satisfies the axis-wise bounds.
136+
/// Determine whether the sum of the values within each slice along each bound
137+
/// axis satisfies the axis-wise bounds.
138138
bool satisfies_axis_wise_bounds(const std::vector<NumberNode::AxisBound>& bound_axes_info,
139139
const std::vector<std::vector<double>>& bound_axes_sums) {
140140
assert(bound_axes_info.size() == bound_axes_sums.size());
@@ -178,8 +178,8 @@ void NumberNode::initialize_state(State& state, std::vector<double>&& number_dat
178178
if (bound_axes_info_.size() == 0) { // No bound axes to consider.
179179
emplace_data_ptr<NumberNodeStateData>(state, std::move(number_data));
180180
} else {
181-
// Given the assingnment to NumberNode `number_data`, compute the sum of the
182-
// values within each hyperslice along each bound axis.
181+
// Given the assingnment to NumberNode `number_data`, compute the sum
182+
// of the values within each slice along each bound axis.
183183
std::vector<std::vector<double>> bound_axes_sums = get_bound_axes_sums(this, number_data);
184184

185185
if (!satisfies_axis_wise_bounds(bound_axes_info_, bound_axes_sums)) {
@@ -252,7 +252,7 @@ double compute_bound_axis_slice_delta(const ssize_t slice, const double sum,
252252
/// Given a NumberNode and exactly one axis-wise bound, assign values to
253253
/// `values` (in-place) to satisfy the axis-wise bound. This method
254254
/// A) Initially sets `values[i] = lower_bound(i)` for all i.
255-
/// B) Incremements the values within each hyperslice until they satisfy
255+
/// B) Incremements the values within each slice until they satisfy
256256
/// the axis-wise bound (should this be possible).
257257
void construct_state_given_exactly_one_bound_axis(const NumberNode* node,
258258
std::vector<double>& values) {
@@ -263,31 +263,31 @@ void construct_state_given_exactly_one_bound_axis(const NumberNode* node,
263263
for (ssize_t i = 0, stop = node->size(); i < stop; ++i) {
264264
values.push_back(node->lower_bound(i));
265265
}
266-
// 2) Determine the hyperslice sums for the bound axis. To improve
267-
// performance, compute sum during previous loop.
266+
// 2) Determine the slice sums for the bound axis. To improve performance,
267+
// compute sum during previous loop.
268268
assert(node->axis_wise_bounds().size() == 1);
269269
const std::vector<double> bound_axis_sums = get_bound_axes_sums(node, values).front();
270270
// Obtain the stateless bound axis data for node.
271271
const NumberNode::AxisBound& bound_axis_info = node->axis_wise_bounds().front();
272272
const ssize_t bound_axis = bound_axis_info.axis;
273273
assert(0 <= bound_axis && bound_axis < ndim);
274274

275-
// We need a way to iterate over each hyperslice along the bound axis and
276-
// adjust it`s values until they satisfy the axis-wise bounds. We do this
277-
// by defining an iterator of `values` that traverses each hyperslice one
278-
// after another. This is equivalent to adjusting the node's shape and
279-
// strides such that the data for the bound_axis is moved to position 0.
275+
// We need a way to iterate over each slice along the bound axis and adjust
276+
// it`s values until they satisfy the axis-wise bounds. We do this by
277+
// defining an iterator of `values` that traverses each slice one after
278+
// another. This is equivalent to adjusting the node's shape and strides
279+
// such that the data for the bound_axis is moved to position 0.
280280
const std::vector<ssize_t> buff_shape = shift_axis_data(node_shape, bound_axis);
281281
const std::vector<ssize_t> buff_strides = shift_axis_data(node->strides(), bound_axis);
282282
// Define an iterator for `values` corresponding with the beginning of
283283
// slice 0 along the bound axis.
284284
const BufferIterator<double, double, false> slice_0_it(values.data(), ndim, buff_shape.data(),
285285
buff_strides.data());
286-
// Determine the size of each hyperslice along the bound axis.
286+
// Determine the size of each slice along the bound axis.
287287
const ssize_t slice_size = std::accumulate(buff_shape.begin() + 1, buff_shape.end(), 1.0,
288288
std::multiplies<ssize_t>());
289289

290-
// 3) Iterate over each hyperslice and adjust it's values until they
290+
// 3) Iterate over each slice and adjust it's values until they
291291
// satisfy the axis-wise bounds.
292292
for (ssize_t slice = 0, stop = node_shape[bound_axis]; slice < stop; ++slice) {
293293
// Determine the amount needed to adjust the values within the slice.
@@ -297,8 +297,8 @@ void construct_state_given_exactly_one_bound_axis(const NumberNode* node,
297297
if (delta == 0) continue; // Axis-wise bounds are satisfied for slice.
298298
assert(delta >= 0); // Should only increment.
299299

300-
// Determine how much we need to offset `slice_0_it` to get to the first
301-
// index in the given `slice`.
300+
// Determine how much we need to offset `slice_0_it` to get to the
301+
// first index in the given `slice`.
302302
const ssize_t offset = slice * slice_size;
303303
// Iterate over all indices in the given slice.
304304
for (auto slice_it = slice_0_it + offset, slice_end_it = slice_it + slice_size;
@@ -367,12 +367,15 @@ void NumberNode::exchange(State& state, ssize_t i, ssize_t j) const {
367367
// assert() that i and j are valid indices occurs in ptr->exchange().
368368
// State change occurs IFF (i != j) and (buffer[i] != buffer[j]).
369369
if (ptr->exchange(i, j)) {
370-
// If exchange occured, update the bound axis sums.
371-
const double difference = ptr->get(i) - ptr->get(j);
372-
// Index i changed from (what is now) ptr->get(j) to ptr->get(i)
373-
update_bound_axis_slice_sums(state, i, difference);
374-
// Index j changed from (what is now) ptr->get(i) to ptr->get(j)
375-
update_bound_axis_slice_sums(state, j, -difference);
370+
// No need to update slice sums as they will be unchanged.
371+
if (!bound_axis_ops_all_equals_) {
372+
// If exchange occurred, update the bound axis sums.
373+
const double difference = ptr->get(i) - ptr->get(j);
374+
// Index i changed from (what is now) ptr->get(j) to ptr->get(i)
375+
update_bound_axis_slice_sums(state, i, difference);
376+
// Index j changed from (what is now) ptr->get(i) to ptr->get(j)
377+
update_bound_axis_slice_sums(state, j, -difference);
378+
}
376379
assert(satisfies_axis_wise_bounds(bound_axes_info_, ptr->bound_axes_sums));
377380
}
378381
}
@@ -419,9 +422,9 @@ void NumberNode::clip_and_set_value(State& state, ssize_t index, double value) c
419422
auto ptr = data_ptr<NumberNodeStateData>(state);
420423
value = std::clamp(value, lower_bound(index), upper_bound(index));
421424
// assert() that i is a valid index occurs in ptr->set().
422-
// State change occurs IFF `value` != buffer[index] .
425+
// State change occurs IFF `value` != buffer[index].
423426
if (ptr->set(index, value)) {
424-
// If change occured, update bound axis sums by differnce.
427+
// If change occurred, update bound axis sums by difference.
425428
update_bound_axis_slice_sums(state, index, value - diff(state).back().old);
426429
assert(satisfies_axis_wise_bounds(bound_axes_info_, ptr->bound_axes_sums));
427430
}
@@ -447,6 +450,16 @@ double get_extreme_index_wise_bound(const std::vector<double>& bound) {
447450
return *it;
448451
}
449452

453+
bool all_bound_axis_operators_are_equals(std::vector<NumberNode::AxisBound>& bound_axes_info) {
454+
for (const NumberNode::AxisBound& bound_axis_info : bound_axes_info) {
455+
for (const NumberNode::AxisBound::Operator op : bound_axis_info.operators) {
456+
if (op != NumberNode::AxisBound::Operator::Equal) return false;
457+
}
458+
}
459+
// Vacuously true if there are no axis-wise bounds.
460+
return true;
461+
}
462+
450463
void check_index_wise_bounds(const NumberNode& node, const std::vector<double>& lower_bounds_,
451464
const std::vector<double>& upper_bounds_) {
452465
bool index_wise_bound = false;
@@ -534,7 +547,8 @@ NumberNode::NumberNode(std::span<const ssize_t> shape, std::vector<double> lower
534547
max_(get_extreme_index_wise_bound<true>(upper_bound)),
535548
lower_bounds_(std::move(lower_bound)),
536549
upper_bounds_(std::move(upper_bound)),
537-
bound_axes_info_(std::move(bound_axes)) {
550+
bound_axes_info_(std::move(bound_axes)),
551+
bound_axis_ops_all_equals_(all_bound_axis_operators_are_equals(bound_axes_info_)) {
538552
if ((shape.size() > 0) && (shape[0] < 0)) {
539553
throw std::invalid_argument("Number array cannot have dynamic size.");
540554
}
@@ -549,14 +563,15 @@ NumberNode::NumberNode(std::span<const ssize_t> shape, std::vector<double> lower
549563

550564
void NumberNode::update_bound_axis_slice_sums(State& state, const ssize_t index,
551565
const double value_change) const {
566+
assert(value_change != 0); // Should not call when no change occurs.
552567
const auto& bound_axes_info = bound_axes_info_;
553-
if (bound_axes_info.size() == 0) return; // No axis-wise bounds to satisfy
568+
if (bound_axes_info.size() == 0) return; // No axis-wise bounds to satisfy.
554569

555570
// Get multidimensional indices for `index` so we can identify the slices
556571
// `index` lies on per bound axis.
557572
const std::vector<ssize_t> multi_index = unravel_index(index, this->shape());
558573
assert(bound_axes_info.size() <= multi_index.size());
559-
// Get the hyperslice sums of all bound axes.
574+
// Get the slice sums of all bound axes.
560575
auto& bound_axes_sums = data_ptr<NumberNodeStateData>(state)->bound_axes_sums;
561576
assert(bound_axes_info.size() == bound_axes_sums.size());
562577

@@ -676,7 +691,7 @@ void IntegerNode::set_value(State& state, ssize_t index, double value) const {
676691
// assert() that i is a valid index occurs in ptr->set().
677692
// State change occurs IFF `value` != buffer[index].
678693
if (ptr->set(index, value)) {
679-
// If change occured, update bound axis sums by differnce.
694+
// If change occurred, update bound axis sums by difference.
680695
update_bound_axis_slice_sums(state, index, value - diff(state).back().old);
681696
assert(satisfies_axis_wise_bounds(bound_axes_info_, ptr->bound_axes_sums));
682697
}

0 commit comments

Comments
 (0)