|
1 | 1 | import pickle |
| 2 | +import joblib |
2 | 3 | import os |
3 | 4 | from pure_sklearn.map import convert_estimator |
4 | 5 |
|
5 | 6 |
|
6 | | -def load_classifier(path_to_sav): |
| 7 | +def load_classifier_SIMBA(path_to_sav): |
7 | 8 | """Load saved classifier""" |
8 | 9 | file = open(path_to_sav, "rb") |
9 | 10 | classifier = pickle.load(file) |
10 | 11 | file.close() |
11 | 12 | return classifier |
12 | 13 |
|
| 14 | +def load_classifier_BSOID(path_to_sav): |
| 15 | + """Load saved classifier""" |
| 16 | + file = open(path_to_sav, "rb") |
| 17 | + clf = joblib.load(file) |
| 18 | + file.close() |
| 19 | + return clf |
13 | 20 |
|
14 | | -def convert_classifier(path): |
| 21 | + |
| 22 | +def convert_classifier(path, origin: str): |
15 | 23 | # convert to pure python estimator |
16 | | - print("Loading classifier...") |
17 | | - clf = load_classifier(path) |
18 | 24 | dir_path = os.path.dirname(path) |
19 | 25 | filename = os.path.basename(path) |
20 | 26 | filename, _ = filename.split(".") |
21 | | - clf_pure_predict = convert_estimator(clf) |
22 | | - with open(dir_path + "/" + filename + "_pure.sav", "wb") as f: |
23 | | - pickle.dump(clf_pure_predict, f) |
| 27 | + |
| 28 | + print("Loading classifier...") |
| 29 | + if origin.lower() == 'simba': |
| 30 | + clf = load_classifier_SIMBA(path) |
| 31 | + clf_pure_predict = convert_estimator(clf) |
| 32 | + with open(dir_path + "/" + filename + "_pure.sav", "wb") as f: |
| 33 | + pickle.dump(clf_pure_predict, f) |
| 34 | + |
| 35 | + elif origin.lower() == 'bsoid': |
| 36 | + clf_pack = load_classifier_BSOID(path) |
| 37 | + # bsoid exported classfier has format [a, b, c, clf, d, e] |
| 38 | + clf_pure_predict = convert_estimator(clf_pack[3]) |
| 39 | + clf_pack[3] =clf_pure_predict |
| 40 | + with open(dir_path + "/" + filename + "_pure.sav", "wb") as f: |
| 41 | + joblib.dump(clf_pack, f) |
| 42 | + else: |
| 43 | + raise ValueError(f'{origin} is not a valid classifier origin.') |
| 44 | + |
24 | 45 | print(f"Converted Classifier {filename}") |
25 | 46 |
|
26 | 47 |
|
27 | 48 | if __name__ == "__main__": |
| 49 | + |
| 50 | + """Converted BSOID Classifiers are not integrated yet, although you can already convert them here""" |
28 | 51 | path_to_classifier = "PATH_TO_CLASSIFIER" |
29 | | - convert_classifier(path_to_classifier) |
| 52 | + convert_classifier(path_to_classifier, origin= 'SIMBA') |
0 commit comments