55custom scoring function for ROC AUC that handles cases with a single class.
66"""
77
8+ import os
9+ import sys
810from typing import Any , Callable , Dict , List , Union
911import logging
1012
13+
14+ # --- FIX for TensorFlow "libdevice not found" error ---
15+ # This error occurs when TensorFlow on GPU cannot locate the CUDA device libraries,
16+ # a common issue with pip-installed CUDA in virtual environments. This code
17+ # finds the CUDA path within the environment and sets the XLA_FLAGS environment
18+ # variable to point to it. This must be done before TensorFlow is initialized.
19+ def _configure_tensorflow_gpu_env ():
20+ """Finds pip-installed CUDA libraries and sets XLA_FLAGS if not already set."""
21+ logger = logging .getLogger ("ml_grid" )
22+ try :
23+ # Check if the flag is already correctly set
24+ if (
25+ "XLA_FLAGS" in os .environ
26+ and "--xla_gpu_cuda_data_dir" in os .environ ["XLA_FLAGS" ]
27+ ):
28+ return
29+
30+ # Find the site-packages directory of the current environment
31+ for path in sys .path :
32+ if "site-packages" in path :
33+ cuda_dir = os .path .join (path , "nvidia" , "cuda_nvcc" )
34+ # Check for the specific file XLA needs to ensure the path is valid
35+ libdevice_path = os .path .join (
36+ cuda_dir , "nvvm" , "libdevice" , "libdevice.10.bc"
37+ )
38+ if os .path .isdir (cuda_dir ) and os .path .exists (libdevice_path ):
39+ logger .info (
40+ f"Found CUDA libraries at { cuda_dir } . Configuring XLA_FLAGS."
41+ )
42+ xla_flags = os .environ .get ("XLA_FLAGS" , "" )
43+ cuda_data_dir_flag = f"--xla_gpu_cuda_data_dir={ cuda_dir } "
44+ os .environ ["XLA_FLAGS" ] = (
45+ f"{ xla_flags } { cuda_data_dir_flag } " .strip ()
46+ )
47+ return # Exit after finding the first valid path
48+
49+ logger .warning (
50+ "libdevice.10.bc not found. TensorFlow XLA may fail. Ensure nvidia-cuda-nvcc-cu12 is installed."
51+ )
52+ except Exception as e :
53+ logger .error (
54+ f"Error configuring TensorFlow GPU environment: { e } " , exc_info = True
55+ )
56+
57+
58+ _configure_tensorflow_gpu_env ()
59+ # --- END FIX ---
60+
1161import numpy as np
12- from sklearn .metrics import make_scorer , roc_auc_score
62+ from sklearn .metrics import (
63+ accuracy_score ,
64+ f1_score ,
65+ make_scorer ,
66+ recall_score ,
67+ roc_auc_score ,
68+ )
1369
1470
1571def custom_roc_auc_score (y_true : np .ndarray , y_pred : np .ndarray ) -> float :
@@ -25,6 +81,14 @@ def custom_roc_auc_score(y_true: np.ndarray, y_pred: np.ndarray) -> float:
2581 Returns:
2682 float: The ROC AUC score, or np.nan if the score is undefined.
2783 """
84+ # If y_pred is None (which can happen if a model fails to predict), AUC is undefined.
85+ if y_pred is None :
86+ # For debugging: raise an error to get a stack trace pointing to the faulty model.
87+ # This will stop the execution but pinpoint the source of the None predictions.
88+ raise ValueError (
89+ "y_pred is None in custom_roc_auc_score. A model's predict() method failed."
90+ )
91+
2892 # Optimization: Check min/max instead of full unique sort (O(N) vs O(N log N))
2993 # If min == max, there is only one unique value (or array is empty/NaNs which implies undefined AUC)
3094 # Also handle Categorical data which may not support min/max if unordered
@@ -42,6 +106,36 @@ def custom_roc_auc_score(y_true: np.ndarray, y_pred: np.ndarray) -> float:
42106 return roc_auc_score (y_true , y_pred )
43107
44108
109+ def custom_f1_score (y_true : np .ndarray , y_pred : np .ndarray , ** kwargs ) -> float :
110+ """Calculates F1 score, handling cases where y_pred is None."""
111+ if y_pred is None :
112+ # For debugging: raise an error to get a stack trace.
113+ raise ValueError (
114+ "y_pred is None in custom_f1_score. A model's predict() method failed."
115+ )
116+ return f1_score (y_true , y_pred , ** kwargs )
117+
118+
119+ def custom_accuracy_score (y_true : np .ndarray , y_pred : np .ndarray , ** kwargs ) -> float :
120+ """Calculates accuracy score, handling cases where y_pred is None."""
121+ if y_pred is None :
122+ # For debugging: raise an error to get a stack trace.
123+ raise ValueError (
124+ "y_pred is None in custom_accuracy_score. A model's predict() method failed."
125+ )
126+ return accuracy_score (y_true , y_pred , ** kwargs )
127+
128+
129+ def custom_recall_score (y_true : np .ndarray , y_pred : np .ndarray , ** kwargs ) -> float :
130+ """Calculates recall score, handling cases where y_pred is None."""
131+ if y_pred is None :
132+ # For debugging: raise an error to get a stack trace.
133+ raise ValueError (
134+ "y_pred is None in custom_recall_score. A model's predict() method failed."
135+ )
136+ return recall_score (y_true , y_pred , ** kwargs )
137+
138+
45139class GlobalParameters :
46140 """A singleton class to manage global configuration parameters for ml_grid.
47141
@@ -109,6 +203,8 @@ class GlobalParameters:
109203 """If True, forces a second cross-validation run even if cached results are available. Defaults to False."""
110204 model_eval_time_limit : int
111205 """The time limit in seconds for a single model evaluation. Defaults to None (no limit)."""
206+ test_mode : bool
207+ """If True, uses minimal parameter spaces and reduced cross-validation for fast testing. Defaults to False."""
112208
113209 def __new__ (cls , * args : Any , ** kwargs : Any ) -> "GlobalParameters" :
114210 """Creates a new instance if one does not already exist (Singleton pattern)."""
@@ -155,13 +251,18 @@ def __init__(self, debug_level: int = 0, knn_n_jobs: int = -1) -> None:
155251 self .search_verbose = 0
156252 self .force_second_cv = False
157253 self .model_eval_time_limit = None
254+ self .test_mode = False
255+
256+ custom_auc_scorer = make_scorer (custom_roc_auc_score )
257+ custom_f1_scorer = make_scorer (custom_f1_score )
258+ custom_accuracy_scorer = make_scorer (custom_accuracy_score )
259+ custom_recall_scorer = make_scorer (custom_recall_score )
158260
159- custom_scorer = make_scorer (custom_roc_auc_score )
160261 self .metric_list = {
161- "auc" : custom_scorer ,
162- "f1" : "f1" ,
163- "accuracy" : "accuracy" ,
164- "recall" : "recall" ,
262+ "auc" : custom_auc_scorer ,
263+ "f1" : custom_f1_scorer ,
264+ "accuracy" : custom_accuracy_scorer ,
265+ "recall" : custom_recall_scorer ,
165266 }
166267
167268 def update_parameters (self , ** kwargs : Any ) -> None :
0 commit comments