Skip to content

Commit 5b2c708

Browse files
Add script for statistical analysis WIP (#1155)
* Add script for statistical analysis WIP * Fix dataset pool names * Update scripts for statistical analysis * Add doc and more cosmetic changes --------- Co-authored-by: Anwai Archit <anwai.archit@gmail.com>
1 parent 504a7e9 commit 5b2c708

4 files changed

Lines changed: 269 additions & 44 deletions

File tree

scripts/apg_experiments/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,4 @@ The top-level folder contains scripts to evaluate other models with `micro-sam`,
2121
- `plot_qualitative.py`: Scripts to display qualitative results over all datasets.
2222
- `plot_quantitative.py`: Scripts to display quantitative results over all datasets.
2323
- `plot_util.py`: Stores related information helpful for plotting.
24+
- `statistical_analysis`: Scripts for performing statistical analysis on quantitative results computed per image.

scripts/apg_experiments/prepare_baselines.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -65,16 +65,18 @@ def run_baseline_engine(image, method, **kwargs):
6565

6666
def run_default_baselines(dataset_name, method, model_type, experiment_folder, target=None):
6767
# Prepare the results folder.
68-
res_folder = os.path.join(experiment_folder, "results")
68+
res_folder = os.path.join(experiment_folder, "results", method, model_type)
6969
inference_folder = os.path.join(experiment_folder, "inference", f"{dataset_name}_{method}_{model_type}")
7070
os.makedirs(res_folder, exist_ok=True)
7171
os.makedirs(inference_folder, exist_ok=True)
7272

73-
fnext = (target if model_type == "sam3" else model_type)
74-
csv_path = os.path.join(res_folder, f"{dataset_name}_{method}_{fnext}.csv")
73+
csv_path = os.path.join(res_folder, f"{dataset_name}.csv")
7574
if os.path.exists(csv_path):
76-
print(pd.read_csv(csv_path))
77-
print(f"The results are computed and stored at '{csv_path}'.")
75+
df = pd.read_csv(csv_path)
76+
print(df)
77+
mean_msa = df["msa"].mean()
78+
print(f"\nThe results are computed and stored at '{csv_path}'.")
79+
print(f"Mean MSA for {dataset_name}: {mean_msa:.4f}")
7880
return
7981

8082
# Get the image and label paths.
@@ -83,7 +85,7 @@ def run_default_baselines(dataset_name, method, model_type, experiment_folder, t
8385
assert isinstance(method, str)
8486
kwargs = {}
8587
if method in ["ais", "amg"]:
86-
predictor, segmenter = get_predictor_and_segmenter(model_type=model_type, segmentation_mode="amg")
88+
predictor, segmenter = get_predictor_and_segmenter(model_type=model_type, segmentation_mode=method)
8789
kwargs["predictor"] = predictor
8890
kwargs["segmenter"] = segmenter
8991
elif method == "apg":
@@ -105,7 +107,7 @@ def run_default_baselines(dataset_name, method, model_type, experiment_folder, t
105107
kwargs["processor"] = Sam3Processor(model)
106108
kwargs["prompt"] = target
107109

108-
msas, sa50s, precisions, recalls, f1s = [], [], [], [], []
110+
per_image_results = []
109111
for curr_image_path, curr_label_path in tqdm(
110112
zip(image_paths, label_paths), total=len(image_paths),
111113
desc=f"Run '{method}' baseline for '{model_type}' on '{dataset_name}'",
@@ -124,23 +126,25 @@ def run_default_baselines(dataset_name, method, model_type, experiment_folder, t
124126
fname = os.path.join(inference_folder, f"{Path(curr_image_path).stem}.tif")
125127
imageio.imwrite(fname, segmentation, compression="zlib")
126128

127-
msas.append(msa)
128-
sa50s.append(sas[0])
129-
precisions.append(stats["precision"])
130-
recalls.append(stats["recall"])
131-
f1s.append(stats["f1"])
132-
133-
results = {
134-
"mSA": np.mean(msas),
135-
"SA50": np.mean(sa50s),
136-
"Precision": np.mean(precisions),
137-
"Recall": np.mean(recalls),
138-
"F1": np.mean(f1s),
139-
}
140-
results = pd.DataFrame.from_dict([results])
141-
results.to_csv(csv_path)
142-
print(results)
143-
print(f"The results above are stored at '{csv_path}'.")
129+
# Store per-image metrics
130+
per_image_results.append({
131+
"image": os.path.basename(curr_image_path),
132+
"label": os.path.basename(curr_label_path),
133+
"msa": msa,
134+
"sa50": sas[0],
135+
"precision": stats["precision"],
136+
"recall": stats["recall"],
137+
"f1": stats["f1"],
138+
})
139+
140+
# Create DataFrame with per-image results
141+
results_df = pd.DataFrame(per_image_results)
142+
results_df.to_csv(csv_path, index=False)
143+
print(results_df)
144+
145+
# Compute and print mean MSA
146+
mean_msa = results_df["msa"].mean()
147+
print(f"\nMean MSA for {dataset_name}: {mean_msa:.4f}")
144148

145149

146150
def main(args):
Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
import numpy as np
2+
import pandas as pd
3+
import seaborn as sns
4+
import matplotlib.pyplot as plt
5+
from scipy.stats import shapiro, wilcoxon
6+
7+
8+
METRIC = "msa"
9+
CRITERION = 0.05
10+
11+
12+
def statistical_analysis_dataset(dataset, method1_path, method2_path, verbose=True):
13+
res1 = pd.read_csv(f"./results/{method1_path}/{dataset}.csv")[METRIC].values
14+
res2 = pd.read_csv(f"./results/{method2_path}/{dataset}.csv")[METRIC].values
15+
assert res1.shape == res2.shape
16+
17+
diff = res1 - res2
18+
19+
_, p_gauss = shapiro(diff)
20+
if verbose:
21+
print("P-value for gaussian distribution:", p_gauss)
22+
23+
is_better = diff.sum() > 0
24+
_, p = wilcoxon(diff, alternative="greater" if is_better else "less")
25+
is_significant = p < CRITERION
26+
27+
if verbose:
28+
print(
29+
"Hypothesis:", method1_path if is_better else method2_path, "is better than",
30+
method2_path if is_better else method1_path
31+
)
32+
print("Result:", "True" if is_significant else "False", f"(p = {p:.4f})")
33+
34+
return is_better, is_significant
35+
36+
37+
def statistical_analysis_pair(datasets, method1_path, method2_path, verbose=False):
38+
better1 = 0
39+
better2 = 0
40+
neutral = 0
41+
42+
for ds in datasets:
43+
is_better, is_significant = statistical_analysis_dataset(ds, method1_path, method2_path, verbose=verbose)
44+
if is_significant and is_better:
45+
better1 += 1
46+
elif is_significant:
47+
better2 += 1
48+
else:
49+
neutral += 1
50+
51+
assert better1 + better2 + neutral == len(datasets)
52+
if verbose:
53+
print(method1_path, "better than", method2_path, ":", better1)
54+
print(method2_path, "better than", method1_path, ":", better2)
55+
print("No difference:", neutral)
56+
return better1, better2, neutral
57+
58+
59+
def get_datasets(domain):
60+
domain_to_ds = {
61+
"fluo_cells": [
62+
"cellpose",
63+
"covid_if",
64+
"hpa",
65+
"plantseg_root",
66+
"plantseg_ovules",
67+
"pnas_arabidopsis",
68+
"tissuenet",
69+
"cellbindb",
70+
"mouse_embryo",
71+
],
72+
"fluo_nuclei": [
73+
"arvidsson",
74+
"bitdepth_nucseg",
75+
"dsb",
76+
"dynamicnuclearnet",
77+
"gonuclear",
78+
"ifnuclei",
79+
"nis3d",
80+
"parhyale_regen",
81+
"u20s",
82+
],
83+
"label_free": [
84+
"deepbacs",
85+
"deepseas",
86+
"livecell",
87+
"omnipose",
88+
"usiigaci",
89+
"vicar",
90+
"toiam",
91+
"yeaz",
92+
"segpc",
93+
],
94+
"histopatho": [
95+
"cytodark0",
96+
"ihc_tma",
97+
"monuseg",
98+
"lynsec",
99+
"nuinsseg",
100+
"pannuke",
101+
"puma",
102+
"tnbc",
103+
"cryonuseg",
104+
],
105+
}
106+
datasets = domain_to_ds[domain]
107+
assert len(datasets) == 9
108+
return datasets
109+
110+
111+
def _plot_comparison_heatmap(domain, comparison_df, title=None):
112+
# Extract wins for method in row vs method in column
113+
n = len(comparison_df)
114+
win_matrix = np.zeros((n, n))
115+
116+
for i in range(n):
117+
for j in range(n):
118+
if i != j:
119+
parts = comparison_df.iloc[i, j].split(' / ')
120+
win_matrix[i, j] = int(parts[0]) # wins for row method
121+
122+
# Masking the diagonal to exclude it from coloring.
123+
mask = np.eye(n, dtype=bool)
124+
125+
fig, ax = plt.subplots(figsize=(10, 8))
126+
sns.heatmap(
127+
win_matrix, annot=comparison_df.values, fmt='',
128+
cmap='RdYlGn', center=len(get_datasets(domain))/2,
129+
xticklabels=comparison_df.columns,
130+
yticklabels=comparison_df.index,
131+
cbar_kws={'label': 'Wins'}, ax=ax,
132+
mask=mask, linewidths=0.5, linecolor='#A9A9A9'
133+
)
134+
135+
plt.title(title)
136+
plt.tight_layout()
137+
plt.savefig(f'comparison_heatmap_{domain}.png', dpi=400, bbox_inches='tight')
138+
plt.savefig(f'comparison_heatmap_{domain}.svg', dpi=400, bbox_inches='tight')
139+
plt.close()
140+
141+
142+
def compare_all():
143+
# Sorting out the paths where the methods' results exist.
144+
method_configs = {
145+
"amg": "amg/vit_b",
146+
"ais_lm": "ais/vit_b_lm",
147+
"ais_histo": "ais/vit_b_histopathology",
148+
"cellpose3": "cellpose/cyto3",
149+
"cellpose4": "cellpose/cpsam",
150+
"cellsam": "cellsam/cellsam",
151+
"sam3": "sam3/cell",
152+
"apg_lm": "apg/vit_b_lm",
153+
"apg_histo": "apg/vit_b_histopathology",
154+
}
155+
156+
# Sorting the methods we would like to compare stuff with.
157+
domain_methods = {
158+
"fluo_cells": ["amg", "ais_lm", "cellpose3", "cellpose4", "cellsam", "sam3", "apg_lm"],
159+
"fluo_nuclei": ["amg", "ais_lm", "cellpose3", "cellpose4", "cellsam", "sam3", "apg_lm"],
160+
"label_free": ["amg", "ais_lm", "cellpose3", "cellpose4", "cellsam", "sam3", "apg_lm"],
161+
"histopatho": ["amg", "ais_histo", "cellpose3", "cellpose4", "cellsam", "sam3", "apg_histo"],
162+
}
163+
164+
# Let's map the keys to expected names.
165+
display_names = {
166+
"amg": "AMG (SAM)",
167+
"ais_lm": "AIS (μSAM)",
168+
"ais_histo": "AIS\n(PathoSAM)",
169+
"cellsam": "CellSAM",
170+
"cellpose3": "Cellpose 3",
171+
"cellpose4": "CellposeSAM",
172+
"sam3": "SAM3",
173+
"apg_lm": r"$\mathbf{APG}$" + r" $\mathbf{(μSAM)}$",
174+
"apg_histo": r"$\mathbf{APG}$" + "\n" + r"$\mathbf{(PathoSAM)}$",
175+
}
176+
177+
# Choosing custom plot titles.
178+
custom_titles = {
179+
"fluo_cells": "Fluorescence Microscopy (Cell Segmentation)",
180+
"fluo_nuclei": "Fluorescence Microscopy (Nucleus Segmentation)",
181+
"label_free": "Label-Free Microscopy (Cell Segmentation)",
182+
"histopatho": "Histopathology (Nucleus Segmentation)",
183+
}
184+
185+
for domain in ["fluo_cells", "fluo_nuclei", "label_free", "histopatho"]:
186+
datasets = get_datasets(domain)
187+
methods = domain_methods[domain]
188+
n_methods = len(methods)
189+
190+
comparison = np.empty((n_methods, n_methods), dtype="U15")
191+
192+
for i in range(n_methods):
193+
for j in range(n_methods):
194+
if i == j:
195+
comparison[i, j] = "-"
196+
continue
197+
198+
method_row = methods[i]
199+
method_col = methods[j]
200+
method_row_path = method_configs[method_row]
201+
method_col_path = method_configs[method_col]
202+
203+
better_row, better_col, neutral = statistical_analysis_pair(
204+
datasets, method_row_path, method_col_path
205+
)
206+
comparison[i, j] = f"{better_row} / {better_col} / {neutral}"
207+
208+
# Let's use expected display names.
209+
display_method_names = [display_names[m] for m in methods]
210+
comparison = pd.DataFrame(comparison, index=display_method_names, columns=display_method_names)
211+
212+
# Let's visualize the results
213+
_plot_comparison_heatmap(domain, comparison, title=custom_titles[domain])
214+
print(f"Generated heatmap for {domain}: comparison_heatmap_{domain}.png")
215+
216+
217+
def main():
218+
compare_all()
219+
220+
221+
if __name__ == "__main__":
222+
main()

scripts/apg_experiments/submit_evaluation.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,21 @@ def write_batch_script(
99
):
1010
"""Writing scripts to submit multiple evaluations relevant for APG.
1111
"""
12-
if method == "cellpose":
13-
if model_type == "cyto3":
14-
env = "cp3"
15-
elif model_type == "cpsam":
16-
env = "cp4"
17-
else:
18-
raise ValueError
12+
if method == "cellpose" and model_type == "cyto3":
13+
env = "cp3"
1914
else:
2015
env = "super"
2116

2217
batch_script = f"""#!/bin/bash
2318
#SBATCH -t 2-00:00:00
2419
#SBATCH --nodes=1
2520
#SBATCH --ntasks=1
26-
#SBATCH -p grete:shared
27-
#SBATCH -G A100:1
21+
#SBATCH -p grete-h100:shared
22+
#SBATCH -G H100:1
2823
#SBATCH -A nim00007
2924
#SBATCH -c 16
3025
#SBATCH --mem 64G
31-
#SBATCH --constraint=inet,80gb
26+
#SBATCH --constraint=inet
3227
#SBATCH --job-name=apg_evaluation
3328
3429
source ~/.bashrc
@@ -83,33 +78,36 @@ def submit_slurm(args):
8378

8479
method_combinations = [
8580
# SAM-based models
86-
# ["amg", "vit_b"],
87-
# ["amg", generalist_model],
88-
# ["ais", generalist_model],
81+
["amg", "vit_b"],
82+
["ais", generalist_model],
8983
["apg", generalist_model],
9084
# SAM3
91-
# ["sam3", "cells"],
85+
["sam3", "cell"],
9286
# And other external methods.
93-
# ["cellpose", "cyto3"],
94-
# ["cellpose", "cpsam"],
95-
# ["cellsam", "cellsam"],
87+
["cellpose", "cyto3"],
88+
["cellpose", "cpsam"],
89+
["cellsam", "cellsam"],
9690
]
9791

9892
if dataset_name is None:
9993
if generalist_model == "vit_b_lm":
10094
datasets = [
10195
# Label-free
102-
"livecell", "omnipose", "deepbacs", "usiigaci", "vicar", "deepseas", "toiam",
96+
"livecell", "omnipose", "deepbacs", "usiigaci", "vicar",
97+
"deepseas", "toiam", "yeaz", "segpc",
10398
# Fluo (nuclei)
10499
"dynamicnuclearnet", "u20s", "arvidsson", "ifnuclei",
105100
"gonuclear", "nis3d", "parhyale_regen", "dsb", "bitdepth_nucseg",
106101
# Fluo (cells)
107102
"cellpose", "cellbindb", "tissuenet", "plantseg_root", "covid_if",
108-
"hpa", "plantseg_ovules", "pnas_arabidopsis",
103+
"hpa", "plantseg_ovules", "pnas_arabidopsis", "mouse_embryo",
109104
]
110105
else: # Histopatholgoy
111106
assert generalist_model == "vit_b_histopathology"
112-
datasets = ["ihc_tma", "lynsec", "pannuke", "monuseg", "tnbc", "nuinsseg", "puma", "cytodark0"]
107+
datasets = [
108+
"ihc_tma", "lynsec", "pannuke", "monuseg", "tnbc",
109+
"nuinsseg", "puma", "cytodark0", "cryonuseg"
110+
]
113111
else:
114112
datasets = [dataset_name]
115113

0 commit comments

Comments
 (0)