Skip to content

Commit 0fe6b4d

Browse files
committed
comlexity-cost, off-by-one bug fix, question length, refactor
1 parent 55c4349 commit 0fe6b4d

18 files changed

Lines changed: 489 additions & 312 deletions

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ Run the program without any parameters to see a full list of the available param
8686
## Parameters
8787
ConTree can be configured by the following parameters:
8888
* `max_depth` : The maximum depth of the tree. Note that a tree of depth zero has a single leaf node. A tree of depth one has one branching node and two leaf nodes.
89+
* `complexity_cost` : The cost of adding of adding a branching node (between 0 and 1).
8990
* `max_gap` : The maximum permissible gap to the optimal solution.
9091
* `max_gap_decay` : Use this parameter, if you want to find solutions iteratively, with each iteration decreasing the `max_gap` by multiplying it with `max_gap_decay`.
9192
* `time_limit` : The run time limit in seconds. If the time limit is exceeded a possibly non-optimal tree is returned.
@@ -107,3 +108,4 @@ Other related work:
107108
* Lin, Jimmy, et al. "Generalized and scalable optimal sparse decision trees." In _International Conference on Machine Learning_ (2020). [pdf](https://proceedings.mlr.press/v119/lin20g/lin20g.pdf) / [source](https://github.com/Jimmy-Lin/GeneralizedOptimalSparseDecisionTrees)
108109
* Aglin, Gaël, Siegfried Nijssen, and Pierre Schaus. "Learning optimal decision trees using caching branch-and-bound search." In _Proceedings of the AAAI conference on artificial intelligence_ (2020). [pdf](https://ojs.aaai.org/index.php/AAAI/article/download/5711/5567) / [source](https://github.com/aia-uclouvain/pydl8.5)
109110
* Mazumder, Rahul, Xiang Meng, and Haoyue Wang. "Quant-BnB: A scalable branch-and-bound method for optimal decision trees with continuous features." In _International Conference on Machine Learning_ (2022). [pdf](https://proceedings.mlr.press/v162/mazumder22a/mazumder22a.pdf) / [source](https://github.com/mengxianglgal/Quant-BnB)
111+
* Kiossou, Harold, Pierre Schaus, and Siegfried Nijssen. "Anytime Optimal Decision Tree Learning with Continuous Features." _arXiv preprint arXiv:2601.14765_ (2026). [pdf](https://arxiv.org/pdf/2601.14765) / [source](https://anonymous.4open.science/r/contree-rs-C7B8)

code/DataStructures/include/intervals_pruner.h

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <unordered_map>
66
#include <utility>
77
#include <vector>
8+
#include <cmath>
89

910
class IntervalsPruner {
1011
public:
@@ -16,14 +17,18 @@ class IntervalsPruner {
1617
*
1718
* @param possible_split_indexes_ref Reference to a vector containing possible split indices.
1819
* @param max_gap The maximum allowable gap for the solution to be considered valid (off by at most max_gap).
20+
* @param complexity_cost The cost of adding a branching node
1921
*/
20-
IntervalsPruner(const std::vector<int>& possible_split_indexes_ref, int max_gap);
22+
IntervalsPruner(const std::vector<int>& possible_split_indexes_ref, int max_gap, float complexity_cost);
2123

24+
/**
25+
* We store the intervals using four values
26+
*/
2227
struct Bound {
23-
int left_bound;
24-
int right_bound;
25-
int last_split_left_index;
26-
int last_split_right_index;
28+
int left_bound; // the left split-index bound (inclusive) of the interval
29+
int right_bound; // the right split-index bound (inclusive) of the interval
30+
int last_split_left_index; // the last split index tested to the left of the left_bound (-1 if not set)
31+
int last_split_right_index; // the last split index tested to the right of the right_bound (-1 if not set)
2732
};
2833

2934
/**
@@ -39,7 +44,7 @@ class IntervalsPruner {
3944
* @param split_index The index at which the split is evaluated.
4045
* @return A pair of integers representing the new pruned interval bounds.
4146
*/
42-
std::pair<int, int> neighbourhood_pruning(int score_difference, int left, int right, int split_index);
47+
std::pair<int, int> neighbourhood_pruning(float score_difference, int left, int right, int split_index);
4348

4449
/**
4550
* Applies subinterval pruning to determine if a given interval can be entirely pruned.
@@ -51,7 +56,7 @@ class IntervalsPruner {
5156
* @param current_best_score The best score obtained so far, used as a reference for pruning.
5257
* @return True if the subinterval can be pruned, otherwise false.
5358
*/
54-
bool subinterval_pruning(const Bound& current_bounds, int current_best_score);
59+
bool subinterval_pruning(const Bound& current_bounds, float current_best_score);
5560

5661
/**
5762
* Performs interval shrinking by narrowing the bounds of the interval based on the current best score.
@@ -63,7 +68,7 @@ class IntervalsPruner {
6368
* @param current_bounds The current bounds of the interval to be updated by shrinking
6469
* @param current_best_score The best score to compare against during the shrinking process.
6570
*/
66-
void interval_shrinking(Bound& current_bounds, int current_best_score);
71+
void interval_shrinking(Bound& current_bounds, float current_best_score);
6772

6873
/**
6974
* Records the result of a split, storing the index and associated scores.
@@ -75,15 +80,16 @@ class IntervalsPruner {
7580
* @param left_score The score associated with the left subinterval.
7681
* @param right_score The score associated with the right subinterval.
7782
*/
78-
void add_result(int index, int left_score, int right_score);
83+
void add_result(int index, float left_score, float right_score);
7984

8085
private:
8186
const std::vector<int>& possible_split_indexes;
8287
int possible_split_size;
8388
int rightmost_zero_index;
8489
int leftmost_zero_index;
85-
int max_gap;
86-
std::unordered_map<int, std::pair<int, int>> evaluated_indices_record;
90+
int max_gap;
91+
float complexity_cost;
92+
std::unordered_map<int, std::pair<float, float>> evaluated_indices_record;
8793
};
8894

8995
#endif // INTERVALS_PRUNER_H

code/DataStructures/include/tree.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@
99

1010
struct Tree {
1111
Tree();
12-
Tree(int label, int misclassifications);
13-
Tree(int split_feature, float split_threshold, const std::shared_ptr<Tree>& left, const std::shared_ptr<Tree>& right);
12+
Tree(int label, float objective);
13+
Tree(int split_feature, float split_threshold, const std::shared_ptr<Tree>& left, const std::shared_ptr<Tree>& right, float complexity_cost);
1414

1515
std::shared_ptr<Tree> left{ nullptr }, right{ nullptr };
1616
int split_feature = -1, label = -1;
1717
float split_threshold = 0.0;
1818

19-
int misclassification_score = INT_MAX;
19+
float objective = INT_MAX;
2020

2121
bool is_leaf() const;
2222
bool is_internal() const;
@@ -33,7 +33,7 @@ struct Tree {
3333
inline std::shared_ptr<Tree> get_right_tree() const { return right; }
3434

3535
void make_leaf(int label, int misclassifications);
36-
void update_split(int split_feature, float split_threshold, const std::shared_ptr<Tree>& left, const std::shared_ptr<Tree>& right);
36+
void update_split(int split_feature, float split_threshold, const std::shared_ptr<Tree>& left, const std::shared_ptr<Tree>& right, float complexity_cost);
3737

3838
std::string to_string(int indent = 0) const;
3939
};

code/DataStructures/src/dataview.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ Dataview::Dataview(Dataset* sorted_dataset, Dataset* unsorted_dataset, int class
3232
for (int i = 0; i < first_feature.size() - 1; i++) {
3333
right_label_frequency[first_feature[i].label]--;
3434
left_label_frequency[first_feature[i].label]++;
35+
36+
// we cannot compute a gini value if there is no split here
37+
if (first_feature[i].unique_value_index == first_feature[i + 1].unique_value_index) continue;
38+
3539

3640
float left_gini = 1.0f; float right_gini = 1.0f;
3741
int left_count = i + 1; int right_count = int(first_feature.size()) - left_count;
@@ -78,10 +82,13 @@ Dataview::Dataview(Dataset* sorted_dataset, Dataset* unsorted_dataset, int class
7882
right_label_frequency[current_feature[feature_element_idx].label]--;
7983
left_label_frequency[current_feature[feature_element_idx].label]++;
8084

85+
// we cannot compute a gini value if there is no split here
86+
if (feature_element_idx < current_feature.size() - 1
87+
&& current_feature[feature_element_idx].unique_value_index == current_feature[feature_element_idx + 1].unique_value_index) continue;
88+
8189
float left_gini = 1.0f; float right_gini = 1.0f;
8290
int left_count = feature_element_idx + 1; int right_count = int(current_feature.size()) - left_count;
8391

84-
8592
for (int label = 0; label < class_number; label++) {
8693
if (left_count > 0) {
8794
float left_probability = static_cast<float>(left_label_frequency[label]) / static_cast<float>(left_count);
@@ -142,7 +149,8 @@ const std::vector<int>& Dataview::get_possible_split_indices(int feature_index)
142149
return possible_split_indices[feature_index];
143150
}
144151

145-
void Dataview::split_data_points(const Dataview& current_dataview, int feature_index, int split_point, int split_unique_value_index, Dataview& left_dataview, Dataview& right_dataview, int current_max_depth) {
152+
void Dataview::split_data_points(const Dataview& current_dataview, int feature_index, int split_point, int split_unique_value_index,
153+
Dataview& left_dataview, Dataview& right_dataview, int current_max_depth) {
146154
left_dataview.feature_data.resize(current_dataview.get_feature_number());
147155
right_dataview.feature_data.resize(current_dataview.get_feature_number());
148156

0 commit comments

Comments
 (0)