Skip to content

Commit caa9d54

Browse files
author
SamoraHunter
committed
Further optimisations. Auto filter out client_idcode column.
1 parent 67b7937 commit caa9d54

4 files changed

Lines changed: 213 additions & 72 deletions

File tree

ml_grid/model_classes/H2OBaseClassifier.py

Lines changed: 128 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import shutil
55
import tempfile
6+
import uuid
67
from typing import Any, Dict, List, Optional, Tuple
78

89
import h2o
@@ -43,6 +44,20 @@ class H2OBaseClassifier(BaseEstimator, ClassifierMixin):
4344

4445
MIN_SAMPLES_FOR_STABLE_FIT = 10
4546

47+
# Class-level cache for init parameters to avoid repeated inspect.signature calls
48+
_init_param_names_cache = {}
49+
50+
# Class-level cache for estimator class signatures to avoid repeated inspect.signature calls in fit()
51+
_estimator_signature_cache = {}
52+
53+
# Class-level set of keys to exclude from get_params, avoiding recreation overhead
54+
_excluded_param_keys = {
55+
"estimator_class",
56+
"logger",
57+
"model",
58+
"model_id",
59+
}
60+
4661
def __init__(self, estimator_class=None, **kwargs):
4762
"""Initializes the H2OBaseClassifier.
4863
@@ -100,6 +115,8 @@ def __init__(self, estimator_class=None, **kwargs):
100115
# H2O models are not safe with joblib's process-based parallelism.
101116
self._n_jobs = 1
102117

118+
self._cached_param_names = None # Cache for get_params
119+
103120
def __del__(self):
104121
"""Cleans up the shared checkpoint directory if this is the last instance."""
105122
# This is a best-effort cleanup. In multi-process scenarios,
@@ -335,7 +352,11 @@ def _prepare_fit(
335352
)
336353

337354
train_df = pd.concat([X, y_series], axis=1)
338-
train_h2o = h2o.H2OFrame(train_df)
355+
# Optimization: Provide destination_frame to avoid expensive gc.get_referrers() name search
356+
train_h2o = h2o.H2OFrame(
357+
train_df,
358+
destination_frame=f"train_{uuid.uuid4().hex}"
359+
)
339360

340361
# Explicitly convert the outcome column to factor
341362
train_h2o[outcome_var] = train_h2o[outcome_var].asfactor()
@@ -349,7 +370,12 @@ def _prepare_fit(
349370
model_params = self._get_model_params()
350371

351372
# Get valid parameters for the specific H2O estimator
352-
estimator_params = inspect.signature(self.estimator_class).parameters
373+
# Optimization: Use cached signature
374+
if self.estimator_class not in self._estimator_signature_cache:
375+
self._estimator_signature_cache[self.estimator_class] = inspect.signature(
376+
self.estimator_class
377+
).parameters
378+
estimator_params = self._estimator_signature_cache[self.estimator_class]
353379

354380
# If there's only one feature, prevent H2O from dropping it if it's constant
355381
if len(x_vars) == 1 and self.estimator_class:
@@ -381,9 +407,12 @@ def _get_model_params(self) -> Dict[str, Any]:
381407
if k != "estimator_class"
382408
}
383409

384-
valid_param_keys = set(
385-
inspect.signature(self.estimator_class).parameters.keys()
386-
)
410+
# Optimization: Use cached signature
411+
if self.estimator_class not in self._estimator_signature_cache:
412+
self._estimator_signature_cache[self.estimator_class] = inspect.signature(
413+
self.estimator_class
414+
).parameters
415+
valid_param_keys = set(self._estimator_signature_cache[self.estimator_class].keys())
387416

388417
model_params = {
389418
key: value for key, value in all_params.items() if key in valid_param_keys
@@ -473,10 +502,22 @@ def fit(self, X: pd.DataFrame, y: pd.Series, **kwargs) -> "H2OBaseClassifier":
473502
# Sanitize parameters to prevent backend errors (e.g. HGLM)
474503
self._sanitize_model_params()
475504

505+
# --- OPTIMIZATION: Force disable progress bar immediately before train ---
506+
# This addresses profiling results showing significant overhead in progressbar.py
507+
if not getattr(global_parameters, "h2o_show_progress", False):
508+
h2o.no_progress()
509+
476510
# Call the train() method with ONLY the data-related arguments
477511
self.logger.debug("Calling H2O model.train()...")
478512
self.model_.train(x=x_vars, y=outcome_var, training_frame=train_h2o)
479513

514+
# --- OPTIMIZATION: Explicitly cleanup training frame ---
515+
# This reduces overhead from H2O's reference counting (gc.get_referrers)
516+
try:
517+
h2o.remove(train_h2o)
518+
except Exception:
519+
pass
520+
480521
# Store model_id for recovery - THIS IS CRITICAL for predict() to work
481522
self.logger.debug(
482523
f"H2O train complete, extracting model_id from {self.model_}"
@@ -566,11 +607,11 @@ def predict(self, X: pd.DataFrame) -> np.ndarray:
566607
)
567608
return np.full(len(X), dummy_prediction)
568609

569-
# Ensure H2O is running
570-
self._ensure_h2o_is_running()
571-
572-
# OPTIMIZATION: Lazy load model. Only check if we don't have the object.
610+
# OPTIMIZATION: Lazy load model and optimistically assume H2O is running.
611+
# Only check/init cluster if we don't have the model object or if prediction fails.
612+
# This saves an API call (h2o.cluster()) per predict.
573613
if self.model_ is None:
614+
self._ensure_h2o_is_running()
574615
self._ensure_model_is_loaded()
575616

576617
try:
@@ -589,7 +630,10 @@ def predict(self, X: pd.DataFrame) -> np.ndarray:
589630
}
590631

591632
tmp_frame = h2o.H2OFrame(
592-
X, column_names=self.feature_names_, column_types=col_types
633+
X,
634+
column_names=self.feature_names_,
635+
column_types=col_types,
636+
destination_frame=f"pred_{uuid.uuid4().hex}"
593637
)
594638

595639
# Optimization: Use the temporary frame directly.
@@ -607,6 +651,7 @@ def predict(self, X: pd.DataFrame) -> np.ndarray:
607651
# Try reloading and predicting again.
608652
self.logger.debug(f"Prediction failed ({e}), attempting to reload model...")
609653
try:
654+
self._ensure_h2o_is_running()
610655
self._ensure_model_is_loaded()
611656
predictions = self.model_.predict(test_h2o)
612657
except Exception as e2:
@@ -626,9 +671,27 @@ def predict(self, X: pd.DataFrame) -> np.ndarray:
626671
raise RuntimeError(f"H2O prediction failed: {e2}")
627672

628673
# Extract predictions
674+
# Optimization: Download full frame to avoid H2O slicing overhead (gc.get_referrers ~48s)
675+
# Slicing in H2O (predictions['predict']) creates a new unnamed frame which triggers a stack scan.
629676
pred_df = predictions.as_data_frame(use_multi_thread=False)
677+
678+
# --- OPTIMIZATION: Explicitly cleanup frames ---
679+
try:
680+
h2o.remove(test_h2o)
681+
h2o.remove(predictions)
682+
except Exception:
683+
pass
684+
630685
if "predict" in pred_df.columns:
631-
return pred_df["predict"].values.ravel()
686+
preds = pred_df["predict"].values.ravel()
687+
# OPTIMIZATION: Cast to the same type as classes_ to avoid object/string overhead
688+
# in sklearn metrics (which triggers expensive np.unique on object arrays).
689+
if self.classes_ is not None:
690+
try:
691+
preds = preds.astype(self.classes_.dtype)
692+
except (ValueError, TypeError):
693+
pass
694+
return preds
632695
else:
633696
raise RuntimeError("Prediction output missing 'predict' column")
634697

@@ -680,11 +743,9 @@ def predict_proba(self, X: pd.DataFrame) -> np.ndarray:
680743
dummy_probas = np.full((len(X), n_classes), 1 / n_classes)
681744
return dummy_probas
682745

683-
# Ensure H2O is running
684-
self._ensure_h2o_is_running()
685-
686-
# OPTIMIZATION: Lazy load model.
746+
# OPTIMIZATION: Lazy load model and optimistically assume H2O is running.
687747
if self.model_ is None:
748+
self._ensure_h2o_is_running()
688749
self._ensure_model_is_loaded()
689750

690751
# Create H2O frame with explicit column names
@@ -697,7 +758,10 @@ def predict_proba(self, X: pd.DataFrame) -> np.ndarray:
697758
}
698759

699760
test_h2o = h2o.H2OFrame(
700-
X, column_names=self.feature_names_, column_types=col_types
761+
X,
762+
column_names=self.feature_names_,
763+
column_types=col_types,
764+
destination_frame=f"prob_{uuid.uuid4().hex}"
701765
)
702766
except Exception as e:
703767
raise RuntimeError(f"Failed to create H2O frame for prediction: {e}")
@@ -709,6 +773,7 @@ def predict_proba(self, X: pd.DataFrame) -> np.ndarray:
709773
# Retry logic for unloaded models
710774
self.logger.debug(f"Prediction failed ({e}), attempting to reload model...")
711775
try:
776+
self._ensure_h2o_is_running()
712777
self._ensure_model_is_loaded()
713778
predictions = self.model_.predict(test_h2o)
714779
except Exception as e2:
@@ -728,7 +793,20 @@ def predict_proba(self, X: pd.DataFrame) -> np.ndarray:
728793
raise RuntimeError(f"H2O prediction failed: {e2}")
729794

730795
# Extract probabilities (drop the 'predict' column)
731-
prob_df = predictions.drop("predict").as_data_frame(use_multi_thread=False)
796+
# Optimization: Download full frame then drop in pandas to avoid H2O slicing overhead
797+
res_df = predictions.as_data_frame(use_multi_thread=False)
798+
if "predict" in res_df.columns:
799+
prob_df = res_df.drop(columns=["predict"])
800+
else:
801+
prob_df = res_df
802+
803+
# --- OPTIMIZATION: Explicitly cleanup frames ---
804+
try:
805+
h2o.remove(test_h2o)
806+
h2o.remove(predictions)
807+
except Exception:
808+
pass
809+
732810
return prob_df.values
733811

734812
def _ensure_model_is_loaded(self):
@@ -828,21 +906,6 @@ def __deepcopy__(self, memo):
828906
)
829907
return cloned
830908

831-
def __sklearn_clone__(self):
832-
"""Custom cloning method for sklearn compatibility.
833-
834-
This ensures that when sklearn clones the estimator, we return a properly
835-
initialized new instance with the same parameters.
836-
"""
837-
# Get all parameters (not fitted attributes)
838-
params = self.get_params(deep=False)
839-
# Create new instance with same parameters
840-
cloned = self.__class__(**params)
841-
self.logger.debug(
842-
f"__sklearn_clone__ called: original instance {id(self)}, clone instance {id(cloned)}"
843-
)
844-
return cloned # Removing dead code
845-
846909
def _get_param_names(self):
847910
"""Get parameter names for the estimator.
848911
@@ -851,24 +914,45 @@ def _get_param_names(self):
851914
852915
CRITICAL: This should ONLY return parameter names, NOT fitted attribute names.
853916
"""
854-
init_signature = inspect.signature(self.__class__.__init__)
855-
init_params = [
856-
p.name
857-
for p in init_signature.parameters.values()
858-
if p.name not in ("self", "args", "kwargs")
859-
]
917+
if self._cached_param_names is not None:
918+
return self._cached_param_names
919+
920+
# Optimization: Cache the signature inspection on the class
921+
cls = self.__class__
922+
if cls not in self._init_param_names_cache:
923+
init_signature = inspect.signature(cls.__init__)
924+
self._init_param_names_cache[cls] = [
925+
p.name
926+
for p in init_signature.parameters.values()
927+
if p.name not in ("self", "args", "kwargs")
928+
]
929+
930+
init_params = self._init_param_names_cache[cls]
931+
932+
# Optimization: Use sets for O(1) lookup
933+
init_params_set = set(init_params)
860934

861935
extra_params = [
862936
key
863937
for key in self.__dict__
864938
if not key.startswith("_")
865939
and not (key.endswith("_") and key != "lambda_") # Allow lambda_
866-
and key not in init_params
867-
and key not in ["estimator_class", "logger"]
868-
and key not in ["model", "model_", "classes_", "feature_names_", "model_id"]
940+
and key not in init_params_set
941+
and key not in self._excluded_param_keys
869942
]
870943

871-
return sorted(init_params + extra_params)
944+
self._cached_param_names = sorted(init_params + extra_params)
945+
return self._cached_param_names
946+
947+
def get_params(self, deep=True):
948+
"""Get parameters for this estimator.
949+
950+
Optimized implementation that bypasses BaseEstimator.get_params overhead.
951+
"""
952+
# Optimization: Directly construct dict from cached param names
953+
# We bypass the deep=True recursion check for speed, as H2O estimators
954+
# don't have nested sklearn estimators.
955+
return {key: getattr(self, key) for key in self._get_param_names()}
872956

873957
def set_params(self: "H2OBaseClassifier", **kwargs: Any) -> "H2OBaseClassifier":
874958
"""Sets the parameters of this estimator, compatible with scikit-learn.
@@ -911,6 +995,8 @@ def set_params(self: "H2OBaseClassifier", **kwargs: Any) -> "H2OBaseClassifier":
911995
# This shouldn't happen for our use case, but handle it anyway
912996
setattr(self, key, value)
913997

998+
self._cached_param_names = None # Invalidate cache
999+
9141000
# Restore fitted attributes
9151001
for attr, value in fitted_attributes.items():
9161002
setattr(self, attr, value)

0 commit comments

Comments
 (0)