@@ -559,8 +559,9 @@ def predict(self, X: pd.DataFrame) -> np.ndarray:
559559 # Ensure H2O is running
560560 self ._ensure_h2o_is_running ()
561561
562- # Ensure the model is loaded (critical for cross-validation)
563- self ._ensure_model_is_loaded ()
562+ # OPTIMIZATION: Lazy load model. Only check if we don't have the object.
563+ if self .model_ is None :
564+ self ._ensure_model_is_loaded ()
564565
565566 try :
566567 # --- ROBUSTNESS FIX for java.lang.NullPointerException ---
@@ -592,20 +593,27 @@ def predict(self, X: pd.DataFrame) -> np.ndarray:
592593 try :
593594 predictions = self .model_ .predict (test_h2o )
594595 except Exception as e :
595- # --- FIX: Catch H2O backend crashes (NPE) during prediction and fallback ---
596- if "java.lang.NullPointerException" in str (e ):
597- self .logger .warning (
598- f"H2O backend crashed with NPE during predict(). Returning dummy predictions. Details: { e } "
599- )
600- # Fallback: predict the first class (usually 0)
601- dummy_val = (
602- self .classes_ [0 ]
603- if self .classes_ is not None and len (self .classes_ ) > 0
604- else 0
605- )
606- return np .full (len (X ), dummy_val )
596+ # If prediction failed, it might be because the model was unloaded/GC'd on server.
597+ # Try reloading and predicting again.
598+ self .logger .debug (f"Prediction failed ({ e } ), attempting to reload model..." )
599+ try :
600+ self ._ensure_model_is_loaded ()
601+ predictions = self .model_ .predict (test_h2o )
602+ except Exception as e2 :
603+ # --- FIX: Catch H2O backend crashes (NPE) during prediction and fallback ---
604+ if "java.lang.NullPointerException" in str (e ):
605+ self .logger .warning (
606+ f"H2O backend crashed with NPE during predict(). Returning dummy predictions. Details: { e } "
607+ )
608+ # Fallback: predict the first class (usually 0)
609+ dummy_val = (
610+ self .classes_ [0 ]
611+ if self .classes_ is not None and len (self .classes_ ) > 0
612+ else 0
613+ )
614+ return np .full (len (X ), dummy_val )
607615
608- raise RuntimeError (f"H2O prediction failed: { e } " )
616+ raise RuntimeError (f"H2O prediction failed: { e2 } " )
609617
610618 # Extract predictions
611619 pred_df = predictions .as_data_frame (use_multi_thread = False )
@@ -665,8 +673,9 @@ def predict_proba(self, X: pd.DataFrame) -> np.ndarray:
665673 # Ensure H2O is running
666674 self ._ensure_h2o_is_running ()
667675
668- # Ensure the model is loaded
669- self ._ensure_model_is_loaded ()
676+ # OPTIMIZATION: Lazy load model.
677+ if self .model_ is None :
678+ self ._ensure_model_is_loaded ()
670679
671680 # Create H2O frame with explicit column names
672681 try :
@@ -687,20 +696,26 @@ def predict_proba(self, X: pd.DataFrame) -> np.ndarray:
687696 try :
688697 predictions = self .model_ .predict (test_h2o )
689698 except Exception as e :
690- # --- FIX: Catch H2O backend crashes (NPE) during prediction and fallback ---
691- if "java.lang.NullPointerException" in str (e ):
692- self .logger .warning (
693- f"H2O backend crashed with NPE during predict_proba(). Returning dummy probabilities. Details: { e } "
694- )
695- # Fallback: uniform probabilities
696- n_classes = (
697- len (self .classes_ )
698- if self .classes_ is not None and len (self .classes_ ) > 0
699- else 2
700- )
701- return np .full ((len (X ), n_classes ), 1.0 / n_classes )
699+ # Retry logic for unloaded models
700+ self .logger .debug (f"Prediction failed ({ e } ), attempting to reload model..." )
701+ try :
702+ self ._ensure_model_is_loaded ()
703+ predictions = self .model_ .predict (test_h2o )
704+ except Exception as e2 :
705+ # --- FIX: Catch H2O backend crashes (NPE) during prediction and fallback ---
706+ if "java.lang.NullPointerException" in str (e ):
707+ self .logger .warning (
708+ f"H2O backend crashed with NPE during predict_proba(). Returning dummy probabilities. Details: { e } "
709+ )
710+ # Fallback: uniform probabilities
711+ n_classes = (
712+ len (self .classes_ )
713+ if self .classes_ is not None and len (self .classes_ ) > 0
714+ else 2
715+ )
716+ return np .full ((len (X ), n_classes ), 1.0 / n_classes )
702717
703- raise RuntimeError (f"H2O prediction failed: { e } " )
718+ raise RuntimeError (f"H2O prediction failed: { e2 } " )
704719
705720 # Extract probabilities (drop the 'predict' column)
706721 prob_df = predictions .drop ("predict" ).as_data_frame (use_multi_thread = False )
0 commit comments