Skip to content

Commit c970cc5

Browse files
committed
LightGBMWrapper - FIX some variables names and types and codestyle
1 parent d8bee3f commit c970cc5

4 files changed

Lines changed: 54 additions & 22 deletions

File tree

cmake/dependencies.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
find_package(Armadillo REQUIRED)
33
find_package(Boost REQUIRED COMPONENTS regex serialization)
44
find_package(LIGHTGBM REQUIRED)
5+
find_package(MLPACK REQUIRED)
56
find_package(OpenMP REQUIRED)
67
find_package(Python3 REQUIRED COMPONENTS Development NumPy)
78

cmake/modules/FindMLPACK.cmake

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Find the mlpack includes and library
2+
#
3+
# This module defines the following IMPORTED targets:
4+
#
5+
# mlpack::mlpack - The "mlpack" library, if found.
6+
#
7+
# This module will set the following variables in your project:
8+
#
9+
# MLPACK_INCLUDE_DIRS - where to find <mlpack/core.hpp>, etc.
10+
# MLPACK_FOUND - True if the mlpack library has been found.
11+
12+
# Use pkg-config (if available) to get the library directories and then use
13+
# these values as hints for find_path() and find_library() functions.
14+
find_package(PkgConfig QUIET)
15+
if (PKG_CONFIG_FOUND)
16+
pkg_check_modules(PC_MLPACK QUIET mlpack)
17+
endif()
18+
19+
find_path(
20+
MLPACK_INCLUDE_DIR mlpack
21+
HINTS ${PC_MLPACK_INCLUDEDIR} ${PC_MLPACK_INCLUDE_DIRS}
22+
PATH_SUFFIXES include
23+
)
24+
25+
if (PC_MLPACK_VERSION)
26+
# Version extracted from pkg-config
27+
set(MLPACK_VERSION_STRING ${PC_MLPACK_VERSION})
28+
endif()
29+
30+
# Handle find_package() arguments (i.e. QUIETLY and REQUIRED) and set
31+
# MLPACK_FOUND to TRUE if all listed variables are filled.
32+
include(FindPackageHandleStandardArgs)
33+
find_package_handle_standard_args(
34+
MLPACK
35+
REQUIRED_VARS MLPACK_INCLUDE_DIR
36+
VERSION_VAR MLPACK_VERSION_STRING
37+
)
38+
39+
set(MLPACK_INCLUDE_DIRS ${MLPACK_INCLUDE_DIR})
40+
mark_as_advanced(MLPACK_INCLUDE_DIR)

include/wif/ml/lightGBMWrapper.hpp

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -99,21 +99,18 @@ class LightGBMWrapper {
9999
* @brief Train LightGBM model
100100
*
101101
* @param datasetFileName is the name of the file with training data
102-
* @param datasetParams are the additional parameters of the training dataset. Default:
103-
* eader=true label=name:label
102+
* @param datasetParams are the additional parameters of the training dataset
104103
* @param numOfIterations is the number of training iterations (how many
105104
* LGBM_BoosterUpdateOneIter function is called, see
106-
* https://lightgbm.readthedocs.io/en/stable/C-API.html). Default: 100
105+
* https://lightgbm.readthedocs.io/en/stable/C-API.html)
107106
* @param params are parameters in format ‘key1=value1 key2=value2’ (see
108-
* https://lightgbm.readthedocs.io/en/stable/C-API.html). Default: boosting_type=gbdt
109-
* objective=binary metric=auc num_leaves=30 learning_rate=0.05 feature_fraction=0.9
110-
* @param modelFileName name of the file where the trained model will be saved. Default:
111-
* model.txt
107+
* https://lightgbm.readthedocs.io/en/stable/C-API.html)
108+
* @param modelFileName name of the file where the trained model will be saved
112109
*/
113110
void train(
114-
const std::string datasetFileName,
111+
const std::string& datasetFileName,
115112
const char* datasetParams = "header=true label=name:label",
116-
const int numOfIterations = 100,
113+
const unsigned numOfIterations = 100,
117114
const char* params
118115
= "boosting_type=gbdt objective=binary metric=auc num_leaves=30 learning_rate=0.05 "
119116
"feature_fraction=0.9",

src/wif/ml/lightGBMWrapper.cpp

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ bool LightGBMWrapper::loadModel(const std::string& modelPath)
4242
if (!LGBM_BoosterCreateFromModelfile(modelPath.c_str(), &m_outNumIterations, &m_booster)) {
4343
m_isLoaded = true;
4444
m_modelPath = modelPath;
45-
4645
return true;
4746
}
4847

@@ -54,14 +53,11 @@ ClfResult LightGBMWrapper::classify(const FlowFeatures& flowFeatures)
5453
std::vector<double> dataToClassify; // classified features from flowfeatures are extracted here
5554
int64_t outLen; // length of output result
5655
int numOfClasses; // number of classes
57-
5856
LGBM_BoosterGetNumClasses(m_booster, &numOfClasses);
5957

6058
std::vector<double> pred(numOfClasses); // vector with predictions
61-
6259
for (const auto& featureID : m_featureIDs) {
6360
double value = flowFeatures.get<double>(featureID);
64-
6561
dataToClassify.push_back(value);
6662
}
6763

@@ -86,8 +82,7 @@ std::vector<ClfResult> LightGBMWrapper::classify(const std::vector<FlowFeatures>
8682
std::vector<double> dataToClassify; // Classified features from flowfeatures are extracted here
8783
int64_t outLen; // length of output result
8884
int numOfClasses; // number of classes
89-
std::vector<ClfResult>
90-
burstResults; // vector with predictions in ClfResult format for return value
85+
std::vector<ClfResult> burstResults; // vector with predictions in ClfResult format
9186

9287
burstResults.reserve(burstOfFeatures.size());
9388
LGBM_BoosterGetNumClasses(m_booster, &numOfClasses);
@@ -97,7 +92,6 @@ std::vector<ClfResult> LightGBMWrapper::classify(const std::vector<FlowFeatures>
9792
for (const auto& feature : burstOfFeatures) { // data preparation for classification
9893
for (const auto& featureId : m_featureIDs) {
9994
double value = feature.get<double>(featureId);
100-
10195
dataToClassify.push_back(value);
10296
}
10397
}
@@ -116,10 +110,10 @@ std::vector<ClfResult> LightGBMWrapper::classify(const std::vector<FlowFeatures>
116110
&outLen,
117111
pred.data());
118112

119-
for (size_t i = 0; i < burstOfFeatures.size(); ++i) { // converting pred to burstResults
113+
for (unsigned idx = 0; idx < burstOfFeatures.size(); ++idx) { // converting pred to burstResults
120114
std::vector<double> probabilities(
121-
pred.begin() + i * numOfClasses,
122-
pred.begin() + (i + 1) * numOfClasses);
115+
pred.begin() + idx * numOfClasses,
116+
pred.begin() + (idx + 1) * numOfClasses);
123117
burstResults.emplace_back(probabilities);
124118
}
125119

@@ -132,9 +126,9 @@ bool LightGBMWrapper::isLoaded() const
132126
}
133127

134128
void LightGBMWrapper::train(
135-
const std::string datasetFileName,
129+
const std::string& datasetFileName,
136130
const char* datasetParams,
137-
const int numOfIterations,
131+
const unsigned numOfIterations,
138132
const char* params,
139133
const std::string modelFileName)
140134
{
@@ -152,7 +146,7 @@ void LightGBMWrapper::train(
152146
throw std::runtime_error("Error creating booster");
153147
}
154148

155-
for (int i = 0; i < numOfIterations; ++i) { // training
149+
for (unsigned i = 0; i < numOfIterations; ++i) { // training
156150
int isFinished;
157151
LGBM_BoosterUpdateOneIter(m_booster, &isFinished);
158152
if (isFinished) {

0 commit comments

Comments
 (0)