Skip to content

Commit a7c3c5e

Browse files
committed
LightGBMWrapper - Type of return value of classify methods for binary classification was changed to two-element vector of probabilities (instead of one-element vector)
1 parent 504e4f7 commit a7c3c5e

1 file changed

Lines changed: 15 additions & 0 deletions

File tree

src/wif/ml/lightGBMWrapper.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,11 @@ ClfResult LightGBMWrapper::classify(const FlowFeatures& flowFeatures)
5353
std::vector<double> dataToClassify; // classified features from flowfeatures are extracted here
5454
int64_t outLen; // length of output result
5555
int numOfClasses; // number of classes
56+
5657
LGBM_BoosterGetNumClasses(m_booster, &numOfClasses);
5758

5859
std::vector<double> pred(numOfClasses); // vector with predictions
60+
5961
for (const auto& featureID : m_featureIDs) {
6062
double value = flowFeatures.get<double>(featureID);
6163
dataToClassify.push_back(value);
@@ -74,6 +76,12 @@ ClfResult LightGBMWrapper::classify(const FlowFeatures& flowFeatures)
7476
&outLen,
7577
pred.data());
7678

79+
if (numOfClasses == 1) {
80+
double tmp = pred[0];
81+
82+
pred.insert(pred.begin(), (1.0 - tmp));
83+
}
84+
7785
return ClfResult(pred);
7886
}
7987

@@ -114,6 +122,13 @@ std::vector<ClfResult> LightGBMWrapper::classify(const std::vector<FlowFeatures>
114122
std::vector<double> probabilities(
115123
pred.begin() + idx * numOfClasses,
116124
pred.begin() + (idx + 1) * numOfClasses);
125+
126+
if (numOfClasses == 1) {
127+
double tmp = probabilities[0];
128+
129+
probabilities.insert(probabilities.begin(), (1.0 - tmp));
130+
}
131+
117132
burstResults.emplace_back(probabilities);
118133
}
119134

0 commit comments

Comments
 (0)