33import os
44import shutil
55import tempfile
6+ import uuid
67from typing import Any , Dict , List , Optional , Tuple
78
89import h2o
@@ -43,6 +44,20 @@ class H2OBaseClassifier(BaseEstimator, ClassifierMixin):
4344
4445 MIN_SAMPLES_FOR_STABLE_FIT = 10
4546
47+ # Class-level cache for init parameters to avoid repeated inspect.signature calls
48+ _init_param_names_cache = {}
49+
50+ # Class-level cache for estimator class signatures to avoid repeated inspect.signature calls in fit()
51+ _estimator_signature_cache = {}
52+
53+ # Class-level set of keys to exclude from get_params, avoiding recreation overhead
54+ _excluded_param_keys = {
55+ "estimator_class" ,
56+ "logger" ,
57+ "model" ,
58+ "model_id" ,
59+ }
60+
4661 def __init__ (self , estimator_class = None , ** kwargs ):
4762 """Initializes the H2OBaseClassifier.
4863
@@ -100,6 +115,8 @@ def __init__(self, estimator_class=None, **kwargs):
100115 # H2O models are not safe with joblib's process-based parallelism.
101116 self ._n_jobs = 1
102117
118+ self ._cached_param_names = None # Cache for get_params
119+
103120 def __del__ (self ):
104121 """Cleans up the shared checkpoint directory if this is the last instance."""
105122 # This is a best-effort cleanup. In multi-process scenarios,
@@ -335,7 +352,11 @@ def _prepare_fit(
335352 )
336353
337354 train_df = pd .concat ([X , y_series ], axis = 1 )
338- train_h2o = h2o .H2OFrame (train_df )
355+ # Optimization: Provide destination_frame to avoid expensive gc.get_referrers() name search
356+ train_h2o = h2o .H2OFrame (
357+ train_df ,
358+ destination_frame = f"train_{ uuid .uuid4 ().hex } "
359+ )
339360
340361 # Explicitly convert the outcome column to factor
341362 train_h2o [outcome_var ] = train_h2o [outcome_var ].asfactor ()
@@ -349,7 +370,12 @@ def _prepare_fit(
349370 model_params = self ._get_model_params ()
350371
351372 # Get valid parameters for the specific H2O estimator
352- estimator_params = inspect .signature (self .estimator_class ).parameters
373+ # Optimization: Use cached signature
374+ if self .estimator_class not in self ._estimator_signature_cache :
375+ self ._estimator_signature_cache [self .estimator_class ] = inspect .signature (
376+ self .estimator_class
377+ ).parameters
378+ estimator_params = self ._estimator_signature_cache [self .estimator_class ]
353379
354380 # If there's only one feature, prevent H2O from dropping it if it's constant
355381 if len (x_vars ) == 1 and self .estimator_class :
@@ -381,9 +407,12 @@ def _get_model_params(self) -> Dict[str, Any]:
381407 if k != "estimator_class"
382408 }
383409
384- valid_param_keys = set (
385- inspect .signature (self .estimator_class ).parameters .keys ()
386- )
410+ # Optimization: Use cached signature
411+ if self .estimator_class not in self ._estimator_signature_cache :
412+ self ._estimator_signature_cache [self .estimator_class ] = inspect .signature (
413+ self .estimator_class
414+ ).parameters
415+ valid_param_keys = set (self ._estimator_signature_cache [self .estimator_class ].keys ())
387416
388417 model_params = {
389418 key : value for key , value in all_params .items () if key in valid_param_keys
@@ -473,10 +502,22 @@ def fit(self, X: pd.DataFrame, y: pd.Series, **kwargs) -> "H2OBaseClassifier":
473502 # Sanitize parameters to prevent backend errors (e.g. HGLM)
474503 self ._sanitize_model_params ()
475504
505+ # --- OPTIMIZATION: Force disable progress bar immediately before train ---
506+ # This addresses profiling results showing significant overhead in progressbar.py
507+ if not getattr (global_parameters , "h2o_show_progress" , False ):
508+ h2o .no_progress ()
509+
476510 # Call the train() method with ONLY the data-related arguments
477511 self .logger .debug ("Calling H2O model.train()..." )
478512 self .model_ .train (x = x_vars , y = outcome_var , training_frame = train_h2o )
479513
514+ # --- OPTIMIZATION: Explicitly cleanup training frame ---
515+ # This reduces overhead from H2O's reference counting (gc.get_referrers)
516+ try :
517+ h2o .remove (train_h2o )
518+ except Exception :
519+ pass
520+
480521 # Store model_id for recovery - THIS IS CRITICAL for predict() to work
481522 self .logger .debug (
482523 f"H2O train complete, extracting model_id from { self .model_ } "
@@ -566,11 +607,11 @@ def predict(self, X: pd.DataFrame) -> np.ndarray:
566607 )
567608 return np .full (len (X ), dummy_prediction )
568609
569- # Ensure H2O is running
570- self ._ensure_h2o_is_running ()
571-
572- # OPTIMIZATION: Lazy load model. Only check if we don't have the object.
610+ # OPTIMIZATION: Lazy load model and optimistically assume H2O is running.
611+ # Only check/init cluster if we don't have the model object or if prediction fails.
612+ # This saves an API call (h2o.cluster()) per predict.
573613 if self .model_ is None :
614+ self ._ensure_h2o_is_running ()
574615 self ._ensure_model_is_loaded ()
575616
576617 try :
@@ -589,7 +630,10 @@ def predict(self, X: pd.DataFrame) -> np.ndarray:
589630 }
590631
591632 tmp_frame = h2o .H2OFrame (
592- X , column_names = self .feature_names_ , column_types = col_types
633+ X ,
634+ column_names = self .feature_names_ ,
635+ column_types = col_types ,
636+ destination_frame = f"pred_{ uuid .uuid4 ().hex } "
593637 )
594638
595639 # Optimization: Use the temporary frame directly.
@@ -607,6 +651,7 @@ def predict(self, X: pd.DataFrame) -> np.ndarray:
607651 # Try reloading and predicting again.
608652 self .logger .debug (f"Prediction failed ({ e } ), attempting to reload model..." )
609653 try :
654+ self ._ensure_h2o_is_running ()
610655 self ._ensure_model_is_loaded ()
611656 predictions = self .model_ .predict (test_h2o )
612657 except Exception as e2 :
@@ -626,9 +671,27 @@ def predict(self, X: pd.DataFrame) -> np.ndarray:
626671 raise RuntimeError (f"H2O prediction failed: { e2 } " )
627672
628673 # Extract predictions
674+ # Optimization: Download full frame to avoid H2O slicing overhead (gc.get_referrers ~48s)
675+ # Slicing in H2O (predictions['predict']) creates a new unnamed frame which triggers a stack scan.
629676 pred_df = predictions .as_data_frame (use_multi_thread = False )
677+
678+ # --- OPTIMIZATION: Explicitly cleanup frames ---
679+ try :
680+ h2o .remove (test_h2o )
681+ h2o .remove (predictions )
682+ except Exception :
683+ pass
684+
630685 if "predict" in pred_df .columns :
631- return pred_df ["predict" ].values .ravel ()
686+ preds = pred_df ["predict" ].values .ravel ()
687+ # OPTIMIZATION: Cast to the same type as classes_ to avoid object/string overhead
688+ # in sklearn metrics (which triggers expensive np.unique on object arrays).
689+ if self .classes_ is not None :
690+ try :
691+ preds = preds .astype (self .classes_ .dtype )
692+ except (ValueError , TypeError ):
693+ pass
694+ return preds
632695 else :
633696 raise RuntimeError ("Prediction output missing 'predict' column" )
634697
@@ -680,11 +743,9 @@ def predict_proba(self, X: pd.DataFrame) -> np.ndarray:
680743 dummy_probas = np .full ((len (X ), n_classes ), 1 / n_classes )
681744 return dummy_probas
682745
683- # Ensure H2O is running
684- self ._ensure_h2o_is_running ()
685-
686- # OPTIMIZATION: Lazy load model.
746+ # OPTIMIZATION: Lazy load model and optimistically assume H2O is running.
687747 if self .model_ is None :
748+ self ._ensure_h2o_is_running ()
688749 self ._ensure_model_is_loaded ()
689750
690751 # Create H2O frame with explicit column names
@@ -697,7 +758,10 @@ def predict_proba(self, X: pd.DataFrame) -> np.ndarray:
697758 }
698759
699760 test_h2o = h2o .H2OFrame (
700- X , column_names = self .feature_names_ , column_types = col_types
761+ X ,
762+ column_names = self .feature_names_ ,
763+ column_types = col_types ,
764+ destination_frame = f"prob_{ uuid .uuid4 ().hex } "
701765 )
702766 except Exception as e :
703767 raise RuntimeError (f"Failed to create H2O frame for prediction: { e } " )
@@ -709,6 +773,7 @@ def predict_proba(self, X: pd.DataFrame) -> np.ndarray:
709773 # Retry logic for unloaded models
710774 self .logger .debug (f"Prediction failed ({ e } ), attempting to reload model..." )
711775 try :
776+ self ._ensure_h2o_is_running ()
712777 self ._ensure_model_is_loaded ()
713778 predictions = self .model_ .predict (test_h2o )
714779 except Exception as e2 :
@@ -728,7 +793,20 @@ def predict_proba(self, X: pd.DataFrame) -> np.ndarray:
728793 raise RuntimeError (f"H2O prediction failed: { e2 } " )
729794
730795 # Extract probabilities (drop the 'predict' column)
731- prob_df = predictions .drop ("predict" ).as_data_frame (use_multi_thread = False )
796+ # Optimization: Download full frame then drop in pandas to avoid H2O slicing overhead
797+ res_df = predictions .as_data_frame (use_multi_thread = False )
798+ if "predict" in res_df .columns :
799+ prob_df = res_df .drop (columns = ["predict" ])
800+ else :
801+ prob_df = res_df
802+
803+ # --- OPTIMIZATION: Explicitly cleanup frames ---
804+ try :
805+ h2o .remove (test_h2o )
806+ h2o .remove (predictions )
807+ except Exception :
808+ pass
809+
732810 return prob_df .values
733811
734812 def _ensure_model_is_loaded (self ):
@@ -828,21 +906,6 @@ def __deepcopy__(self, memo):
828906 )
829907 return cloned
830908
831- def __sklearn_clone__ (self ):
832- """Custom cloning method for sklearn compatibility.
833-
834- This ensures that when sklearn clones the estimator, we return a properly
835- initialized new instance with the same parameters.
836- """
837- # Get all parameters (not fitted attributes)
838- params = self .get_params (deep = False )
839- # Create new instance with same parameters
840- cloned = self .__class__ (** params )
841- self .logger .debug (
842- f"__sklearn_clone__ called: original instance { id (self )} , clone instance { id (cloned )} "
843- )
844- return cloned # Removing dead code
845-
846909 def _get_param_names (self ):
847910 """Get parameter names for the estimator.
848911
@@ -851,24 +914,45 @@ def _get_param_names(self):
851914
852915 CRITICAL: This should ONLY return parameter names, NOT fitted attribute names.
853916 """
854- init_signature = inspect .signature (self .__class__ .__init__ )
855- init_params = [
856- p .name
857- for p in init_signature .parameters .values ()
858- if p .name not in ("self" , "args" , "kwargs" )
859- ]
917+ if self ._cached_param_names is not None :
918+ return self ._cached_param_names
919+
920+ # Optimization: Cache the signature inspection on the class
921+ cls = self .__class__
922+ if cls not in self ._init_param_names_cache :
923+ init_signature = inspect .signature (cls .__init__ )
924+ self ._init_param_names_cache [cls ] = [
925+ p .name
926+ for p in init_signature .parameters .values ()
927+ if p .name not in ("self" , "args" , "kwargs" )
928+ ]
929+
930+ init_params = self ._init_param_names_cache [cls ]
931+
932+ # Optimization: Use sets for O(1) lookup
933+ init_params_set = set (init_params )
860934
861935 extra_params = [
862936 key
863937 for key in self .__dict__
864938 if not key .startswith ("_" )
865939 and not (key .endswith ("_" ) and key != "lambda_" ) # Allow lambda_
866- and key not in init_params
867- and key not in ["estimator_class" , "logger" ]
868- and key not in ["model" , "model_" , "classes_" , "feature_names_" , "model_id" ]
940+ and key not in init_params_set
941+ and key not in self ._excluded_param_keys
869942 ]
870943
871- return sorted (init_params + extra_params )
944+ self ._cached_param_names = sorted (init_params + extra_params )
945+ return self ._cached_param_names
946+
947+ def get_params (self , deep = True ):
948+ """Get parameters for this estimator.
949+
950+ Optimized implementation that bypasses BaseEstimator.get_params overhead.
951+ """
952+ # Optimization: Directly construct dict from cached param names
953+ # We bypass the deep=True recursion check for speed, as H2O estimators
954+ # don't have nested sklearn estimators.
955+ return {key : getattr (self , key ) for key in self ._get_param_names ()}
872956
873957 def set_params (self : "H2OBaseClassifier" , ** kwargs : Any ) -> "H2OBaseClassifier" :
874958 """Sets the parameters of this estimator, compatible with scikit-learn.
@@ -911,6 +995,8 @@ def set_params(self: "H2OBaseClassifier", **kwargs: Any) -> "H2OBaseClassifier":
911995 # This shouldn't happen for our use case, but handle it anyway
912996 setattr (self , key , value )
913997
998+ self ._cached_param_names = None # Invalidate cache
999+
9141000 # Restore fitted attributes
9151001 for attr , value in fitted_attributes .items ():
9161002 setattr (self , attr , value )
0 commit comments