@@ -51,23 +51,25 @@ void SpecializedSolver::get_best_left_right_scores(const Dataview& dataview, int
5151 const auto & split_feature = dataview.get_sorted_dataset_feature (feature_index);
5252 const auto & unsorted_split_feature = dataview.get_unsorted_dataset_feature (feature_index);
5353 std::vector<int > split_feature_split_indices (unsorted_split_feature.size ());
54- double prev = -1 .0f ;
5554 int split_index = -1 ;
5655 for (const auto & split_feature_data : split_feature) {
5756 split_feature_split_indices[split_feature_data.data_point_index ] = split_feature_data.unique_value_index ;
5857 if (split_index == -1 && split_feature_data.value >= threshold) {
5958 split_index = split_feature_data.unique_value_index ;
6059 }
6160 }
61+ RUNTIME_ASSERT (split_index != -1 , " Split index not found." );
6262
6363 const int dataset_size = dataview.get_dataset_size ();
6464 const int class_number = dataview.get_class_number ();
6565
66+ RUNTIME_ASSERT (split_point > 0 && split_point < dataset_size, " left and right subtree need to be non-empty." );
6667 Depth1ScoreHelper left_tree (split_point, class_number);
6768 Depth1ScoreHelper right_tree (dataset_size - split_point, class_number);
69+
6870
69- left_tree.classification_score = left_tree.size - upper_bound;
70- right_tree.classification_score = right_tree.size - upper_bound;
71+ left_tree.classification_score = std::max ( 0 , left_tree.size - upper_bound) ;
72+ right_tree.classification_score = std::max ( 0 , right_tree.size - upper_bound) ;
7173
7274 Dataview::initialize_split_parameters (split_feature, class_number, dataview.get_label_frequency (), split_point, left_tree.label_frequency , right_tree.label_frequency );
7375
@@ -106,15 +108,21 @@ void SpecializedSolver::get_best_left_right_scores(const Dataview& dataview, int
106108 left_optimal_dt->make_leaf (left_tree.max_label , left_tree.size - left_tree.classification_score );
107109 } else {
108110 left_optimal_dt->update_split (left_tree.best_feature_index , left_tree.best_threshold , std::make_shared<Tree>(left_tree.best_left_label , -1 ), std::make_shared<Tree>(left_tree.best_right_label , -1 ));
111+ RUNTIME_ASSERT (left_tree.best_left_label != -1 , " Left tree left label should be initialized." );
112+ RUNTIME_ASSERT (left_tree.best_right_label != -1 , " Left tree right label should be initialized." );
109113 }
110114 left_optimal_dt->misclassification_score = left_tree.size - left_tree.classification_score ;
115+ RUNTIME_ASSERT (left_optimal_dt->misclassification_score >= 0 , " LR - Left tree misclassification score should be non-negative." );
111116
112117 if (right_tree.classification_score == right_tree.max_label_frequency ) {
113118 right_optimal_dt->make_leaf (right_tree.max_label , right_tree.size - right_tree.classification_score );
114119 } else {
115120 right_optimal_dt->update_split (right_tree.best_feature_index , right_tree.best_threshold , std::make_shared<Tree>(right_tree.best_left_label , -1 ), std::make_shared<Tree>(right_tree.best_right_label , -1 ));
121+ RUNTIME_ASSERT (right_tree.best_left_label != -1 , " Right tree left label should be initialized." );
122+ RUNTIME_ASSERT (right_tree.best_right_label != -1 , " Right tree right label should be initialized." );
116123 }
117124 right_optimal_dt->misclassification_score = right_tree.size - right_tree.classification_score ;
125+ RUNTIME_ASSERT (right_optimal_dt->misclassification_score >= 0 , " LR - Right tree misclassification score should be non-negative." );
118126}
119127
120128template <bool is_same_feature>
@@ -177,6 +185,7 @@ void SpecializedSolver::process_depth_one_feature(const Dataview& dataview,
177185 }
178186
179187 if (left_classification_score + right_classification_score > tree.classification_score ) {
188+ RUNTIME_ASSERT (tree.classification_score <= tree.size , " LR - Classification score cannot exceed the number of instances." );
180189 tree.classification_score = left_classification_score + right_classification_score;
181190 tree.best_feature_index = current_feature_index;
182191 tree.best_threshold = (current_feature_data.value + tree.previous_value ) / 2 .0f ;
@@ -233,14 +242,16 @@ void SpecializedSolver::create_optimal_decision_tree(const Dataview& dataview, c
233242 const int mid = (left + right) / 2 ;
234243 const int split_point = possible_split_indices[mid];
235244
236- const float threshold = mid > 0 ? (current_feature[possible_split_indices[mid - 1 ]].value + current_feature[split_point].value ) / 2 .0f
237- : current_feature[split_point].value - 2 * EPSILON;
245+ const float threshold = mid > 0 ? (current_feature[possible_split_indices[mid - 1 ]].value + current_feature[split_point].value ) / 2 .0f
246+ : ( current_feature[split_point].value + current_feature[ 0 ]. value ) / 2 . 0f ;
238247
239248 std::shared_ptr<Tree> left_optimal_dt = std::make_shared<Tree>();
240249 std::shared_ptr<Tree> right_optimal_dt = std::make_shared<Tree>();
241250
242251 statistics::total_number_of_specialized_solver_calls += 1 ;
243252 get_best_left_right_scores (dataview, feature_index, split_point, threshold, left_optimal_dt, right_optimal_dt, current_optimal_decision_tree->misclassification_score );
253+ RUNTIME_ASSERT (left_optimal_dt->misclassification_score >= 0 , " D2 - Left tree should have non-negative misclassification score." );
254+ RUNTIME_ASSERT (right_optimal_dt->misclassification_score >= 0 , " D2 - Right tree should have non-negative misclassification score." );
244255
245256 const int current_best_score = left_optimal_dt->misclassification_score + right_optimal_dt->misclassification_score ;
246257
0 commit comments