Skip to content

Commit bfec35a

Browse files
authored
Merge pull request #57 from Project-MONAI/aggregation
Aggregation
2 parents cb38dfc + 2bf9c91 commit bfec35a

6 files changed

Lines changed: 208 additions & 51 deletions

File tree

MetricsReloaded/metrics/pairwise_measures.py

Lines changed: 68 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -286,15 +286,29 @@ def __init__(
286286
self.flag_empty = empty
287287
self.flag_empty_pred = False
288288
self.flag_empty_ref = False
289-
if np.sum(self.pred) == 0:
289+
if int(np.sum(self.pred)) == 0:
290290
self.flag_empty_pred = True
291-
if np.sum(self.ref) == 0:
291+
if int(np.sum(self.ref)) == 0:
292292
self.flag_empty_ref = True
293293
self.measures = measures if measures is not None else self.measures_dict
294294
self.connectivity = connectivity_type
295295
self.pixdim = pixdim
296+
self.worse_dist = self.calculate_worse_dist()
296297
self.dict_args = dict_args
297298

299+
def calculate_worse_dist(self):
300+
shape = self.ref.shape
301+
pixdim = self.pixdim
302+
if pixdim is not None:
303+
mult_sp = shape * np.asarray(pixdim)
304+
else:
305+
mult_sp = shape
306+
print(mult_sp)
307+
max_dist = np.sqrt(np.sum(np.square(mult_sp)))
308+
print(max_dist)
309+
return max_dist
310+
311+
298312
def __fp_map(self):
299313
"""
300314
This function calculates the false positive map
@@ -479,6 +493,12 @@ def youden_index(self):
479493
480494
:return: youden_index
481495
"""
496+
if self.n_pos_ref() == 0:
497+
warnings.warn("Reference is empty - sensitivity is not defined")
498+
return np.nan
499+
if self.n_neg_ref() == 0:
500+
warnings.warn("Reference is fully positive, specificity is not defined")
501+
return np.nan
482502
youden_index = self.specificity() + self.sensitivity() - 1
483503
return youden_index
484504

@@ -533,6 +553,12 @@ def balanced_accuracy(self):
533553
534554
:return: balanced accuracy
535555
"""
556+
if self.n_neg_ref() == 0:
557+
warnings.warn('Reference All positive - speciicity not defined')
558+
return np.nan
559+
if self.n_pos_ref() == 0:
560+
warnings.warn("Reference All negative - sensitivity not defined")
561+
return np.nan
536562
balanced_accuracy = 0.5 * self.sensitivity() + 0.5 * self.specificity()
537563
return balanced_accuracy
538564

@@ -564,6 +590,9 @@ def false_positive_rate(self):
564590
565591
:return: false_positive_rate
566592
"""
593+
if self.n_neg_ref() == 0:
594+
warnings.warn("All positive in reference - FPR not defined")
595+
return np.nan
567596
false_positive_rate = self.fp() / self.n_neg_ref()
568597
return false_positive_rate
569598

@@ -578,6 +607,12 @@ def normalised_expected_cost(self):
578607

579608
prior_background = (self.tn() + self.fp()) / (np.size(self.ref))
580609
prior_foreground = (self.tp() + self.fn()) / np.size(self.ref)
610+
if self.n_pos_ref() == 0:
611+
warnings.warn("Reference empty - r_fn not defined")
612+
return np.nan
613+
if self.n_neg_ref() == 0:
614+
warnings.warn("Reference all positive - r_fp not defined")
615+
return np.nan
581616

582617
if "cost_fn" in self.dict_args.keys():
583618
c_fn = self.dict_args["cost_fn"]
@@ -681,7 +716,7 @@ def positive_likelihood_ratio(self):
681716
warnings.warn("reference empty - sensitivity not defined")
682717
return np.nan
683718
if self.specificity() == 1:
684-
warnings.warn("Perfect specifiicty - likelihood ratio not defined")
719+
warnings.warn("Perfect specificity - likelihood ratio not defined")
685720
return np.nan
686721
positive_likelihood_ratio = numerator / denominator
687722
return positive_likelihood_ratio
@@ -718,8 +753,8 @@ def positive_predictive_value(self):
718753
warnings.warn("ref and prediction empty ppv not defined")
719754
return np.nan
720755
else:
721-
warnings.warn("prediction empty, ppv not defined but set to 0")
722-
return 0
756+
warnings.warn("prediction empty, ppv not defined")
757+
return np.nan
723758
positive_predictive_value = self.tp() / (self.tp() + self.fp())
724759
return positive_predictive_value
725760

@@ -734,7 +769,7 @@ def recall(self):
734769
return np.nan
735770
if self.n_pos_pred() == 0:
736771
warnings.warn(
737-
"prediction is empty but ref not, recall not defined but set to 0"
772+
"prediction is empty but ref not, recall set to 0"
738773
)
739774
return 0
740775
recall = self.tp() / (self.tp() + self.fn())
@@ -760,8 +795,8 @@ def dsc(self):
760795
numerator = 2 * self.tp()
761796
denominator = self.n_pos_pred() + self.n_pos_ref()
762797
if denominator == 0:
763-
warnings.warn("Both Prediction and Reference are empty - set to 1 as correct solution even if not defined")
764-
return 1
798+
warnings.warn("Both Prediction and Reference are empty - not defined - can be set to 1 when aggregating")
799+
return np.nan
765800
else:
766801
dsc = numerator / denominator
767802
return dsc
@@ -796,15 +831,15 @@ def fbeta(self):
796831
np.square(beta) * self.positive_predictive_value() + self.recall()
797832
)
798833
if np.isnan(denominator):
799-
if self.fp() + self.fn() > 0:
834+
if self.fp() + self.fn() > 0: # Would occur if reference empty and prediction not
800835
return 0
801836
else:
802-
return 1 # Potentially modify to nan
837+
return np.nan # Potentially modify to nan
803838
elif denominator == 0:
804839
if self.fp() + self.fn() > 0:
805840
return 0
806841
else:
807-
return 1 # Potentially modify to nan
842+
return np.nan # Potentially modify to nan
808843
else:
809844
fbeta = numerator / denominator
810845
return fbeta
@@ -857,9 +892,9 @@ def negative_predictive_value(self):
857892
return np.nan # Potentially modify to 1
858893
else:
859894
warnings.warn(
860-
"Nothing negative in pred but should be NPV not defined but set to 0"
895+
"Nothing negative in pred but should be NPV not defined set to nan - possibly set to 0 in aggregation"
861896
)
862-
return 0
897+
return np.nan
863898
negative_predictive_value = self.tn() / (self.fn() + self.tn())
864899
return negative_predictive_value
865900

@@ -1071,18 +1106,19 @@ def centreline_dsc(self):
10711106
10721107
:return: cDSC
10731108
"""
1074-
if self.n_pos_pred == 0 and self.n_pos_ref == 0:
1075-
warnings.warn("Both reference and prediction are empty - setting to max")
1076-
return 1
1077-
top_prec = self.topology_precision()
1078-
top_sens = self.topology_sensitivity()
1079-
numerator = 2 * top_sens * top_prec
1080-
denominator = top_sens + top_prec
1081-
if np.isnan(top_sens) or np.isnan(top_sens):
1082-
warnings.warn("Topology sensitivity or precision not defined")
1109+
if int(self.n_pos_pred()) == 0 and int(self.n_pos_ref()) == 0:
1110+
warnings.warn("Both reference and prediction are empty - setting to nan - should be changed to max in aggregation")
10831111
return np.nan
1084-
cDSC = numerator / denominator
1085-
return cDSC
1112+
else:
1113+
top_prec = self.topology_precision()
1114+
top_sens = self.topology_sensitivity()
1115+
numerator = 2 * top_sens * top_prec
1116+
denominator = top_sens + top_prec
1117+
if np.isnan(top_sens) or np.isnan(top_sens):
1118+
warnings.warn("Topology sensitivity or precision not defined")
1119+
return np.nan
1120+
cDSC = numerator / denominator
1121+
return cDSC
10861122

10871123
def boundary_iou(self):
10881124
"""
@@ -1105,8 +1141,8 @@ def boundary_iou(self):
11051141
else:
11061142
distance = 1
11071143
if int(self.n_pos_ref()) == 0 and int(self.n_pos_pred()) == 0:
1108-
warnings.warn("Both prediction and reference empty - setting to max for boudnary ioU")
1109-
return 1
1144+
warnings.warn("Both prediction and reference empty - return nan but setting to max for boudnary ioU in aggregation")
1145+
return np.nan
11101146
else:
11111147
border_ref = MorphologyOps(self.ref, self.connectivity).border_map()
11121148
distance_border_ref = ndimage.distance_transform_edt(1 - border_ref)
@@ -1185,8 +1221,8 @@ def normalised_surface_distance(self):
11851221
warnings.warn('No value set up for NSD tolerance - default to 1')
11861222
tau = 1
11871223
if int(self.n_pos_pred()) == 0 and int(self.n_pos_ref()) == 0 :
1188-
warnings.warn("Both reference and prediction are empty - setting to best")
1189-
return 1
1224+
warnings.warn("Both reference and prediction are empty - setting to best in aggregation but returning nan here")
1225+
return np.nan
11901226
else:
11911227
dist_ref, dist_pred, border_ref, border_pred = self.border_distance()
11921228
reg_ref = np.where(
@@ -1201,7 +1237,8 @@ def normalised_surface_distance(self):
12011237
denominator = np.sum(border_ref) + np.sum(border_pred)
12021238
# print(numerator, denominator, tau)
12031239
return numerator / denominator
1204-
1240+
1241+
@CacheFunctionOutput
12051242
def measured_distance(self):
12061243
"""
12071244
This functions calculates the average symmetric distance and the
@@ -1218,8 +1255,8 @@ def measured_distance(self):
12181255
warnings.warn('Percentile not specified in options for Hausdorff distance - default set to 95')
12191256
perc = 95
12201257
if np.sum(self.pred + self.ref) == 0:
1221-
warnings.warn("Prediction and reference empty - distances set to 0")
1222-
return 0, 0, 0, 0
1258+
warnings.warn("Prediction and reference empty -not defined - need to set to 0 in aggregation ")
1259+
return np.nan, np.nan, np.nan, np.nan
12231260
if np.sum(self.pred) == 0 and np.sum(self.ref)>0:
12241261
warnings.warn("Prediction empty but reference not empty - need to set to worse case in aggregation")
12251262
return np.nan, np.nan, np.nan, np.nan

MetricsReloaded/metrics/prob_pairwise_measures.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ def __init__(
7474
self.ref = ref_proba
7575
self.case = case
7676
self.flag_empty = empty
77+
self.flag_ref_empty = True if int(self.n_pos_ref()) == 0 else False
78+
self.flag_pred_empty = True if int(self.n_pos_pred()) == 0 else False
7779
self.dict_args = dict_args
7880
self.measures = measures if measures is not None else self.measures_dict
7981

@@ -96,6 +98,10 @@ def tn_thr(self, thresh):
9698
@CacheFunctionOutput
9799
def n_pos_ref(self):
98100
return np.sum(self.ref)
101+
102+
@CacheFunctionOutput
103+
def n_pos_pred(self):
104+
return np.sum(self.pred)
99105

100106
@CacheFunctionOutput
101107
def n_neg_ref(self):

MetricsReloaded/processes/mixed_measures_processes.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
"MultiLabelPairwiseMeasures",
7474
]
7575

76+
list_distance = ['masd','assd','hd','hd_perc']
7677

7778
class MixedLocSegPairwiseMeasure(object):
7879
"""
@@ -704,8 +705,9 @@ def __init__(
704705
self.connectivity_type = connectivity_type
705706
ndim = 0
706707
self.pixdim = pixdim
708+
self.squeeze_ref_and_pred_to_size()
707709
if len(self.pred)>0:
708-
ndim = np.asarray(self.pred[0]).ndim
710+
ndim = np.asarray(self.ref[0]).ndim
709711
if len(self.pixdim) == 0 and ndim>0:
710712
self.pixdim = np.ones([ndim])
711713
elif ndim>0:
@@ -721,6 +723,16 @@ def __init__(
721723
if pred_proba is None or pred_proba[0] is None:
722724
self.flag_valid_proba = False
723725

726+
def squeeze_ref_and_pred_to_size(self):
727+
for i,(p,r) in enumerate(zip(self.pred, self.ref)):
728+
if np.size(np.asarray(p)) == np.size(np.asarray(r)) and np.asarray(p).ndim != np.asarray(r).ndim:
729+
warnings.warn("There is a dimensional mismatch between pred and ref despite same size")
730+
p = np.squeeze(np.asarray(p))
731+
r = np.squeeze(np.asarray(r))
732+
self.pred[i] = p
733+
self.ref[i] = r
734+
return
735+
724736
def per_label_dict(self):
725737
list_bin = []
726738
list_mt = []
@@ -764,6 +776,16 @@ def per_label_dict(self):
764776
dict_bin = BPM.to_dict_meas()
765777
dict_bin["label"] = lab
766778
dict_bin["case"] = name
779+
dict_bin["worse_dist"] = BPM.worse_dist
780+
if any(x in self.measures_binary for x in list_distance):
781+
dict_bin["worse_dist"] = BPM.worse_dist
782+
dict_bin["check_empty"] = "None"
783+
if BPM.flag_empty_pred and BPM.flag_empty_ref:
784+
dict_bin["check_empty"] = "Both"
785+
elif BPM.flag_empty_ref:
786+
dict_bin["check_empty"] = "Ref"
787+
elif BPM.flag_empty_pred:
788+
dict_bin["check_empty"] = "Pred"
767789
list_bin.append(dict_bin)
768790
if self.flag_valid_proba and len(self.measures_mt)>0:
769791
PPM = ProbabilityPairwiseMeasures(
@@ -772,6 +794,7 @@ def per_label_dict(self):
772794
measures=self.measures_mt,
773795
dict_args=self.dict_args,
774796
)
797+
775798
dict_mt = PPM.to_dict_meas()
776799
dict_mt["label"] = lab
777800
dict_mt["case"] = name

MetricsReloaded/processes/overall_process.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,8 @@
227227

228228
MAX = 1000
229229

230+
LIST_DISTANCE = ['hd','masd','assd','hd_perc']
231+
230232
WORSE = {
231233
"ap": 0,
232234
"auroc": 0,
@@ -257,6 +259,38 @@
257259
"nsd": 0,
258260
}
259261

262+
BEST = {
263+
"ap": 1,
264+
"auroc": 1,
265+
"froc": 1,
266+
"sens@spec": 1,
267+
"sens@ppv": 1,
268+
"spec@sens": 1,
269+
"fppi@sens": 0,
270+
"ppv@sens": 1,
271+
"sens@fppi": 1,
272+
"fbeta": 1,
273+
"ec":0,
274+
"accuracy": 1,
275+
"ba": 1,
276+
"lr+": 1,
277+
"youden_ind": 1,
278+
"mcc": 1,
279+
"wck": 1,
280+
"cohens_kappa": 1,
281+
"iou": 1,
282+
"dsc": 1,
283+
"cldice": 1,
284+
"masd": 0,
285+
"assd": 0,
286+
"hd_perc": 0,
287+
"hd": 0,
288+
"boundary_iou": 1,
289+
"nsd": 1,
290+
}
291+
292+
NAN_LIST = ["iou","dsc","fbeta","masd",'cldice','hd','hd_perc','assd','boundary_iou','nsd']
293+
260294
class ProcessEvaluation(object):
261295
"""
262296
Performs the evaluation of the data stored in a pickled file according to all the measures, categories and choices of processing
@@ -484,6 +518,39 @@ def process_data(self):
484518
self.resmt = df_resmt
485519
self.resmcc = df_resmcc
486520
self.rescal = df_rescal
521+
self.create_mapping_column_nan_replaced_seg()
522+
return
523+
524+
def create_mapping_column_nan_replaced_seg(self):
525+
"""
526+
For each measure (segmentation) for which nan are possible
527+
creates an additional column in which nans are replaced by value (worse or best according to situation
528+
"""
529+
list_to_map = []
530+
for x in self.measures_boundary:
531+
if x in NAN_LIST:
532+
list_to_map.append(x)
533+
for x in self.measures_overlap:
534+
if x in NAN_LIST:
535+
list_to_map.append(x)
536+
for k in list_to_map:
537+
self.resseg[k+'_nanrep'] = self.resseg[k]
538+
539+
self.resseg[k+'_nanrep'] = np.where(np.logical_and(self.resseg[k].isna(),self.resseg['check_empty']=='Both')
540+
,BEST[k],self.resseg[k+'_nanrep'])
541+
self.resseg[k+'_nanrep'] = np.where(np.logical_and(self.resseg[k+'_nanrep'].isna(),self.resseg['check_empty']=='Ref')
542+
,WORSE[k],self.resseg[k+'_nanrep'])
543+
self.resseg[k+'_nanrep'] = np.where(np.logical_and(self.resseg[k+'_nanrep'].isna(),self.resseg['check_empty']=='Pred')
544+
,WORSE[k],self.resseg[k+'_nanrep'])
545+
self.resseg[k+'_nanrep'] = np.where(np.logical_and(self.resseg[k].isna(),k in LIST_DISTANCE)
546+
,self.resseg['worse_dist'],self.resseg[k+'_nanrep'])
547+
548+
return
549+
550+
551+
552+
553+
def identify_empty_ref(self):
487554
return
488555

489556
def complete_missing_cases(self):

0 commit comments

Comments
 (0)