Skip to content

Commit 8d0d403

Browse files
authored
Merge pull request #34 from SchwarzNeuroconLab/dev_bsoidupdate
updated likelihood handling for BSOID classification DLC pose estimation can now be filtered by likelihood in advanced settings. Values below the threshold will be set to NaN and need to be handled by calculate skeleton
2 parents 8897152 + 3fb9bdb commit 8d0d403

9 files changed

Lines changed: 294 additions & 165 deletions

File tree

DeepLabStream.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
CROP,
3636
CROP_X,
3737
CROP_Y,
38+
USE_DLSTREAM_POSTURE_DETECTION,
3839
)
3940
from utils.plotter import plot_bodyparts, plot_metadata_frame
4041
from utils.poser import (
@@ -336,11 +337,14 @@ def get_pose_mp(input_q, output_q):
336337
scmap, locref, pose = get_pose(
337338
frame, config, sess, inputs, outputs
338339
)
339-
peaks = find_local_peaks_new(
340-
scmap, locref, ANIMALS_NUMBER, config
341-
)
340+
if USE_DLSTREAM_POSTURE_DETECTION:
341+
""" This is a legacy function that was used in earlier versions"""
342+
peaks = find_local_peaks_new(
343+
scmap, locref, ANIMALS_NUMBER, config
344+
)
342345
# Use the line below to use raw DLC output rather then DLStream optimization
343-
#peaks = pose
346+
else:
347+
peaks = pose
344348
if MODEL_ORIGIN == "MADLC":
345349
peaks = get_ma_pose(frame, config, sess, inputs, outputs)
346350
analysis_time = time.time() - start_time

convert_classifier.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,52 @@
11
import pickle
2+
import joblib
23
import os
34
from pure_sklearn.map import convert_estimator
45

56

6-
def load_classifier(path_to_sav):
7+
def load_classifier_SIMBA(path_to_sav):
78
"""Load saved classifier"""
89
file = open(path_to_sav, "rb")
910
classifier = pickle.load(file)
1011
file.close()
1112
return classifier
1213

14+
def load_classifier_BSOID(path_to_sav):
15+
"""Load saved classifier"""
16+
file = open(path_to_sav, "rb")
17+
clf = joblib.load(file)
18+
file.close()
19+
return clf
1320

14-
def convert_classifier(path):
21+
22+
def convert_classifier(path, origin: str):
1523
# convert to pure python estimator
16-
print("Loading classifier...")
17-
clf = load_classifier(path)
1824
dir_path = os.path.dirname(path)
1925
filename = os.path.basename(path)
2026
filename, _ = filename.split(".")
21-
clf_pure_predict = convert_estimator(clf)
22-
with open(dir_path + "/" + filename + "_pure.sav", "wb") as f:
23-
pickle.dump(clf_pure_predict, f)
27+
28+
print("Loading classifier...")
29+
if origin.lower() == 'simba':
30+
clf = load_classifier_SIMBA(path)
31+
clf_pure_predict = convert_estimator(clf)
32+
with open(dir_path + "/" + filename + "_pure.sav", "wb") as f:
33+
pickle.dump(clf_pure_predict, f)
34+
35+
elif origin.lower() == 'bsoid':
36+
clf_pack = load_classifier_BSOID(path)
37+
# bsoid exported classfier has format [a, b, c, clf, d, e]
38+
clf_pure_predict = convert_estimator(clf_pack[3])
39+
clf_pack[3] =clf_pure_predict
40+
with open(dir_path + "/" + filename + "_pure.sav", "wb") as f:
41+
joblib.dump(clf_pack, f)
42+
else:
43+
raise ValueError(f'{origin} is not a valid classifier origin.')
44+
2445
print(f"Converted Classifier {filename}")
2546

2647

2748
if __name__ == "__main__":
49+
50+
"""Converted BSOID Classifiers are not integrated yet, although you can already convert them here"""
2851
path_to_classifier = "PATH_TO_CLASSIFIER"
29-
convert_classifier(path_to_classifier)
52+
convert_classifier(path_to_classifier, origin= 'SIMBA')

experiments/custom/classifier.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import time
1111
import pickle
1212

13-
from utils.configloader import PATH_TO_CLASSIFIER, TIME_WINDOW
13+
from utils.configloader import PATH_TO_CLASSIFIER, TIME_WINDOW, FRAMERATE
1414
from experiments.custom.featureextraction import (
1515
SimbaFeatureExtractor,
1616
SimbaFeatureExtractorStandard14bp,
@@ -117,6 +117,7 @@ def classify(self, features):
117117
Adapted from BSOID; https://github.com/YttriLab/B-SOID
118118
"""
119119
labels_fslow = []
120+
# TODO: adapt to pure version of BSOID Classifier
120121
for i in range(0, len(features)):
121122
labels = self._classifier.predict(features[i].T)
122123
labels_fslow.append(labels)
@@ -183,21 +184,26 @@ def simba_feat_classifier_pool_run(input_q: mp.Queue, output_q: mp.Queue):
183184

184185

185186
def bsoid_feat_classifier_pool_run(input_q: mp.Queue, output_q: mp.Queue):
186-
feature_extractor = BsoidFeatureExtractor(TIME_WINDOW)
187+
feature_extractor = BsoidFeatureExtractor()
187188
classifier = BsoidClassifier() # initialize classifier
188189
while True:
189190
skel_time_window = None
190191
feature_id = 0
191192
if input_q.full():
192193
skel_time_window, feature_id = input_q.get()
193194
if skel_time_window is not None:
194-
start_time = time.time()
195+
start_time_feat = time.time()
195196
features = feature_extractor.extract_features(skel_time_window)
197+
end_time_feat = time.time()
198+
start_time_clf = time.time()
196199
last_prob = classifier.classify(features)
197200
output_q.put((last_prob, feature_id))
198201
end_time = time.time()
199-
# print("Classification time: {:.2f} msec".format((end_time-start_time)*1000))
200-
# print("Feature ID: "+ feature_id)
202+
print("Feature Extraction time: {:.2f} msec".format((end_time_feat - start_time_feat) * 1000))
203+
print("Classification time: {:.2f} msec".format((end_time-start_time_clf)*1000))
204+
print("Total time: {:.2f} msec".format((end_time-start_time_feat)*1000))
205+
print("Current motif: ", *last_prob)
206+
print("Feature ID: "+ str(feature_id))
201207
else:
202208
pass
203209

@@ -275,6 +281,8 @@ def pass_time_window(self, skel_time_window: tuple, debug: bool = False):
275281
:param debug bool: reporting of process + feature id to identify discrepancies in processing sequence
276282
"""
277283
for process in self._process_pool:
284+
#if the process is not already busy, feed it some new input and break the loop
285+
#this should only be valid the first time the process is fed.
278286
if not process["running"]:
279287
if process["input"].empty():
280288
process["input"].put(skel_time_window)
@@ -287,6 +295,8 @@ def pass_time_window(self, skel_time_window: tuple, debug: bool = False):
287295
)
288296
break
289297

298+
#if the process is busy but finished (has output), feed it some new input.
299+
#this should be the normal case
290300
elif process["input"].empty() and process["output"].full():
291301
process["input"].put(skel_time_window)
292302
if debug:
@@ -306,7 +316,12 @@ def get_result(self, debug: bool = False):
306316
"""
307317
result = (None, 0)
308318
for process in self._process_pool:
319+
#check if process is finished
309320
if process["output"].full():
321+
#take result and break the loop. This way two simultaneously finished processes are emptied in sequence
322+
#rather then overwriting the results of each other
323+
#the disadvantage is that the result won't be the latest classification but in the next in sequential order (to the last).
324+
#the advantage is that we won't miss any results this way and have "consistent" latency, which is the intended behavior.
310325
result = process["output"].get()
311326
if debug:
312327
print("Output", process["process"].name, "ID: " + str(result[1]))
@@ -381,13 +396,15 @@ def simba_classifier_run(input_q: mp.Queue, output_q: mp.Queue):
381396

382397

383398
def bsoid_classifier_run(input_q: mp.Queue, output_q: mp.Queue):
399+
#takes features from input and feeds them into classifier. Outputs classification
384400
classifier = BsoidClassifier() # initialize classifier
385401
while True:
386402
features = None
387403
if input_q.full():
388404
features = input_q.get()
389405
if features is not None:
390406
start_time = time.time()
407+
#last prob is a missleading name that comes from a binary classifier. B-SOID's output is a cluster id rather then the probability.
391408
last_prob = classifier.classify(features)
392409
output_q.put((last_prob))
393410
end_time = time.time()

0 commit comments

Comments
 (0)