Skip to content

Commit 41debb9

Browse files
authored
Merge branch 'main' into improve_frame_slicing_assertion
2 parents 9b8b56f + 2914abd commit 41debb9

7 files changed

Lines changed: 36 additions & 33 deletions

File tree

doc/how_to/combine_recordings.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Combine recordings in SpikeInterface
44
In this tutorial we will walk through combining multiple recording objects. Sometimes this occurs due to hardware
55
settings (e.g. Intan software has a default setting of new files every 1 minute) or the experimenter decides to
66
split their recording into multiple files for different experimental conditions. If the probe has not been moved,
7-
however, then during sorting it would likely make sense to combine these individual reocrding objects into one
7+
however, then during sorting it would likely make sense to combine these individual recording objects into one
88
recording object.
99

1010
**Why Combine?**

src/spikeinterface/core/baserecording.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,8 @@ def shift_times(self, shift: int | float, segment_index: int | None = None) -> N
547547
if self.has_time_vector(segment_index=idx):
548548
rs.time_vector += shift
549549
else:
550-
rs.t_start += shift
550+
new_start_time = 0 + shift if rs.t_start is None else rs.t_start + shift
551+
rs.t_start = new_start_time
551552

552553
def sample_index_to_time(self, sample_ind, segment_index=None):
553554
"""
@@ -749,9 +750,9 @@ def frame_slice(self, start_frame: int | None, end_frame: int | None) -> BaseRec
749750
Parameters
750751
----------
751752
start_frame : int, optional
752-
The start frame, if not provided it is set to 0
753+
Start frame index. If None, defaults to the beginning of the recording (frame 0).
753754
end_frame : int, optional
754-
The end frame, it not provided it is set to the total number of samples
755+
End frame index. If None, defaults to the last frame of the recording.
755756
756757
Returns
757758
-------
@@ -771,9 +772,9 @@ def time_slice(self, start_time: float | None, end_time: float | None) -> BaseRe
771772
Parameters
772773
----------
773774
start_time : float, optional
774-
The start time in seconds. If not provided it is set to 0.
775+
Start time in seconds. If None, defaults to the beginning of the recording.
775776
end_time : float, optional
776-
The end time in seconds. If not provided it is set to the total duration.
777+
End time in seconds. If None, defaults to the end of the recording.
777778
778779
Returns
779780
-------

src/spikeinterface/core/tests/test_time_handling.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,3 +435,12 @@ def _get_sorting_with_recording_attached(self, recording_for_durations, recordin
435435
assert sorting.has_recording()
436436

437437
return sorting
438+
439+
440+
def test_shift_times_with_None_as_t_start():
441+
"""Ensures we can shift times even when t_stat is None which is interpeted as zero"""
442+
recording = generate_recording(num_channels=4, durations=[10])
443+
444+
assert recording._recording_segments[0].t_start is None
445+
recording.shift_times(shift=1.0) # Shift by one seconds should not generate an error
446+
assert recording.get_start_time() == 1.0

src/spikeinterface/extractors/nwbextractors.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -759,10 +759,13 @@ def _fetch_recording_segment_info_backend(self, file, cache, load_time_vector, s
759759
if "starting_time" in electrical_series.keys():
760760
t_start = electrical_series["starting_time"][()]
761761
sampling_frequency = electrical_series["starting_time"].attrs["rate"]
762+
timestamps = None
762763
elif "timestamps" in electrical_series.keys():
763764
timestamps = electrical_series["timestamps"][:]
764765
t_start = timestamps[0]
765766
sampling_frequency = 1.0 / np.median(np.diff(timestamps[:samples_for_rate_estimation]))
767+
else:
768+
raise ValueError("TimeSeries must have either starting_time or timestamps")
766769

767770
if load_time_vector and timestamps is not None:
768771
times_kwargs = dict(time_vector=electrical_series["timestamps"])
@@ -1572,6 +1575,8 @@ def _fetch_recording_segment_info(self, file, cache, load_time_vector, samples_f
15721575
timestamps = timeseries.timestamps
15731576
sampling_frequency = 1.0 / np.median(np.diff(timestamps[:samples_for_rate_estimation]))
15741577
t_start = timestamps[0]
1578+
else:
1579+
raise ValueError("TimeSeries must have either starting_time or timestamps")
15751580

15761581
if load_time_vector and timestamps is not None:
15771582
times_kwargs = dict(time_vector=timestamps)

src/spikeinterface/postprocessing/spike_amplitudes.py

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,38 +17,19 @@ class ComputeSpikeAmplitudes(AnalyzerExtension):
1717
Computes the spike amplitudes.
1818
1919
Needs "templates" to be computed first.
20-
Localize spikes in 2D or 3D with several methods given the template.
20+
Computes spike amplitudes from the template's peak channel for every spike.
2121
2222
Parameters
2323
----------
2424
sorting_analyzer : SortingAnalyzer
2525
A SortingAnalyzer object
26-
ms_before : float, default: 0.5
27-
The left window, before a peak, in milliseconds
28-
ms_after : float, default: 0.5
29-
The right window, after a peak, in milliseconds
30-
spike_retriver_kwargs : dict
31-
A dictionary to control the behavior for getting the maximum channel for each spike
32-
This dictionary contains:
33-
* channel_from_template: bool, default: True
34-
For each spike is the maximum channel computed from template or re estimated at every spikes
35-
channel_from_template = True is old behavior but less acurate
36-
channel_from_template = False is slower but more accurate
37-
* radius_um: float, default: 50
38-
In case channel_from_template=False, this is the radius to get the true peak
39-
* peak_sign, default: "neg"
40-
In case channel_from_template=False, this is the peak sign.
41-
method : "center_of_mass" | "monopolar_triangulation" | "grid_convolution", default: "center_of_mass"
42-
The localization method to use
43-
**method_kwargs : dict, default: {}
44-
Kwargs which are passed to the method function. These can be found in the docstrings of `compute_center_of_mass`, `compute_grid_convolution` and `compute_monopolar_triangulation`.
45-
outputs : "numpy" | "by_unit", default: "numpy"
46-
The output format, either concatenated as numpy array or separated on a per unit basis
26+
peak_sign : "neg" | "pos" | "both", default: "neg"
27+
Sign of the template to compute extremum channel used to retrieve spike amplitudes.
4728
4829
Returns
4930
-------
50-
spike_locations: np.array
51-
All locations for all spikes and all units are concatenated
31+
spike_amplitudes: np.array
32+
All amplitudes for all spikes and all units are concatenated (along time, like in spike vector)
5233
5334
"""
5435

src/spikeinterface/preprocessing/motion.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -443,9 +443,6 @@ def compute_motion(
443443
t1 = time.perf_counter()
444444
run_times["estimate_motion"] = t1 - t0
445445

446-
if recording.get_dtype().kind != "f":
447-
recording = recording.astype("float32")
448-
449446
motion_info = dict(
450447
parameters=parameters,
451448
run_times=run_times,
@@ -554,6 +551,9 @@ def correct_motion(
554551
**job_kwargs,
555552
)
556553

554+
if recording.get_dtype().kind != "f":
555+
recording = recording.astype("float32")
556+
557557
recording_corrected = interpolate_motion(recording, motion, **interpolate_motion_kwargs)
558558

559559
if not output_motion and not output_motion_info:

src/spikeinterface/preprocessing/tests/test_motion.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,13 @@ def test_estimate_and_correct_motion(create_cache_folder):
3232
assert motion_info_loaded["motion"] == motion_info["motion"]
3333

3434

35+
def test_estimate_and_correct_motion_int():
36+
37+
rec = generate_recording(durations=[30.0], num_channels=12).astype(int)
38+
rec_corrected = correct_motion(rec, estimate_motion_kwargs={"win_step_um": 50, "win_scale_um": 100})
39+
assert rec_corrected.get_dtype().kind == "f"
40+
41+
3542
def test_get_motion_parameters_preset():
3643
from pprint import pprint
3744

0 commit comments

Comments
 (0)