Skip to content

Commit 63e4154

Browse files
author
SamoraHunter
committed
further optimisations
1 parent caa9d54 commit 63e4154

1 file changed

Lines changed: 47 additions & 43 deletions

File tree

ml_grid/util/project_score_save.py

Lines changed: 47 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,9 @@ def update_score_log(
185185
logger = logging.getLogger("ml_grid")
186186
logger.info("Writing grid permutation to log")
187187
# write line to best grid scores---------------------
188+
189+
# --- OPTIMIZATION: Construct dictionary first to avoid slow DataFrame element-wise setting ---
190+
row_data = {}
188191
column_list = _get_score_log_columns(list(global_params.metric_list.keys()))
189192
line = pd.DataFrame(data=None, columns=column_list)
190193

@@ -240,19 +243,17 @@ def update_score_log(
240243
mcc = matthews_corrcoef(y_test_np, best_pred_np)
241244
accuracy = accuracy_score(y_test_np, best_pred_np)
242245

243-
# get info from current settings iter...local_param_dict ml_grid_object
246+
# Populate row_data dictionary instead of repeated DataFrame indexing
244247
for key in ml_grid_object.local_param_dict:
245248
# print(key)
246249
if key != "data":
247250
if key in column_list:
248-
line[key] = [ml_grid_object.local_param_dict.get(key)]
251+
row_data[key] = ml_grid_object.local_param_dict.get(key)
249252
else:
250253
for key_1 in ml_grid_object.local_param_dict.get("data"):
251254
# print(key_1)
252255
if key_1 in column_list:
253-
line[key_1] = [
254-
ml_grid_object.local_param_dict.get("data").get(key_1)
255-
]
256+
row_data[key_1] = ml_grid_object.local_param_dict.get("data").get(key_1)
256257

257258
current_f = ml_grid_object.final_column_list
258259
# current_f = list(self.X_test.columns)
@@ -266,61 +267,64 @@ def update_score_log(
266267
# f_list.append(np.array(current_f_vector))
267268
f_list.append(current_f_vector)
268269

269-
line["algorithm_implementation"] = [current_algorithm]
270-
line["parameter_sample"] = [current_algorithm.get_params()]
271-
line["method_name"] = [method_name]
272-
line["nb_size"] = [sum(np.array(current_f_vector))]
273-
line["n_features"] = [len(current_f_vector)]
274-
line["f_list"] = [f_list]
275-
276-
line["auc"] = [auc]
277-
line["mcc"] = [mcc]
278-
line["f1"] = [f1]
279-
line["precision"] = [precision]
280-
line["recall"] = [recall]
281-
line["accuracy"] = [accuracy]
282-
line["support"] = [support_val]
283-
284-
line["X_train_size"] = [len(X_train)]
285-
line["X_test_orig_size"] = [len(X_test_orig)]
286-
line["X_test_size"] = [len(X_test)]
270+
row_data["algorithm_implementation"] = current_algorithm
271+
row_data["parameter_sample"] = current_algorithm.get_params()
272+
row_data["method_name"] = method_name
273+
row_data["nb_size"] = sum(np.array(current_f_vector))
274+
row_data["n_features"] = len(current_f_vector)
275+
row_data["f_list"] = f_list
276+
277+
row_data["auc"] = auc
278+
row_data["mcc"] = mcc
279+
row_data["f1"] = f1
280+
row_data["precision"] = precision
281+
row_data["recall"] = recall
282+
row_data["accuracy"] = accuracy
283+
row_data["support"] = support_val
284+
285+
row_data["X_train_size"] = len(X_train)
286+
row_data["X_test_orig_size"] = len(X_test_orig)
287+
row_data["X_test_size"] = len(X_test)
287288

288289
end = time.time()
289290

290291
logger.debug(f"Cross-validation scores: {scores}")
291-
line["run_time"] = end - start
292-
line["t_fits"] = pg
293-
line["n_fits"] = n_iter_v
294-
line["i"] = param_space_index # 0 # should be index of the iterator
295-
line["outcome_variable"] = ml_grid_object_iter.outcome_variable
296-
line["failed"] = failed
292+
row_data["run_time"] = end - start
293+
row_data["t_fits"] = pg
294+
row_data["n_fits"] = n_iter_v
295+
row_data["i"] = param_space_index # 0 # should be index of the iterator
296+
row_data["outcome_variable"] = ml_grid_object_iter.outcome_variable
297+
row_data["failed"] = failed
297298

298299
if bayessearch:
299300
try:
300-
line["fit_time_m"] = np.array([scores["fit_time"]]).mean()
301-
line["fit_time_std"] = np.array([scores["fit_time"]]).std()
301+
# Optimization: Use np.mean directly to avoid redundant array creation and nanmean overhead (~68s)
302+
row_data["fit_time_m"] = np.mean(scores["fit_time"])
303+
row_data["fit_time_std"] = np.std(scores["fit_time"])
302304

303-
line["score_time_m"] = np.array(scores["score_time"]).mean()
304-
line["score_time_std"] = np.array(scores["score_time"]).std()
305+
row_data["score_time_m"] = np.mean(scores["score_time"])
306+
row_data["score_time_std"] = np.std(scores["score_time"])
305307

306308
for metric in global_params.metric_list:
307-
line[f"{metric}_m"] = np.array(scores[f"test_{metric}"]).mean()
308-
line[f"{metric}_std"] = np.array(scores[f"test_{metric}"]).std()
309+
row_data[f"{metric}_m"] = np.mean(scores[f"test_{metric}"])
310+
row_data[f"{metric}_std"] = np.std(scores[f"test_{metric}"])
309311

310312
except Exception as e:
311313
logger.error(f"Error processing scores for BayesSearch: {e}")
312314
logger.debug(f"Scores dictionary: {scores}")
313315
else:
314-
line["fit_time_m"] = np.array(
315-
scores["fit_time"]
316-
).mean() # deprecated for bayes
317-
line["fit_time_std"] = np.array(scores["fit_time"]).std()
318-
line["score_time_m"] = np.array(scores["score_time"]).mean()
319-
line["score_time_std"] = np.array(scores["score_time"]).std()
316+
# Optimization: Use np.mean directly
317+
row_data["fit_time_m"] = np.mean(scores["fit_time"])
318+
row_data["fit_time_std"] = np.std(scores["fit_time"])
319+
row_data["score_time_m"] = np.mean(scores["score_time"])
320+
row_data["score_time_std"] = np.std(scores["score_time"])
320321

321322
for metric in global_params.metric_list:
322-
line[f"{metric}_m"] = np.array(scores[f"test_{metric}"]).mean()
323-
line[f"{metric}_std"] = np.array(scores[f"test_{metric}"]).std()
323+
row_data[f"{metric}_m"] = np.mean(scores[f"test_{metric}"])
324+
row_data[f"{metric}_std"] = np.std(scores[f"test_{metric}"])
325+
326+
# Create the DataFrame once with all data
327+
line = pd.DataFrame([row_data], columns=column_list)
324328

325329
logger.info(f"Logged results for method '{method_name}'")
326330
logger.debug(f"Log line data: \n{line.to_string()}")

0 commit comments

Comments
 (0)