Skip to content

Commit c4d79a3

Browse files
committed
Add polymer repetition augmentation for GRIN
1 parent 086699a commit c4d79a3

3 files changed

Lines changed: 539 additions & 15 deletions

File tree

tests/predictor/grin.py

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,19 @@
55
def test_grin_predictor():
66
# Test data
77
smiles_list = [
8-
'CNC[C@H]1OCc2cnnn2CCCC(=O)N([C@H](C)CO)C[C@@H]1C',
9-
'CNC[C@@H]1OCc2cnnn2CCCC(=O)N([C@H](C)CO)C[C@H]1C',
10-
'C[C@H]1CN([C@@H](C)CO)C(=O)CCCn2cc(nn2)CO[C@@H]1CN(C)C(=O)CCC(F)(F)F',
11-
'CC1=CC=C(C=C1)C2=CC(=NN2C3=CC=C(C=C3)S(=O)(=O)N)C(F)(F)F' # Additional molecule
8+
'FC(F)(F)C(C1=CC2=C(C=C1)C(=O)N(C2=O)C1=CC=C(CC2=CC=C(*)C=C2)C=C1)(C1=CC=C2C(=O)N(*)C(=O)C2=C1)C(F)(F)F',
9+
'CCCCCCCCCCCCC1=C(*)SC(*)=C1',
10+
'CC(C)C1=C(O*)C=CC(=C1)C(C1C=CC=CC1C(O)=O)C1=CC(C(C)C)=C(OC2=CC=C(C=C2)C(=O)C2=CC=C(*)C=C2)C=C1',
11+
'*OC1=CC=C(C=C1)C1(OC(=O)C2=C1C=CC=C2)C1=CC=C(OC2=CC=C(C=C2)S(=O)(=O)C2=CC=C(*)C=C2)C=C1'
1212
]
13-
properties = np.array([0, 0, 1, 1]) # Binary classification
13+
properties = np.array([19, 88.2, 22.8, 5.74]) # Regression
1414
print('smiles_list', len(smiles_list))
1515
print('properties', len(properties))
1616
# 1. Basic initialization test
1717
print("\n=== Testing model initialization ===")
1818
model = GRINMolecularPredictor(
1919
num_task=1,
20-
task_type="classification",
20+
task_type="regression",
2121
num_layer=3,
2222
hidden_size=128,
2323
batch_size=4,
@@ -26,18 +26,45 @@ def test_grin_predictor():
2626
)
2727
print("Model initialized successfully")
2828

29-
# 2. Basic fitting test
29+
# 1.2. Basic fitting test
3030
print("\n=== Testing model fitting ===")
3131
model.fit(smiles_list, properties)
3232
print("Model fitting completed")
3333

34-
# 3. Prediction test
34+
# 1.3. Prediction test
3535
print("\n=== Testing model prediction ===")
3636
predictions = model.predict(smiles_list)
3737
print(f"Prediction shape: {predictions['prediction'].shape}")
38-
print(f"Prediction for new molecule: {predictions['prediction']}")
38+
print(f"Prediction for new polymer repeat times 1: {predictions['prediction']}")
3939

40-
# 4. Auto-fitting test with custom parameters
40+
# 2. Initialize with polymer train augmentation
41+
print("\n=== Testing model initialization with polymer train augmentation ===")
42+
model = GRINMolecularPredictor(
43+
num_task=1,
44+
task_type="regression",
45+
polymer_train_augmentation=3,
46+
num_layer=3,
47+
hidden_size=128,
48+
batch_size=4,
49+
epochs=5, # Small number for testing
50+
verbose="progress_bar",
51+
)
52+
print("Model initialized with polymer train augmentation successfully")
53+
54+
# 2.1. Fitting test
55+
print("\n=== Testing model fitting with polymer train augmentation ===")
56+
model.fit(smiles_list, properties)
57+
print("Model fitting completed")
58+
59+
# 2.2. Prediction test on different repeat times
60+
print("\n=== Testing model prediction with polymer train augmentation ===")
61+
predictions = model.predict(smiles_list, test_augmentation=3)
62+
print(f"Prediction shape: {predictions['prediction'].shape}")
63+
print(f"Prediction for new polymer repeat times 3: {predictions['prediction']}")
64+
predictions = model.predict(smiles_list, test_augmentation=5)
65+
print(f"Prediction for new polymer repeat times 5: {predictions['prediction']}")
66+
67+
# 3. Auto-fitting test with custom parameters
4168
print("\n=== Testing model auto-fitting ===")
4269
search_parameters = {
4370
'num_layer': (2, 4),
@@ -70,7 +97,7 @@ def test_grin_predictor():
7097
}
7198
model_auto = GRINMolecularPredictor(
7299
num_task=1,
73-
task_type="classification",
100+
task_type="regression",
74101
epochs=3, # Small number for testing
75102
verbose="none"
76103
)
@@ -83,15 +110,15 @@ def test_grin_predictor():
83110
)
84111
print("Model auto-fitting completed")
85112

86-
# 5. Model saving and loading test
113+
# 4. Model saving and loading test
87114
print("\n=== Testing model saving and loading ===")
88115
save_path = "test_model.pt"
89116
model.save_to_local(save_path)
90117
print(f"Model saved to {save_path}")
91118

92119
new_model = GRINMolecularPredictor(
93120
num_task=1,
94-
task_type="classification"
121+
task_type="regression"
95122
)
96123
new_model.load_from_local(save_path)
97124
print("Model loaded successfully")

torch_molecule/predictor/grin/modeling_grin.py

Lines changed: 239 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
from typing import Optional, Union, Dict, Any, List, Callable, Literal
22

33
import torch
4-
4+
import numpy as np
5+
import warnings
6+
import copy
7+
from torch_geometric.data import Data, DataLoader
8+
from tqdm import tqdm
59
from .model import GRIN
10+
from .utils import SmilesRepeat
611
from ..gnn.modeling_gnn import GNNMolecularPredictor
12+
from ...utils import graph_from_smiles
713
from ...utils.search import (
814
ParameterSpec,
915
ParameterType,
@@ -21,6 +27,8 @@ class GRINMolecularPredictor(GNNMolecularPredictor):
2127
2228
Parameters
2329
----------
30+
polymer_train_augmentation : int, default=None
31+
Number of times to repeat the polymer for training.
2432
l1_penalty : float, default=1e-3
2533
Weight for the L1 penalty.
2634
epochs_to_penalize : int, default=100
@@ -78,6 +86,7 @@ class GRINMolecularPredictor(GNNMolecularPredictor):
7886
def __init__(
7987
self,
8088
# GRIN-specific parameters
89+
polymer_train_augmentation: Optional[int] = None,
8190
l1_penalty: float = 1e-3,
8291
epochs_to_penalize: int = 100,
8392
# Core model parameters
@@ -139,6 +148,7 @@ def __init__(
139148
)
140149

141150
# GRIN-specific parameters
151+
self.polymer_train_augmentation = polymer_train_augmentation
142152
self.l1_penalty = l1_penalty
143153
self.epochs_to_penalize = epochs_to_penalize
144154
self.model_class = GRIN
@@ -160,6 +170,184 @@ def _get_model_params(self, checkpoint: Optional[Dict] = None) -> Dict[str, Any]
160170
base_params = super()._get_model_params(checkpoint)
161171
return base_params
162172

173+
def fit(
174+
self,
175+
X_train: List[str],
176+
y_train: Optional[Union[List, np.ndarray]],
177+
X_val: Optional[List[str]] = None,
178+
y_val: Optional[Union[List, np.ndarray]] = None,
179+
X_unlbl: Optional[List[str]] = None,
180+
) -> "GRINMolecularPredictor":
181+
"""Fit the model to the training data with optional validation set.
182+
183+
Parameters
184+
----------
185+
X_train : List[str]
186+
Training set input molecular structures as SMILES strings
187+
y_train : Union[List, np.ndarray]
188+
Training set target values for property prediction
189+
X_val : List[str], optional
190+
Validation set input molecular structures as SMILES strings.
191+
If None, training data will be used for validation
192+
y_val : Union[List, np.ndarray], optional
193+
Validation set target values. Required if X_val is provided
194+
X_unlbl : List[str], optional
195+
Unlabeled set input molecular structures as SMILES strings.
196+
197+
Returns
198+
-------
199+
self : GRINMolecularPredictor
200+
Fitted estimator
201+
"""
202+
if (X_val is None) != (y_val is None):
203+
raise ValueError(
204+
"Both X_val and y_val must be provided for validation. "
205+
f"Got X_val={X_val is not None}, y_val={y_val is not None}"
206+
)
207+
208+
self._initialize_model(self.model_class)
209+
self.model.initialize_parameters()
210+
optimizer, scheduler = self._setup_optimizers()
211+
212+
# Prepare datasets and loaders
213+
if self.polymer_train_augmentation is not None:
214+
X_train_aug, y_train_aug = SmilesRepeat(self.polymer_train_augmentation).repeat(X_train, y_train)
215+
X_train = X_train + X_train_aug
216+
if y_train_aug is not None:
217+
if isinstance(y_train, np.ndarray):
218+
y_train = np.concatenate([y_train, np.array(y_train_aug)], axis=0)
219+
else:
220+
y_train = list(y_train) + list(y_train_aug)
221+
222+
X_train, y_train = self._validate_inputs(X_train, y_train)
223+
train_dataset = self._convert_to_pytorch_data(X_train, y_train)
224+
train_loader = DataLoader(
225+
train_dataset,
226+
batch_size=self.batch_size,
227+
shuffle=True,
228+
num_workers=0
229+
)
230+
231+
if X_val is None or y_val is None:
232+
val_loader = train_loader
233+
warnings.warn(
234+
"No validation set provided. Using training set for validation. "
235+
"This may lead to overfitting.",
236+
UserWarning
237+
)
238+
else:
239+
if self.polymer_train_augmentation is not None:
240+
X_val_aug, y_val_aug = SmilesRepeat(self.polymer_train_augmentation).repeat(X_val, y_val)
241+
X_val = X_val + X_val_aug
242+
if y_val_aug is not None:
243+
if isinstance(y_val, np.ndarray):
244+
y_val = np.concatenate([y_val, np.array(y_val_aug)], axis=0)
245+
else:
246+
y_val = list(y_val) + list(y_val_aug)
247+
248+
X_val, y_val = self._validate_inputs(X_val, y_val)
249+
val_dataset = self._convert_to_pytorch_data(X_val, y_val)
250+
val_loader = DataLoader(
251+
val_dataset,
252+
batch_size=self.batch_size,
253+
shuffle=False,
254+
num_workers=0
255+
)
256+
257+
# Initialize training state
258+
self.fitting_loss = []
259+
self.fitting_epoch = 0
260+
best_state_dict = None
261+
best_eval = float('-inf') if self.evaluate_higher_better else float('inf')
262+
cnt_wait = 0
263+
264+
# Calculate total steps for global progress bar
265+
steps_per_epoch = len(train_loader)
266+
total_steps = self.epochs * steps_per_epoch
267+
268+
# Initialize global progress bar
269+
global_pbar = None
270+
if self.verbose == "progress_bar":
271+
global_pbar = tqdm(
272+
total=total_steps,
273+
desc="Training Progress",
274+
unit="step",
275+
dynamic_ncols=True
276+
)
277+
278+
for epoch in range(self.epochs):
279+
# Training phase
280+
train_losses = self._train_epoch(train_loader, optimizer, epoch, global_pbar)
281+
self.fitting_loss.append(float(np.mean(train_losses)))
282+
283+
# Validation phase
284+
current_eval = self._evaluation_epoch(val_loader)
285+
286+
if scheduler:
287+
scheduler.step(current_eval)
288+
289+
# Model selection (check if current evaluation is better)
290+
is_better = (
291+
current_eval > best_eval if self.evaluate_higher_better
292+
else current_eval < best_eval
293+
)
294+
295+
if is_better:
296+
self.fitting_epoch = epoch
297+
best_eval = current_eval
298+
best_state_dict = copy.deepcopy(self.model.state_dict()) # Save the best epoch model not the last one
299+
cnt_wait = 0
300+
log_dict = {
301+
"Epoch": f"{epoch+1}/{self.epochs}",
302+
"Loss": f"{float(np.mean(train_losses)):.4f}",
303+
f"{self.evaluate_name}": f"{best_eval:.4f}",
304+
"Status": "✓ Best"
305+
}
306+
if self.verbose == "progress_bar" and global_pbar:
307+
global_pbar.set_postfix(log_dict)
308+
elif self.verbose == "print_statement":
309+
print(log_dict)
310+
else:
311+
cnt_wait += 1
312+
log_dict = {
313+
"Epoch": f"{epoch+1}/{self.epochs}",
314+
"Loss": f"{float(np.mean(train_losses)):.4f}",
315+
f"{self.evaluate_name}": f"{current_eval:.4f}",
316+
"Wait": f"{cnt_wait}/{self.patience}"
317+
}
318+
if self.verbose == "progress_bar" and global_pbar:
319+
global_pbar.set_postfix(log_dict)
320+
elif self.verbose == "print_statement":
321+
print(log_dict)
322+
if cnt_wait > self.patience:
323+
log_dict = {
324+
"Status": "Early Stopped",
325+
"Epoch": f"{epoch+1}/{self.epochs}"
326+
}
327+
if self.verbose == "progress_bar" and global_pbar:
328+
global_pbar.set_postfix(log_dict)
329+
global_pbar.close()
330+
elif self.verbose == "print_stament":
331+
print(log_dict)
332+
break
333+
334+
# Close global progress bar
335+
if global_pbar is not None:
336+
global_pbar.close()
337+
338+
# Restore best model
339+
if best_state_dict is not None:
340+
self.model.load_state_dict(best_state_dict)
341+
else:
342+
warnings.warn(
343+
"No improvement was achieved during training. "
344+
"The model may not be fitted properly.",
345+
UserWarning
346+
)
347+
348+
self.is_fitted_ = True
349+
return self
350+
163351
def _train_epoch(self, train_loader, optimizer, epoch, global_pbar=None):
164352
self.model.train()
165353
losses = []
@@ -186,4 +374,53 @@ def _train_epoch(self, train_loader, optimizer, epoch, global_pbar=None):
186374
})
187375
losses.append(loss.item())
188376

189-
return losses
377+
return losses
378+
379+
def predict(self, X: List[str], test_augmentation: Optional[int] = None) -> Dict[str, np.ndarray]:
380+
"""Make predictions using the fitted model.
381+
382+
Parameters
383+
----------
384+
X : List[str]
385+
List of SMILES strings to make predictions for
386+
test_augmentation : int, optional
387+
Number of times to repeat the polymer for making predictions.
388+
389+
Returns
390+
-------
391+
Dict[str, np.ndarray]
392+
Dictionary containing:
393+
- 'prediction': Model predictions (shape: [n_samples, n_tasks])
394+
395+
"""
396+
self._check_is_fitted()
397+
if test_augmentation is not None:
398+
X_aug, _ = SmilesRepeat(test_augmentation).repeat(X)
399+
X = X_aug
400+
401+
# Convert to PyTorch Geometric format and create loader
402+
X, _ = self._validate_inputs(X)
403+
dataset = self._convert_to_pytorch_data(X)
404+
loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False)
405+
406+
if self.model is None:
407+
raise RuntimeError("Model not initialized")
408+
# Make predictions
409+
self.model = self.model.to(self.device)
410+
self.model.eval()
411+
predictions = []
412+
with torch.no_grad():
413+
if self.verbose == "progress_bar":
414+
iterator = tqdm(loader, desc="Predicting")
415+
elif self.verbose == "print_statement":
416+
print("Predicting...")
417+
iterator = loader
418+
else:
419+
iterator = loader
420+
for batch in iterator:
421+
batch = batch.to(self.device)
422+
out = self.model(batch)
423+
predictions.append(out["prediction"].cpu().numpy())
424+
return {
425+
"prediction": np.concatenate(predictions, axis=0),
426+
}

0 commit comments

Comments
 (0)