Skip to content

Commit e9b39ef

Browse files
authored
Merge pull request #3829 from SpikeInterface/fix-sortbyproperty-args
BUG FIX: ensure matching of args in aggregate_channels and ChannelAggregationRecording
2 parents b8f7253 + 16dafc0 commit e9b39ef

2 files changed

Lines changed: 36 additions & 11 deletions

File tree

src/spikeinterface/core/channelsaggregationrecording.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,15 @@ class ChannelsAggregationRecording(BaseRecording):
1414
1515
"""
1616

17-
def __init__(self, recording_list_or_dict, renamed_channel_ids=None):
17+
def __init__(self, recording_list_or_dict=None, renamed_channel_ids=None, recording_list=None):
18+
19+
if recording_list is not None:
20+
warnings.warn(
21+
"`recording_list` is deprecated and will be removed in 0.105.0. Please use `recording_list_or_dict` instead.",
22+
category=DeprecationWarning,
23+
stacklevel=2,
24+
)
25+
recording_list_or_dict = recording_list
1826

1927
if isinstance(recording_list_or_dict, dict):
2028
recording_list = list(recording_list_or_dict.values())
@@ -258,12 +266,4 @@ def aggregate_channels(
258266
The aggregated recording object
259267
"""
260268

261-
if recording_list is not None:
262-
warnings.warn(
263-
"`recording_list` is deprecated and will be removed in 0.105.0. Please use `recording_list_or_dict` instead.",
264-
category=DeprecationWarning,
265-
stacklevel=2,
266-
)
267-
recording_list_or_dict = recording_list
268-
269-
return ChannelsAggregationRecording(recording_list_or_dict, renamed_channel_ids)
269+
return ChannelsAggregationRecording(recording_list_or_dict, renamed_channel_ids, recording_list)

src/spikeinterface/core/tests/test_loading.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
11
import pytest
22

33
import numpy as np
4-
from spikeinterface import generate_ground_truth_recording, create_sorting_analyzer, load, SortingAnalyzer, Templates
4+
from spikeinterface import (
5+
generate_ground_truth_recording,
6+
create_sorting_analyzer,
7+
load,
8+
SortingAnalyzer,
9+
Templates,
10+
aggregate_channels,
11+
)
512
from spikeinterface.core.motion import Motion
613
from spikeinterface.core.generate import generate_unit_locations, generate_templates
714
from spikeinterface.core.testing import check_recordings_equal, check_sortings_equal
@@ -176,6 +183,24 @@ def test_load_motion(tmp_path, generate_motion_object):
176183
assert motion == motion_loaded
177184

178185

186+
def test_load_aggregate_recording_from_json(generate_recording_sorting, tmp_path):
187+
"""
188+
Save, then load an aggregated recording using its provenance.json file.
189+
"""
190+
191+
recording, _ = generate_recording_sorting
192+
193+
recording.set_property("group", [0, 0, 1, 1])
194+
list_of_recs = list(recording.split_by("group").values())
195+
aggregated_rec = aggregate_channels(list_of_recs)
196+
197+
recording_path = tmp_path / "aggregated_recording"
198+
aggregated_rec.save_to_folder(folder=recording_path)
199+
loaded_rec = load(recording_path / "provenance.json", base_folder=recording_path)
200+
201+
assert np.all(loaded_rec.get_property("group") == recording.get_property("group"))
202+
203+
179204
@pytest.mark.streaming_extractors
180205
@pytest.mark.skipif(not HAVE_S3, reason="s3fs not installed")
181206
def test_remote_recording():

0 commit comments

Comments
 (0)