Skip to content

Commit 364c0cf

Browse files
committed
revise grin compatibility
1 parent c4d79a3 commit 364c0cf

4 files changed

Lines changed: 35 additions & 21 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "torch_molecule"
7-
version = "0.1.6.post1"
7+
version = "0.1.7"
88
description = "Deep learning packages for molecular discovery with a simple sklearn-style interface"
99
authors = [{name = "Gang Liu", email = "gliu7@nd.edu"}]
1010
readme = "README.md"

tests/predictor/grin.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def test_grin_predictor():
4242
model = GRINMolecularPredictor(
4343
num_task=1,
4444
task_type="regression",
45-
polymer_train_augmentation=3,
45+
repetition_augmentation=3,
4646
num_layer=3,
4747
hidden_size=128,
4848
batch_size=4,
@@ -58,10 +58,10 @@ def test_grin_predictor():
5858

5959
# 2.2. Prediction test on different repeat times
6060
print("\n=== Testing model prediction with polymer train augmentation ===")
61-
predictions = model.predict(smiles_list, test_augmentation=3)
61+
predictions = model.predict(smiles_list, test_time_augmentation=3)
6262
print(f"Prediction shape: {predictions['prediction'].shape}")
6363
print(f"Prediction for new polymer repeat times 3: {predictions['prediction']}")
64-
predictions = model.predict(smiles_list, test_augmentation=5)
64+
predictions = model.predict(smiles_list, test_time_augmentation=5)
6565
print(f"Prediction for new polymer repeat times 5: {predictions['prediction']}")
6666

6767
# 3. Auto-fitting test with custom parameters

torch_molecule/predictor/grin/modeling_grin.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@ class GRINMolecularPredictor(GNNMolecularPredictor):
2222
2323
References
2424
----------
25-
- Learning Repetition-Invariant Representations for Polymer Informatics.
25+
- Learning Repetition-Invariant Representations for Polymer Informatics. NeurIPS 2025
2626
https://arxiv.org/pdf/2505.10726
2727
2828
Parameters
2929
----------
30-
polymer_train_augmentation : int, default=None
31-
Number of times to repeat the polymer for training.
30+
repetition_augmentation : bool, default=False
31+
Whether to enable polymer augmentation for training.
3232
l1_penalty : float, default=1e-3
3333
Weight for the L1 penalty.
3434
epochs_to_penalize : int, default=100
@@ -86,7 +86,7 @@ class GRINMolecularPredictor(GNNMolecularPredictor):
8686
def __init__(
8787
self,
8888
# GRIN-specific parameters
89-
polymer_train_augmentation: Optional[int] = None,
89+
repetition_augmentation: bool = False,
9090
l1_penalty: float = 1e-3,
9191
epochs_to_penalize: int = 100,
9292
# Core model parameters
@@ -148,11 +148,21 @@ def __init__(
148148
)
149149

150150
# GRIN-specific parameters
151-
self.polymer_train_augmentation = polymer_train_augmentation
151+
self.repetition_augmentation = repetition_augmentation
152152
self.l1_penalty = l1_penalty
153153
self.epochs_to_penalize = epochs_to_penalize
154154
self.model_class = GRIN
155155

156+
# Check CombineMols dependency if polymer augmentation is enabled
157+
if self.repetition_augmentation:
158+
try:
159+
from CombineMols.CombineMols import CombineMols
160+
except ImportError:
161+
raise ImportError(
162+
"CombineMols is required for repetition augmentation for polymer. "
163+
"Please install it using: pip install CombineMols"
164+
)
165+
156166

157167
@staticmethod
158168
def _get_param_names() -> List[str]:
@@ -210,8 +220,8 @@ def fit(
210220
optimizer, scheduler = self._setup_optimizers()
211221

212222
# Prepare datasets and loaders
213-
if self.polymer_train_augmentation is not None:
214-
X_train_aug, y_train_aug = SmilesRepeat(self.polymer_train_augmentation).repeat(X_train, y_train)
223+
if self.repetition_augmentation:
224+
X_train_aug, y_train_aug = SmilesRepeat(2).repeat(X_train, y_train)
215225
X_train = X_train + X_train_aug
216226
if y_train_aug is not None:
217227
if isinstance(y_train, np.ndarray):
@@ -236,8 +246,8 @@ def fit(
236246
UserWarning
237247
)
238248
else:
239-
if self.polymer_train_augmentation is not None:
240-
X_val_aug, y_val_aug = SmilesRepeat(self.polymer_train_augmentation).repeat(X_val, y_val)
249+
if self.repetition_augmentation:
250+
X_val_aug, y_val_aug = SmilesRepeat(2).repeat(X_val, y_val)
241251
X_val = X_val + X_val_aug
242252
if y_val_aug is not None:
243253
if isinstance(y_val, np.ndarray):
@@ -376,15 +386,15 @@ def _train_epoch(self, train_loader, optimizer, epoch, global_pbar=None):
376386

377387
return losses
378388

379-
def predict(self, X: List[str], test_augmentation: Optional[int] = None) -> Dict[str, np.ndarray]:
389+
def predict(self, X: List[str], test_time_augmentation: bool = False) -> Dict[str, np.ndarray]:
380390
"""Make predictions using the fitted model.
381391
382392
Parameters
383393
----------
384394
X : List[str]
385395
List of SMILES strings to make predictions for
386-
test_augmentation : int, optional
387-
Number of times to repeat the polymer for making predictions.
396+
test_time_augmentation : bool, default=False
397+
Whether to enable polymer augmentation for making predictions.
388398
389399
Returns
390400
-------
@@ -394,8 +404,8 @@ def predict(self, X: List[str], test_augmentation: Optional[int] = None) -> Dict
394404
395405
"""
396406
self._check_is_fitted()
397-
if test_augmentation is not None:
398-
X_aug, _ = SmilesRepeat(test_augmentation).repeat(X)
407+
if test_time_augmentation:
408+
X_aug, _ = SmilesRepeat(2).repeat(X)
399409
X = X_aug
400410

401411
# Convert to PyTorch Geometric format and create loader

torch_molecule/predictor/grin/utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,13 @@
22
import numpy as np
33

44
from rdkit import Chem
5-
from CombineMols.CombineMols import CombineMols
65
from rdkit.Chem import Draw
76

7+
try:
8+
from CombineMols.CombineMols import CombineMols
9+
except ImportError:
10+
pass
11+
812
class SmilesRepeat():
913
def __init__(self, repeat_times) -> None:
1014
self.repeat_times = repeat_times
@@ -121,7 +125,7 @@ def edit_mol(self, ori_psmiles, des_psmiles) -> str:
121125
# Obtain connection info for future bonds/atoms remove/add
122126
info = self.get_connection_info(combo)
123127
if not info:
124-
print("************************** No Star Mark! **************************")
128+
print(f"************************** No Star Mark for polymer {ori_psmiles} **************************")
125129
return des_psmiles
126130

127131
# add a new bond between two star symbols and discard these two stars
@@ -144,7 +148,7 @@ def star_edge(self, ori_psmiles) -> str:
144148
ori_mol = self.get_mol(ori_psmiles)
145149
info = self.get_connection_info(ori_mol)
146150
if not info or not info["neighbor"]['path']:
147-
print(f"************************** No Star Mark: {ori_psmiles} **************************")
151+
print(f"************************** No Star Mark for polymer {ori_psmiles} **************************")
148152
return ori_psmiles
149153

150154
edsmiles = Chem.EditableMol(ori_mol)

0 commit comments

Comments
 (0)