Skip to content

Commit a7ff6a0

Browse files
committed
Fixed NumberNodeStateData bug
Previously, `bound_axes_sums` and `prior_bound_axes_sums` on `NumberNodeStateData` resulted in segfaults.
1 parent 116d03b commit a7ff6a0

1 file changed

Lines changed: 18 additions & 9 deletions

File tree

dwave/optimization/src/nodes/numbers.cpp

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,9 @@ struct NumberNodeStateData : public ArrayNodeStateData {
7777
/// bound_axes_sums[i][j] = "sum of the values within the jth slice along
7878
/// the ith bound axis"
7979
/// Note that "ith bound axis" does not necessarily mean the ith axis.
80-
std::vector<std::vector<double>> bound_axes_sums;
80+
std::optional<std::vector<std::vector<double>>> bound_axes_sums = std::nullopt;
8181
// Store a copy for NumberNode::revert() and commit()
82-
std::vector<std::vector<double>> prior_bound_axes_sums;
82+
std::optional<std::vector<std::vector<double>>> prior_bound_axes_sums = std::nullopt;
8383
};
8484

8585
double const* NumberNode::buff(const State& state) const noexcept {
@@ -350,7 +350,8 @@ void NumberNode::initialize_state(State& state) const {
350350

351351
void NumberNode::propagate(State& state) const {
352352
// Should only propagate states that obey the axis-wise bounds.
353-
assert(satisfies_axis_wise_bounds(bound_axes_info_, bound_axis_sums(state)));
353+
assert(bound_axes_info_.size() == 0 ||
354+
satisfies_axis_wise_bounds(bound_axes_info_, bound_axis_sums(state)));
354355
// Technically vestigial but will keep it for forms sake.
355356
for (const auto& sv : successors()) {
356357
sv->update(state, sv.index);
@@ -359,16 +360,22 @@ void NumberNode::propagate(State& state) const {
359360

360361
void NumberNode::commit(State& state) const noexcept {
361362
auto node_data = data_ptr<NumberNodeStateData>(state);
362-
// Manually store a copy of bound_axes_sums.
363-
node_data->prior_bound_axes_sums = node_data->bound_axes_sums;
364363
node_data->commit();
364+
if (node_data->prior_bound_axes_sums.has_value()) {
365+
assert(node_data->bound_axes_sums.has_value());
366+
// Manually store a copy of bound_axes_sums.
367+
node_data->prior_bound_axes_sums = node_data->bound_axes_sums;
368+
}
365369
}
366370

367371
void NumberNode::revert(State& state) const noexcept {
368372
auto node_data = data_ptr<NumberNodeStateData>(state);
369-
// Manually reset bound_axes_sums.
370-
node_data->bound_axes_sums = node_data->prior_bound_axes_sums;
371373
node_data->revert();
374+
if (node_data->prior_bound_axes_sums.has_value()) {
375+
assert(node_data->bound_axes_sums.has_value());
376+
// Manually reset bound_axes_sums.
377+
node_data->bound_axes_sums = node_data->prior_bound_axes_sums;
378+
}
372379
}
373380

374381
void NumberNode::exchange(State& state, ssize_t i, ssize_t j) const {
@@ -449,7 +456,8 @@ const std::vector<NumberNode::AxisBound>& NumberNode::axis_wise_bounds() const {
449456
}
450457

451458
const std::vector<std::vector<double>>& NumberNode::bound_axis_sums(const State& state) const {
452-
return data_ptr<NumberNodeStateData>(state)->bound_axes_sums;
459+
assert(data_ptr<NumberNodeStateData>(state)->bound_axes_sums.has_value());
460+
return *data_ptr<NumberNodeStateData>(state)->bound_axes_sums;
453461
}
454462

455463
template <bool maximum>
@@ -587,7 +595,8 @@ void NumberNode::update_bound_axis_slice_sums(State& state, const ssize_t index,
587595
const std::vector<ssize_t> multi_index = unravel_index(index, this->shape());
588596
assert(bound_axes_info.size() <= multi_index.size());
589597
// Get the slice sums of all bound axes.
590-
auto& bound_axes_sums = data_ptr<NumberNodeStateData>(state)->bound_axes_sums;
598+
assert(data_ptr<NumberNodeStateData>(state)->bound_axes_sums.has_value());
599+
auto& bound_axes_sums = data_ptr<NumberNodeStateData>(state)->bound_axes_sums.value();
591600
assert(bound_axes_info.size() == bound_axes_sums.size());
592601

593602
// For each bound axis

0 commit comments

Comments
 (0)