@@ -22,13 +22,13 @@ class GRINMolecularPredictor(GNNMolecularPredictor):
2222
2323 References
2424 ----------
25- - Learning Repetition-Invariant Representations for Polymer Informatics.
25+ - Learning Repetition-Invariant Representations for Polymer Informatics. NeurIPS 2025
2626 https://arxiv.org/pdf/2505.10726
2727
2828 Parameters
2929 ----------
30- polymer_train_augmentation : int , default=None
31- Number of times to repeat the polymer for training.
30+ repetition_augmentation : bool , default=False
31+ Whether to enable polymer augmentation for training.
3232 l1_penalty : float, default=1e-3
3333 Weight for the L1 penalty.
3434 epochs_to_penalize : int, default=100
@@ -86,7 +86,7 @@ class GRINMolecularPredictor(GNNMolecularPredictor):
8686 def __init__ (
8787 self ,
8888 # GRIN-specific parameters
89- polymer_train_augmentation : Optional [ int ] = None ,
89+ repetition_augmentation : bool = False ,
9090 l1_penalty : float = 1e-3 ,
9191 epochs_to_penalize : int = 100 ,
9292 # Core model parameters
@@ -148,11 +148,21 @@ def __init__(
148148 )
149149
150150 # GRIN-specific parameters
151- self .polymer_train_augmentation = polymer_train_augmentation
151+ self .repetition_augmentation = repetition_augmentation
152152 self .l1_penalty = l1_penalty
153153 self .epochs_to_penalize = epochs_to_penalize
154154 self .model_class = GRIN
155155
156+ # Check CombineMols dependency if polymer augmentation is enabled
157+ if self .repetition_augmentation :
158+ try :
159+ from CombineMols .CombineMols import CombineMols
160+ except ImportError :
161+ raise ImportError (
162+ "CombineMols is required for repetition augmentation for polymer. "
163+ "Please install it using: pip install CombineMols"
164+ )
165+
156166
157167 @staticmethod
158168 def _get_param_names () -> List [str ]:
@@ -210,8 +220,8 @@ def fit(
210220 optimizer , scheduler = self ._setup_optimizers ()
211221
212222 # Prepare datasets and loaders
213- if self .polymer_train_augmentation is not None :
214- X_train_aug , y_train_aug = SmilesRepeat (self . polymer_train_augmentation ).repeat (X_train , y_train )
223+ if self .repetition_augmentation :
224+ X_train_aug , y_train_aug = SmilesRepeat (2 ).repeat (X_train , y_train )
215225 X_train = X_train + X_train_aug
216226 if y_train_aug is not None :
217227 if isinstance (y_train , np .ndarray ):
@@ -236,8 +246,8 @@ def fit(
236246 UserWarning
237247 )
238248 else :
239- if self .polymer_train_augmentation is not None :
240- X_val_aug , y_val_aug = SmilesRepeat (self . polymer_train_augmentation ).repeat (X_val , y_val )
249+ if self .repetition_augmentation :
250+ X_val_aug , y_val_aug = SmilesRepeat (2 ).repeat (X_val , y_val )
241251 X_val = X_val + X_val_aug
242252 if y_val_aug is not None :
243253 if isinstance (y_val , np .ndarray ):
@@ -376,15 +386,15 @@ def _train_epoch(self, train_loader, optimizer, epoch, global_pbar=None):
376386
377387 return losses
378388
379- def predict (self , X : List [str ], test_augmentation : Optional [ int ] = None ) -> Dict [str , np .ndarray ]:
389+ def predict (self , X : List [str ], test_time_augmentation : bool = False ) -> Dict [str , np .ndarray ]:
380390 """Make predictions using the fitted model.
381391
382392 Parameters
383393 ----------
384394 X : List[str]
385395 List of SMILES strings to make predictions for
386- test_augmentation : int, optional
387- Number of times to repeat the polymer for making predictions.
396+ test_time_augmentation : bool, default=False
397+ Whether to enable polymer augmentation for making predictions.
388398
389399 Returns
390400 -------
@@ -394,8 +404,8 @@ def predict(self, X: List[str], test_augmentation: Optional[int] = None) -> Dict
394404
395405 """
396406 self ._check_is_fitted ()
397- if test_augmentation is not None :
398- X_aug , _ = SmilesRepeat (test_augmentation ).repeat (X )
407+ if test_time_augmentation :
408+ X_aug , _ = SmilesRepeat (2 ).repeat (X )
399409 X = X_aug
400410
401411 # Convert to PyTorch Geometric format and create loader
0 commit comments