7070 "nearest_chans" : 8 ,
7171 "nearest_templates" : 35 ,
7272 "max_channel_distance" : 5 ,
73- "templates_from_data" : False ,
7473 "n_templates" : 10 ,
7574 "n_pcs" : 3 ,
7675 "Th_single_ch" : 4 ,
109108 # max_peels is not affecting the results in this short dataset
110109 PARAMETERS_NOT_AFFECTING_RESULTS .append ("max_peels" )
111110
111+ if parse (kilosort .__version__ ) >= parse ("4.0.33" ):
112+ PARAMS_TO_TEST_DICT .update ({"cluster_neighbors" : 11 })
113+ PARAMETERS_NOT_AFFECTING_RESULTS .append ("cluster_neighbors" )
114+
112115
113116PARAMS_TO_TEST = list (PARAMS_TO_TEST_DICT .keys ())
114117
@@ -178,11 +181,11 @@ def _save_ground_truth_recording(self, recording, tmp_path):
178181 """
179182 paths = {
180183 "session_scope_tmp_path" : tmp_path ,
181- "recording_path" : tmp_path / "my_test_recording" ,
184+ "recording_path" : tmp_path / "my_test_recording" / "traces_cached_seg0.raw" ,
182185 "probe_path" : tmp_path / "my_test_probe.prb" ,
183186 }
184187
185- recording .save (folder = paths ["recording_path" ], overwrite = True )
188+ recording .save (folder = paths ["recording_path" ]. parent , overwrite = True )
186189
187190 probegroup = recording .get_probegroup ()
188191 write_prb (paths ["probe_path" ].as_posix (), probegroup )
@@ -214,7 +217,7 @@ def test_default_settings_all_represented(self):
214217 tested_keys += additional_non_tested_keys
215218
216219 for param_key in DEFAULT_SETTINGS :
217- if param_key not in ["n_chan_bin" , "fs" , "tmin" , "tmax" ]:
220+ if param_key not in ["n_chan_bin" , "fs" , "tmin" , "tmax" , "templates_from_data" ]:
218221 assert param_key in tested_keys , f"param: { param_key } in DEFAULT SETTINGS but not tested."
219222
220223 def test_spikeinterface_defaults_against_kilsort (self ):
@@ -234,8 +237,11 @@ def test_spikeinterface_defaults_against_kilsort(self):
234237
235238 # Testing Arguments ###
236239 def test_set_files_arguments (self ):
240+ expected_arguments = ["settings" , "filename" , "probe" , "probe_name" , "data_dir" , "results_dir" , "bad_channels" ]
241+ if parse (kilosort .__version__ ) >= parse ("4.0.34" ):
242+ expected_arguments += ["shank_idx" ]
237243 self ._check_arguments (
238- set_files , [ "settings" , "filename" , "probe" , "probe_name" , "data_dir" , "results_dir" , "bad_channels" ]
244+ set_files , expected_arguments
239245 )
240246
241247 def test_initialize_ops_arguments (self ):
@@ -533,33 +539,60 @@ def test_kilosort4_skip_preprocessing_correction(self, tmp_path, monkeypatch, pa
533539 kilosort_output_dir = tmp_path / "kilosort_output_dir"
534540 spikeinterface_output_dir = tmp_path / "spikeinterface_output_dir"
535541
536- def monkeypatch_filter_function (self , X , ops = None , ibatch = None ):
537- """
538- This is a direct copy of the kilosort io.BinaryFiltered.filter
539- function, with hp_filter and whitening matrix code sections, and
540- comments removed. This is the easiest way to monkeypatch (tried a few approaches)
541- """
542- if self .chan_map is not None :
543- X = X [self .chan_map ]
544-
545- if self .invert_sign :
546- X = X * - 1
547-
548- X = X - X .mean (1 ).unsqueeze (1 )
549- if self .do_CAR :
550- X = X - torch .median (X , 0 )[0 ]
551-
552- if self .hp_filter is not None :
553- pass
554-
555- if self .artifact_threshold < np .inf :
556- if torch .any (torch .abs (X ) >= self .artifact_threshold ):
557- return torch .zeros_like (X )
558-
559- if self .whiten_mat is not None :
560- pass
561- return X
562-
542+ if parse (kilosort .__version__ ) >= parse ("4.0.33" ):
543+ def monkeypatch_filter_function (self , X , ops = None , ibatch = None , skip_preproc = False ):
544+ """
545+ This is a direct copy of the kilosort io.BinaryFiltered.filter
546+ function, with hp_filter and whitening matrix code sections, and
547+ comments removed. This is the easiest way to monkeypatch (tried a few approaches)
548+ """
549+ if self .chan_map is not None :
550+ X = X [self .chan_map ]
551+
552+ if self .invert_sign :
553+ X = X * - 1
554+
555+ X = X - X .mean (1 ).unsqueeze (1 )
556+ if self .do_CAR :
557+ X = X - torch .median (X , 0 )[0 ]
558+
559+ if self .hp_filter is not None :
560+ pass
561+
562+ if self .artifact_threshold < np .inf :
563+ if torch .any (torch .abs (X ) >= self .artifact_threshold ):
564+ return torch .zeros_like (X )
565+
566+ if self .whiten_mat is not None :
567+ pass
568+ return X
569+ else :
570+ def monkeypatch_filter_function (self , X , ops = None , ibatch = None ):
571+ """
572+ This is a direct copy of the kilosort io.BinaryFiltered.filter
573+ function, with hp_filter and whitening matrix code sections, and
574+ comments removed. This is the easiest way to monkeypatch (tried a few approaches)
575+ """
576+ if self .chan_map is not None :
577+ X = X [self .chan_map ]
578+
579+ if self .invert_sign :
580+ X = X * - 1
581+
582+ X = X - X .mean (1 ).unsqueeze (1 )
583+ if self .do_CAR :
584+ X = X - torch .median (X , 0 )[0 ]
585+
586+ if self .hp_filter is not None :
587+ pass
588+
589+ if self .artifact_threshold < np .inf :
590+ if torch .any (torch .abs (X ) >= self .artifact_threshold ):
591+ return torch .zeros_like (X )
592+
593+ if self .whiten_mat is not None :
594+ pass
595+ return X
563596 monkeypatch .setattr ("kilosort.io.BinaryFiltered.filter" , monkeypatch_filter_function )
564597
565598 ks_settings , _ , ks_format_probe = self ._get_kilosort_native_settings (recording , paths , param_key , param_value )
@@ -620,7 +653,7 @@ def _get_kilosort_native_settings(self, recording, paths, param_key, param_value
620653 are through the function, these are split here.
621654 """
622655 settings = {
623- "data_dir " : paths ["recording_path" ],
656+ "filename " : paths ["recording_path" ],
624657 "n_chan_bin" : recording .get_num_channels (),
625658 "fs" : recording .get_sampling_frequency (),
626659 }
0 commit comments