@@ -380,16 +380,15 @@ def get_pose_mp(input_q, output_q):
380380 if input_q .full ():
381381 index , frame = input_q .get ()
382382 start_time = time .time ()
383- input_frame = frame [:, :, ::- 1 ]
384- # this is weird, but without it, it does not seem to work...
385- frames = np .array ([input_frame ])
386- prediction = sleap_model .predict (frames [[0 ]], batch_size = 1 )
387- # check if this is multiple animal instances or single animal model
388- if sleap_model .name == "single_instance_inference_model" :
389- # get predictions (wrap it again, so the behavior is the same for both model types)
390- peaks = np .array ([prediction ["peaks" ][0 , :]])
391- else :
392- peaks = prediction ["instance_peaks" ][0 , :]
383+ # Make sure image is (1, height, width, channels) and uint8
384+ # (height, width) -> (height, width, 1)
385+ frame = np .expand_dims (frame , axis = - 1 ) if frame .ndim == 2 else frame
386+ # (height, width, channels) -> (1, height, width, channels)
387+ frame = np .expand_dims (frame , axis = 0 ) if frame .ndim == 3 else frame
388+ # predict_on_batch is MUCH faster as it does not retrace the model graph for same size inputs
389+ pred = sleap_model .predict_on_batch (frame )
390+ peaks = pred ["instance_peaks" ][0 ] # (n_poses, n_nodes, 2)
391+
393392 analysis_time = time .time () - start_time
394393 output_q .put ((index , peaks , analysis_time ))
395394 else :
0 commit comments