@@ -331,6 +331,12 @@ class QuadraticModelBase {
331331 // / Remove the interaction between variables `u` and `v`.
332332 bool remove_interaction (index_type u, index_type v);
333333
334+ // / Remove all interactions for which `filter` returns `true`.
335+ // / Returns the number of interactions removed.
336+ // / `filter` must be symmetric. That is `filter(u, v, bias)` must equal `filter(v, u, bias)`.
337+ template <class Filter >
338+ size_type remove_interactions (Filter filter);
339+
334340 /* *
335341 * Remove variable `v` from the model.
336342 *
@@ -889,6 +895,35 @@ bool QuadraticModelBase<bias_type, index_type>::remove_interaction(index_type u,
889895 return false ;
890896}
891897
898+ template <class bias_type , class index_type >
899+ template <class Filter >
900+ std::size_t QuadraticModelBase<bias_type, index_type>::remove_interactions(Filter filter) {
901+ if (!has_adj ()) return 0 ; // nothing to filter
902+
903+ std::size_t num_removed = 0 ;
904+
905+ index_type u = 0 ;
906+ for (auto & n : *adj_ptr_) {
907+ auto it = std::remove_if (n.begin (), n.end (),
908+ [&u, &filter](const OneVarTerm<bias_type, index_type>& term) {
909+ const index_type& v = term.v ;
910+ const bias_type& bias = term.bias ;
911+ assert (filter (u, v, bias) == filter (v, u, bias));
912+ return filter (u, v, bias);
913+ });
914+
915+ num_removed += n.end () - it;
916+
917+ n.erase (it, n.end ());
918+
919+ u += 1 ;
920+ }
921+
922+ assert (num_removed % 2 == 0 );
923+
924+ return num_removed / 2 ;
925+ }
926+
892927template <class bias_type , class index_type >
893928void QuadraticModelBase<bias_type, index_type>::remove_variable(index_type v) {
894929 assert (0 <= v && static_cast <size_type>(v) < num_variables ());
0 commit comments