Skip to content

Commit 0b69fec

Browse files
authored
Merge pull request #3794 from h-mayorquin/scale_to_scale_to_uV
`return_scaled` to `return_in_uV`
2 parents dfbc584 + 50ebf7c commit 0b69fec

21 files changed

Lines changed: 321 additions & 125 deletions

src/spikeinterface/core/analyzer_extension_core.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def _run(self, verbose=False, **job_kwargs):
194194
self.nbefore,
195195
self.nafter,
196196
mode=mode,
197-
return_scaled=self.sorting_analyzer.return_scaled,
197+
return_in_uV=self.sorting_analyzer.return_in_uV,
198198
file_path=file_path,
199199
dtype=self.params["dtype"],
200200
sparsity_mask=sparsity_mask,
@@ -216,7 +216,7 @@ def _set_params(
216216
if dtype is None:
217217
dtype = recording.get_dtype()
218218

219-
if np.issubdtype(dtype, np.integer) and self.sorting_analyzer.return_scaled:
219+
if np.issubdtype(dtype, np.integer) and self.sorting_analyzer.return_in_uV:
220220
dtype = "float32"
221221

222222
dtype = np.dtype(dtype)
@@ -427,7 +427,7 @@ def _run(self, verbose=False, **job_kwargs):
427427
# retrieve spike vector and the sampling
428428
some_spikes = self.sorting_analyzer.get_extension("random_spikes").get_random_spikes()
429429

430-
return_scaled = self.sorting_analyzer.return_scaled
430+
return_in_uV = self.sorting_analyzer.return_in_uV
431431

432432
return_std = "std" in self.params["operators"]
433433
output = estimate_templates_with_accumulator(
@@ -436,7 +436,7 @@ def _run(self, verbose=False, **job_kwargs):
436436
unit_ids,
437437
self.nbefore,
438438
self.nafter,
439-
return_scaled=return_scaled,
439+
return_in_uV=return_in_uV,
440440
return_std=return_std,
441441
verbose=verbose,
442442
**job_kwargs,
@@ -648,7 +648,7 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save
648648
channel_ids=self.sorting_analyzer.channel_ids,
649649
unit_ids=unit_ids,
650650
probe=self.sorting_analyzer.get_probe(),
651-
is_scaled=self.sorting_analyzer.return_scaled,
651+
is_scaled=self.sorting_analyzer.return_in_uV,
652652
)
653653
else:
654654
raise ValueError("`outputs` must be 'numpy' or 'Templates'")
@@ -732,7 +732,7 @@ def _merge_extension_data(
732732
def _run(self, verbose=False, **job_kwargs):
733733
self.data["noise_levels"] = get_noise_levels(
734734
self.sorting_analyzer.recording,
735-
return_scaled=self.sorting_analyzer.return_scaled,
735+
return_in_uV=self.sorting_analyzer.return_in_uV,
736736
**self.params,
737737
**job_kwargs,
738738
)

src/spikeinterface/core/baserecording.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,8 @@ def get_traces(
295295
end_frame: int | None = None,
296296
channel_ids: list | np.array | tuple | None = None,
297297
order: "C" | "F" | None = None,
298-
return_scaled: bool = False,
298+
return_scaled: bool | None = None,
299+
return_in_uV: bool = False,
299300
cast_unsigned: bool = False,
300301
) -> np.ndarray:
301302
"""Returns traces from recording.
@@ -312,7 +313,11 @@ def get_traces(
312313
The channel ids. If None, all channels are used, default: None
313314
order : "C" | "F" | None, default: None
314315
The order of the traces ("C" | "F"). If None, traces are returned as they are
315-
return_scaled : bool, default: False
316+
return_scaled : bool | None, default: None
317+
DEPRECATED. Use return_in_uV instead.
318+
If True and the recording has scaling (gain_to_uV and offset_to_uV properties),
319+
traces are scaled to uV
320+
return_in_uV : bool, default: False
316321
If True and the recording has scaling (gain_to_uV and offset_to_uV properties),
317322
traces are scaled to uV
318323
cast_unsigned : bool, default: False
@@ -327,7 +332,7 @@ def get_traces(
327332
Raises
328333
------
329334
ValueError
330-
If return_scaled is True, but recording does not have scaled traces
335+
If return_in_uV is True, but recording does not have scaled traces
331336
"""
332337
segment_index = self._check_segment_index(segment_index)
333338
channel_indices = self.ids_to_indices(channel_ids, prefer_slice=True)
@@ -351,15 +356,24 @@ def get_traces(
351356
traces = traces.astype(f"int{2 * (dtype.itemsize) * 8}") - 2 ** (nbits - 1)
352357
traces = traces.astype(f"int{dtype.itemsize * 8}")
353358

354-
if return_scaled:
359+
# Handle deprecated return_scaled parameter
360+
if return_scaled is not None:
361+
warnings.warn(
362+
"`return_scaled` is deprecated and will be removed in a future version. Use `return_in_uV` instead.",
363+
category=DeprecationWarning,
364+
stacklevel=2,
365+
)
366+
return_in_uV = return_scaled
367+
368+
if return_in_uV:
355369
if not self.has_scaleable_traces():
356370
if self._dtype.kind == "f":
357371
# here we do not truely have scale but we assume this is scaled
358372
# this helps a lot for simulated data
359373
pass
360374
else:
361375
raise ValueError(
362-
"This recording does not support return_scaled=True (need gain_to_uV and offset_"
376+
"This recording does not support return_in_uV=True (need gain_to_uV and offset_"
363377
"to_uV properties)"
364378
)
365379
else:

src/spikeinterface/core/basesnippets.py

Lines changed: 75 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,17 +97,51 @@ def get_snippets(
9797
indices=None,
9898
segment_index: Union[int, None] = None,
9999
channel_ids: Union[list, None] = None,
100-
return_scaled=False,
100+
return_scaled: bool | None = None,
101+
return_in_uV: bool = False,
101102
):
103+
"""
104+
Return the snippets, optionally for a subset of samples and/or channels
105+
106+
Parameters
107+
----------
108+
indices : list[int], default: None
109+
Indices of the snippets to return. If None, all snippets are returned.
110+
segment_index : Union[int, None], default: None
111+
The segment index to get snippets from. If snippets is multi-segment, it is required.
112+
channel_ids : Union[list, None], default: None
113+
The channel ids. If None, all channels are used.
114+
return_scaled : bool | None, default: None
115+
DEPRECATED. Use return_in_uV instead.
116+
If True and the snippets has scaling (gain_to_uV and offset_to_uV properties),
117+
snippets are scaled to uV
118+
return_in_uV : bool, default: False
119+
If True and the snippets has scaling (gain_to_uV and offset_to_uV properties),
120+
snippets are scaled to uV
121+
122+
Returns
123+
-------
124+
np.array
125+
The snippets (num_snippets, num_samples, num_channels)
126+
"""
102127
segment_index = self._check_segment_index(segment_index)
103128
spts = self._snippets_segments[segment_index]
104129
channel_indices = self.ids_to_indices(channel_ids, prefer_slice=True)
105130
wfs = spts.get_snippets(indices, channel_indices=channel_indices)
106131

107-
if return_scaled:
132+
# Handle deprecated return_scaled parameter
133+
if return_scaled is not None:
134+
warn(
135+
"`return_scaled` is deprecated and will be removed in a future version. Use `return_in_uV` instead.",
136+
category=DeprecationWarning,
137+
stacklevel=2,
138+
)
139+
return_in_uV = return_scaled
140+
141+
if return_in_uV:
108142
if not self.has_scaleable_traces():
109143
raise ValueError(
110-
"These snippets do not support return_scaled=True (need gain_to_uV and offset_" "to_uV properties)"
144+
"These snippets do not support return_in_uV=True (need gain_to_uV and offset_" "to_uV properties)"
111145
)
112146
else:
113147
gains = self.get_property("gain_to_uV")
@@ -123,13 +157,49 @@ def get_snippets_from_frames(
123157
start_frame: Union[int, None] = None,
124158
end_frame: Union[int, None] = None,
125159
channel_ids: Union[list, None] = None,
126-
return_scaled=False,
160+
return_scaled: bool | None = None,
161+
return_in_uV: bool = False,
127162
):
163+
"""
164+
Return the snippets from frames, optionally for a subset of samples and/or channels
165+
166+
Parameters
167+
----------
168+
segment_index : Union[int, None], default: None
169+
The segment index to get snippets from. If snippets is multi-segment, it is required.
170+
start_frame : Union[int, None], default: None
171+
The start frame. If None, 0 is used.
172+
end_frame : Union[int, None], default: None
173+
The end frame. If None, the number of samples in the segment is used.
174+
channel_ids : Union[list, None], default: None
175+
The channel ids. If None, all channels are used.
176+
return_scaled : bool | None, default: None
177+
DEPRECATED. Use return_in_uV instead.
178+
If True and the snippets has scaling (gain_to_uV and offset_to_uV properties),
179+
snippets are scaled to uV
180+
return_in_uV : bool, default: False
181+
If True and the snippets has scaling (gain_to_uV and offset_to_uV properties),
182+
snippets are scaled to uV
183+
184+
Returns
185+
-------
186+
np.array
187+
The snippets (num_snippets, num_samples, num_channels)
188+
"""
128189
segment_index = self._check_segment_index(segment_index)
129190
spts = self._snippets_segments[segment_index]
130191
indices = spts.frames_to_indices(start_frame, end_frame)
131192

132-
return self.get_snippets(indices, channel_ids=channel_ids, return_scaled=return_scaled)
193+
# Handle deprecated return_scaled parameter
194+
if return_scaled is not None:
195+
warn(
196+
"`return_scaled` is deprecated and will be removed in a future version. Use `return_in_uV` instead.",
197+
category=DeprecationWarning,
198+
stacklevel=2,
199+
)
200+
return_in_uV = return_scaled
201+
202+
return self.get_snippets(indices, channel_ids=channel_ids, return_in_uV=return_in_uV)
133203

134204
def _save(self, format="binary", **save_kwargs):
135205
raise NotImplementedError

src/spikeinterface/core/recording_tools.py

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,8 @@ def write_to_h5_dataset_format(
380380
chunk_memory="500M",
381381
verbose=False,
382382
auto_cast_uint=True,
383-
return_scaled=False,
383+
return_scaled=None,
384+
return_in_uV=False,
384385
):
385386
"""
386387
Save the traces of a recording extractor in an h5 dataset.
@@ -414,7 +415,11 @@ def write_to_h5_dataset_format(
414415
If True, output is verbose (when chunks are used)
415416
auto_cast_uint : bool, default: True
416417
If True, unsigned integers are automatically cast to int if the specified dtype is signed
417-
return_scaled : bool, default: False
418+
return_scaled : bool | None, default: None
419+
DEPRECATED. Use return_in_uV instead.
420+
If True and the recording has scaling (gain_to_uV and offset_to_uV properties),
421+
traces are dumped to uV
422+
return_in_uV : bool, default: False
418423
If True and the recording has scaling (gain_to_uV and offset_to_uV properties),
419424
traces are dumped to uV
420425
"""
@@ -459,7 +464,15 @@ def write_to_h5_dataset_format(
459464
chunk_size = ensure_chunk_size(recording, chunk_size=chunk_size, chunk_memory=chunk_memory, n_jobs=1)
460465

461466
if chunk_size is None:
462-
traces = recording.get_traces(cast_unsigned=cast_unsigned, return_scaled=return_scaled)
467+
# Handle deprecated return_scaled parameter
468+
if return_scaled is not None:
469+
warnings.warn(
470+
"`return_scaled` is deprecated and will be removed in a future version. Use `return_in_uV` instead.",
471+
category=DeprecationWarning,
472+
)
473+
return_in_uV = return_scaled
474+
475+
traces = recording.get_traces(cast_unsigned=cast_unsigned, return_scaled=return_in_uV)
463476
if dtype is not None:
464477
traces = traces.astype(dtype_file, copy=False)
465478
if time_axis == 1:
@@ -484,7 +497,7 @@ def write_to_h5_dataset_format(
484497
start_frame=i * chunk_size,
485498
end_frame=min((i + 1) * chunk_size, num_frames),
486499
cast_unsigned=cast_unsigned,
487-
return_scaled=return_scaled,
500+
return_scaled=return_in_uV if return_scaled is None else return_scaled,
488501
)
489502
chunk_frames = traces.shape[0]
490503
if dtype is not None:
@@ -599,7 +612,9 @@ def get_random_recording_slices(
599612
return recording_slices
600613

601614

602-
def get_random_data_chunks(recording, return_scaled=False, concatenated=True, **random_slices_kwargs):
615+
def get_random_data_chunks(
616+
recording, return_scaled=None, return_in_uV=False, concatenated=True, **random_slices_kwargs
617+
):
603618
"""
604619
Extract random chunks across segments.
605620
@@ -636,7 +651,7 @@ def get_random_data_chunks(recording, return_scaled=False, concatenated=True, **
636651
start_frame=start_frame,
637652
end_frame=end_frame,
638653
segment_index=segment_index,
639-
return_scaled=return_scaled,
654+
return_scaled=return_in_uV if return_scaled is None else return_scaled,
640655
)
641656
chunk_list.append(traces_chunk)
642657

@@ -713,17 +728,18 @@ def _noise_level_chunk(segment_index, start_frame, end_frame, worker_ctx):
713728
return noise_levels
714729

715730

716-
def _noise_level_chunk_init(recording, return_scaled, method):
731+
def _noise_level_chunk_init(recording, return_in_uV, method):
717732
worker_ctx = {}
718733
worker_ctx["recording"] = recording
719-
worker_ctx["return_scaled"] = return_scaled
734+
worker_ctx["return_scaled"] = return_in_uV
720735
worker_ctx["method"] = method
721736
return worker_ctx
722737

723738

724739
def get_noise_levels(
725740
recording: "BaseRecording",
726-
return_scaled: bool = True,
741+
return_scaled: bool | None = None,
742+
return_in_uV: bool = True,
727743
method: Literal["mad", "std"] = "mad",
728744
force_recompute: bool = False,
729745
random_slices_kwargs: dict = {},
@@ -745,7 +761,10 @@ def get_noise_levels(
745761
746762
recording : BaseRecording
747763
The recording extractor to get noise levels
748-
return_scaled : bool
764+
return_scaled : bool | None, default: None
765+
DEPRECATED. Use return_in_uV instead.
766+
If True, returned noise levels are scaled to uV
767+
return_in_uV : bool, default: True
749768
If True, returned noise levels are scaled to uV
750769
method : "mad" | "std", default: "mad"
751770
The method to use to estimate noise levels
@@ -763,7 +782,15 @@ def get_noise_levels(
763782
Noise levels for each channel
764783
"""
765784

766-
if return_scaled:
785+
# Handle deprecated return_scaled parameter
786+
if return_scaled is not None:
787+
warnings.warn(
788+
"`return_scaled` is deprecated and will be removed in a future version. Use `return_in_uV` instead.",
789+
category=DeprecationWarning,
790+
)
791+
return_in_uV = return_scaled
792+
793+
if return_in_uV:
767794
key = f"noise_level_{method}_scaled"
768795
else:
769796
key = f"noise_level_{method}_raw"
@@ -797,7 +824,7 @@ def append_noise_chunk(res):
797824

798825
func = _noise_level_chunk
799826
init_func = _noise_level_chunk_init
800-
init_args = (recording, return_scaled, method)
827+
init_args = (recording, return_in_uV, method)
801828
executor = ChunkRecordingExecutor(
802829
recording,
803830
func,

0 commit comments

Comments
 (0)