Skip to content

Commit 0b53b50

Browse files
author
SamoraHunter
committed
Optimize metric calculation and H2O prediction robustness
- `ml_grid/util/project_score_save.py`: Explicitly convert prediction targets to numpy arrays before calculating metrics (MCC, F1, etc.). This bypasses expensive pandas overhead in scikit-learn's `np.unique` checks. - `ml_grid/model_classes/H2OBaseClassifier.py`: - Implement "lazy loading" for H2O models in `predict` and `predict_proba` to reduce redundant API calls to the H2O cluster. - Add retry logic and fallback handling for `java.lang.NullPointerException` crashes in the H2O backend during prediction.
1 parent f89b312 commit 0b53b50

2 files changed

Lines changed: 63 additions & 36 deletions

File tree

ml_grid/model_classes/H2OBaseClassifier.py

Lines changed: 45 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -559,8 +559,9 @@ def predict(self, X: pd.DataFrame) -> np.ndarray:
559559
# Ensure H2O is running
560560
self._ensure_h2o_is_running()
561561

562-
# Ensure the model is loaded (critical for cross-validation)
563-
self._ensure_model_is_loaded()
562+
# OPTIMIZATION: Lazy load model. Only check if we don't have the object.
563+
if self.model_ is None:
564+
self._ensure_model_is_loaded()
564565

565566
try:
566567
# --- ROBUSTNESS FIX for java.lang.NullPointerException ---
@@ -592,20 +593,27 @@ def predict(self, X: pd.DataFrame) -> np.ndarray:
592593
try:
593594
predictions = self.model_.predict(test_h2o)
594595
except Exception as e:
595-
# --- FIX: Catch H2O backend crashes (NPE) during prediction and fallback ---
596-
if "java.lang.NullPointerException" in str(e):
597-
self.logger.warning(
598-
f"H2O backend crashed with NPE during predict(). Returning dummy predictions. Details: {e}"
599-
)
600-
# Fallback: predict the first class (usually 0)
601-
dummy_val = (
602-
self.classes_[0]
603-
if self.classes_ is not None and len(self.classes_) > 0
604-
else 0
605-
)
606-
return np.full(len(X), dummy_val)
596+
# If prediction failed, it might be because the model was unloaded/GC'd on server.
597+
# Try reloading and predicting again.
598+
self.logger.debug(f"Prediction failed ({e}), attempting to reload model...")
599+
try:
600+
self._ensure_model_is_loaded()
601+
predictions = self.model_.predict(test_h2o)
602+
except Exception as e2:
603+
# --- FIX: Catch H2O backend crashes (NPE) during prediction and fallback ---
604+
if "java.lang.NullPointerException" in str(e):
605+
self.logger.warning(
606+
f"H2O backend crashed with NPE during predict(). Returning dummy predictions. Details: {e}"
607+
)
608+
# Fallback: predict the first class (usually 0)
609+
dummy_val = (
610+
self.classes_[0]
611+
if self.classes_ is not None and len(self.classes_) > 0
612+
else 0
613+
)
614+
return np.full(len(X), dummy_val)
607615

608-
raise RuntimeError(f"H2O prediction failed: {e}")
616+
raise RuntimeError(f"H2O prediction failed: {e2}")
609617

610618
# Extract predictions
611619
pred_df = predictions.as_data_frame(use_multi_thread=False)
@@ -665,8 +673,9 @@ def predict_proba(self, X: pd.DataFrame) -> np.ndarray:
665673
# Ensure H2O is running
666674
self._ensure_h2o_is_running()
667675

668-
# Ensure the model is loaded
669-
self._ensure_model_is_loaded()
676+
# OPTIMIZATION: Lazy load model.
677+
if self.model_ is None:
678+
self._ensure_model_is_loaded()
670679

671680
# Create H2O frame with explicit column names
672681
try:
@@ -687,20 +696,26 @@ def predict_proba(self, X: pd.DataFrame) -> np.ndarray:
687696
try:
688697
predictions = self.model_.predict(test_h2o)
689698
except Exception as e:
690-
# --- FIX: Catch H2O backend crashes (NPE) during prediction and fallback ---
691-
if "java.lang.NullPointerException" in str(e):
692-
self.logger.warning(
693-
f"H2O backend crashed with NPE during predict_proba(). Returning dummy probabilities. Details: {e}"
694-
)
695-
# Fallback: uniform probabilities
696-
n_classes = (
697-
len(self.classes_)
698-
if self.classes_ is not None and len(self.classes_) > 0
699-
else 2
700-
)
701-
return np.full((len(X), n_classes), 1.0 / n_classes)
699+
# Retry logic for unloaded models
700+
self.logger.debug(f"Prediction failed ({e}), attempting to reload model...")
701+
try:
702+
self._ensure_model_is_loaded()
703+
predictions = self.model_.predict(test_h2o)
704+
except Exception as e2:
705+
# --- FIX: Catch H2O backend crashes (NPE) during prediction and fallback ---
706+
if "java.lang.NullPointerException" in str(e):
707+
self.logger.warning(
708+
f"H2O backend crashed with NPE during predict_proba(). Returning dummy probabilities. Details: {e}"
709+
)
710+
# Fallback: uniform probabilities
711+
n_classes = (
712+
len(self.classes_)
713+
if self.classes_ is not None and len(self.classes_) > 0
714+
else 2
715+
)
716+
return np.full((len(X), n_classes), 1.0 / n_classes)
702717

703-
raise RuntimeError(f"H2O prediction failed: {e}")
718+
raise RuntimeError(f"H2O prediction failed: {e2}")
704719

705720
# Extract probabilities (drop the 'predict' column)
706721
prob_df = predictions.drop("predict").as_data_frame(use_multi_thread=False)

ml_grid/util/project_score_save.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,21 @@ def update_score_log(
187187
column_list = _get_score_log_columns(list(global_params.metric_list.keys()))
188188
line = pd.DataFrame(data=None, columns=column_list)
189189

190+
# --- OPTIMIZATION: Pre-process targets for faster metric calculation ---
191+
# Convert to numpy arrays to avoid pandas overhead in sklearn metrics
192+
y_test_np = y_test.values if hasattr(y_test, "values") else y_test
193+
best_pred_np = best_pred_orig.values if hasattr(best_pred_orig, "values") else best_pred_orig
194+
195+
# Attempt to convert to integers (e.g. "0"/"1" strings from H2O) for faster np.unique
196+
try:
197+
y_test_np = y_test_np.astype(int)
198+
best_pred_np = best_pred_np.astype(int)
199+
except (ValueError, TypeError):
200+
pass
201+
190202
# best_pred_orig = grid.best_estimator_.predict(X_test_orig)
191203
try:
192-
auc = metrics.roc_auc_score(y_test, best_pred_orig)
204+
auc = metrics.roc_auc_score(y_test_np, best_pred_np)
193205
except Exception as e:
194206
logger.warning(f"Could not calculate AUC score: {e}")
195207
logger.debug(f"y_test unique values: {y_test.unique()!s}")
@@ -198,11 +210,11 @@ def update_score_log(
198210
)
199211
auc = np.nan
200212

201-
mcc = matthews_corrcoef(y_test, best_pred_orig)
202-
f1 = f1_score(y_test, best_pred_orig, average="binary")
203-
precision = precision_score(y_test, best_pred_orig, average="binary")
204-
recall = recall_score(y_test, best_pred_orig, average="binary")
205-
accuracy = accuracy_score(y_test, best_pred_orig)
213+
mcc = matthews_corrcoef(y_test_np, best_pred_np)
214+
f1 = f1_score(y_test_np, best_pred_np, average="binary")
215+
precision = precision_score(y_test_np, best_pred_np, average="binary")
216+
recall = recall_score(y_test_np, best_pred_np, average="binary")
217+
accuracy = accuracy_score(y_test_np, best_pred_np)
206218

207219
# get info from current settings iter...local_param_dict ml_grid_object
208220
for key in ml_grid_object.local_param_dict:

0 commit comments

Comments
 (0)