Skip to content

Commit 6661cc9

Browse files
author
Nikhil Chandra
committed
Moved previous changes to kilosortbase.py to kilosort4.py because they apply only to Kilosort4.
1 parent 2a992ba commit 6661cc9

2 files changed

Lines changed: 43 additions & 37 deletions

File tree

src/spikeinterface/sorters/external/kilosort4.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def check_sorter_version(cls):
138138

139139
@classmethod
140140
def _setup_recording(cls, recording, sorter_output_folder, params, verbose):
141-
KilosortBase._generate_channel_map_file(recording, sorter_output_folder, format="json")
141+
cls._setup_json_probe_map(recording, sorter_output_folder)
142142

143143
if params["use_binary_file"]:
144144
if not recording.binary_compatible_with(time_axis=0, file_paths_length=1):
@@ -382,3 +382,28 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
382382
@classmethod
383383
def _get_result_from_folder(cls, sorter_output_folder):
384384
return KilosortBase._get_result_from_folder(sorter_output_folder)
385+
386+
@classmethod
387+
def _setup_json_probe_map(cls, recording, sorter_output_folder):
388+
"""Create a JSON probe map file for Kilosort4."""
389+
from kilosort.io import save_probe
390+
import numpy as np
391+
392+
groups = recording.get_channel_groups()
393+
positions = np.array(recording.get_channel_locations())
394+
if positions.shape[1] != 2:
395+
raise RuntimeError("3D 'location' are not supported. Set 2D locations instead.")
396+
397+
nchan = recording.get_num_channels()
398+
xcoords = ([p[0] for p in positions],)
399+
ycoords = ([p[1] for p in positions],)
400+
kcoords = (groups,)
401+
402+
probe = {
403+
'chanMap': np.arange(nchan),
404+
'xc': np.array(xcoords[0]).astype(float).squeeze(),
405+
'yc': np.array(ycoords[0]).astype(float).squeeze(),
406+
'kcoords': np.array(kcoords).astype(float).squeeze(),
407+
'n_chan': nchan,
408+
}
409+
save_probe(probe, str(sorter_output_folder / "chanMap.json"))

src/spikeinterface/sorters/external/kilosortbase.py

Lines changed: 17 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,9 @@ class KilosortBase:
2828
requires_binary_data = True
2929

3030
@staticmethod
31-
def _generate_channel_map_file(recording, sorter_output_folder, format="mat"):
31+
def _generate_channel_map_file(recording, sorter_output_folder):
3232
"""
3333
This function generates channel map data for kilosort and saves as `chanMap.mat`
34-
or `chanMap.json`, depending on the format chosen by the caller.
3534
3635
Loading example in Matlab (shouldn't be assigned to a variable):
3736
>> load('/path/to/sorter_output_folder/chanMap.mat');
@@ -42,15 +41,9 @@ def _generate_channel_map_file(recording, sorter_output_folder, format="mat"):
4241
The recording to generate the channel map file
4342
sorter_output_folder: pathlib.Path
4443
Path object to save `chanMap.mat` file
45-
format: str
46-
The format of the channel map file. Currently, 'mat'
47-
and 'json' are supported ('mat' is the default).
4844
"""
49-
if format not in ("mat", "json"):
50-
raise ValueError("Only 'mat' and 'json' formats are supported for Kilosort channel maps.")
51-
5245
# prepare electrode positions for this group (only one group, the split is done in basesorter)
53-
groups = recording.get_channel_groups()
46+
groups = [1] * recording.get_num_channels()
5447
positions = np.array(recording.get_channel_locations())
5548
if positions.shape[1] != 2:
5649
raise RuntimeError("3D 'location' are not supported. Set 2D locations instead")
@@ -60,33 +53,21 @@ def _generate_channel_map_file(recording, sorter_output_folder, format="mat"):
6053
ycoords = ([p[1] for p in positions],)
6154
kcoords = (groups,)
6255

63-
if format == "mat":
64-
channel_map = {}
65-
channel_map["Nchannels"] = nchan
66-
channel_map["connected"] = np.full((nchan, 1), True)
67-
channel_map["chanMap0ind"] = np.arange(nchan)
68-
channel_map["chanMap"] = channel_map["chanMap0ind"] + 1
69-
70-
channel_map["xcoords"] = np.array(xcoords).astype(float)
71-
channel_map["ycoords"] = np.array(ycoords).astype(float)
72-
channel_map["kcoords"] = np.array(kcoords).astype(float)
73-
74-
sample_rate = recording.get_sampling_frequency()
75-
channel_map["fs"] = float(sample_rate)
76-
import scipy.io
77-
78-
scipy.io.savemat(str(sorter_output_folder / "chanMap.mat"), channel_map)
79-
elif format == "json":
80-
from kilosort.io import save_probe
81-
82-
probe = {
83-
"chanMap": np.arange(nchan),
84-
"xc": np.array(xcoords[0]).astype(float).squeeze(),
85-
"yc": np.array(ycoords[0]).astype(float).squeeze(),
86-
"kcoords": np.array(kcoords).astype(float).squeeze(),
87-
"n_chan": nchan,
88-
}
89-
save_probe(probe, str(sorter_output_folder / "chanMap.json"))
56+
channel_map = {}
57+
channel_map["Nchannels"] = nchan
58+
channel_map["connected"] = np.full((nchan, 1), True)
59+
channel_map["chanMap0ind"] = np.arange(nchan)
60+
channel_map["chanMap"] = channel_map["chanMap0ind"] + 1
61+
62+
channel_map["xcoords"] = np.array(xcoords).astype(float)
63+
channel_map["ycoords"] = np.array(ycoords).astype(float)
64+
channel_map["kcoords"] = np.array(kcoords).astype(float)
65+
66+
sample_rate = recording.get_sampling_frequency()
67+
channel_map["fs"] = float(sample_rate)
68+
import scipy.io
69+
70+
scipy.io.savemat(str(sorter_output_folder / "chanMap.mat"), channel_map)
9071

9172
@classmethod
9273
def _generate_ops_file(cls, recording, params, sorter_output_folder, binary_file_path):

0 commit comments

Comments
 (0)