11from typing import Optional , Union , Dict , Any , List , Callable , Literal
22
33import torch
4-
4+ import numpy as np
5+ import warnings
6+ import copy
7+ from torch_geometric .data import Data , DataLoader
8+ from tqdm import tqdm
59from .model import GRIN
10+ from .utils import SmilesRepeat
611from ..gnn .modeling_gnn import GNNMolecularPredictor
12+ from ...utils import graph_from_smiles
713from ...utils .search import (
814 ParameterSpec ,
915 ParameterType ,
@@ -21,6 +27,8 @@ class GRINMolecularPredictor(GNNMolecularPredictor):
2127
2228 Parameters
2329 ----------
30+ polymer_train_augmentation : int, default=None
31+ Number of times to repeat the polymer for training.
2432 l1_penalty : float, default=1e-3
2533 Weight for the L1 penalty.
2634 epochs_to_penalize : int, default=100
@@ -78,6 +86,7 @@ class GRINMolecularPredictor(GNNMolecularPredictor):
7886 def __init__ (
7987 self ,
8088 # GRIN-specific parameters
89+ polymer_train_augmentation : Optional [int ] = None ,
8190 l1_penalty : float = 1e-3 ,
8291 epochs_to_penalize : int = 100 ,
8392 # Core model parameters
@@ -139,6 +148,7 @@ def __init__(
139148 )
140149
141150 # GRIN-specific parameters
151+ self .polymer_train_augmentation = polymer_train_augmentation
142152 self .l1_penalty = l1_penalty
143153 self .epochs_to_penalize = epochs_to_penalize
144154 self .model_class = GRIN
@@ -160,6 +170,184 @@ def _get_model_params(self, checkpoint: Optional[Dict] = None) -> Dict[str, Any]
160170 base_params = super ()._get_model_params (checkpoint )
161171 return base_params
162172
173+ def fit (
174+ self ,
175+ X_train : List [str ],
176+ y_train : Optional [Union [List , np .ndarray ]],
177+ X_val : Optional [List [str ]] = None ,
178+ y_val : Optional [Union [List , np .ndarray ]] = None ,
179+ X_unlbl : Optional [List [str ]] = None ,
180+ ) -> "GRINMolecularPredictor" :
181+ """Fit the model to the training data with optional validation set.
182+
183+ Parameters
184+ ----------
185+ X_train : List[str]
186+ Training set input molecular structures as SMILES strings
187+ y_train : Union[List, np.ndarray]
188+ Training set target values for property prediction
189+ X_val : List[str], optional
190+ Validation set input molecular structures as SMILES strings.
191+ If None, training data will be used for validation
192+ y_val : Union[List, np.ndarray], optional
193+ Validation set target values. Required if X_val is provided
194+ X_unlbl : List[str], optional
195+ Unlabeled set input molecular structures as SMILES strings.
196+
197+ Returns
198+ -------
199+ self : GRINMolecularPredictor
200+ Fitted estimator
201+ """
202+ if (X_val is None ) != (y_val is None ):
203+ raise ValueError (
204+ "Both X_val and y_val must be provided for validation. "
205+ f"Got X_val={ X_val is not None } , y_val={ y_val is not None } "
206+ )
207+
208+ self ._initialize_model (self .model_class )
209+ self .model .initialize_parameters ()
210+ optimizer , scheduler = self ._setup_optimizers ()
211+
212+ # Prepare datasets and loaders
213+ if self .polymer_train_augmentation is not None :
214+ X_train_aug , y_train_aug = SmilesRepeat (self .polymer_train_augmentation ).repeat (X_train , y_train )
215+ X_train = X_train + X_train_aug
216+ if y_train_aug is not None :
217+ if isinstance (y_train , np .ndarray ):
218+ y_train = np .concatenate ([y_train , np .array (y_train_aug )], axis = 0 )
219+ else :
220+ y_train = list (y_train ) + list (y_train_aug )
221+
222+ X_train , y_train = self ._validate_inputs (X_train , y_train )
223+ train_dataset = self ._convert_to_pytorch_data (X_train , y_train )
224+ train_loader = DataLoader (
225+ train_dataset ,
226+ batch_size = self .batch_size ,
227+ shuffle = True ,
228+ num_workers = 0
229+ )
230+
231+ if X_val is None or y_val is None :
232+ val_loader = train_loader
233+ warnings .warn (
234+ "No validation set provided. Using training set for validation. "
235+ "This may lead to overfitting." ,
236+ UserWarning
237+ )
238+ else :
239+ if self .polymer_train_augmentation is not None :
240+ X_val_aug , y_val_aug = SmilesRepeat (self .polymer_train_augmentation ).repeat (X_val , y_val )
241+ X_val = X_val + X_val_aug
242+ if y_val_aug is not None :
243+ if isinstance (y_val , np .ndarray ):
244+ y_val = np .concatenate ([y_val , np .array (y_val_aug )], axis = 0 )
245+ else :
246+ y_val = list (y_val ) + list (y_val_aug )
247+
248+ X_val , y_val = self ._validate_inputs (X_val , y_val )
249+ val_dataset = self ._convert_to_pytorch_data (X_val , y_val )
250+ val_loader = DataLoader (
251+ val_dataset ,
252+ batch_size = self .batch_size ,
253+ shuffle = False ,
254+ num_workers = 0
255+ )
256+
257+ # Initialize training state
258+ self .fitting_loss = []
259+ self .fitting_epoch = 0
260+ best_state_dict = None
261+ best_eval = float ('-inf' ) if self .evaluate_higher_better else float ('inf' )
262+ cnt_wait = 0
263+
264+ # Calculate total steps for global progress bar
265+ steps_per_epoch = len (train_loader )
266+ total_steps = self .epochs * steps_per_epoch
267+
268+ # Initialize global progress bar
269+ global_pbar = None
270+ if self .verbose == "progress_bar" :
271+ global_pbar = tqdm (
272+ total = total_steps ,
273+ desc = "Training Progress" ,
274+ unit = "step" ,
275+ dynamic_ncols = True
276+ )
277+
278+ for epoch in range (self .epochs ):
279+ # Training phase
280+ train_losses = self ._train_epoch (train_loader , optimizer , epoch , global_pbar )
281+ self .fitting_loss .append (float (np .mean (train_losses )))
282+
283+ # Validation phase
284+ current_eval = self ._evaluation_epoch (val_loader )
285+
286+ if scheduler :
287+ scheduler .step (current_eval )
288+
289+ # Model selection (check if current evaluation is better)
290+ is_better = (
291+ current_eval > best_eval if self .evaluate_higher_better
292+ else current_eval < best_eval
293+ )
294+
295+ if is_better :
296+ self .fitting_epoch = epoch
297+ best_eval = current_eval
298+ best_state_dict = copy .deepcopy (self .model .state_dict ()) # Save the best epoch model not the last one
299+ cnt_wait = 0
300+ log_dict = {
301+ "Epoch" : f"{ epoch + 1 } /{ self .epochs } " ,
302+ "Loss" : f"{ float (np .mean (train_losses )):.4f} " ,
303+ f"{ self .evaluate_name } " : f"{ best_eval :.4f} " ,
304+ "Status" : "✓ Best"
305+ }
306+ if self .verbose == "progress_bar" and global_pbar :
307+ global_pbar .set_postfix (log_dict )
308+ elif self .verbose == "print_statement" :
309+ print (log_dict )
310+ else :
311+ cnt_wait += 1
312+ log_dict = {
313+ "Epoch" : f"{ epoch + 1 } /{ self .epochs } " ,
314+ "Loss" : f"{ float (np .mean (train_losses )):.4f} " ,
315+ f"{ self .evaluate_name } " : f"{ current_eval :.4f} " ,
316+ "Wait" : f"{ cnt_wait } /{ self .patience } "
317+ }
318+ if self .verbose == "progress_bar" and global_pbar :
319+ global_pbar .set_postfix (log_dict )
320+ elif self .verbose == "print_statement" :
321+ print (log_dict )
322+ if cnt_wait > self .patience :
323+ log_dict = {
324+ "Status" : "Early Stopped" ,
325+ "Epoch" : f"{ epoch + 1 } /{ self .epochs } "
326+ }
327+ if self .verbose == "progress_bar" and global_pbar :
328+ global_pbar .set_postfix (log_dict )
329+ global_pbar .close ()
330+ elif self .verbose == "print_stament" :
331+ print (log_dict )
332+ break
333+
334+ # Close global progress bar
335+ if global_pbar is not None :
336+ global_pbar .close ()
337+
338+ # Restore best model
339+ if best_state_dict is not None :
340+ self .model .load_state_dict (best_state_dict )
341+ else :
342+ warnings .warn (
343+ "No improvement was achieved during training. "
344+ "The model may not be fitted properly." ,
345+ UserWarning
346+ )
347+
348+ self .is_fitted_ = True
349+ return self
350+
163351 def _train_epoch (self , train_loader , optimizer , epoch , global_pbar = None ):
164352 self .model .train ()
165353 losses = []
@@ -186,4 +374,53 @@ def _train_epoch(self, train_loader, optimizer, epoch, global_pbar=None):
186374 })
187375 losses .append (loss .item ())
188376
189- return losses
377+ return losses
378+
379+ def predict (self , X : List [str ], test_augmentation : Optional [int ] = None ) -> Dict [str , np .ndarray ]:
380+ """Make predictions using the fitted model.
381+
382+ Parameters
383+ ----------
384+ X : List[str]
385+ List of SMILES strings to make predictions for
386+ test_augmentation : int, optional
387+ Number of times to repeat the polymer for making predictions.
388+
389+ Returns
390+ -------
391+ Dict[str, np.ndarray]
392+ Dictionary containing:
393+ - 'prediction': Model predictions (shape: [n_samples, n_tasks])
394+
395+ """
396+ self ._check_is_fitted ()
397+ if test_augmentation is not None :
398+ X_aug , _ = SmilesRepeat (test_augmentation ).repeat (X )
399+ X = X_aug
400+
401+ # Convert to PyTorch Geometric format and create loader
402+ X , _ = self ._validate_inputs (X )
403+ dataset = self ._convert_to_pytorch_data (X )
404+ loader = DataLoader (dataset , batch_size = self .batch_size , shuffle = False )
405+
406+ if self .model is None :
407+ raise RuntimeError ("Model not initialized" )
408+ # Make predictions
409+ self .model = self .model .to (self .device )
410+ self .model .eval ()
411+ predictions = []
412+ with torch .no_grad ():
413+ if self .verbose == "progress_bar" :
414+ iterator = tqdm (loader , desc = "Predicting" )
415+ elif self .verbose == "print_statement" :
416+ print ("Predicting..." )
417+ iterator = loader
418+ else :
419+ iterator = loader
420+ for batch in iterator :
421+ batch = batch .to (self .device )
422+ out = self .model (batch )
423+ predictions .append (out ["prediction" ].cpu ().numpy ())
424+ return {
425+ "prediction" : np .concatenate (predictions , axis = 0 ),
426+ }
0 commit comments