Skip to content

Commit d89e1f3

Browse files
committed
bug fix first value is -1.0, threshold for first split point
1 parent 1c29b0b commit d89e1f3

5 files changed

Lines changed: 51 additions & 10 deletions

File tree

code/DataStructures/src/dataset.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,14 @@ void Dataset::compute_unique_value_indices() {
4242
std::sort(idx.begin(), idx.end(),
4343
[&cur_feature_data](size_t i1, size_t i2) {return cur_feature_data[i1].value < cur_feature_data[i2].value;});
4444
double prev = -1.0f;
45+
bool first = true;
4546
int cur_unique_value_index = -1;
4647
for (size_t ix : idx) {
4748
auto& cur_feature_element = cur_feature_data[ix];
48-
if (prev == -1.0f || cur_feature_element.value - prev >= EPSILON) cur_unique_value_index++;
49+
if (first || cur_feature_element.value - prev >= EPSILON) cur_unique_value_index++;
4950
cur_feature_element.unique_value_index = cur_unique_value_index;
5051
prev = cur_feature_element.value;
52+
first = false;
5153
}
5254

5355
}

code/DataStructures/src/tree.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ int Tree::get_num_branching_nodes() const {
3232
}
3333

3434
void Tree::make_leaf(int label, int misclassifications) {
35+
RUNTIME_ASSERT(misclassifications >= 0, "Leaf should have non-negative misclassifications.");
36+
RUNTIME_ASSERT(label >= 0, "Leaf label should be non-negative.");
3537
this->label = label;
3638
this->misclassification_score = misclassifications;
3739
left = nullptr;

code/Engine/src/general_solver.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ void GeneralSolver::create_optimal_decision_tree(const Dataview& dataview, const
5353
while(!unsearched_intervals.empty()) {
5454
if (!solution_configuration.stopwatch.IsWithinTimeLimit()) return;
5555
auto current_interval = unsearched_intervals.front(); unsearched_intervals.pop();
56-
56+
5757
if (interval_pruner.subinterval_pruning(current_interval, current_optimal_decision_tree->misclassification_score)) {
5858
continue;
5959
}
@@ -70,7 +70,7 @@ void GeneralSolver::create_optimal_decision_tree(const Dataview& dataview, const
7070
const int interval_half_distance = std::max(split_point - possible_split_indices[left], possible_split_indices[right] - split_point);
7171

7272
const float threshold = mid > 0 ? (current_feature[possible_split_indices[mid - 1]].value + current_feature[split_point].value) / 2.0f
73-
: current_feature[split_point].value - 2 * EPSILON;
73+
: (current_feature[split_point].value + current_feature[0].value) / 2.0f;
7474

7575
Dataview left_dataview = Dataview(dataview.get_class_number(), dataview.should_sort_by_gini_index());
7676
Dataview right_dataview = Dataview(dataview.get_class_number(), dataview.should_sort_by_gini_index());
@@ -101,11 +101,16 @@ void GeneralSolver::create_optimal_decision_tree(const Dataview& dataview, const
101101
statistics::total_number_of_general_solver_calls += 1;
102102
const Configuration right_solution_configuration = solution_configuration.GetRightSubtreeConfig(left_solution_configuration.max_gap);
103103
GeneralSolver::create_optimal_decision_tree(smaller_data, right_solution_configuration, smaller_optimal_dt, smaller_ub);
104-
104+
RUNTIME_ASSERT(left_optimal_dt->misclassification_score >= 0, "Left tree should have non-negative misclassification score.");
105+
RUNTIME_ASSERT(right_optimal_dt->misclassification_score >= 0, "Right tree should have non-negative misclassification score.");
106+
105107
const int current_best_score = left_optimal_dt->misclassification_score + right_optimal_dt->misclassification_score;
108+
106109
if (current_best_score < current_optimal_decision_tree->misclassification_score) {
107-
current_optimal_decision_tree->misclassification_score = current_best_score;
110+
RUNTIME_ASSERT(left_optimal_dt->is_initialized(), "Left tree should be initialized.");
111+
RUNTIME_ASSERT(right_optimal_dt->is_initialized(), "Right tree should be initialized.");
108112

113+
current_optimal_decision_tree->misclassification_score = current_best_score;
109114
current_optimal_decision_tree->update_split(feature_index, threshold, left_optimal_dt, right_optimal_dt);
110115

111116
if (current_best_score == 0) {
@@ -155,6 +160,7 @@ void GeneralSolver::calculate_leaf_node(int class_number, int instance_number, c
155160
const int best_misclassification_score = instance_number - best_classification_score;
156161

157162
if (best_misclassification_score < current_optimal_decision_tree->misclassification_score) {
163+
RUNTIME_ASSERT(best_classification_label != -1, "Cannot assign negative leaf label.");
158164
current_optimal_decision_tree->make_leaf(best_classification_label, best_misclassification_score);
159165
}
160166
}

code/Engine/src/specialized_solver.cpp

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,23 +51,25 @@ void SpecializedSolver::get_best_left_right_scores(const Dataview& dataview, int
5151
const auto& split_feature = dataview.get_sorted_dataset_feature(feature_index);
5252
const auto& unsorted_split_feature = dataview.get_unsorted_dataset_feature(feature_index);
5353
std::vector<int> split_feature_split_indices(unsorted_split_feature.size());
54-
double prev = -1.0f;
5554
int split_index = -1;
5655
for (const auto& split_feature_data : split_feature) {
5756
split_feature_split_indices[split_feature_data.data_point_index] = split_feature_data.unique_value_index;
5857
if (split_index == -1 && split_feature_data.value >= threshold) {
5958
split_index = split_feature_data.unique_value_index;
6059
}
6160
}
61+
RUNTIME_ASSERT(split_index != -1, "Split index not found.");
6262

6363
const int dataset_size = dataview.get_dataset_size();
6464
const int class_number = dataview.get_class_number();
6565

66+
RUNTIME_ASSERT(split_point > 0 && split_point < dataset_size, "left and right subtree need to be non-empty.");
6667
Depth1ScoreHelper left_tree(split_point, class_number);
6768
Depth1ScoreHelper right_tree(dataset_size - split_point, class_number);
69+
6870

69-
left_tree.classification_score = left_tree.size - upper_bound;
70-
right_tree.classification_score = right_tree.size - upper_bound;
71+
left_tree.classification_score = std::max(0, left_tree.size - upper_bound);
72+
right_tree.classification_score = std::max(0, right_tree.size - upper_bound);
7173

7274
Dataview::initialize_split_parameters(split_feature, class_number, dataview.get_label_frequency(), split_point, left_tree.label_frequency, right_tree.label_frequency);
7375

@@ -106,15 +108,21 @@ void SpecializedSolver::get_best_left_right_scores(const Dataview& dataview, int
106108
left_optimal_dt->make_leaf(left_tree.max_label, left_tree.size - left_tree.classification_score);
107109
} else {
108110
left_optimal_dt->update_split(left_tree.best_feature_index, left_tree.best_threshold, std::make_shared<Tree>(left_tree.best_left_label, -1), std::make_shared<Tree>(left_tree.best_right_label, -1));
111+
RUNTIME_ASSERT(left_tree.best_left_label != -1, "Left tree left label should be initialized.");
112+
RUNTIME_ASSERT(left_tree.best_right_label != -1, "Left tree right label should be initialized.");
109113
}
110114
left_optimal_dt->misclassification_score = left_tree.size - left_tree.classification_score;
115+
RUNTIME_ASSERT(left_optimal_dt->misclassification_score >= 0, "LR - Left tree misclassification score should be non-negative.");
111116

112117
if (right_tree.classification_score == right_tree.max_label_frequency) {
113118
right_optimal_dt->make_leaf(right_tree.max_label, right_tree.size - right_tree.classification_score);
114119
} else {
115120
right_optimal_dt->update_split(right_tree.best_feature_index, right_tree.best_threshold, std::make_shared<Tree>(right_tree.best_left_label, -1), std::make_shared<Tree>(right_tree.best_right_label, -1));
121+
RUNTIME_ASSERT(right_tree.best_left_label != -1, "Right tree left label should be initialized.");
122+
RUNTIME_ASSERT(right_tree.best_right_label != -1, "Right tree right label should be initialized.");
116123
}
117124
right_optimal_dt->misclassification_score = right_tree.size - right_tree.classification_score;
125+
RUNTIME_ASSERT(right_optimal_dt->misclassification_score >= 0, "LR - Right tree misclassification score should be non-negative.");
118126
}
119127

120128
template <bool is_same_feature>
@@ -177,6 +185,7 @@ void SpecializedSolver::process_depth_one_feature(const Dataview& dataview,
177185
}
178186

179187
if (left_classification_score + right_classification_score > tree.classification_score) {
188+
RUNTIME_ASSERT(tree.classification_score <= tree.size, "LR - Classification score cannot exceed the number of instances.");
180189
tree.classification_score = left_classification_score + right_classification_score;
181190
tree.best_feature_index = current_feature_index;
182191
tree.best_threshold = (current_feature_data.value + tree.previous_value) / 2.0f;
@@ -233,14 +242,16 @@ void SpecializedSolver::create_optimal_decision_tree(const Dataview& dataview, c
233242
const int mid = (left + right) / 2;
234243
const int split_point = possible_split_indices[mid];
235244

236-
const float threshold = mid > 0 ? (current_feature[possible_split_indices[mid - 1]].value + current_feature[split_point].value) / 2.0f
237-
: current_feature[split_point].value - 2 * EPSILON;
245+
const float threshold = mid > 0 ? (current_feature[possible_split_indices[mid - 1]].value + current_feature[split_point].value) / 2.0f
246+
: (current_feature[split_point].value + current_feature[0].value) / 2.0f;
238247

239248
std::shared_ptr<Tree> left_optimal_dt = std::make_shared<Tree>();
240249
std::shared_ptr<Tree> right_optimal_dt = std::make_shared<Tree>();
241250

242251
statistics::total_number_of_specialized_solver_calls += 1;
243252
get_best_left_right_scores(dataview, feature_index, split_point, threshold, left_optimal_dt, right_optimal_dt, current_optimal_decision_tree->misclassification_score);
253+
RUNTIME_ASSERT(left_optimal_dt->misclassification_score >= 0, "D2 - Left tree should have non-negative misclassification score.");
254+
RUNTIME_ASSERT(right_optimal_dt->misclassification_score >= 0, "D2 - Right tree should have non-negative misclassification score.");
244255

245256
const int current_best_score = left_optimal_dt->misclassification_score + right_optimal_dt->misclassification_score;
246257

code/Utilities/include/configuration.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,33 @@
33

44
#include <chrono>
55
#include "stopwatch.h"
6+
#include <cassert>
7+
#include <stdexcept>
8+
#include <iostream>
69

710
#define EPSILON 0.0000001f
811

912
#define PRINT_INTERMEDIARY_TIME_SOLUTIONS 0
1013
extern std::chrono::high_resolution_clock::time_point starting_time;
1114

1215

16+
17+
18+
#ifdef NDEBUG
19+
#define RUNTIME_ASSERT(cond, msg) ((void)0)
20+
#else
21+
#define RUNTIME_ASSERT(cond, msg) \
22+
do { \
23+
if (!(cond)) { \
24+
std::cerr << "Assertion failed: " << #cond \
25+
<< "\nMessage: " << msg \
26+
<< "\nFile: " << __FILE__ \
27+
<< "\nLine: " << __LINE__ << std::endl; \
28+
std::terminate(); \
29+
} \
30+
} while (0)
31+
#endif
32+
1333
struct Configuration {
1434
int max_depth;
1535
int max_gap;

0 commit comments

Comments
 (0)