1+ # plot_best_model.py
2+ """
3+ Module for analyzing and visualizing the single best performing model for each outcome.
4+ """
5+
6+ import pandas as pd
7+ import matplotlib .pyplot as plt
8+ import seaborn as sns
9+ from typing import List , Optional , Tuple , Dict , Any
10+ import warnings
11+ import textwrap
12+ import ast
13+
14+ from .core import get_clean_data
15+ from .plot_hyperparameters import HyperparameterAnalysisPlotter # To reuse parsing logic
16+
17+ # Limit on how many outcomes to plot automatically to avoid generating too many figures.
18+ MAX_OUTCOMES_TO_PLOT = 10
19+
20+ class BestModelAnalyzerPlotter :
21+ """
22+ Analyzes and plots the characteristics of the best performing model for each outcome variable.
23+ """
24+
25+ def __init__ (self , data : pd .DataFrame ):
26+ """
27+ Initialize the plotter.
28+
29+ Args:
30+ data: Aggregated results DataFrame. Must contain 'outcome_variable'.
31+ """
32+ if 'outcome_variable' not in data .columns :
33+ raise ValueError ("Data must contain an 'outcome_variable' column for this analysis." )
34+
35+ self .data = data
36+ self .clean_data = get_clean_data (data )
37+
38+ # Define feature categories and pipeline parameters from other modules for consistency
39+ self .feature_categories = [
40+ 'age' , 'sex' , 'bmi' , 'ethnicity' , 'bloods' , 'diagnostic_order' , 'drug_order' ,
41+ 'annotation_n' , 'meta_sp_annotation_n' , 'meta_sp_annotation_mrc_n' ,
42+ 'annotation_mrc_n' , 'core_02' , 'bed' , 'vte_status' , 'hosp_site' ,
43+ 'core_resus' , 'news' , 'date_time_stamp'
44+ ]
45+ self .pipeline_params = ['resample' , 'scale' , 'param_space_size' , 'percent_missing' ]
46+
47+ plt .style .use ('default' )
48+ sns .set_palette ("muted" )
49+
50+ def _get_best_models (self , metric : str ) -> pd .DataFrame :
51+ """
52+ Finds the single best model for each outcome variable based on a given metric.
53+ """
54+ if metric not in self .clean_data .columns :
55+ raise ValueError (f"Metric '{ metric } ' not found in the data." )
56+
57+ # Find the index of the maximum metric value for each outcome group
58+ best_indices = self .clean_data .loc [self .clean_data .groupby ('outcome_variable' )[metric ].idxmax ()]
59+
60+ return best_indices .sort_values (by = metric , ascending = False )
61+
62+ def plot_best_model_summary (self ,
63+ metric : str = 'auc' ,
64+ outcomes_to_plot : Optional [List [str ]] = None ,
65+ figsize : Tuple [int , int ] = (14 , 9 )):
66+ """
67+ For each outcome, finds the best performing model and generates a summary plot
68+ detailing its characteristics, including algorithm, hyperparameters, and pipeline settings.
69+
70+ Args:
71+ metric: The performance metric to use for determining the "best" model.
72+ outcomes_to_plot: An optional list of specific outcome variables to analyze.
73+ If None, analyzes all outcomes up to a certain limit.
74+ figsize: The figure size for each summary plot.
75+ """
76+ best_models_df = self ._get_best_models (metric )
77+
78+ if outcomes_to_plot :
79+ # Filter to only the requested outcomes
80+ best_models_df = best_models_df [best_models_df ['outcome_variable' ].isin (outcomes_to_plot )]
81+ if best_models_df .empty :
82+ print (f"Warning: No data found for the specified outcomes: { outcomes_to_plot } " )
83+ return
84+ elif len (best_models_df ) > MAX_OUTCOMES_TO_PLOT :
85+ warnings .warn (
86+ f"Found { len (best_models_df )} unique outcomes. To avoid excessive plotting, "
87+ f"showing summaries for the top { MAX_OUTCOMES_TO_PLOT } outcomes based on the '{ metric } ' metric. "
88+ "Use the 'outcomes_to_plot' parameter to specify which outcomes to analyze." ,
89+ stacklevel = 2
90+ )
91+ best_models_df = best_models_df .head (MAX_OUTCOMES_TO_PLOT )
92+
93+ print (f"--- Generating Best Model Summaries (Metric: { metric .upper ()} ) ---" )
94+ for _ , model_series in best_models_df .iterrows ():
95+ self ._plot_single_model_summary (model_series , metric , figsize )
96+
97+ def _plot_single_model_summary (self , model_series : pd .Series , metric : str , figsize : Tuple [int , int ]):
98+ """
99+ Generates a single 2x2 summary plot for one model.
100+ """
101+ fig , axes = plt .subplots (2 , 2 , figsize = figsize )
102+ fig .suptitle (f"Best Model Analysis for: { model_series ['outcome_variable' ]} " , fontsize = 16 , fontweight = 'bold' )
103+
104+ # Subplot 1: Key Information (Text)
105+ self ._plot_key_info (axes [0 , 0 ], model_series , metric )
106+
107+ # Subplot 2: Hyperparameters
108+ self ._plot_hyperparameters (axes [0 , 1 ], model_series )
109+
110+ # Subplot 3: Feature Categories Used
111+ self ._plot_feature_categories (axes [1 , 0 ], model_series )
112+
113+ # Subplot 4: Pipeline Parameters
114+ self ._plot_pipeline_parameters (axes [1 , 1 ], model_series )
115+
116+ plt .tight_layout (rect = [0 , 0.03 , 1 , 0.95 ])
117+ plt .show ()
118+
119+ def _plot_key_info (self , ax : plt .Axes , model_series : pd .Series , metric : str ):
120+ ax .set_title ("Model & Performance Summary" , fontsize = 12 , fontweight = 'bold' )
121+ ax .axis ('off' )
122+
123+ score = model_series .get (metric , 'N/A' )
124+ score_str = f"{ score :.4f} " if isinstance (score , (int , float )) else str (score )
125+
126+ info_text = (
127+ f"Algorithm: { model_series .get ('method_name' , 'N/A' )} \n "
128+ f"Best Score ({ metric .upper ()} ): { score_str } \n "
129+ f"Number of Features: { model_series .get ('nb_size' , 'N/A' )} \n "
130+ f"Run Timestamp: { model_series .get ('run_timestamp' , 'N/A' )} \n \n "
131+ f"Other Metrics:\n "
132+ f" - F1: { model_series .get ('f1' , 'N/A' ):.4f} \n "
133+ f" - MCC: { model_series .get ('mcc' , 'N/A' ):.4f} \n "
134+ f" - Accuracy: { model_series .get ('accuracy' , 'N/A' ):.4f} \n "
135+ )
136+
137+ ax .text (0.05 , 0.95 , info_text , transform = ax .transAxes , ha = 'left' , va = 'top' , fontsize = 11 ,
138+ bbox = dict (boxstyle = 'round,pad=0.5' , fc = 'aliceblue' , ec = 'grey' , lw = 1 ))
139+
140+ def _plot_hyperparameters (self , ax : plt .Axes , model_series : pd .Series ):
141+ ax .set_title ("Hyperparameters" , fontsize = 12 , fontweight = 'bold' )
142+ ax .axis ('off' )
143+
144+ params = {}
145+ if 'algorithm_implementation' in model_series and pd .notna (model_series ['algorithm_implementation' ]):
146+ # Reuse parsing logic from HyperparameterAnalysisPlotter
147+ params = HyperparameterAnalysisPlotter ._parse_model_string_to_params (model_series ['algorithm_implementation' ])
148+
149+ if not params :
150+ ax .text (0.5 , 0.5 , "Hyperparameters not available\n or could not be parsed." ,
151+ transform = ax .transAxes , ha = 'center' , va = 'center' , fontsize = 10 )
152+ return
153+
154+ param_str = ""
155+ for key , val in params .items ():
156+ val_str = str (val )
157+ if len (val_str ) > 40 :
158+ val_str = textwrap .fill (val_str , width = 40 , subsequent_indent = ' ' )
159+ param_str += f"{ key } : { val_str } \n "
160+
161+ ax .text (0.05 , 0.95 , param_str .strip (), transform = ax .transAxes , ha = 'left' , va = 'top' ,
162+ fontsize = 9 , family = 'monospace' ,
163+ bbox = dict (boxstyle = 'round,pad=0.5' , fc = 'lightyellow' , ec = 'grey' , lw = 1 ))
164+
165+ def _plot_feature_categories (self , ax : plt .Axes , model_series : pd .Series ):
166+ ax .set_title ("Feature Categories Used" , fontsize = 12 , fontweight = 'bold' )
167+
168+ used_categories = {}
169+ for cat in self .feature_categories :
170+ if cat in model_series and pd .notna (model_series [cat ]):
171+ val = model_series [cat ]
172+ try :
173+ is_used = ast .literal_eval (str (val ).capitalize ()) if isinstance (val , str ) else bool (val )
174+ except (ValueError , SyntaxError ):
175+ is_used = False
176+
177+ if is_used :
178+ used_categories [cat .replace ("_" , " " ).title ()] = 1
179+
180+ if not used_categories :
181+ ax .text (0.5 , 0.5 , "No feature category information available." ,
182+ transform = ax .transAxes , ha = 'center' , va = 'center' , fontsize = 10 )
183+ ax .set_xticks ([])
184+ ax .set_yticks ([])
185+ return
186+
187+ cat_df = pd .DataFrame .from_dict (used_categories , orient = 'index' , columns = ['Used' ]).sort_index ()
188+
189+ sns .barplot (x = cat_df .index , y = cat_df ['Used' ], ax = ax , palette = 'viridis' , hue = cat_df .index , legend = False )
190+ ax .set_xticklabels (ax .get_xticklabels (), rotation = 45 , ha = 'right' )
191+ ax .set_xlabel ("" )
192+ ax .set_ylabel ("Enabled" )
193+ ax .set_yticks ([0 , 1 ])
194+ ax .set_yticklabels (['' , 'Yes' ])
195+ ax .grid (axis = 'y' , linestyle = '--' , alpha = 0.7 )
196+
197+ def _plot_pipeline_parameters (self , ax : plt .Axes , model_series : pd .Series ):
198+ ax .set_title ("Pipeline Settings" , fontsize = 12 , fontweight = 'bold' )
199+ ax .axis ('off' )
200+
201+ pipeline_settings = {}
202+ for param in self .pipeline_params :
203+ if param in model_series and pd .notna (model_series [param ]):
204+ pipeline_settings [param .replace ("_" , " " ).title ()] = model_series [param ]
205+
206+ if not pipeline_settings :
207+ ax .text (0.5 , 0.5 , "No pipeline setting information available." ,
208+ transform = ax .transAxes , ha = 'center' , va = 'center' , fontsize = 10 )
209+ return
210+
211+ settings_str = ""
212+ for key , val in pipeline_settings .items ():
213+ settings_str += f"{ key } : { val } \n "
214+
215+ ax .text (0.05 , 0.95 , settings_str .strip (), transform = ax .transAxes , ha = 'left' , va = 'top' ,
216+ fontsize = 11 , bbox = dict (boxstyle = 'round,pad=0.5' , fc = 'honeydew' , ec = 'grey' , lw = 1 ))
0 commit comments