1+ import itertools
2+ from collections import defaultdict
13from jaxkineticmodel .simulated_dbtl .dbtl import DesignBuildTestLearnCycle
24from logging import getLogger
35from source .design_strategies import stratified_sampling , best_sampling
46import pandas as pd
7+ import xgboost as xgb
8+ from scipy .special import softmax
59
610logger = getLogger (__name__ )
711
@@ -43,6 +47,13 @@ def update_dbtl(self):
4347 data = self .dbtl .test_format_dataset (designs , production_values , self .dbtl .parameters )
4448 self .data [self .cycle_id ] = data
4549 elif dbtl_config ['design_build_test' ]['design_method' ] == "ml_assisted_library_transform" :
50+ #this method cannot be done as the first round, since it requires a model
51+ assert self .cycle_status != 0
52+
53+ dbtl_config_dbt = dbtl_config ['design_build_test' ]
54+ designs , production_values , strain_promoters = self .ml_assisted_library_transform (dbtl_config_dbt )
55+
56+
4657 logger .error ("Not implemented yet" )
4758 else :
4859 logger .error ("This design scenario is not implemented. Choose from options"
@@ -57,6 +68,71 @@ def update_dbtl(self):
5768 "[greedy_baseline, ..]" )
5869 return best_design , best_producer
5970
71+ def ml_assisted_library_transform (self ,
72+ dbtl_config_dbt ):
73+ parameter_perturbations = self .config ['optimization_settings' ]['parameters_perturbation_values' ]
74+ n_strains_screened = dbtl_config_dbt ['design_method_hyperparams' ]['n_screened_strains' ]
75+ n_engineered_positions = dbtl_config_dbt ['n_engineered_positions' ]
76+ n_strains = dbtl_config_dbt ['n_strains' ]
77+ beta = dbtl_config_dbt ['beta' ] #exploration/exploitation
78+
79+ ## train a model based on previous data
80+ if dbtl_config_dbt ['ml_method' ]== "xgboost" :
81+ xgbparameters = {'tree_method' : 'auto' , 'reg_lambda' : 1 , 'max_depth' : 2 , "disable_default_eval_metric" : 0 }
82+ alternative_params = {'num_boost_round' : 10 , 'early_stopping_rounds' : 40 }
83+
84+ if dbtl_config_dbt ['data_strategy' ] == "all" :
85+ cycle_names = self .data .keys ()
86+ data = pd .concat ([self .data [i ] for i in cycle_names ])
87+ else :
88+ logger .error ("data_strategy not supported yet. Choose all data strategy" )
89+ bst , r2_scores = self .dbtl .learn_train_model (data = self .data ,
90+ target = self .target ,
91+ model_type = "XGBoost" ,
92+ args = (xgbparameters , alternative_params ), test_size = 0.20 )
93+ else :
94+ logger .error (f"This ml_method { dbtl_config_dbt ['ml_method' ]} is not implemented. Choose from"
95+ f"options xgboost, or implement your own method" )
96+
97+ #construct library (we do this always)
98+ self .dbtl .design_establish_library_elements (parameter_perturbations = parameter_perturbations )
99+ _ = self .dbtl .design_assign_positions (n_positions = n_engineered_positions )
100+
101+ ## construct probability distribution
102+ position_elements = []
103+ for position in self .dbtl .library .columns .get_level_values (0 ).unique ():
104+ temp = list (zip (self .dbtl .library [position ]['parameter_name' ].values , self .dbtl .library [position ]['promoter_value' ].values ))
105+ position_elements .append (temp )
106+ combinatorial_designs = list (itertools .product (* position_elements ))
107+
108+ all_designs = []
109+ for k , design in enumerate (combinatorial_designs ):
110+ design = combine_duplicate_keys (design )
111+ all_designs .append (design )
112+
113+ ## we now need to assign the probabilities determined using the XGBoost model. This requires that we save the data from the previous cycle
114+ all_designs = pd .DataFrame (all_designs ).fillna (0 ) + 1
115+ all_designs_xgb = xgb .DMatrix (all_designs )
116+ y_predicted = bst .predict (all_designs_xgb )
117+ softmax_distribution = softmax (y_predicted * beta )
118+
119+ positions = [f"pos_{ i } " for i in range (n_engineered_positions )]
120+ pandas_combinatorial_designs = pd .DataFrame (combinatorial_designs , columns = positions )
121+
122+ pandas_combinatorial_designs ['softmax' ] = softmax_distribution
123+
124+ probability_dist_per_position = {}
125+ for position in positions :
126+ probabilities = pandas_combinatorial_designs .groupby (position )['softmax' ].sum ()
127+ probability_dist_per_position [position ] = probabilities .to_dict ()
128+
129+
130+
131+ _ = self .dbtl .design_assign_probabilities (probabilities_per_position = probability_dist_per_position ) #this now needs to be done
132+
133+ logger .info ("ml_assisted_library_transform" )
134+
135+
60136 def library_transform (self ,
61137 dbtl_config_dbt ):
62138 """Generates a random set of designs.
@@ -119,3 +195,11 @@ def greedy_baseline(self):
119195
120196 return best_design , best_producer
121197
198+
199+
200+ def combine_duplicate_keys (tuples ):
201+ """Combines similar keys by adding up their promoter values"""
202+ combined = defaultdict (float )
203+ for key , value in tuples :
204+ combined [str (key )] += value
205+ return combined
0 commit comments