Skip to content

Commit 25a5d6c

Browse files
authored
Merge pull request #15 from CESNET/lightGBMClassifier
LightGBMWrapper - Type of return value of classify methods for binary classification was changed to two-element vector of probabilities (instead of one-element vector)
2 parents 504e4f7 + e084402 commit 25a5d6c

1 file changed

Lines changed: 12 additions & 1 deletion

File tree

src/wif/ml/lightGBMWrapper.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ ClfResult LightGBMWrapper::classify(const FlowFeatures& flowFeatures)
5454
int64_t outLen; // length of output result
5555
int numOfClasses; // number of classes
5656
LGBM_BoosterGetNumClasses(m_booster, &numOfClasses);
57-
5857
std::vector<double> pred(numOfClasses); // vector with predictions
58+
5959
for (const auto& featureID : m_featureIDs) {
6060
double value = flowFeatures.get<double>(featureID);
6161
dataToClassify.push_back(value);
@@ -74,6 +74,11 @@ ClfResult LightGBMWrapper::classify(const FlowFeatures& flowFeatures)
7474
&outLen,
7575
pred.data());
7676

77+
if (numOfClasses == 1) {
78+
double tmp = pred[0];
79+
pred.insert(pred.begin(), (1.0 - tmp));
80+
}
81+
7782
return ClfResult(pred);
7883
}
7984

@@ -114,6 +119,12 @@ std::vector<ClfResult> LightGBMWrapper::classify(const std::vector<FlowFeatures>
114119
std::vector<double> probabilities(
115120
pred.begin() + idx * numOfClasses,
116121
pred.begin() + (idx + 1) * numOfClasses);
122+
123+
if (numOfClasses == 1) {
124+
double tmp = probabilities[0];
125+
probabilities.insert(probabilities.begin(), (1.0 - tmp));
126+
}
127+
117128
burstResults.emplace_back(probabilities);
118129
}
119130

0 commit comments

Comments
 (0)