Skip to content

Commit 50a6702

Browse files
committed
added best model plot and integrated to master. added function to produce summary dataframe of best result per outcome variable
1 parent ecc889f commit 50a6702

4 files changed

Lines changed: 371 additions & 13 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,3 +192,4 @@ doc/_build/
192192
.vscode/
193193
>>>>>>> 68b72edb1ca57c5b15fc038fd0f5b650fd782b00
194194
packages.txt
195+
notebooks/best_models_summary.csv
Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
# plot_best_model.py
2+
"""
3+
Module for analyzing and visualizing the single best performing model for each outcome.
4+
"""
5+
6+
import pandas as pd
7+
import matplotlib.pyplot as plt
8+
import seaborn as sns
9+
from typing import List, Optional, Tuple, Dict, Any
10+
import warnings
11+
import textwrap
12+
import ast
13+
14+
from .core import get_clean_data
15+
from .plot_hyperparameters import HyperparameterAnalysisPlotter # To reuse parsing logic
16+
17+
# Limit on how many outcomes to plot automatically to avoid generating too many figures.
18+
MAX_OUTCOMES_TO_PLOT = 10
19+
20+
class BestModelAnalyzerPlotter:
21+
"""
22+
Analyzes and plots the characteristics of the best performing model for each outcome variable.
23+
"""
24+
25+
def __init__(self, data: pd.DataFrame):
26+
"""
27+
Initialize the plotter.
28+
29+
Args:
30+
data: Aggregated results DataFrame. Must contain 'outcome_variable'.
31+
"""
32+
if 'outcome_variable' not in data.columns:
33+
raise ValueError("Data must contain an 'outcome_variable' column for this analysis.")
34+
35+
self.data = data
36+
self.clean_data = get_clean_data(data)
37+
38+
# Define feature categories and pipeline parameters from other modules for consistency
39+
self.feature_categories = [
40+
'age', 'sex', 'bmi', 'ethnicity', 'bloods', 'diagnostic_order', 'drug_order',
41+
'annotation_n', 'meta_sp_annotation_n', 'meta_sp_annotation_mrc_n',
42+
'annotation_mrc_n', 'core_02', 'bed', 'vte_status', 'hosp_site',
43+
'core_resus', 'news', 'date_time_stamp'
44+
]
45+
self.pipeline_params = ['resample', 'scale', 'param_space_size', 'percent_missing']
46+
47+
plt.style.use('default')
48+
sns.set_palette("muted")
49+
50+
def _get_best_models(self, metric: str) -> pd.DataFrame:
51+
"""
52+
Finds the single best model for each outcome variable based on a given metric.
53+
"""
54+
if metric not in self.clean_data.columns:
55+
raise ValueError(f"Metric '{metric}' not found in the data.")
56+
57+
# Find the index of the maximum metric value for each outcome group
58+
best_indices = self.clean_data.loc[self.clean_data.groupby('outcome_variable')[metric].idxmax()]
59+
60+
return best_indices.sort_values(by=metric, ascending=False)
61+
62+
def plot_best_model_summary(self,
63+
metric: str = 'auc',
64+
outcomes_to_plot: Optional[List[str]] = None,
65+
figsize: Tuple[int, int] = (14, 9)):
66+
"""
67+
For each outcome, finds the best performing model and generates a summary plot
68+
detailing its characteristics, including algorithm, hyperparameters, and pipeline settings.
69+
70+
Args:
71+
metric: The performance metric to use for determining the "best" model.
72+
outcomes_to_plot: An optional list of specific outcome variables to analyze.
73+
If None, analyzes all outcomes up to a certain limit.
74+
figsize: The figure size for each summary plot.
75+
"""
76+
best_models_df = self._get_best_models(metric)
77+
78+
if outcomes_to_plot:
79+
# Filter to only the requested outcomes
80+
best_models_df = best_models_df[best_models_df['outcome_variable'].isin(outcomes_to_plot)]
81+
if best_models_df.empty:
82+
print(f"Warning: No data found for the specified outcomes: {outcomes_to_plot}")
83+
return
84+
elif len(best_models_df) > MAX_OUTCOMES_TO_PLOT:
85+
warnings.warn(
86+
f"Found {len(best_models_df)} unique outcomes. To avoid excessive plotting, "
87+
f"showing summaries for the top {MAX_OUTCOMES_TO_PLOT} outcomes based on the '{metric}' metric. "
88+
"Use the 'outcomes_to_plot' parameter to specify which outcomes to analyze.",
89+
stacklevel=2
90+
)
91+
best_models_df = best_models_df.head(MAX_OUTCOMES_TO_PLOT)
92+
93+
print(f"--- Generating Best Model Summaries (Metric: {metric.upper()}) ---")
94+
for _, model_series in best_models_df.iterrows():
95+
self._plot_single_model_summary(model_series, metric, figsize)
96+
97+
def _plot_single_model_summary(self, model_series: pd.Series, metric: str, figsize: Tuple[int, int]):
98+
"""
99+
Generates a single 2x2 summary plot for one model.
100+
"""
101+
fig, axes = plt.subplots(2, 2, figsize=figsize)
102+
fig.suptitle(f"Best Model Analysis for: {model_series['outcome_variable']}", fontsize=16, fontweight='bold')
103+
104+
# Subplot 1: Key Information (Text)
105+
self._plot_key_info(axes[0, 0], model_series, metric)
106+
107+
# Subplot 2: Hyperparameters
108+
self._plot_hyperparameters(axes[0, 1], model_series)
109+
110+
# Subplot 3: Feature Categories Used
111+
self._plot_feature_categories(axes[1, 0], model_series)
112+
113+
# Subplot 4: Pipeline Parameters
114+
self._plot_pipeline_parameters(axes[1, 1], model_series)
115+
116+
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
117+
plt.show()
118+
119+
def _plot_key_info(self, ax: plt.Axes, model_series: pd.Series, metric: str):
120+
ax.set_title("Model & Performance Summary", fontsize=12, fontweight='bold')
121+
ax.axis('off')
122+
123+
score = model_series.get(metric, 'N/A')
124+
score_str = f"{score:.4f}" if isinstance(score, (int, float)) else str(score)
125+
126+
info_text = (
127+
f"Algorithm: {model_series.get('method_name', 'N/A')}\n"
128+
f"Best Score ({metric.upper()}): {score_str}\n"
129+
f"Number of Features: {model_series.get('nb_size', 'N/A')}\n"
130+
f"Run Timestamp: {model_series.get('run_timestamp', 'N/A')}\n\n"
131+
f"Other Metrics:\n"
132+
f" - F1: {model_series.get('f1', 'N/A'):.4f}\n"
133+
f" - MCC: {model_series.get('mcc', 'N/A'):.4f}\n"
134+
f" - Accuracy: {model_series.get('accuracy', 'N/A'):.4f}\n"
135+
)
136+
137+
ax.text(0.05, 0.95, info_text, transform=ax.transAxes, ha='left', va='top', fontsize=11,
138+
bbox=dict(boxstyle='round,pad=0.5', fc='aliceblue', ec='grey', lw=1))
139+
140+
def _plot_hyperparameters(self, ax: plt.Axes, model_series: pd.Series):
141+
ax.set_title("Hyperparameters", fontsize=12, fontweight='bold')
142+
ax.axis('off')
143+
144+
params = {}
145+
if 'algorithm_implementation' in model_series and pd.notna(model_series['algorithm_implementation']):
146+
# Reuse parsing logic from HyperparameterAnalysisPlotter
147+
params = HyperparameterAnalysisPlotter._parse_model_string_to_params(model_series['algorithm_implementation'])
148+
149+
if not params:
150+
ax.text(0.5, 0.5, "Hyperparameters not available\nor could not be parsed.",
151+
transform=ax.transAxes, ha='center', va='center', fontsize=10)
152+
return
153+
154+
param_str = ""
155+
for key, val in params.items():
156+
val_str = str(val)
157+
if len(val_str) > 40:
158+
val_str = textwrap.fill(val_str, width=40, subsequent_indent=' ')
159+
param_str += f"{key}: {val_str}\n"
160+
161+
ax.text(0.05, 0.95, param_str.strip(), transform=ax.transAxes, ha='left', va='top',
162+
fontsize=9, family='monospace',
163+
bbox=dict(boxstyle='round,pad=0.5', fc='lightyellow', ec='grey', lw=1))
164+
165+
def _plot_feature_categories(self, ax: plt.Axes, model_series: pd.Series):
166+
ax.set_title("Feature Categories Used", fontsize=12, fontweight='bold')
167+
168+
used_categories = {}
169+
for cat in self.feature_categories:
170+
if cat in model_series and pd.notna(model_series[cat]):
171+
val = model_series[cat]
172+
try:
173+
is_used = ast.literal_eval(str(val).capitalize()) if isinstance(val, str) else bool(val)
174+
except (ValueError, SyntaxError):
175+
is_used = False
176+
177+
if is_used:
178+
used_categories[cat.replace("_", " ").title()] = 1
179+
180+
if not used_categories:
181+
ax.text(0.5, 0.5, "No feature category information available.",
182+
transform=ax.transAxes, ha='center', va='center', fontsize=10)
183+
ax.set_xticks([])
184+
ax.set_yticks([])
185+
return
186+
187+
cat_df = pd.DataFrame.from_dict(used_categories, orient='index', columns=['Used']).sort_index()
188+
189+
sns.barplot(x=cat_df.index, y=cat_df['Used'], ax=ax, palette='viridis', hue=cat_df.index, legend=False)
190+
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
191+
ax.set_xlabel("")
192+
ax.set_ylabel("Enabled")
193+
ax.set_yticks([0, 1])
194+
ax.set_yticklabels(['', 'Yes'])
195+
ax.grid(axis='y', linestyle='--', alpha=0.7)
196+
197+
def _plot_pipeline_parameters(self, ax: plt.Axes, model_series: pd.Series):
198+
ax.set_title("Pipeline Settings", fontsize=12, fontweight='bold')
199+
ax.axis('off')
200+
201+
pipeline_settings = {}
202+
for param in self.pipeline_params:
203+
if param in model_series and pd.notna(model_series[param]):
204+
pipeline_settings[param.replace("_", " ").title()] = model_series[param]
205+
206+
if not pipeline_settings:
207+
ax.text(0.5, 0.5, "No pipeline setting information available.",
208+
transform=ax.transAxes, ha='center', va='center', fontsize=10)
209+
return
210+
211+
settings_str = ""
212+
for key, val in pipeline_settings.items():
213+
settings_str += f"{key}: {val}\n"
214+
215+
ax.text(0.05, 0.95, settings_str.strip(), transform=ax.transAxes, ha='left', va='top',
216+
fontsize=11, bbox=dict(boxstyle='round,pad=0.5', fc='honeydew', ec='grey', lw=1))

0 commit comments

Comments
 (0)