Skip to content

Commit 48008a1

Browse files
authored
Merge pull request #1362 from arcondello/QuadraticModelBase.remove_variables
Add QuadraticModelBase::remove_variables() method
2 parents a37defc + 0744b5b commit 48008a1

8 files changed

Lines changed: 247 additions & 16 deletions

File tree

dimod/include/dimod/abc.h

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include <utility>
2323
#include <vector>
2424

25+
#include "dimod/utils.h"
2526
#include "dimod/vartypes.h"
2627

2728
namespace dimod {
@@ -338,6 +339,9 @@ class QuadraticModelBase {
338339
*/
339340
virtual void remove_variable(index_type v);
340341

342+
/// Remove multiple variables from the model and reindex accordingly.
343+
virtual void remove_variables(const std::vector<index_type>& variables);
344+
341345
/// Multiply all biases by the value of `scalar`.
342346
void scale(bias_type scalar);
343347

@@ -918,6 +922,58 @@ void QuadraticModelBase<bias_type, index_type>::remove_variable(index_type v) {
918922
}
919923
}
920924

925+
template <class bias_type, class index_type>
926+
void QuadraticModelBase<bias_type, index_type>::remove_variables(
927+
const std::vector<index_type>& variables) {
928+
if (!variables.size()) return; // shortcut
929+
930+
if (!std::is_sorted(variables.begin(), variables.end())) {
931+
// create a copy and sort it
932+
std::vector<index_type> sorted_indices = variables;
933+
std::sort(sorted_indices.begin(), sorted_indices.end());
934+
QuadraticModelBase<bias_type, index_type>::remove_variables(sorted_indices);
935+
return;
936+
}
937+
938+
linear_biases_.erase(utils::remove_by_index(linear_biases_.begin(), linear_biases_.end(),
939+
variables.begin(), variables.end()),
940+
linear_biases_.end());
941+
942+
if (has_adj()) {
943+
// clean up the remaining neighborhoods
944+
// in this case we need a reindexing scheme, so we do the expensive O(num_variables)
945+
// thing once to save time later on
946+
std::vector<int> reindex(adj_ptr_->size());
947+
for (const auto& v : variables) {
948+
if (v > static_cast<int>(reindex.size())) break; // we can break because it's sorted
949+
reindex[v] = -1;
950+
}
951+
int label = 0;
952+
for (auto& v : reindex) {
953+
if (v == -1) continue; // the removed variables
954+
v = label;
955+
++label;
956+
}
957+
958+
// remove the relevant neighborhoods
959+
adj_ptr_->erase(utils::remove_by_index(adj_ptr_->begin(), adj_ptr_->end(), variables.begin(),
960+
variables.end()),
961+
adj_ptr_->end());
962+
963+
// now go through and adjust the remaining neighborhoods
964+
auto pred = [&reindex](OneVarTerm<bias_type, index_type>& term) {
965+
if (reindex[term.v] == -1) return true; // remove
966+
// otherwise apply the new label
967+
term.v = reindex[term.v];
968+
return false;
969+
};
970+
for (auto& n : *adj_ptr_) {
971+
// we modify the indices and remove the variables we need to remove
972+
n.erase(std::remove_if(n.begin(), n.end(), pred), n.end());
973+
}
974+
}
975+
}
976+
921977
template <class bias_type, class index_type>
922978
void QuadraticModelBase<bias_type, index_type>::resize(index_type n) {
923979
assert(n >= 0);

dimod/include/dimod/expression.h

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414

1515
#pragma once
1616

17+
#include <algorithm>
1718
#include <limits>
1819
#include <unordered_map>
1920
#include <unordered_set>
2021
#include <utility>
2122
#include <vector>
2223

2324
#include "dimod/abc.h"
25+
#include "dimod/utils.h"
2426
#include "dimod/vartypes.h"
2527

2628
namespace dimod {
@@ -272,6 +274,10 @@ class Expression : public abc::QuadraticModelBase<Bias, Index> {
272274
template<class Iter>
273275
void remove_variables(Iter first, Iter last);
274276

277+
void remove_variables(const std::vector<index_type>& variables) {
278+
return remove_variables(variables.begin(), variables.end());
279+
}
280+
275281
/// Set the linear bias of variable `v`.
276282
void set_linear(index_type v, bias_type bias);
277283

@@ -632,27 +638,23 @@ void Expression<bias_type, index_type>::remove_variable(index_type v) {
632638
template <class bias_type, class index_type>
633639
template <class Iter>
634640
void Expression<bias_type, index_type>::remove_variables(Iter first, Iter last) {
635-
std::unordered_set<index_type> to_remove;
641+
// get the indices of any variables that need to be removed
642+
std::vector<index_type> to_remove;
636643
for (auto it = first; it != last; ++it) {
637-
if (indices_.find(*it) != indices_.end()) {
638-
to_remove.emplace(*it);
644+
auto search = indices_.find(*it);
645+
if (search != indices_.end()) {
646+
to_remove.emplace_back(search->second);
639647
}
640648
}
649+
std::sort(to_remove.begin(), to_remove.end());
641650

642-
if (!to_remove.size()) {
643-
return; // nothing to remove
644-
}
651+
// remove the indices from variables_ and the underlying
652+
variables_.erase(utils::remove_by_index(variables_.begin(), variables_.end(), to_remove.begin(),
653+
to_remove.end()),
654+
variables_.end());
645655

646-
// now remove any variables found in to_remove
647-
size_type i = 0;
648-
while (i < this->num_variables()) {
649-
if (to_remove.count(variables_[i])) {
650-
base_type::remove_variable(i);
651-
variables_.erase(variables_.begin() + i);
652-
} else {
653-
++i;
654-
}
655-
}
656+
// remove the indices from the underlying quadratic model
657+
base_type::remove_variables(to_remove);
656658

657659
// finally fix the indices by rebuilding from scratch
658660
indices_.clear();

dimod/include/dimod/quadratic_model.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#pragma once
1616

17+
#include <algorithm>
1718
#include <stdexcept>
1819
#include <utility>
1920
#include <vector>
@@ -87,6 +88,9 @@ class QuadraticModel : public abc::QuadraticModelBase<Bias, Index> {
8788
/// Remove variable `v`.
8889
void remove_variable(index_type v);
8990

91+
/// Remove variables.
92+
void remove_variables(const std::vector<index_type>& variables);
93+
9094
// Resize the model to contain `n` variables.
9195
void resize(index_type n);
9296

@@ -269,6 +273,19 @@ void QuadraticModel<bias_type, index_type>::remove_variable(index_type v) {
269273
varinfo_.erase(varinfo_.begin() + v);
270274
}
271275

276+
template <class bias_type, class index_type>
277+
void QuadraticModel<bias_type, index_type>::remove_variables(const std::vector<index_type>& variables) {
278+
if (!std::is_sorted(variables.begin(), variables.end())) {
279+
// create a copy and sort it
280+
std::vector<index_type> sorted_indices = variables;
281+
std::sort(sorted_indices.begin(), sorted_indices.end());
282+
QuadraticModel<bias_type, index_type>::remove_variables(sorted_indices);
283+
return;
284+
}
285+
base_type::remove_variables(variables);
286+
varinfo_.erase(utils::remove_by_index(varinfo_.begin(), varinfo_.end(), variables.begin(), variables.end()), varinfo_.end());
287+
}
288+
272289
template <class bias_type, class index_type>
273290
void QuadraticModel<bias_type, index_type>::resize(index_type n) {
274291
// we could do this as an assert, but let's be careful since

dimod/include/dimod/utils.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,34 @@
2222
namespace dimod {
2323
namespace utils {
2424

25+
// Remove all elements in the range defined by vfirst to vlast at indices
26+
// specified by ifirst to ilast.
27+
// All iterators must be forward iterators
28+
// Indices must be non-negative, sorted, and unique.
29+
template <class ValueIter, class IndexIter>
30+
ValueIter remove_by_index(ValueIter vfirst, ValueIter vlast, IndexIter ifirst, IndexIter ilast) {
31+
assert(std::is_sorted(ifirst, ilast));
32+
assert((ifirst == ilast || *ifirst >= 0));
33+
34+
using value_type = typename std::iterator_traits<ValueIter>::value_type;
35+
36+
typename std::iterator_traits<IndexIter>::value_type loc = 0; // location in the values
37+
IndexIter it = ifirst;
38+
auto pred = [&](const value_type&) {
39+
if (it != ilast && *it == loc) {
40+
++loc;
41+
++it;
42+
return true;
43+
} else {
44+
++loc;
45+
return false;
46+
}
47+
};
48+
49+
// relies on this being executed sequentially
50+
return std::remove_if(vfirst, vlast, pred);
51+
}
52+
2553
// zip_sort is a modification of the code found here :
2654
// https://www.geeksforgeeks.org/iterative-quick-sort/
2755

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
features:
3+
- Add C++ ``dimod::abc::QuadraticModelBase::remove_variables()`` method and accompanying overloads.
4+
- Speed up C++ ``dimod::Expression::remove_variables()`` method.

testscpp/tests/test_binary_quadratic_model.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,19 @@ TEST_CASE("BinaryQuadraticModel tests") {
3838
}
3939
}
4040

41+
WHEN("we use remove_variables()") {
42+
bqm.remove_variables(std::vector<int>{3, 1});
43+
44+
THEN("The variables are removed and the model is reindexed") {
45+
REQUIRE(bqm.num_variables() == 3);
46+
REQUIRE(bqm.num_interactions() == 0);
47+
CHECK(bqm.linear(0) == 0);
48+
CHECK(bqm.linear(1) == 2); // this was reindexed
49+
CHECK(bqm.linear(2) == 4); // this was reindexed twice
50+
CHECK(bqm.offset() == 5);
51+
}
52+
}
53+
4154
WHEN("we use fix_variable()") {
4255
bqm.fix_variable(2, -1);
4356
THEN("the variable is removed, its biases distributed and the model is reindexed") {
@@ -94,6 +107,52 @@ TEST_CASE("BinaryQuadraticModel tests") {
94107
}
95108
}
96109

110+
WHEN("we use remove_variables()") {
111+
bqm.remove_variables(std::vector<int>{3, 1});
112+
113+
THEN("The variables are removed and the model is reindexed") {
114+
REQUIRE(bqm.num_variables() == 3);
115+
CHECK(bqm.linear(0) == 0);
116+
CHECK(bqm.linear(1) == 2); // this was reindexed
117+
CHECK(bqm.linear(2) == 4); // this was reindexed twice
118+
CHECK(bqm.offset() == 5);
119+
CHECK(bqm.num_interactions() == 0); // no remaining quadratic
120+
}
121+
}
122+
123+
WHEN("we use remove_variables() to remove one variable") {
124+
bqm.remove_variables({2});
125+
126+
THEN("everything is reindexed") {
127+
REQUIRE(bqm.num_variables() == 4);
128+
REQUIRE(bqm.num_interactions() == 2);
129+
CHECK(bqm.linear(0) == 0);
130+
CHECK(bqm.linear(1) == -1);
131+
CHECK(bqm.linear(2) == -3); // this was reindexed
132+
CHECK(bqm.linear(3) == 4); // this was reindexed
133+
CHECK(bqm.quadratic(0, 1) == 1);
134+
CHECK(bqm.quadratic(2, 3) == 34); // this was reindexed
135+
CHECK(bqm.offset() == 5);
136+
}
137+
}
138+
139+
WHEN("we use remove_variables() to remove no variables") {
140+
bqm.remove_variables({});
141+
142+
THEN("nothing has changed") {
143+
REQUIRE(bqm.num_variables() == 5);
144+
REQUIRE(bqm.num_interactions() == 4);
145+
CHECK(bqm.linear(0) == 0);
146+
CHECK(bqm.linear(1) == -1);
147+
CHECK(bqm.linear(2) == 2);
148+
CHECK(bqm.linear(3) == -3);
149+
CHECK(bqm.linear(4) == 4);
150+
CHECK(bqm.quadratic(0, 1) == 1);
151+
CHECK(bqm.quadratic(3, 4) == 34);
152+
CHECK(bqm.offset() == 5);
153+
}
154+
}
155+
97156
AND_GIVEN("another identical BQM") {
98157
auto bqm2 = BinaryQuadraticModel<double>(5, Vartype::SPIN);
99158
bqm2.set_offset(5);

testscpp/tests/test_quadratic_model.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,28 @@ TEST_CASE("QuadraticModel tests") {
5353
}
5454
}
5555

56+
WHEN("we use remove_variables()") {
57+
qm.remove_variables({2});
58+
59+
THEN("the variable is removed and the model is reindexed") {
60+
REQUIRE(qm.num_variables() == 4);
61+
REQUIRE(qm.num_interactions() == 0);
62+
CHECK(qm.offset() == 5);
63+
CHECK(qm.linear(0) == 0);
64+
CHECK(qm.linear(1) == -1);
65+
CHECK(qm.linear(2) == -3); // this was reindexed
66+
CHECK(qm.linear(3) == 4); // this was reindexed
67+
CHECK(qm.vartype(0) == Vartype::BINARY);
68+
CHECK(qm.vartype(1) == Vartype::INTEGER);
69+
CHECK(qm.vartype(2) == Vartype::REAL); // this was reindexed
70+
CHECK(qm.vartype(3) == Vartype::SPIN); // this was reindexed
71+
CHECK(qm.lower_bound(1) == -1);
72+
CHECK(qm.lower_bound(2) == -3); // this was reindexed
73+
CHECK(qm.upper_bound(1) == 1);
74+
CHECK(qm.upper_bound(2) == 3); // this was reindexed
75+
}
76+
}
77+
5678
WHEN("we use fix_variable()") {
5779
qm.fix_variable(2, -1);
5880
THEN("the variable is removed, its biases distributed and the model is reindexed") {

testscpp/tests/test_utils.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,56 @@
1414

1515
#include <iostream>
1616
#include <random>
17+
#include <vector>
1718

1819
#include "catch2/catch.hpp"
1920
#include "dimod/utils.h"
2021

2122
namespace dimod {
2223
namespace utils {
2324

25+
TEST_CASE("remove_by_index()") {
26+
GIVEN("A vector") {
27+
auto v = std::vector<int>{0, 1, 2, 3, 4, 5, 6};
28+
29+
AND_GIVEN("some indices") {
30+
auto i = std::vector<int>{1, 3, 4};
31+
32+
WHEN("We use remove_by_index() to shrink the vector") {
33+
v.erase(remove_by_index(v.begin(), v.end(), i.begin(), i.end()), v.end());
34+
35+
THEN("The vector has the values we expect") {
36+
REQUIRE_THAT(v, Catch::Approx(std::vector<int>{0, 2, 5, 6}));
37+
}
38+
}
39+
}
40+
41+
AND_GIVEN("Some indices that are out-of-range") {
42+
auto i = std::vector<int>{5, 105};
43+
44+
WHEN("We use remove_by_index() to shrink the vector") {
45+
v.erase(remove_by_index(v.begin(), v.end(), i.begin(), i.end()), v.end());
46+
47+
THEN("The vector has the values we expect") {
48+
REQUIRE_THAT(v, Catch::Approx(std::vector<int>{0, 1, 2, 3, 4, 6}));
49+
}
50+
}
51+
}
52+
53+
AND_GIVEN("An empty indices vector") {
54+
auto i = std::vector<int>{};
55+
56+
WHEN("We use remove_by_index() to shrink the vector") {
57+
v.erase(remove_by_index(v.begin(), v.end(), i.begin(), i.end()), v.end());
58+
59+
THEN("The vector has the values we expect") {
60+
REQUIRE_THAT(v, Catch::Approx(std::vector<int>{0, 1, 2, 3, 4, 5, 6}));
61+
}
62+
}
63+
}
64+
}
65+
}
66+
2467
TEST_CASE("Two vectors are zip-sorted", "[utils]") {
2568
std::default_random_engine generator;
2669
std::uniform_int_distribution<int> int_distribution(0, 100);

0 commit comments

Comments
 (0)