Skip to content

Commit b74c317

Browse files
committed
ml_assisted_library_transform implementation (work in progress)
1 parent 0cd96fb commit b74c317

5 files changed

Lines changed: 1527 additions & 451 deletions

File tree

scripts/210325_design_methods_simulation_dev.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
logger = getLogger(__name__)
1212

13-
configs_filenames = ["config_best_sampling"]
13+
configs_filenames = ["config_batch_model_v2_n_rounds"]
1414

1515
start = time.time()
1616
fig, ax = plt.subplots(figsize=(4, 4))

scripts/develop_library_transformation.ipynb

Lines changed: 1417 additions & 438 deletions
Large diffs are not rendered by default.

scripts/setupconfigfile.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,23 @@
1111
output_name = "config_batch_model_v2_n_rounds"
1212
target = "product"
1313
run_id = 1
14-
n_cycles = 5
15-
n_experiments = [10, 10, 10, 10, 10]
16-
n_screened= 60
14+
n_cycles = 2
15+
n_experiments = [10, 10]
16+
n_screened = 30
1717

18-
n_engineered_positions = [6, 6, 6, 6, 6]
19-
design_method_per_cycle = ["library_transform", "library_transform",
20-
"library_transform", "library_transform", "library_transform"]
18+
n_engineered_positions = [6, 6]
19+
design_method_per_cycle = ["library_transform", "ml_assisted_library_transform"]
2120
noise_percentage = 0.1 # not a percentage
2221
noise_type = "heteroscedastic"
2322

24-
recommendation_method = ["greedy_baseline", "greedy_baseline",
25-
"greedy_baseline", "greedy_baseline",
26-
"greedy_baseline"]
23+
recommendation_method = ["greedy_baseline", "greedy_baseline"]
2724
hyperparams = {'library_transform': {"n_screened_strains": n_screened,
28-
"sequencing_selection_method": "best_sampling"},}
29-
# 'library_ml_assisted':{}} #includes
25+
"sequencing_selection_method": "best_sampling"},
26+
'ml_assisted_library_transform': {"n_screened_strains": n_screened,
27+
"ml_method" : "xgboost",
28+
"beta": 2**(2.5),
29+
"data_strategy" : "all", # there are multiple strategies here,
30+
"sequencing_selection_method": "best_sampling",},}
3031

3132
hyperparams = [hyperparams[design_method] for design_method in design_method_per_cycle]
3233

source/optimization_process.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1+
import itertools
2+
from collections import defaultdict
13
from jaxkineticmodel.simulated_dbtl.dbtl import DesignBuildTestLearnCycle
24
from logging import getLogger
35
from source.design_strategies import stratified_sampling, best_sampling
46
import pandas as pd
7+
import xgboost as xgb
8+
from scipy.special import softmax
59

610
logger = getLogger(__name__)
711

@@ -43,6 +47,13 @@ def update_dbtl(self):
4347
data = self.dbtl.test_format_dataset(designs, production_values, self.dbtl.parameters)
4448
self.data[self.cycle_id] = data
4549
elif dbtl_config['design_build_test']['design_method'] == "ml_assisted_library_transform":
50+
#this method cannot be done as the first round, since it requires a model
51+
assert self.cycle_status != 0
52+
53+
dbtl_config_dbt = dbtl_config['design_build_test']
54+
designs, production_values, strain_promoters = self.ml_assisted_library_transform(dbtl_config_dbt)
55+
56+
4657
logger.error("Not implemented yet")
4758
else:
4859
logger.error("This design scenario is not implemented. Choose from options"
@@ -57,6 +68,71 @@ def update_dbtl(self):
5768
"[greedy_baseline, ..]")
5869
return best_design, best_producer
5970

71+
def ml_assisted_library_transform(self,
72+
dbtl_config_dbt):
73+
parameter_perturbations = self.config['optimization_settings']['parameters_perturbation_values']
74+
n_strains_screened = dbtl_config_dbt['design_method_hyperparams']['n_screened_strains']
75+
n_engineered_positions = dbtl_config_dbt['n_engineered_positions']
76+
n_strains = dbtl_config_dbt['n_strains']
77+
beta = dbtl_config_dbt['beta'] #exploration/exploitation
78+
79+
## train a model based on previous data
80+
if dbtl_config_dbt['ml_method']=="xgboost":
81+
xgbparameters = {'tree_method': 'auto', 'reg_lambda': 1, 'max_depth': 2, "disable_default_eval_metric": 0}
82+
alternative_params = {'num_boost_round': 10, 'early_stopping_rounds': 40}
83+
84+
if dbtl_config_dbt['data_strategy'] == "all":
85+
cycle_names = self.data.keys()
86+
data = pd.concat([self.data[i] for i in cycle_names])
87+
else:
88+
logger.error("data_strategy not supported yet. Choose all data strategy")
89+
bst, r2_scores = self.dbtl.learn_train_model(data=self.data,
90+
target=self.target,
91+
model_type="XGBoost",
92+
args=(xgbparameters, alternative_params), test_size=0.20)
93+
else:
94+
logger.error(f"This ml_method {dbtl_config_dbt['ml_method']} is not implemented. Choose from"
95+
f"options xgboost, or implement your own method")
96+
97+
#construct library (we do this always)
98+
self.dbtl.design_establish_library_elements(parameter_perturbations=parameter_perturbations)
99+
_ = self.dbtl.design_assign_positions(n_positions=n_engineered_positions)
100+
101+
## construct probability distribution
102+
position_elements = []
103+
for position in self.dbtl.library.columns.get_level_values(0).unique():
104+
temp = list(zip(self.dbtl.library[position]['parameter_name'].values, self.dbtl.library[position]['promoter_value'].values))
105+
position_elements.append(temp)
106+
combinatorial_designs = list(itertools.product(*position_elements))
107+
108+
all_designs = []
109+
for k, design in enumerate(combinatorial_designs):
110+
design = combine_duplicate_keys(design)
111+
all_designs.append(design)
112+
113+
## we now need to assign the probabilities determined using the XGBoost model. This requires that we save the data from the previous cycle
114+
all_designs = pd.DataFrame(all_designs).fillna(0) + 1
115+
all_designs_xgb = xgb.DMatrix(all_designs)
116+
y_predicted = bst.predict(all_designs_xgb)
117+
softmax_distribution = softmax(y_predicted * beta)
118+
119+
positions = [f"pos_{i}" for i in range(n_engineered_positions)]
120+
pandas_combinatorial_designs = pd.DataFrame(combinatorial_designs, columns=positions)
121+
122+
pandas_combinatorial_designs['softmax'] = softmax_distribution
123+
124+
probability_dist_per_position = {}
125+
for position in positions:
126+
probabilities = pandas_combinatorial_designs.groupby(position)['softmax'].sum()
127+
probability_dist_per_position[position] = probabilities.to_dict()
128+
129+
130+
131+
_ = self.dbtl.design_assign_probabilities(probabilities_per_position=probability_dist_per_position) #this now needs to be done
132+
133+
logger.info("ml_assisted_library_transform")
134+
135+
60136
def library_transform(self,
61137
dbtl_config_dbt):
62138
"""Generates a random set of designs.
@@ -119,3 +195,11 @@ def greedy_baseline(self):
119195

120196
return best_design, best_producer
121197

198+
199+
200+
def combine_duplicate_keys(tuples):
201+
"""Combines similar keys by adding up their promoter values"""
202+
combined = defaultdict(float)
203+
for key, value in tuples:
204+
combined[str(key)] += value
205+
return combined

todos.md

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,16 @@ Overview of the update_dbtl function
2222

2323
11-04
2424
Development of the ML-assisted recommendation: develop_library_transformation.ipynb
25-
Now we just need to add it the probability distribution to assign probabilities
25+
Now we just need to add it the probability distribution to assign probabilities
26+
27+
22-04-2025
28+
There are several parameters neccessary here. The method we will fix to XGboost, there is a beta parameter
29+
scheme that could be set per round. The strategy for including data in the model. Perhaps at some point in
30+
the optimization "forgetting" is better than keep stacking data on top.
31+
32+
33+
{"n_screened_strains": n_screened,
34+
"ml_method" : "xgboost",
35+
"beta": 2**(2.5),
36+
"data_strategy" : "all", # there are multiple strategies here,
37+
"sequencing_selection_method": "best_sampling",}

0 commit comments

Comments
 (0)