Skip to content

Commit 8656cd4

Browse files
author
SamoraHunter
committed
Pass cached input df to avoid reloading each trial.
1 parent 2d73cf3 commit 8656cd4

2 files changed

Lines changed: 39 additions & 18 deletions

File tree

ml_grid/pipeline/data.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ def __init__(
193193
time_series_mode: bool = False,
194194
model_class_dict: Optional[Dict[str, bool]] = None,
195195
outcome_var_override: Optional[str] = None,
196+
input_df: Optional[pd.DataFrame] = None,
196197
):
197198
"""Initializes the data pipeline object.
198199
@@ -224,6 +225,8 @@ def __init__(
224225
outcome_var_override (Optional[str], optional): A specific outcome
225226
variable name to use, overriding the one from `local_param_dict`.
226227
Defaults to None.
228+
input_df (Optional[pd.DataFrame], optional): A pre-loaded DataFrame to use
229+
instead of reading from file_name. Defaults to None.
227230
"""
228231

229232
self.additional_naming = additional_naming
@@ -249,7 +252,18 @@ def __init__(
249252

250253
pipeline_error = None
251254
try:
252-
self._load_data(file_name, test_sample_n, column_sample_n)
255+
if input_df is not None:
256+
self.df = input_df.copy()
257+
self.all_df_columns = list(self.df.columns)
258+
self.orignal_feature_names = self.all_df_columns.copy()
259+
self._log_feature_transformation(
260+
"Initial Load",
261+
len(self.all_df_columns),
262+
len(self.all_df_columns),
263+
"Initial data loaded from passed DataFrame.",
264+
)
265+
else:
266+
self._load_data(file_name, test_sample_n, column_sample_n)
253267
self._initial_feature_selection(
254268
local_param_dict, drop_term_list, outcome_var_override
255269
)

notebooks/unit_test_synthetic.ipynb

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -992,15 +992,18 @@
992992
" # Retrieve trial timeout, defaulting to None if not set\n",
993993
" trial_timeout = config['hyperopt_settings'].get('trial_timeout', None)\n",
994994
"\n",
995-
" # 5. --- Determine Outcome Variables ---\n",
995+
" # 5. --- Determine Outcome Variables & PRE-LOAD DATA ---\n",
996+
" data_file_path = config['data']['file_path']\n",
997+
" if not Path(data_file_path).is_absolute():\n",
998+
" data_file_path = project_root / data_file_path\n",
999+
"\n",
1000+
" # OPTIMIZATION: Load data ONCE here\n",
1001+
" print(f\"Pre-loading data from {data_file_path}...\", flush=True)\n",
1002+
" GLOBAL_DF = pd.read_csv(data_file_path)\n",
1003+
"\n",
9961004
" outcome_var_list = []\n",
9971005
" if config['data'].get('multiple_outcomes', False):\n",
998-
" data_file_path = config['data']['file_path']\n",
999-
" if not Path(data_file_path).is_absolute():\n",
1000-
" data_file_path = project_root / data_file_path\n",
1001-
"\n",
1002-
" df = pd.read_csv(data_file_path)\n",
1003-
" outcome_var_list = [col for col in df.columns if 'outcome_var_' in col]\n",
1006+
" outcome_var_list = [col for col in GLOBAL_DF.columns if 'outcome_var_' in col]\n",
10041007
"\n",
10051008
" if not outcome_var_list:\n",
10061009
" raise ValueError(f\"No outcome variables found with 'outcome_var_' prefix in {data_file_path}\")\n",
@@ -1009,26 +1012,27 @@
10091012
" outcome_var_list = config['hyperopt_search_space']['outcome_var_n']\n",
10101013
"\n",
10111014
" # 6. --- Define the Objective Function for Hyperopt ---\n",
1012-
" def objective(params, outcome_var):\n",
1015+
" def objective(params, outcome_var, loaded_df):\n",
10131016
" \"\"\"\n",
1014-
" Objective function for hyperopt. It receives sampled parameters\n",
1015-
" and the specific outcome variable for the current run.\n",
1017+
" Objective function for hyperopt. It receives sampled parameters,\n",
1018+
" the specific outcome variable, and the PRE-LOADED DataFrame.\n",
10161019
" \"\"\"\n",
10171020
" try:\n",
10181021
" # Wrap the entire trial execution in the time_limit context manager\n",
10191022
" with time_limit(trial_timeout):\n",
10201023
" local_param_dict = params\n",
10211024
"\n",
1022-
" # Initialize the data pipeline\n",
1025+
" # Initialize the data pipeline using the cached DataFrame\n",
10231026
" ml_grid_object = pipe(\n",
1024-
" file_name=config['data']['file_path'],\n",
1027+
" file_name=None, # Not needed when input_df is provided\n",
10251028
" drop_term_list=config['data']['drop_term_list'],\n",
10261029
" model_class_dict=config['models'],\n",
10271030
" local_param_dict=local_param_dict,\n",
10281031
" base_project_dir=project_root,\n",
10291032
" experiment_dir=experiment_dir,\n",
10301033
" param_space_index=0, \n",
1031-
" outcome_var_override=outcome_var\n",
1034+
" outcome_var_override=outcome_var,\n",
1035+
" input_df=loaded_df # <--- PASS CACHED DATA\n",
10321036
" )\n",
10331037
"\n",
10341038
" # Execute the models\n",
@@ -1052,7 +1056,9 @@
10521056
" start_time = datetime.now()\n",
10531057
" print(f\"[{start_time}] Starting optimization for outcome: {outcome_var}\", flush=True)\n",
10541058
"\n",
1055-
" fmin_objective = partial(objective, outcome_var=outcome_var)\n",
1059+
" # Pass the global dataframe to the objective function via partial\n",
1060+
" # Joblib will handle the serialization/shared memory of GLOBAL_DF efficiently\n",
1061+
" fmin_objective = partial(objective, outcome_var=outcome_var, loaded_df=GLOBAL_DF)\n",
10561062
"\n",
10571063
" trials = Trials()\n",
10581064
" best = fmin(\n",
@@ -1068,8 +1074,8 @@
10681074
" failed_trials = [t for t in trials.results if t['status'] == 'fail']\n",
10691075
"\n",
10701076
" print(f\"[{end_time}] Finished {outcome_var} (Duration: {end_time - start_time})\", flush=True)\n",
1071-
" print(f\" -> Best param set for this outcome: {best}\", flush=True)\n",
1072-
" print(f\" -> Trials summary: {len(failed_trials)}/{len(trials.results)} failed.\", flush=True)\n",
1077+
" print(f\" -&gt; Best param set for this outcome: {best}\", flush=True)\n",
1078+
" print(f\" -&gt; Trials summary: {len(failed_trials)}/{len(trials.results)} failed.\", flush=True)\n",
10731079
"\n",
10741080
" return (outcome_var, best)\n",
10751081
"\n",
@@ -1098,7 +1104,8 @@
10981104
" print(f\" Best parameter combination found: {best_params}\")\n",
10991105
"\n",
11001106
" end_total = datetime.now()\n",
1101-
" print(f\"\\nCompleted all optimizations at {end_total} (Total duration: {end_total - start_total})\")\n"
1107+
" print(f\"\\nCompleted all optimizations at {end_total} (Total duration: {end_total - start_total})\")\n",
1108+
"\n"
11021109
]
11031110
},
11041111
{

0 commit comments

Comments
 (0)