Skip to content

Commit 2e2c4c9

Browse files
committed
bug fix double precision, re-use unique value index rather than threshold when splitting
1 parent d89e1f3 commit 2e2c4c9

3 files changed

Lines changed: 5 additions & 4 deletions

File tree

code/DataStructures/include/dataview.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class Dataview {
4949
int get_class_number() const;
5050
const std::vector<int>& get_label_frequency() const;
5151

52-
static void split_data_points(const Dataview& current_dataview, int feature_index, int split_point, float threshold, Dataview& left_data, Dataview& right_data, int current_max_depth);
52+
static void split_data_points(const Dataview& current_dataview, int feature_index, int split_point, int split_unique_value_index, Dataview& left_data, Dataview& right_data, int current_max_depth);
5353
static void initialize_split_parameters(const std::vector<Dataset::FeatureElement>& current_feature, int class_number, const std::vector<int>& current_label_frequency, int split_point, std::vector<int> &left_label_frequency, std::vector<int> &right_label_frequency);
5454

5555
DataviewBitset& get_bitset() const {

code/DataStructures/src/dataview.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ const std::vector<int>& Dataview::get_possible_split_indices(int feature_index)
142142
return possible_split_indices[feature_index];
143143
}
144144

145-
void Dataview::split_data_points(const Dataview& current_dataview, int feature_index, int split_point, float threshold, Dataview& left_dataview, Dataview& right_dataview, int current_max_depth) {
145+
void Dataview::split_data_points(const Dataview& current_dataview, int feature_index, int split_point, int split_unique_value_index, Dataview& left_dataview, Dataview& right_dataview, int current_max_depth) {
146146
left_dataview.feature_data.resize(current_dataview.get_feature_number());
147147
right_dataview.feature_data.resize(current_dataview.get_feature_number());
148148

@@ -190,7 +190,7 @@ void Dataview::split_data_points(const Dataview& current_dataview, int feature_i
190190
std::vector<int> right_tree_right_label_frequency(right_dataview.label_frequency);
191191

192192
for (const auto& feature_data : it) {
193-
if (unsorted_split_feature[feature_data.data_point_index].value > threshold) {
193+
if (unsorted_split_feature[feature_data.data_point_index].unique_value_index >= split_unique_value_index) {
194194
right_split_feature_data[right_counter] = feature_data;
195195

196196
if (feature_data.unique_value_index != rigth_last_unique_index && rigth_last_unique_index != -1) {

code/Engine/src/general_solver.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,11 @@ void GeneralSolver::create_optimal_decision_tree(const Dataview& dataview, const
7171

7272
const float threshold = mid > 0 ? (current_feature[possible_split_indices[mid - 1]].value + current_feature[split_point].value) / 2.0f
7373
: (current_feature[split_point].value + current_feature[0].value) / 2.0f;
74+
const int split_unique_value_index = current_feature[split_point].unique_value_index;
7475

7576
Dataview left_dataview = Dataview(dataview.get_class_number(), dataview.should_sort_by_gini_index());
7677
Dataview right_dataview = Dataview(dataview.get_class_number(), dataview.should_sort_by_gini_index());
77-
Dataview::split_data_points(dataview, feature_index, split_point, threshold, left_dataview, right_dataview, solution_configuration.max_depth);
78+
Dataview::split_data_points(dataview, feature_index, split_point, split_unique_value_index, left_dataview, right_dataview, solution_configuration.max_depth);
7879

7980
std::shared_ptr<Tree> left_optimal_dt = std::make_shared<Tree>(-1, current_optimal_decision_tree->misclassification_score);
8081
std::shared_ptr<Tree> right_optimal_dt = std::make_shared<Tree>(-1, current_optimal_decision_tree->misclassification_score);

0 commit comments

Comments
 (0)