Skip to content

Commit 573cb2c

Browse files
committed
update to SLEAP 1.13
1 parent 63483b7 commit 573cb2c

2 files changed

Lines changed: 10 additions & 11 deletions

File tree

DeepLabStream.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -380,16 +380,15 @@ def get_pose_mp(input_q, output_q):
380380
if input_q.full():
381381
index, frame = input_q.get()
382382
start_time = time.time()
383-
input_frame = frame[:, :, ::-1]
384-
# this is weird, but without it, it does not seem to work...
385-
frames = np.array([input_frame])
386-
prediction = sleap_model.predict(frames[[0]], batch_size=1)
387-
# check if this is multiple animal instances or single animal model
388-
if sleap_model.name == "single_instance_inference_model":
389-
# get predictions (wrap it again, so the behavior is the same for both model types)
390-
peaks = np.array([prediction["peaks"][0, :]])
391-
else:
392-
peaks = prediction["instance_peaks"][0, :]
383+
# Make sure image is (1, height, width, channels) and uint8
384+
# (height, width) -> (height, width, 1)
385+
frame = np.expand_dims(frame, axis=-1) if frame.ndim == 2 else frame
386+
# (height, width, channels) -> (1, height, width, channels)
387+
frame = np.expand_dims(frame, axis=0) if frame.ndim == 3 else frame
388+
# predict_on_batch is MUCH faster as it does not retrace the model graph for same size inputs
389+
pred = sleap_model.predict_on_batch(frame)
390+
peaks = pred["instance_peaks"][0] # (n_poses, n_nodes, 2)
391+
393392
analysis_time = time.time() - start_time
394393
output_q.put((index, peaks, analysis_time))
395394
else:

utils/poser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ def load_dlc_live():
371371

372372

373373
def load_sleap():
374-
model = load_model(MODEL_PATH)
374+
model = load_model(MODEL_PATH, batch_size=1)
375375
model.inference_model
376376
return model.inference_model
377377

0 commit comments

Comments
 (0)