@@ -77,9 +77,9 @@ struct NumberNodeStateData : public ArrayNodeStateData {
7777 // / bound_axes_sums[i][j] = "sum of the values within the jth slice along
7878 // / the ith bound axis"
7979 // / Note that "ith bound axis" does not necessarily mean the ith axis.
80- std::vector<std::vector<double >> bound_axes_sums;
80+ std::optional<std:: vector<std::vector<double >>> bound_axes_sums = std:: nullopt ;
8181 // Store a copy for NumberNode::revert() and commit()
82- std::vector<std::vector<double >> prior_bound_axes_sums;
82+ std::optional<std:: vector<std::vector<double >>> prior_bound_axes_sums = std:: nullopt ;
8383};
8484
8585double const * NumberNode::buff (const State& state) const noexcept {
@@ -350,7 +350,8 @@ void NumberNode::initialize_state(State& state) const {
350350
351351void NumberNode::propagate (State& state) const {
352352 // Should only propagate states that obey the axis-wise bounds.
353- assert (satisfies_axis_wise_bounds (bound_axes_info_, bound_axis_sums (state)));
353+ assert (bound_axes_info_.size () == 0 ||
354+ satisfies_axis_wise_bounds (bound_axes_info_, bound_axis_sums (state)));
354355 // Technically vestigial but will keep it for forms sake.
355356 for (const auto & sv : successors ()) {
356357 sv->update (state, sv.index );
@@ -359,16 +360,22 @@ void NumberNode::propagate(State& state) const {
359360
360361void NumberNode::commit (State& state) const noexcept {
361362 auto node_data = data_ptr<NumberNodeStateData>(state);
362- // Manually store a copy of bound_axes_sums.
363- node_data->prior_bound_axes_sums = node_data->bound_axes_sums ;
364363 node_data->commit ();
364+ if (node_data->prior_bound_axes_sums .has_value ()) {
365+ assert (node_data->bound_axes_sums .has_value ());
366+ // Manually store a copy of bound_axes_sums.
367+ node_data->prior_bound_axes_sums = node_data->bound_axes_sums ;
368+ }
365369}
366370
367371void NumberNode::revert (State& state) const noexcept {
368372 auto node_data = data_ptr<NumberNodeStateData>(state);
369- // Manually reset bound_axes_sums.
370- node_data->bound_axes_sums = node_data->prior_bound_axes_sums ;
371373 node_data->revert ();
374+ if (node_data->prior_bound_axes_sums .has_value ()) {
375+ assert (node_data->bound_axes_sums .has_value ());
376+ // Manually reset bound_axes_sums.
377+ node_data->bound_axes_sums = node_data->prior_bound_axes_sums ;
378+ }
372379}
373380
374381void NumberNode::exchange (State& state, ssize_t i, ssize_t j) const {
@@ -449,7 +456,8 @@ const std::vector<NumberNode::AxisBound>& NumberNode::axis_wise_bounds() const {
449456}
450457
451458const std::vector<std::vector<double >>& NumberNode::bound_axis_sums (const State& state) const {
452- return data_ptr<NumberNodeStateData>(state)->bound_axes_sums ;
459+ assert (data_ptr<NumberNodeStateData>(state)->bound_axes_sums .has_value ());
460+ return *data_ptr<NumberNodeStateData>(state)->bound_axes_sums ;
453461}
454462
455463template <bool maximum>
@@ -587,7 +595,8 @@ void NumberNode::update_bound_axis_slice_sums(State& state, const ssize_t index,
587595 const std::vector<ssize_t > multi_index = unravel_index (index, this ->shape ());
588596 assert (bound_axes_info.size () <= multi_index.size ());
589597 // Get the slice sums of all bound axes.
590- auto & bound_axes_sums = data_ptr<NumberNodeStateData>(state)->bound_axes_sums ;
598+ assert (data_ptr<NumberNodeStateData>(state)->bound_axes_sums .has_value ());
599+ auto & bound_axes_sums = data_ptr<NumberNodeStateData>(state)->bound_axes_sums .value ();
591600 assert (bound_axes_info.size () == bound_axes_sums.size ());
592601
593602 // For each bound axis
0 commit comments