@@ -39,8 +39,8 @@ NumberNode::AxisBound::AxisBound(ssize_t bound_axis, std::vector<Operator> axis_
3939 throw std::invalid_argument (" Axis-wise `operators` and `bounds` must have non-zero size." );
4040 }
4141
42- // If `operators` and `bounds` are both defined PER hyperslice along
43- // `axis`, they must have the same size.
42+ // If `operators` and `bounds` are both defined PER slice along `axis`,
43+ // they must have the same size.
4444 if ((num_operators > 1 ) && (num_bounds > 1 ) && (num_bounds != num_operators)) {
4545 throw std::invalid_argument (
4646 " Axis-wise `operators` and `bounds` should have same size if neither has size 1." );
@@ -70,10 +70,10 @@ struct NumberNodeStateData : public ArrayNodeStateData {
7070 : ArrayNodeStateData(std::move(input)),
7171 bound_axes_sums (std::move(bound_axes_sums)),
7272 prior_bound_axes_sums(this ->bound_axes_sums) {}
73- // / For each bound axis and for each hyperslice along said axis, we
74- // / track the sum of the values within the hyperslice .
75- // / bound_axes_sums[i][j] = "sum of the values within the jth
76- // / hyperslice along the ith bound axis"
73+ // / For each bound axis and for each slice along said axis, we track the
74+ // / sum of the values within the slice .
75+ // / bound_axes_sums[i][j] = "sum of the values within the jth slice along
76+ // / the ith bound axis"
7777 // / Note that "ith bound axis" does not necessarily mean the ith axis.
7878 std::vector<std::vector<double >> bound_axes_sums;
7979 // Store a copy for NumberNode::revert() and commit()
@@ -94,7 +94,7 @@ double NumberNode::max() const { return max_; }
9494
9595// / Given a NumberNode and an assingnment of it's variables (number_data),
9696// / compute and return a vector containing the sum of the values within each
97- // / hyperslice along each bound axis.
97+ // / slice along each bound axis.
9898std::vector<std::vector<double >> get_bound_axes_sums (const NumberNode* node,
9999 const std::vector<double >& number_data) {
100100 std::span<const ssize_t > node_shape = node->shape ();
@@ -105,13 +105,13 @@ std::vector<std::vector<double>> get_bound_axes_sums(const NumberNode* node,
105105 static_cast <ssize_t >(number_data.size ()));
106106
107107 // For each bound axis, initialize the sum of the values contained in each
108- // of it's hyperslice to 0. Define bound_axes_sums[i][j] = "sum of the
109- // values within the jth hyperslice along the ith bound axis".
108+ // of it's slice to 0. Define bound_axes_sums[i][j] = "sum of the values
109+ // within the jth slice along the ith bound axis".
110110 std::vector<std::vector<double >> bound_axes_sums;
111111 bound_axes_sums.reserve (num_bound_axes);
112112 for (const NumberNode::AxisBound& axis_info : bound_axes_info) {
113113 assert (0 <= axis_info.axis && axis_info.axis < static_cast <ssize_t >(node_shape.size ()));
114- // Emplace an all zeros vector of size equal to the number of hyperslice
114+ // Emplace an all zeros vector of size equal to the number of slice
115115 // along the given bound axis (axis_info.axis).
116116 bound_axes_sums.emplace_back (node_shape[axis_info.axis ], 0.0 );
117117 }
@@ -120,7 +120,7 @@ std::vector<std::vector<double>> get_bound_axes_sums(const NumberNode* node,
120120 // NumberNode and iterate over it.
121121 for (BufferIterator<double , double , true > it (number_data.data (), node_shape, node->strides ());
122122 it != std::default_sentinel; ++it) {
123- // Increment the sum of the appropriate hyperslice along each bound axis.
123+ // Increment the sum of the appropriate slice along each bound axis.
124124 for (ssize_t bound_axis = 0 ; bound_axis < num_bound_axes; ++bound_axis) {
125125 const ssize_t axis = bound_axes_info[bound_axis].axis ;
126126 assert (0 <= axis && axis < static_cast <ssize_t >(it.location ().size ()));
@@ -133,8 +133,8 @@ std::vector<std::vector<double>> get_bound_axes_sums(const NumberNode* node,
133133 return bound_axes_sums;
134134}
135135
136- // / Determine whether the sum of the values within each hyperslice along
137- // / each bound axis satisfies the axis-wise bounds.
136+ // / Determine whether the sum of the values within each slice along each bound
137+ // / axis satisfies the axis-wise bounds.
138138bool satisfies_axis_wise_bounds (const std::vector<NumberNode::AxisBound>& bound_axes_info,
139139 const std::vector<std::vector<double >>& bound_axes_sums) {
140140 assert (bound_axes_info.size () == bound_axes_sums.size ());
@@ -178,8 +178,8 @@ void NumberNode::initialize_state(State& state, std::vector<double>&& number_dat
178178 if (bound_axes_info_.size () == 0 ) { // No bound axes to consider.
179179 emplace_data_ptr<NumberNodeStateData>(state, std::move (number_data));
180180 } else {
181- // Given the assingnment to NumberNode `number_data`, compute the sum of the
182- // values within each hyperslice along each bound axis.
181+ // Given the assingnment to NumberNode `number_data`, compute the sum
182+ // of the values within each slice along each bound axis.
183183 std::vector<std::vector<double >> bound_axes_sums = get_bound_axes_sums (this , number_data);
184184
185185 if (!satisfies_axis_wise_bounds (bound_axes_info_, bound_axes_sums)) {
@@ -252,7 +252,7 @@ double compute_bound_axis_slice_delta(const ssize_t slice, const double sum,
252252// / Given a NumberNode and exactly one axis-wise bound, assign values to
253253// / `values` (in-place) to satisfy the axis-wise bound. This method
254254// / A) Initially sets `values[i] = lower_bound(i)` for all i.
255- // / B) Incremements the values within each hyperslice until they satisfy
255+ // / B) Incremements the values within each slice until they satisfy
256256// / the axis-wise bound (should this be possible).
257257void construct_state_given_exactly_one_bound_axis (const NumberNode* node,
258258 std::vector<double >& values) {
@@ -263,31 +263,31 @@ void construct_state_given_exactly_one_bound_axis(const NumberNode* node,
263263 for (ssize_t i = 0 , stop = node->size (); i < stop; ++i) {
264264 values.push_back (node->lower_bound (i));
265265 }
266- // 2) Determine the hyperslice sums for the bound axis. To improve
267- // performance, compute sum during previous loop.
266+ // 2) Determine the slice sums for the bound axis. To improve performance,
267+ // compute sum during previous loop.
268268 assert (node->axis_wise_bounds ().size () == 1 );
269269 const std::vector<double > bound_axis_sums = get_bound_axes_sums (node, values).front ();
270270 // Obtain the stateless bound axis data for node.
271271 const NumberNode::AxisBound& bound_axis_info = node->axis_wise_bounds ().front ();
272272 const ssize_t bound_axis = bound_axis_info.axis ;
273273 assert (0 <= bound_axis && bound_axis < ndim);
274274
275- // We need a way to iterate over each hyperslice along the bound axis and
276- // adjust it`s values until they satisfy the axis-wise bounds. We do this
277- // by defining an iterator of `values` that traverses each hyperslice one
278- // after another. This is equivalent to adjusting the node's shape and
279- // strides such that the data for the bound_axis is moved to position 0.
275+ // We need a way to iterate over each slice along the bound axis and adjust
276+ // it`s values until they satisfy the axis-wise bounds. We do this by
277+ // defining an iterator of `values` that traverses each slice one after
278+ // another. This is equivalent to adjusting the node's shape and strides
279+ // such that the data for the bound_axis is moved to position 0.
280280 const std::vector<ssize_t > buff_shape = shift_axis_data (node_shape, bound_axis);
281281 const std::vector<ssize_t > buff_strides = shift_axis_data (node->strides (), bound_axis);
282282 // Define an iterator for `values` corresponding with the beginning of
283283 // slice 0 along the bound axis.
284284 const BufferIterator<double , double , false > slice_0_it (values.data (), ndim, buff_shape.data (),
285285 buff_strides.data ());
286- // Determine the size of each hyperslice along the bound axis.
286+ // Determine the size of each slice along the bound axis.
287287 const ssize_t slice_size = std::accumulate (buff_shape.begin () + 1 , buff_shape.end (), 1.0 ,
288288 std::multiplies<ssize_t >());
289289
290- // 3) Iterate over each hyperslice and adjust it's values until they
290+ // 3) Iterate over each slice and adjust it's values until they
291291 // satisfy the axis-wise bounds.
292292 for (ssize_t slice = 0 , stop = node_shape[bound_axis]; slice < stop; ++slice) {
293293 // Determine the amount needed to adjust the values within the slice.
@@ -297,8 +297,8 @@ void construct_state_given_exactly_one_bound_axis(const NumberNode* node,
297297 if (delta == 0 ) continue ; // Axis-wise bounds are satisfied for slice.
298298 assert (delta >= 0 ); // Should only increment.
299299
300- // Determine how much we need to offset `slice_0_it` to get to the first
301- // index in the given `slice`.
300+ // Determine how much we need to offset `slice_0_it` to get to the
301+ // first index in the given `slice`.
302302 const ssize_t offset = slice * slice_size;
303303 // Iterate over all indices in the given slice.
304304 for (auto slice_it = slice_0_it + offset, slice_end_it = slice_it + slice_size;
@@ -367,12 +367,15 @@ void NumberNode::exchange(State& state, ssize_t i, ssize_t j) const {
367367 // assert() that i and j are valid indices occurs in ptr->exchange().
368368 // State change occurs IFF (i != j) and (buffer[i] != buffer[j]).
369369 if (ptr->exchange (i, j)) {
370- // If exchange occured, update the bound axis sums.
371- const double difference = ptr->get (i) - ptr->get (j);
372- // Index i changed from (what is now) ptr->get(j) to ptr->get(i)
373- update_bound_axis_slice_sums (state, i, difference);
374- // Index j changed from (what is now) ptr->get(i) to ptr->get(j)
375- update_bound_axis_slice_sums (state, j, -difference);
370+ // No need to update slice sums as they will be unchanged.
371+ if (!bound_axis_ops_all_equals_) {
372+ // If exchange occurred, update the bound axis sums.
373+ const double difference = ptr->get (i) - ptr->get (j);
374+ // Index i changed from (what is now) ptr->get(j) to ptr->get(i)
375+ update_bound_axis_slice_sums (state, i, difference);
376+ // Index j changed from (what is now) ptr->get(i) to ptr->get(j)
377+ update_bound_axis_slice_sums (state, j, -difference);
378+ }
376379 assert (satisfies_axis_wise_bounds (bound_axes_info_, ptr->bound_axes_sums ));
377380 }
378381}
@@ -419,9 +422,9 @@ void NumberNode::clip_and_set_value(State& state, ssize_t index, double value) c
419422 auto ptr = data_ptr<NumberNodeStateData>(state);
420423 value = std::clamp (value, lower_bound (index), upper_bound (index));
421424 // assert() that i is a valid index occurs in ptr->set().
422- // State change occurs IFF `value` != buffer[index] .
425+ // State change occurs IFF `value` != buffer[index].
423426 if (ptr->set (index, value)) {
424- // If change occured , update bound axis sums by differnce .
427+ // If change occurred , update bound axis sums by difference .
425428 update_bound_axis_slice_sums (state, index, value - diff (state).back ().old );
426429 assert (satisfies_axis_wise_bounds (bound_axes_info_, ptr->bound_axes_sums ));
427430 }
@@ -447,6 +450,16 @@ double get_extreme_index_wise_bound(const std::vector<double>& bound) {
447450 return *it;
448451}
449452
453+ bool all_bound_axis_operators_are_equals (std::vector<NumberNode::AxisBound>& bound_axes_info) {
454+ for (const NumberNode::AxisBound& bound_axis_info : bound_axes_info) {
455+ for (const NumberNode::AxisBound::Operator op : bound_axis_info.operators ) {
456+ if (op != NumberNode::AxisBound::Operator::Equal) return false ;
457+ }
458+ }
459+ // Vacuously true if there are no axis-wise bounds.
460+ return true ;
461+ }
462+
450463void check_index_wise_bounds (const NumberNode& node, const std::vector<double >& lower_bounds_,
451464 const std::vector<double >& upper_bounds_) {
452465 bool index_wise_bound = false ;
@@ -534,7 +547,8 @@ NumberNode::NumberNode(std::span<const ssize_t> shape, std::vector<double> lower
534547 max_(get_extreme_index_wise_bound<true >(upper_bound)),
535548 lower_bounds_(std::move(lower_bound)),
536549 upper_bounds_(std::move(upper_bound)),
537- bound_axes_info_(std::move(bound_axes)) {
550+ bound_axes_info_(std::move(bound_axes)),
551+ bound_axis_ops_all_equals_(all_bound_axis_operators_are_equals(bound_axes_info_)) {
538552 if ((shape.size () > 0 ) && (shape[0 ] < 0 )) {
539553 throw std::invalid_argument (" Number array cannot have dynamic size." );
540554 }
@@ -549,14 +563,15 @@ NumberNode::NumberNode(std::span<const ssize_t> shape, std::vector<double> lower
549563
550564void NumberNode::update_bound_axis_slice_sums (State& state, const ssize_t index,
551565 const double value_change) const {
566+ assert (value_change != 0 ); // Should not call when no change occurs.
552567 const auto & bound_axes_info = bound_axes_info_;
553- if (bound_axes_info.size () == 0 ) return ; // No axis-wise bounds to satisfy
568+ if (bound_axes_info.size () == 0 ) return ; // No axis-wise bounds to satisfy.
554569
555570 // Get multidimensional indices for `index` so we can identify the slices
556571 // `index` lies on per bound axis.
557572 const std::vector<ssize_t > multi_index = unravel_index (index, this ->shape ());
558573 assert (bound_axes_info.size () <= multi_index.size ());
559- // Get the hyperslice sums of all bound axes.
574+ // Get the slice sums of all bound axes.
560575 auto & bound_axes_sums = data_ptr<NumberNodeStateData>(state)->bound_axes_sums ;
561576 assert (bound_axes_info.size () == bound_axes_sums.size ());
562577
@@ -676,7 +691,7 @@ void IntegerNode::set_value(State& state, ssize_t index, double value) const {
676691 // assert() that i is a valid index occurs in ptr->set().
677692 // State change occurs IFF `value` != buffer[index].
678693 if (ptr->set (index, value)) {
679- // If change occured , update bound axis sums by differnce .
694+ // If change occurred , update bound axis sums by difference .
680695 update_bound_axis_slice_sums (state, index, value - diff (state).back ().old );
681696 assert (satisfies_axis_wise_bounds (bound_axes_info_, ptr->bound_axes_sums ));
682697 }
0 commit comments