Skip to content

Commit fd4bb89

Browse files
authored
Merge pull request #3934 from chrishalcrow/fix-interp-dtype
Bug fix: cast corrected motion recording to float
2 parents 7c99530 + 07548c7 commit fd4bb89

2 files changed

Lines changed: 10 additions & 3 deletions

File tree

src/spikeinterface/preprocessing/motion.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -443,9 +443,6 @@ def compute_motion(
443443
t1 = time.perf_counter()
444444
run_times["estimate_motion"] = t1 - t0
445445

446-
if recording.get_dtype().kind != "f":
447-
recording = recording.astype("float32")
448-
449446
motion_info = dict(
450447
parameters=parameters,
451448
run_times=run_times,
@@ -554,6 +551,9 @@ def correct_motion(
554551
**job_kwargs,
555552
)
556553

554+
if recording.get_dtype().kind != "f":
555+
recording = recording.astype("float32")
556+
557557
recording_corrected = interpolate_motion(recording, motion, **interpolate_motion_kwargs)
558558

559559
if not output_motion and not output_motion_info:

src/spikeinterface/preprocessing/tests/test_motion.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,13 @@ def test_estimate_and_correct_motion(create_cache_folder):
3232
assert motion_info_loaded["motion"] == motion_info["motion"]
3333

3434

35+
def test_estimate_and_correct_motion_int():
36+
37+
rec = generate_recording(durations=[30.0], num_channels=12).astype(int)
38+
rec_corrected = correct_motion(rec, estimate_motion_kwargs={"win_step_um": 50, "win_scale_um": 100})
39+
assert rec_corrected.get_dtype().kind == "f"
40+
41+
3542
def test_get_motion_parameters_preset():
3643
from pprint import pprint
3744

0 commit comments

Comments
 (0)