|
11 | 11 | from scipy.stats import ttest_ind |
12 | 12 | from typing import List, Dict, Optional, Union, Tuple |
13 | 13 | from ml_grid.results_processing.core import get_clean_data |
| 14 | +import warnings |
| 15 | + |
| 16 | +# Maximum number of outcomes to display in stratified plots to avoid clutter. |
| 17 | +MAX_OUTCOMES_FOR_STRATIFIED_PLOT = 20 |
14 | 18 |
|
15 | 19 |
|
16 | 20 | class AlgorithmComparisonPlotter: |
@@ -99,6 +103,15 @@ def _plot_stratified_algorithm_boxplots(self, metric: str, algorithms_to_plot: L |
99 | 103 | raise ValueError("outcome_variable column not found for stratification") |
100 | 104 |
|
101 | 105 | outcomes = outcomes_to_plot or sorted(self.clean_data['outcome_variable'].unique()) |
| 106 | + if len(outcomes) > MAX_OUTCOMES_FOR_STRATIFIED_PLOT: |
| 107 | + warnings.warn( |
| 108 | + f"Found {len(outcomes)} outcomes, which is more than the display limit of {MAX_OUTCOMES_FOR_STRATIFIED_PLOT}. " |
| 109 | + f"Displaying the first {MAX_OUTCOMES_FOR_STRATIFIED_PLOT}. " |
| 110 | + "Use the 'outcomes_to_plot' parameter to select specific outcomes.", |
| 111 | + stacklevel=2 |
| 112 | + ) |
| 113 | + outcomes = outcomes[:MAX_OUTCOMES_FOR_STRATIFIED_PLOT] |
| 114 | + |
102 | 115 | n_outcomes = len(outcomes) |
103 | 116 |
|
104 | 117 | # Calculate subplot layout |
@@ -130,8 +143,9 @@ def _plot_stratified_algorithm_boxplots(self, metric: str, algorithms_to_plot: L |
130 | 143 | if len(algo_data) > 0: |
131 | 144 | mean_val = algo_data.mean() |
132 | 145 | ax.scatter(j, mean_val, color='red', s=60, marker='D', zorder=10) |
133 | | - |
134 | | - ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right') |
| 146 | + |
| 147 | + ax.tick_params(axis='x', rotation=45) |
| 148 | + plt.setp(ax.get_xticklabels(), ha='right') |
135 | 149 | ax.set_title(f'{outcome}\n{metric.upper()}', fontsize=11, fontweight='bold') |
136 | 150 | ax.set_xlabel('Algorithm' if i >= len(outcomes) - cols else '') |
137 | 151 | ax.set_ylabel(metric.upper() if i % cols == 0 else '') |
@@ -277,6 +291,15 @@ def _plot_stratified_ranking(self, metric: str, algorithms_to_plot: List[str], |
277 | 291 | outcomes_to_plot: List[str], top_n: int, figsize: Tuple[int, int]): |
278 | 292 | """Plot stratified ranking bar charts by outcome.""" |
279 | 293 | outcomes = outcomes_to_plot or sorted(self.clean_data['outcome_variable'].unique()) |
| 294 | + if len(outcomes) > MAX_OUTCOMES_FOR_STRATIFIED_PLOT: |
| 295 | + warnings.warn( |
| 296 | + f"Found {len(outcomes)} outcomes, which is more than the display limit of {MAX_OUTCOMES_FOR_STRATIFIED_PLOT}. " |
| 297 | + f"Displaying the first {MAX_OUTCOMES_FOR_STRATIFIED_PLOT}. " |
| 298 | + "Use the 'outcomes_to_plot' parameter to select specific outcomes.", |
| 299 | + stacklevel=2 |
| 300 | + ) |
| 301 | + outcomes = outcomes[:MAX_OUTCOMES_FOR_STRATIFIED_PLOT] |
| 302 | + |
280 | 303 | n_outcomes = len(outcomes) |
281 | 304 |
|
282 | 305 | cols = min(2, n_outcomes) |
|
0 commit comments