Skip to content

Commit f925e98

Browse files
authored
Merge pull request #8 from CESNET/lightGBMClassifier
Introduce LGBM support via new LightGBMClassifier
2 parents 5fcc26f + eaae553 commit f925e98

8 files changed

Lines changed: 496 additions & 5 deletions

File tree

cmake/dependencies.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Project dependencies
22
find_package(Armadillo REQUIRED)
33
find_package(Boost REQUIRED COMPONENTS regex serialization)
4+
find_package(LIGHTGBM REQUIRED)
45
find_package(MLPACK REQUIRED)
56
find_package(OpenMP REQUIRED)
67
find_package(Python3 REQUIRED COMPONENTS Development NumPy)
@@ -17,4 +18,3 @@ if(OpenMP_CXX_FOUND)
1718
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
1819
add_compile_options(${OpenMP_CXX_FLAGS})
1920
endif()
20-

cmake/modules/FindLIGHTGBM.cmake

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Find the LightGBM includes and library
2+
#
3+
# This module defines the following IMPORTED targets:
4+
#
5+
# LightGBM::lightgbm - The LightGBM library, if found.
6+
#
7+
# This module will set the following variables in your project:
8+
#
9+
# LIGHTGBM_INCLUDE_DIRS - where to find <LightGBM/c_api.h>, etc.
10+
# LIGHTGBM_FOUND - True if the LightGBM library has been found.
11+
12+
find_package(PkgConfig QUIET)
13+
if(PKG_CONFIG_FOUND)
14+
pkg_check_modules(PC_LIGHTGBM QUIET lightgbm)
15+
endif()
16+
17+
# Find headers
18+
find_path(
19+
LIGHTGBM_INCLUDE_DIR
20+
NAMES LightGBM/c_api.h
21+
HINTS ${PC_LIGHTGBM_INCLUDEDIR} ${PC_LIGHTGBM_INCLUDE_DIRS}
22+
/usr/local/include
23+
/usr/include
24+
)
25+
26+
# Find library
27+
find_library(
28+
LIGHTGBM_LIBRARY
29+
NAMES lightgbm lib_lightgbm _lightgbm
30+
HINTS ${PC_LIGHTGBM_LIBDIR} ${PC_LIGHTGBM_LIBRARY_DIRS}
31+
/usr/local/lib
32+
/usr/lib
33+
)
34+
35+
include(FindPackageHandleStandardArgs)
36+
find_package_handle_standard_args(
37+
LIGHTGBM
38+
REQUIRED_VARS LIGHTGBM_INCLUDE_DIR LIGHTGBM_LIBRARY
39+
)
40+
41+
if(LIGHTGBM_FOUND)
42+
set(LIGHTGBM_INCLUDE_DIRS ${LIGHTGBM_INCLUDE_DIR})
43+
set(LIGHTGBM_LIBRARIES ${LIGHTGBM_LIBRARY})
44+
45+
add_library(LightGBM::lightgbm SHARED IMPORTED)
46+
set_target_properties(LightGBM::lightgbm PROPERTIES
47+
IMPORTED_LOCATION "${LIGHTGBM_LIBRARY}"
48+
INTERFACE_INCLUDE_DIRECTORIES "${LIGHTGBM_INCLUDE_DIRS}"
49+
)
50+
endif()
51+
52+
mark_as_advanced(LIGHTGBM_INCLUDE_DIR LIGHTGBM_LIBRARY)
53+

include/wif/classifiers/genericMlClassifier.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
namespace WIF {
1717

1818
/**
19-
* @brief Abstract class specifying interfaces for ML classifiers (ScikitMlClassifier and
20-
* MlpackClassifier)
19+
* @brief Abstract class specifying interfaces for ML classifiers (ScikitMlClassifier,
20+
* MlpackClassifier and LightGBMClassifier)
2121
*
2222
*/
2323
class GenericMlClassifier : public Classifier {
@@ -32,7 +32,7 @@ class GenericMlClassifier : public Classifier {
3232
* @brief Reload the model from file, which was set in the constructor
3333
*
3434
* @param logicalName contains the logical name of the trained model. The parameter is used only
35-
* with MlpackClassifier (it is unused with ScikitMlClassifier)
35+
* with MlpackClassifier (it is unused with ScikitMlClassifier and LightGBMClassifier)
3636
*/
3737
virtual void reloadModelFromDisk(const std::string& logicalName = "trained_data") = 0;
3838
};
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/**
2+
* @file
3+
* @author Jachym Hudlicky <hudlijac@fit.cvut.cz>
4+
* @brief LightGBM classifier interface
5+
*
6+
* SPDX-License-Identifier: BSD-3-Clause
7+
*/
8+
9+
#pragma once
10+
11+
#include "wif/classifiers/genericMlClassifier.hpp"
12+
#include "wif/ml/lightGBMWrapper.hpp"
13+
14+
#include <memory>
15+
#include <string>
16+
#include <vector>
17+
18+
namespace WIF {
19+
/**
20+
* @brief Classifier performing ML classification which is interconnected with LightGBM library
21+
*
22+
*/
23+
class LightGBMClassifier : public GenericMlClassifier {
24+
public:
25+
/**
26+
* @brief Construct a new LightGBM Classifier object
27+
*
28+
* @param path contains the path to the file with the trained model
29+
*/
30+
LightGBMClassifier(const std::string& path);
31+
32+
/**
33+
* @brief Set feature IDs which will be used for classification
34+
*
35+
* @param sourceFeatureIDs
36+
*/
37+
void setFeatureSourceIDs(const std::vector<FeatureID>& sourceFeatureIDs) override;
38+
39+
/**
40+
* @brief Classify single flowFeature object
41+
* See std::vector<ClfResult> classify(const std::vector<FlowFeatures>&) for more details
42+
*
43+
* @param flowFeatures flow features to classify
44+
* @return ClfResult result of the classification, which contains
45+
* vector<double> with probabilities for each class
46+
*/
47+
ClfResult classify(const FlowFeatures& flowFeatures) override;
48+
49+
/**
50+
* @brief Classify a burst of flow features
51+
*
52+
* @param burstOfFlowsFeatures the burst of flow features to classify
53+
* @return std::vector<ClfResult> classification results with ClfResult object for each flow
54+
* features object
55+
*/
56+
std::vector<ClfResult> classify(const std::vector<FlowFeatures>& burstOfFlowFeatures) override;
57+
58+
/**
59+
* @brief Return the path of the ML model, which is currently loaded
60+
* @return const std::string& path of the model
61+
*/
62+
const std::string& getMlModelPath() const noexcept override;
63+
64+
/**
65+
* @brief Reload used ML model from disk
66+
*
67+
* @param logicalName is unused
68+
*/
69+
void
70+
reloadModelFromDisk([[maybe_unused]] const std::string& logicalName = "trained_data") override;
71+
72+
private:
73+
/**
74+
* @brief Pointer to wrapper object with loaded lightGBM model
75+
*/
76+
std::unique_ptr<LightGBMWrapper> m_lightGBMWrapper;
77+
};
78+
} // namespace WIF

include/wif/ml/lightGBMWrapper.hpp

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
/**
2+
* @file
3+
* @author Jachym Hudlicky <hudlijac@fit.cvut.cz>
4+
* @brief LightGBM wrapper interface
5+
*
6+
* SPDX-License-Identifier: BSD-3-Clause
7+
*/
8+
9+
#pragma once
10+
11+
#include "wif/storage/clfResult.hpp"
12+
#include "wif/storage/flowFeatures.hpp"
13+
14+
#include <LightGBM/c_api.h>
15+
#include <fstream>
16+
#include <iostream>
17+
#include <iterator>
18+
#include <map>
19+
#include <memory>
20+
#include <sstream>
21+
#include <stdexcept>
22+
#include <string>
23+
#include <utility>
24+
#include <vector>
25+
26+
namespace WIF {
27+
28+
/**
29+
* @brief Wrapper class which provides a bridge to LightGBM library
30+
*/
31+
class LightGBMWrapper {
32+
public:
33+
/**
34+
* @brief Construct a new LightGBM wrapper object
35+
*/
36+
LightGBMWrapper();
37+
38+
/**
39+
* @brief Construct a new LightGBM wrapper object
40+
*
41+
* @param modelPath contains path to the model file
42+
*/
43+
LightGBMWrapper(const std::string& modelPath);
44+
45+
/**
46+
* @brief Destruct the LightGBM wrapper object
47+
*/
48+
~LightGBMWrapper();
49+
50+
/**
51+
* @brief Set feature IDs which will be used for classification
52+
*
53+
* @param sourceFeatureIDs
54+
*/
55+
void setFeatureSourceIDs(const std::vector<FeatureID>& sourceFeatureIDs);
56+
57+
/**
58+
* @brief Getter for path of the used ML model
59+
* @return const std::string&
60+
*/
61+
const std::string& getModelPath() const;
62+
63+
/**
64+
* @brief Load the model from the file
65+
*
66+
* @param modelPath contains path to the model file.
67+
* @return bool true, if model was succesfully loaded. False if not.
68+
*/
69+
bool loadModel(const std::string& modelPath);
70+
71+
/**
72+
* @brief Classify single flowFeature object
73+
* See std::vector<ClfResult> classify(const std::vector<FlowFeatures>&) for more details
74+
*
75+
* @param flowFeatures flow features to classify
76+
* @return ClfResult result of the classification, which contains
77+
* vector<double> with probabilities for each class
78+
*/
79+
ClfResult classify(const FlowFeatures& flowFeatures);
80+
81+
/**
82+
* @brief Classify a burst of flow features
83+
*
84+
* @param burstOfFlowsFeatures the burst of flow features to classify
85+
* @return std::vector<ClfResult> the results of the classification. Each ClfResult contains
86+
* result of the classification, which contains vector<double> with
87+
* probabilities for each class
88+
*/
89+
std::vector<ClfResult> classify(const std::vector<FlowFeatures>& burstOfFeatures);
90+
91+
/**
92+
* @brief Return information about if ML model is loaded or not
93+
*
94+
* @return bool true, if ML model is loaded. False, if not
95+
*/
96+
bool isLoaded() const;
97+
98+
/**
99+
* @brief Train LightGBM model
100+
*
101+
* @param datasetFileName is the name of the file with training data
102+
* @param datasetParams are the additional parameters of the training dataset
103+
* @param numOfIterations is the number of training iterations (how many
104+
* LGBM_BoosterUpdateOneIter function is called, see
105+
* https://lightgbm.readthedocs.io/en/stable/C-API.html)
106+
* @param params are parameters in format ‘key1=value1 key2=value2’ (see
107+
* https://lightgbm.readthedocs.io/en/stable/C-API.html)
108+
* @param modelFileName name of the file where the trained model will be saved
109+
*/
110+
void train(
111+
const std::string& datasetFileName,
112+
const char* datasetParams = "header=true label=name:label",
113+
const unsigned numOfIterations = 100,
114+
const char* params
115+
= "boosting_type=gbdt objective=binary metric=auc num_leaves=30 learning_rate=0.05 "
116+
"feature_fraction=0.9",
117+
const std::string modelFileName = "model.txt");
118+
119+
private:
120+
/**
121+
* @brief LightGBM model
122+
*/
123+
BoosterHandle m_booster = nullptr;
124+
125+
/**
126+
* @brief Number of iterations of m_booster
127+
*/
128+
int m_outNumIterations = 0;
129+
130+
/**
131+
* @brief Bool value is true, if any ML model is correctly loaded. Otherwise it contains false
132+
*/
133+
bool m_isLoaded = false;
134+
135+
/**
136+
* @brief The path to currently loaded ML model path
137+
*/
138+
std::string m_modelPath;
139+
140+
/**
141+
* @brief Vector of feature IDs, which were set in setFeatureIDs method
142+
*/
143+
std::vector<FeatureID> m_featureIDs;
144+
};
145+
146+
} // namespace WIF

src/wif/CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
set(LIBWIF_SOURCES
22
classifiers/classifier.cpp
33
classifiers/ipPrefixClassifier.cpp
4+
classifiers/lightGBMClassifier.cpp
5+
classifiers/mlpackClassifier.cpp
46
classifiers/regexClassifier.cpp
57
classifiers/scikitMlClassifier.cpp
6-
classifiers/mlpackClassifier.cpp
78
combinators/averageCombinator.cpp
89
combinators/binaryDSTCombinator.cpp
910
combinators/majorityCombinator.cpp
1011
combinators/sumCombinator.cpp
1112
filesystem/fileModificationChecker.cpp
13+
ml/lightGBMWrapper.cpp
1214
ml/mlpackModels/decisionTreeModel.cpp
1315
ml/mlpackModels/hoeffdingTreeModel.cpp
1416
ml/mlpackModels/linearSVMModel.cpp
@@ -29,6 +31,7 @@ set(LIBWIF_LIBS
2931
Boost::regex
3032
Boost::serialization
3133
dstlib::dst
34+
LightGBM::lightgbm
3235
OpenMP::OpenMP_CXX
3336
Python3::Python
3437
Python3::NumPy
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/**
2+
* @file
3+
* @author Jachym Hudlicky <hudlijac@fit.cvut.cz>
4+
* @brief LightGBM classifier implementation
5+
*
6+
* SPDX-License-Identifier: BSD-3-Clause
7+
*/
8+
9+
#include "wif/classifiers/lightGBMClassifier.hpp"
10+
11+
namespace WIF {
12+
13+
LightGBMClassifier::LightGBMClassifier(const std::string& path)
14+
{
15+
m_lightGBMWrapper = std::make_unique<LightGBMWrapper>(path);
16+
}
17+
18+
void LightGBMClassifier::setFeatureSourceIDs(const std::vector<FeatureID>& sourceFeatureIDs)
19+
{
20+
Classifier::setFeatureSourceIDs(sourceFeatureIDs);
21+
m_lightGBMWrapper->setFeatureSourceIDs(sourceFeatureIDs);
22+
}
23+
24+
ClfResult LightGBMClassifier::classify(const FlowFeatures& flowFeatures)
25+
{
26+
return m_lightGBMWrapper->classify(flowFeatures);
27+
}
28+
29+
std::vector<ClfResult>
30+
LightGBMClassifier::classify(const std::vector<FlowFeatures>& burstOfFlowFeatures)
31+
{
32+
return m_lightGBMWrapper->classify(burstOfFlowFeatures);
33+
}
34+
35+
const std::string& LightGBMClassifier::getMlModelPath() const noexcept
36+
{
37+
return m_lightGBMWrapper->getModelPath();
38+
}
39+
40+
void LightGBMClassifier::reloadModelFromDisk([[maybe_unused]] const std::string& logicalName)
41+
{
42+
m_lightGBMWrapper->loadModel(m_lightGBMWrapper->getModelPath());
43+
}
44+
45+
} // namespace WIF

0 commit comments

Comments
 (0)