Skip to content

Commit 2bf9c91

Browse files
Carole SudreCarole Sudre
authored andcommitted
Allowing for squeezing if number of dimensions differ but same size
1 parent f60a3be commit 2bf9c91

1 file changed

Lines changed: 12 additions & 1 deletion

File tree

MetricsReloaded/processes/mixed_measures_processes.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -705,8 +705,9 @@ def __init__(
705705
self.connectivity_type = connectivity_type
706706
ndim = 0
707707
self.pixdim = pixdim
708+
self.squeeze_ref_and_pred_to_size()
708709
if len(self.pred)>0:
709-
ndim = np.asarray(self.pred[0]).ndim
710+
ndim = np.asarray(self.ref[0]).ndim
710711
if len(self.pixdim) == 0 and ndim>0:
711712
self.pixdim = np.ones([ndim])
712713
elif ndim>0:
@@ -722,6 +723,16 @@ def __init__(
722723
if pred_proba is None or pred_proba[0] is None:
723724
self.flag_valid_proba = False
724725

726+
def squeeze_ref_and_pred_to_size(self):
727+
for i,(p,r) in enumerate(zip(self.pred, self.ref)):
728+
if np.size(np.asarray(p)) == np.size(np.asarray(r)) and np.asarray(p).ndim != np.asarray(r).ndim:
729+
warnings.warn("There is a dimensional mismatch between pred and ref despite same size")
730+
p = np.squeeze(np.asarray(p))
731+
r = np.squeeze(np.asarray(r))
732+
self.pred[i] = p
733+
self.ref[i] = r
734+
return
735+
725736
def per_label_dict(self):
726737
list_bin = []
727738
list_mt = []

0 commit comments

Comments
 (0)