22from os .path import join , basename
33from multiprocessing import Pool , cpu_count
44import numpy as np
5+ import warnings
56
67from ctc_metrics .metrics import (
78 valid , det , seg , tra , ct , tf , bc , raw_division_metrics , cca , mota , hota , idf1 , chota , mtml , faf ,
@@ -171,6 +172,19 @@ def calculate_metrics(
171172 traj ["labels_comp_merged" ] = new_labels
172173 traj ["mapped_comp_merged" ] = new_mapped
173174
175+ # Check if a manual i was defined for BC(i)
176+ max_i_for_bci = 3
177+ for m in metrics :
178+ if m .startswith ("BC(" ):
179+ try :
180+ i = int (m [3 :- 1 ])
181+ if i > max_i_for_bci :
182+ max_i_for_bci = i
183+ if "BC" not in metrics :
184+ metrics .append ("BC" )
185+ except ValueError :
186+ warnings .warn (f"{ m } is not a valid metric identifier!." )
187+
174188 # Prepare intermediate results
175189 graph_operations = {}
176190 if "DET" in metrics or "TRA" in metrics :
@@ -225,7 +239,7 @@ def calculate_metrics(
225239 traj ["labels_ref" ], traj ["mapped_ref" ], traj ["mapped_comp" ])
226240
227241 if "BC" in metrics :
228- for i in range (4 ):
242+ for i in range (max_i_for_bci + 1 ):
229243 tp , fp , fn = raw_division_metrics (comp_tracks , ref_tracks ,
230244 traj ["mapped_ref" ], traj ["mapped_comp" ],
231245 i = i )
@@ -243,13 +257,13 @@ def calculate_metrics(
243257
244258 if "CT" in metrics and "BC" in metrics and \
245259 "CCA" in metrics and "TF" in metrics :
246- for i in range (4 ):
260+ for i in range (max_i_for_bci + 1 ):
247261 results [f"BIO({ i } )" ] = bio (
248262 results ["CT" ], results ["TF" ],
249263 results [f"BC({ i } )" ], results ["CCA" ])
250264
251265 if "BIO" in results and "LNK" in results :
252- for i in range (4 ):
266+ for i in range (max_i_for_bci + 1 ):
253267 results [f"OP_CLB({ i } )" ] = op_clb (
254268 results ["LNK" ], results [f"BIO({ i } )" ])
255269
@@ -365,7 +379,7 @@ def parse_args():
365379 parser .add_argument ('--tra' , action = "store_true" )
366380 parser .add_argument ('--ct' , action = "store_true" )
367381 parser .add_argument ('--tf' , action = "store_true" )
368- parser .add_argument ('--bc' , action = "store_true" )
382+ parser .add_argument ('--bc' , type = int , default = 0 )
369383 parser .add_argument ('--cca' , action = "store_true" )
370384 parser .add_argument ('--mota' , action = "store_true" )
371385 parser .add_argument ('--hota' , action = "store_true" )
@@ -391,7 +405,7 @@ def main():
391405 ("TRA" , args .tra ),
392406 ("CT" , args .ct ),
393407 ("TF" , args .tf ),
394- ("BC" , args .bc ),
408+ (f "BC( { args . bc } ) " , args .bc ),
395409 ("CCA" , args .cca ),
396410 ("MOTA" , args .mota ),
397411 ("HOTA" , args .hota ),
0 commit comments