Skip to content

Commit ecc889f

Browse files
committed
performance modifications
1 parent c2cf581 commit ecc889f

2 files changed

Lines changed: 29 additions & 6 deletions

File tree

ml_grid/results_processing/plot_algorithms.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
from scipy.stats import ttest_ind
1212
from typing import List, Dict, Optional, Union, Tuple
1313
from ml_grid.results_processing.core import get_clean_data
14+
import warnings
15+
16+
# Maximum number of outcomes to display in stratified plots to avoid clutter.
17+
MAX_OUTCOMES_FOR_STRATIFIED_PLOT = 20
1418

1519

1620
class AlgorithmComparisonPlotter:
@@ -99,6 +103,15 @@ def _plot_stratified_algorithm_boxplots(self, metric: str, algorithms_to_plot: L
99103
raise ValueError("outcome_variable column not found for stratification")
100104

101105
outcomes = outcomes_to_plot or sorted(self.clean_data['outcome_variable'].unique())
106+
if len(outcomes) > MAX_OUTCOMES_FOR_STRATIFIED_PLOT:
107+
warnings.warn(
108+
f"Found {len(outcomes)} outcomes, which is more than the display limit of {MAX_OUTCOMES_FOR_STRATIFIED_PLOT}. "
109+
f"Displaying the first {MAX_OUTCOMES_FOR_STRATIFIED_PLOT}. "
110+
"Use the 'outcomes_to_plot' parameter to select specific outcomes.",
111+
stacklevel=2
112+
)
113+
outcomes = outcomes[:MAX_OUTCOMES_FOR_STRATIFIED_PLOT]
114+
102115
n_outcomes = len(outcomes)
103116

104117
# Calculate subplot layout
@@ -130,8 +143,9 @@ def _plot_stratified_algorithm_boxplots(self, metric: str, algorithms_to_plot: L
130143
if len(algo_data) > 0:
131144
mean_val = algo_data.mean()
132145
ax.scatter(j, mean_val, color='red', s=60, marker='D', zorder=10)
133-
134-
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
146+
147+
ax.tick_params(axis='x', rotation=45)
148+
plt.setp(ax.get_xticklabels(), ha='right')
135149
ax.set_title(f'{outcome}\n{metric.upper()}', fontsize=11, fontweight='bold')
136150
ax.set_xlabel('Algorithm' if i >= len(outcomes) - cols else '')
137151
ax.set_ylabel(metric.upper() if i % cols == 0 else '')
@@ -277,6 +291,15 @@ def _plot_stratified_ranking(self, metric: str, algorithms_to_plot: List[str],
277291
outcomes_to_plot: List[str], top_n: int, figsize: Tuple[int, int]):
278292
"""Plot stratified ranking bar charts by outcome."""
279293
outcomes = outcomes_to_plot or sorted(self.clean_data['outcome_variable'].unique())
294+
if len(outcomes) > MAX_OUTCOMES_FOR_STRATIFIED_PLOT:
295+
warnings.warn(
296+
f"Found {len(outcomes)} outcomes, which is more than the display limit of {MAX_OUTCOMES_FOR_STRATIFIED_PLOT}. "
297+
f"Displaying the first {MAX_OUTCOMES_FOR_STRATIFIED_PLOT}. "
298+
"Use the 'outcomes_to_plot' parameter to select specific outcomes.",
299+
stacklevel=2
300+
)
301+
outcomes = outcomes[:MAX_OUTCOMES_FOR_STRATIFIED_PLOT]
302+
280303
n_outcomes = len(outcomes)
281304

282305
cols = min(2, n_outcomes)

ml_grid/results_processing/plot_distributions.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from ml_grid.results_processing.core import get_clean_data, stratify_by_outcome
1414

1515
# Maximum number of outcomes to display in stratified plots to avoid clutter.
16-
MAX_OUTCOMES_FOR_STRATIFIED_PLOT = 10
16+
MAX_OUTCOMES_FOR_STRATIFIED_PLOT = 20
1717
MAX_OUTCOMES_FOR_HEATMAP = 25
1818

1919
class DistributionPlotter:
@@ -112,7 +112,7 @@ def _plot_stratified_distributions(self, metrics: List[str], figsize: Tuple[int,
112112
raise ValueError("outcome_variable column not found for stratification")
113113

114114
outcomes = outcomes_to_plot or sorted(self.clean_data['outcome_variable'].unique())
115-
if len(outcomes) > MAX_OUTCOMES_FOR_STRATIFIED_PLOT and outcomes_to_plot is None:
115+
if len(outcomes) > MAX_OUTCOMES_FOR_STRATIFIED_PLOT:
116116
warnings.warn(
117117
f"Found {len(outcomes)} outcomes, which is more than the display limit of {MAX_OUTCOMES_FOR_STRATIFIED_PLOT}. "
118118
f"Displaying the first {MAX_OUTCOMES_FOR_STRATIFIED_PLOT}. "
@@ -194,7 +194,7 @@ def plot_comparative_distributions(self, metric: str = 'auc',
194194

195195
outcomes = outcomes_to_compare or sorted(self.clean_data['outcome_variable'].unique())
196196

197-
if len(outcomes) > MAX_OUTCOMES_FOR_STRATIFIED_PLOT and outcomes_to_compare is None:
197+
if len(outcomes) > MAX_OUTCOMES_FOR_STRATIFIED_PLOT:
198198
warnings.warn(
199199
f"Found {len(outcomes)} outcomes, which is more than the display limit of {MAX_OUTCOMES_FOR_STRATIFIED_PLOT}. "
200200
f"Displaying the first {MAX_OUTCOMES_FOR_STRATIFIED_PLOT}. "
@@ -453,7 +453,7 @@ def plot_metric_correlation_by_outcome(data: pd.DataFrame,
453453
available_metrics = [col for col in metrics if col in clean_data.columns]
454454

455455
outcomes = outcomes_to_plot or sorted(clean_data['outcome_variable'].unique())
456-
if len(outcomes) > MAX_OUTCOMES_FOR_STRATIFIED_PLOT and outcomes_to_plot is None:
456+
if len(outcomes) > MAX_OUTCOMES_FOR_STRATIFIED_PLOT:
457457
warnings.warn(
458458
f"Found {len(outcomes)} outcomes, which is more than the display limit of {MAX_OUTCOMES_FOR_STRATIFIED_PLOT}. "
459459
f"Displaying the first {MAX_OUTCOMES_FOR_STRATIFIED_PLOT}. "

0 commit comments

Comments
 (0)