|
1 | 1 | import pytest |
2 | 2 |
|
3 | 3 | import numpy as np |
4 | | -from spikeinterface import generate_ground_truth_recording, create_sorting_analyzer, load, SortingAnalyzer, Templates |
| 4 | +from spikeinterface import ( |
| 5 | + generate_ground_truth_recording, |
| 6 | + create_sorting_analyzer, |
| 7 | + load, |
| 8 | + SortingAnalyzer, |
| 9 | + Templates, |
| 10 | + aggregate_channels, |
| 11 | +) |
5 | 12 | from spikeinterface.core.motion import Motion |
6 | 13 | from spikeinterface.core.generate import generate_unit_locations, generate_templates |
7 | 14 | from spikeinterface.core.testing import check_recordings_equal, check_sortings_equal |
@@ -176,6 +183,24 @@ def test_load_motion(tmp_path, generate_motion_object): |
176 | 183 | assert motion == motion_loaded |
177 | 184 |
|
178 | 185 |
|
| 186 | +def test_load_aggregate_recording_from_json(generate_recording_sorting, tmp_path): |
| 187 | + """ |
| 188 | + Save, then load an aggregated recording using its provenance.json file. |
| 189 | + """ |
| 190 | + |
| 191 | + recording, _ = generate_recording_sorting |
| 192 | + |
| 193 | + recording.set_property("group", [0, 0, 1, 1]) |
| 194 | + list_of_recs = list(recording.split_by("group").values()) |
| 195 | + aggregated_rec = aggregate_channels(list_of_recs) |
| 196 | + |
| 197 | + recording_path = tmp_path / "aggregated_recording" |
| 198 | + aggregated_rec.save_to_folder(folder=recording_path) |
| 199 | + loaded_rec = load(recording_path / "provenance.json", base_folder=recording_path) |
| 200 | + |
| 201 | + assert np.all(loaded_rec.get_property("group") == recording.get_property("group")) |
| 202 | + |
| 203 | + |
179 | 204 | @pytest.mark.streaming_extractors |
180 | 205 | @pytest.mark.skipif(not HAVE_S3, reason="s3fs not installed") |
181 | 206 | def test_remote_recording(): |
|
0 commit comments