Skip to content

Commit 05266b1

Browse files
author
SamoraHunter
committed
minor optimisation fix
1 parent 0b53b50 commit 05266b1

1 file changed

Lines changed: 3 additions & 2 deletions

File tree

ml_grid/pipeline/grid_search_cross_validate.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def __init__(
133133
"keras" in method_name.lower()
134134
or "xgb" in method_name.lower()
135135
or "catboost" in method_name.lower()
136+
or "neural" in method_name.lower()
136137
)
137138

138139
global _TF_INITIALIZED
@@ -692,7 +693,7 @@ def __init__(
692693
scoring=self.metric_list,
693694
cv=self.cv,
694695
n_jobs=final_cv_n_jobs, # Use adjusted n_jobs
695-
pre_dispatch=80,
696+
pre_dispatch="2*n_jobs",
696697
error_score=self.error_raise, # Raise error if cross-validation fails
697698
)
698699

@@ -741,7 +742,7 @@ def __init__(
741742
scoring=self.metric_list,
742743
cv=self.cv,
743744
n_jobs=final_cv_n_jobs, # Use adjusted n_jobs
744-
pre_dispatch=80,
745+
pre_dispatch="2*n_jobs",
745746
error_score=self.error_raise, # Raise error if cross-validation fails
746747
)
747748
except Exception as e:

0 commit comments

Comments
 (0)