Skip to content

Commit ffdae9a

Browse files
committed
New names for NumberNode bound axis data
`BoundAxisInfo` -> `AxisBound` and `BoundAxisOperator` -> `Operator`. `Operator` is now a nested enum classs of `AxisBound`.
1 parent 23834b6 commit ffdae9a

5 files changed

Lines changed: 173 additions & 178 deletions

File tree

dwave/optimization/include/dwave-optimization/nodes/numbers.hpp

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -28,23 +28,24 @@ namespace dwave::optimization {
2828
/// A contiguous block of numbers.
2929
class NumberNode : public ArrayOutputMixin<ArrayNode>, public DecisionNode {
3030
public:
31-
/// Allowable axis-wise bound operators.
32-
enum BoundAxisOperator { Equal, LessEqual, GreaterEqual };
33-
3431
/// Struct for stateless axis-wise bound information. Given an `axis`, define
3532
/// constraints on the sum of the values in each slice along `axis`.
3633
/// Constraints can be defined for ALL slices along `axis` or PER slice along
37-
/// `axis`. Allowable operators are defined by `BoundAxisOperator`.
38-
struct BoundAxisInfo {
34+
/// `axis`. Allowable operators are defined by `Operator`.
35+
struct AxisBound {
36+
/// Allowable axis-wise bound operators.
37+
enum class Operator { Equal, LessEqual, GreaterEqual };
38+
3939
/// To reduce the # of `IntegerNode` and `BinaryNode` constructors, we
4040
/// allow only one constructor.
41-
BoundAxisInfo(ssize_t axis, std::vector<BoundAxisOperator> axis_operators,
42-
std::vector<double> axis_bounds);
41+
AxisBound(ssize_t axis, std::vector<Operator> axis_operators,
42+
std::vector<double> axis_bounds);
43+
4344
/// The bound axis
4445
ssize_t axis;
4546
/// Operator for ALL axis slices (vector has length one) or operators PER
4647
/// slice (length of vector is equal to the number of slices).
47-
std::vector<BoundAxisOperator> operators;
48+
std::vector<Operator> operators;
4849
/// Bound for ALL axis slices (vector has length one) or bounds PER slice
4950
/// (length of vector is equal to the number of slices).
5051
std::vector<double> bounds;
@@ -53,7 +54,7 @@ class NumberNode : public ArrayOutputMixin<ArrayNode>, public DecisionNode {
5354
double get_bound(const ssize_t slice) const;
5455

5556
/// Obtain the operator associated with a given slice along `axis`.
56-
BoundAxisOperator get_operator(const ssize_t slice) const;
57+
Operator get_operator(const ssize_t slice) const;
5758
};
5859

5960
NumberNode() = delete;
@@ -140,16 +141,15 @@ class NumberNode : public ArrayOutputMixin<ArrayNode>, public DecisionNode {
140141
void clip_and_set_value(State& state, ssize_t index, double value) const;
141142

142143
/// Return the stateless axis-wise bound information i.e. bound_axes_info_.
143-
const std::vector<BoundAxisInfo>& axis_wise_bounds() const;
144+
const std::vector<AxisBound>& axis_wise_bounds() const;
144145

145146
/// Return the state-dependent sum of the values within each hyperslice
146147
/// along each bound axis.
147148
const std::vector<std::vector<double>>& bound_axis_sums(State& state) const;
148149

149150
protected:
150151
explicit NumberNode(std::span<const ssize_t> shape, std::vector<double> lower_bound,
151-
std::vector<double> upper_bound,
152-
std::vector<BoundAxisInfo> bound_axes = {});
152+
std::vector<double> upper_bound, std::vector<AxisBound> bound_axes = {});
153153

154154
// Return truth statement: 'value is valid in a given index'.
155155
virtual bool is_valid(ssize_t index, double value) const = 0;
@@ -171,7 +171,7 @@ class NumberNode : public ArrayOutputMixin<ArrayNode>, public DecisionNode {
171171
std::vector<double> upper_bounds_;
172172

173173
/// Stateless information on each bound axis.
174-
std::vector<BoundAxisInfo> bound_axes_info_;
174+
std::vector<AxisBound> bound_axes_info_;
175175
};
176176

177177
/// A contiguous block of integer numbers.
@@ -191,39 +191,39 @@ class IntegerNode : public NumberNode {
191191
IntegerNode(std::span<const ssize_t> shape,
192192
std::optional<std::vector<double>> lower_bound = std::nullopt,
193193
std::optional<std::vector<double>> upper_bound = std::nullopt,
194-
std::vector<BoundAxisInfo> bound_axes = {});
194+
std::vector<AxisBound> bound_axes = {});
195195
IntegerNode(std::initializer_list<ssize_t> shape,
196196
std::optional<std::vector<double>> lower_bound = std::nullopt,
197197
std::optional<std::vector<double>> upper_bound = std::nullopt,
198-
std::vector<BoundAxisInfo> bound_axes = {});
198+
std::vector<AxisBound> bound_axes = {});
199199
IntegerNode(ssize_t size, std::optional<std::vector<double>> lower_bound = std::nullopt,
200200
std::optional<std::vector<double>> upper_bound = std::nullopt,
201-
std::vector<BoundAxisInfo> bound_axes = {});
201+
std::vector<AxisBound> bound_axes = {});
202202

203203
IntegerNode(std::span<const ssize_t> shape, double lower_bound,
204204
std::optional<std::vector<double>> upper_bound = std::nullopt,
205-
std::vector<BoundAxisInfo> bound_axes = {});
205+
std::vector<AxisBound> bound_axes = {});
206206
IntegerNode(std::initializer_list<ssize_t> shape, double lower_bound,
207207
std::optional<std::vector<double>> upper_bound = std::nullopt,
208-
std::vector<BoundAxisInfo> bound_axes = {});
208+
std::vector<AxisBound> bound_axes = {});
209209
IntegerNode(ssize_t size, double lower_bound,
210210
std::optional<std::vector<double>> upper_bound = std::nullopt,
211-
std::vector<BoundAxisInfo> bound_axes = {});
211+
std::vector<AxisBound> bound_axes = {});
212212

213213
IntegerNode(std::span<const ssize_t> shape, std::optional<std::vector<double>> lower_bound,
214-
double upper_bound, std::vector<BoundAxisInfo> bound_axes = {});
214+
double upper_bound, std::vector<AxisBound> bound_axes = {});
215215
IntegerNode(std::initializer_list<ssize_t> shape,
216216
std::optional<std::vector<double>> lower_bound, double upper_bound,
217-
std::vector<BoundAxisInfo> bound_axes = {});
217+
std::vector<AxisBound> bound_axes = {});
218218
IntegerNode(ssize_t size, std::optional<std::vector<double>> lower_bound, double upper_bound,
219-
std::vector<BoundAxisInfo> bound_axes = {});
219+
std::vector<AxisBound> bound_axes = {});
220220

221221
IntegerNode(std::span<const ssize_t> shape, double lower_bound, double upper_bound,
222-
std::vector<BoundAxisInfo> bound_axes = {});
222+
std::vector<AxisBound> bound_axes = {});
223223
IntegerNode(std::initializer_list<ssize_t> shape, double lower_bound, double upper_bound,
224-
std::vector<BoundAxisInfo> bound_axes = {});
224+
std::vector<AxisBound> bound_axes = {});
225225
IntegerNode(ssize_t size, double lower_bound, double upper_bound,
226-
std::vector<BoundAxisInfo> bound_axes = {});
226+
std::vector<AxisBound> bound_axes = {});
227227

228228
// Overloads needed by the Node ABC ***************************************
229229

@@ -259,38 +259,38 @@ class BinaryNode : public IntegerNode {
259259
BinaryNode(std::span<const ssize_t> shape,
260260
std::optional<std::vector<double>> lower_bound = std::nullopt,
261261
std::optional<std::vector<double>> upper_bound = std::nullopt,
262-
std::vector<BoundAxisInfo> bound_axes = {});
262+
std::vector<AxisBound> bound_axes = {});
263263
BinaryNode(std::initializer_list<ssize_t> shape,
264264
std::optional<std::vector<double>> lower_bound = std::nullopt,
265265
std::optional<std::vector<double>> upper_bound = std::nullopt,
266-
std::vector<BoundAxisInfo> bound_axes = {});
266+
std::vector<AxisBound> bound_axes = {});
267267
BinaryNode(ssize_t size, std::optional<std::vector<double>> lower_bound = std::nullopt,
268268
std::optional<std::vector<double>> upper_bound = std::nullopt,
269-
std::vector<BoundAxisInfo> bound_axes = {});
269+
std::vector<AxisBound> bound_axes = {});
270270

271271
BinaryNode(std::span<const ssize_t> shape, double lower_bound,
272272
std::optional<std::vector<double>> upper_bound = std::nullopt,
273-
std::vector<BoundAxisInfo> bound_axes = {});
273+
std::vector<AxisBound> bound_axes = {});
274274
BinaryNode(std::initializer_list<ssize_t> shape, double lower_bound,
275275
std::optional<std::vector<double>> upper_bound = std::nullopt,
276-
std::vector<BoundAxisInfo> bound_axes = {});
276+
std::vector<AxisBound> bound_axes = {});
277277
BinaryNode(ssize_t size, double lower_bound,
278278
std::optional<std::vector<double>> upper_bound = std::nullopt,
279-
std::vector<BoundAxisInfo> bound_axes = {});
279+
std::vector<AxisBound> bound_axes = {});
280280

281281
BinaryNode(std::span<const ssize_t> shape, std::optional<std::vector<double>> lower_bound,
282-
double upper_bound, std::vector<BoundAxisInfo> bound_axes = {});
282+
double upper_bound, std::vector<AxisBound> bound_axes = {});
283283
BinaryNode(std::initializer_list<ssize_t> shape, std::optional<std::vector<double>> lower_bound,
284-
double upper_bound, std::vector<BoundAxisInfo> bound_axes = {});
284+
double upper_bound, std::vector<AxisBound> bound_axes = {});
285285
BinaryNode(ssize_t size, std::optional<std::vector<double>> lower_bound, double upper_bound,
286-
std::vector<BoundAxisInfo> bound_axes = {});
286+
std::vector<AxisBound> bound_axes = {});
287287

288288
BinaryNode(std::span<const ssize_t> shape, double lower_bound, double upper_bound,
289-
std::vector<BoundAxisInfo> bound_axes = {});
289+
std::vector<AxisBound> bound_axes = {});
290290
BinaryNode(std::initializer_list<ssize_t> shape, double lower_bound, double upper_bound,
291-
std::vector<BoundAxisInfo> bound_axes = {});
291+
std::vector<AxisBound> bound_axes = {});
292292
BinaryNode(ssize_t size, double lower_bound, double upper_bound,
293-
std::vector<BoundAxisInfo> bound_axes = {});
293+
std::vector<AxisBound> bound_axes = {});
294294

295295
// Flip the value (0 -> 1 or 1 -> 0) at index i in the given state.
296296
void flip(State& state, ssize_t i) const;

dwave/optimization/libcpp/nodes/numbers.pxd

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,26 +21,26 @@ from dwave.optimization.libcpp.state cimport State
2121
cdef extern from "dwave-optimization/nodes/numbers.hpp" namespace "dwave::optimization" nogil:
2222

2323
cdef cppclass NumberNode(ArrayNode):
24-
enum BoundAxisOperator :
24+
struct AxisBound:
2525
# It appears Cython automatically assumes all (standard) enums are "public".
26-
# Because of this, these very explict overrides are needed per enum item.
27-
Equal "dwave::optimization::NumberNode::BoundAxisOperator::Equal"
28-
LessEqual "dwave::optimization::NumberNode::BoundAxisOperator::LessEqual"
29-
GreaterEqual "dwave::optimization::NumberNode::BoundAxisOperator::GreaterEqual"
26+
# Because of this, we use this very explict override.
27+
enum class Operator "dwave::optimization::NumberNode::AxisBound::Operator":
28+
Equal
29+
LessEqual
30+
GreaterEqual
3031

31-
struct BoundAxisInfo:
32-
BoundAxisInfo(Py_ssize_t axis, vector[BoundAxisOperator] axis_opertors,
32+
AxisBound(Py_ssize_t axis, vector[Operator] axis_opertors,
3333
vector[double] axis_bounds)
3434
Py_ssize_t axis
35-
vector[BoundAxisOperator] operators;
35+
vector[Operator] operators;
3636
vector[double] bounds;
3737

3838
void initialize_state(State&, vector[double]) except+
3939
double lower_bound(Py_ssize_t index)
4040
double upper_bound(Py_ssize_t index)
4141
double lower_bound() except+
4242
double upper_bound() except+
43-
const vector[BoundAxisInfo] axis_wise_bounds()
43+
const vector[AxisBound] axis_wise_bounds()
4444

4545
cdef cppclass IntegerNode(NumberNode):
4646
pass

0 commit comments

Comments
 (0)