Skip to content

Commit d2712fe

Browse files
authored
Merge pull request #3942 from h-mayorquin/better_assertion_for_shift_times
Add informative assertions to `slice_time`
2 parents 2914abd + b9374f9 commit d2712fe

3 files changed

Lines changed: 54 additions & 7 deletions

File tree

src/spikeinterface/core/baserecording.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -541,10 +541,10 @@ def shift_times(self, shift: int | float, segment_index: int | None = None) -> N
541541
else:
542542
segments_to_shift = (segment_index,)
543543

544-
for idx in segments_to_shift:
545-
rs = self._recording_segments[idx]
544+
for segment_index in segments_to_shift:
545+
rs = self._recording_segments[segment_index]
546546

547-
if self.has_time_vector(segment_index=idx):
547+
if self.has_time_vector(segment_index=segment_index):
548548
rs.time_vector += shift
549549
else:
550550
new_start_time = 0 + shift if rs.t_start is None else rs.t_start + shift
@@ -781,11 +781,37 @@ def time_slice(self, start_time: float | None, end_time: float | None) -> BaseRe
781781
BaseRecording
782782
A new recording object with only samples between start_time and end_time
783783
"""
784+
num_segments = self.get_num_segments()
785+
assert (
786+
num_segments == 1
787+
), f"Time slicing is only supported for single segment recordings. Found {num_segments} segments."
788+
789+
t_start = self.get_start_time()
790+
t_end = self.get_end_time()
791+
792+
if start_time is not None:
793+
t_start = self.get_start_time()
794+
t_start_too_early = start_time < t_start
795+
if t_start_too_early:
796+
raise ValueError(f"start_time {start_time} is before the start time {t_start} of the recording.")
797+
t_start_too_late = start_time > t_end
798+
if t_start_too_late:
799+
raise ValueError(f"start_time {start_time} is after the end time {t_end} of the recording.")
800+
start_frame = self.time_to_sample_index(start_time)
801+
else:
802+
start_frame = None
784803

785-
assert self.get_num_segments() == 1, "Time slicing is only supported for single segment recordings."
804+
if end_time is not None:
805+
t_end_too_early = end_time < t_start
806+
if t_end_too_early:
807+
raise ValueError(f"end_time {end_time} is before the start time {t_start} of the recording.")
786808

787-
start_frame = self.time_to_sample_index(start_time) if start_time else None
788-
end_frame = self.time_to_sample_index(end_time) if end_time else None
809+
t_end_too_late = end_time > t_end
810+
if t_end_too_late:
811+
raise ValueError(f"end_time {end_time} is after the end time {t_end} of the recording.")
812+
end_frame = self.time_to_sample_index(end_time)
813+
else:
814+
end_frame = None
789815

790816
return self.frame_slice(start_frame=start_frame, end_frame=end_frame)
791817

src/spikeinterface/core/generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def generate_recording(
3333
set_probe: bool | None = True,
3434
ndim: int | None = 2,
3535
seed: int | None = None,
36-
) -> NumpySorting:
36+
) -> BaseRecording:
3737
"""
3838
Generate a lazy recording object.
3939
Useful for testing API and algos.

src/spikeinterface/core/tests/test_baserecording.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,27 @@ def test_time_slice():
412412
assert np.allclose(sliced_recording_times.get_traces(), sliced_recording_frames.get_traces())
413413

414414

415+
def test_out_of_range_time_slice():
416+
recording = generate_recording(durations=[0.100]) # duration = 0.1 s
417+
recording.shift_times(1.0) # shifts start time to 1.0 s, end time to 1.1 s
418+
419+
# start_time before recording
420+
with pytest.raises(ValueError, match="start_time .* is before the start time"):
421+
recording.time_slice(start_time=0, end_time=None)
422+
423+
# end_time before start of recording
424+
with pytest.raises(ValueError, match="end_time .* is before the start time"):
425+
recording.time_slice(start_time=None, end_time=0.5)
426+
427+
# start_time after end of recording
428+
with pytest.raises(ValueError, match="start_time .* is after the end time"):
429+
recording.time_slice(start_time=2.0, end_time=None)
430+
431+
# end_time after end of recording
432+
with pytest.raises(ValueError, match="end_time .* is after the end time"):
433+
recording.time_slice(start_time=None, end_time=2.0)
434+
435+
415436
def test_time_slice_with_time_vector():
416437

417438
# Case with time vector

0 commit comments

Comments
 (0)