Skip to content

Commit 3454816

Browse files
committed
hotfix to handle_nan reshape
added to do for pure BSOID clf added conversion for BSOID clf
1 parent 47d8127 commit 3454816

3 files changed

Lines changed: 41 additions & 12 deletions

File tree

convert_classifier.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,52 @@
11
import pickle
2+
import joblib
23
import os
34
from pure_sklearn.map import convert_estimator
45

56

6-
def load_classifier(path_to_sav):
7+
def load_classifier_SIMBA(path_to_sav):
78
"""Load saved classifier"""
89
file = open(path_to_sav, "rb")
910
classifier = pickle.load(file)
1011
file.close()
1112
return classifier
1213

14+
def load_classifier_BSOID(path_to_sav):
15+
"""Load saved classifier"""
16+
file = open(path_to_sav, "rb")
17+
clf = joblib.load(file)
18+
file.close()
19+
return clf
1320

14-
def convert_classifier(path):
21+
22+
def convert_classifier(path, origin: str):
1523
# convert to pure python estimator
16-
print("Loading classifier...")
17-
clf = load_classifier(path)
1824
dir_path = os.path.dirname(path)
1925
filename = os.path.basename(path)
2026
filename, _ = filename.split(".")
21-
clf_pure_predict = convert_estimator(clf)
22-
with open(dir_path + "/" + filename + "_pure.sav", "wb") as f:
23-
pickle.dump(clf_pure_predict, f)
27+
28+
print("Loading classifier...")
29+
if origin.lower() == 'simba':
30+
clf = load_classifier_SIMBA(path)
31+
clf_pure_predict = convert_estimator(clf)
32+
with open(dir_path + "/" + filename + "_pure.sav", "wb") as f:
33+
pickle.dump(clf_pure_predict, f)
34+
35+
elif origin.lower() == 'bsoid':
36+
clf_pack = load_classifier_BSOID(path)
37+
# bsoid exported classfier has format [a, b, c, clf, d, e]
38+
clf_pure_predict = convert_estimator(clf_pack[3])
39+
clf_pack[3] =clf_pure_predict
40+
with open(dir_path + "/" + filename + "_pure.sav", "wb") as f:
41+
joblib.dump(clf_pack, f)
42+
else:
43+
raise ValueError(f'{origin} is not a valid classifier origin.')
44+
2445
print(f"Converted Classifier {filename}")
2546

2647

2748
if __name__ == "__main__":
49+
50+
"""Converted BSOID Classifiers are not integrated yet, although you can already convert them here"""
2851
path_to_classifier = "PATH_TO_CLASSIFIER"
29-
convert_classifier(path_to_classifier)
52+
convert_classifier(path_to_classifier, origin= 'SIMBA')

experiments/custom/classifier.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def classify(self, features):
117117
Adapted from BSOID; https://github.com/YttriLab/B-SOID
118118
"""
119119
labels_fslow = []
120+
# TODO: adapt to pure version of BSOID Classifier
120121
for i in range(0, len(features)):
121122
labels = self._classifier.predict(features[i].T)
122123
labels_fslow.append(labels)
@@ -191,13 +192,18 @@ def bsoid_feat_classifier_pool_run(input_q: mp.Queue, output_q: mp.Queue):
191192
if input_q.full():
192193
skel_time_window, feature_id = input_q.get()
193194
if skel_time_window is not None:
194-
start_time = time.time()
195+
start_time_feat = time.time()
195196
features = feature_extractor.extract_features(skel_time_window)
197+
end_time_feat = time.time()
198+
start_time_clf = time.time()
196199
last_prob = classifier.classify(features)
197200
output_q.put((last_prob, feature_id))
198201
end_time = time.time()
199-
# print("Classification time: {:.2f} msec".format((end_time-start_time)*1000))
200-
# print("Feature ID: "+ feature_id)
202+
print("Feature Extraction time: {:.2f} msec".format((end_time_feat - start_time_feat) * 1000))
203+
print("Classification time: {:.2f} msec".format((end_time-start_time_clf)*1000))
204+
print("Total time: {:.2f} msec".format((end_time-start_time_feat)*1000))
205+
print("Current motif: ", *last_prob)
206+
print("Feature ID: "+ str(feature_id))
201207
else:
202208
pass
203209

experiments/custom/featureextraction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1281,7 +1281,7 @@ def handle_nan(self, pose_window):
12811281
# no NaN? then update the last valid pose for this body part
12821282
self._last_valid_pose[bp_num] = bp
12831283
#reshape to match requirements of bsoid feature extraction: time_window , bodyparts*2
1284-
pro_window = np.reshape(pro_window, (pro_window.shape[1], pro_window.shape[2]*pro_window.shape[3]))# converted to np.array
1284+
pro_window = np.reshape(pro_window, (pro_window.shape[0], pro_window.shape[1]*pro_window.shape[2]))# converted to np.array
12851285

12861286
return pro_window
12871287

0 commit comments

Comments
 (0)