-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path5_create_hard_test_pool_dataframe.py
More file actions
146 lines (110 loc) · 4.43 KB
/
5_create_hard_test_pool_dataframe.py
File metadata and controls
146 lines (110 loc) · 4.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""
Create a DataFrame from the hard test pool for easy access in benchmark creation.
The hard test pool serves as the basis for the benchmark dataset. To make it easily
accessible, we store it as a DataFrame containing basic information about each
molecule including SMILES, complexity, heavy atom count, and IUPAC names.
"""
from __future__ import annotations
import pickle
from dataclasses import dataclass
from multiprocessing import Pool
from pathlib import Path
from typing import List
import pandas as pd
from rdkit import Chem
from rdkit.Chem.GraphDescriptors import BertzCT
from tqdm import tqdm
SUBMISSION_ROOT = Path(__file__).resolve().parent
DATA_ROOT = SUBMISSION_ROOT.parent.parent / "data" / "dataset_pools"
DEFAULT_INPUT_PATH = DATA_ROOT / "processed" / "pubchem_train_test_pools.pkl"
DEFAULT_OUTPUT_PATH = DATA_ROOT / "processed" / "hard_test_pool_dataframe.pkl"
@dataclass
class CreateHardTestPoolDataFrameConfig:
"""Paths and parameters for creating the hard test pool DataFrame."""
input_path: Path = DEFAULT_INPUT_PATH
output_path: Path = DEFAULT_OUTPUT_PATH
nbr_processes: int = 100
def _compute_bertz_complexity(mol: Chem.Mol) -> float | None:
"""
Compute Bertz complexity for a given RDKit molecule.
Args:
mol: The RDKit molecule object.
Returns:
Bertz complexity of the molecule, or None if computation fails.
"""
bertz_complexity = BertzCT(mol)
return bertz_complexity
def _compute_nbr_heavy_atoms(mol: Chem.Mol) -> int | None:
"""
Compute the number of heavy atoms in a given RDKit molecule.
Args:
mol: The RDKit molecule object.
Returns:
Number of heavy atoms in the molecule.
"""
return mol.GetNumHeavyAtoms() if mol else None
def _create_dataframe_from_pool(
smiles_list: List[str], iupac_dict: dict, nbr_processes: int
) -> pd.DataFrame:
"""
Create a DataFrame from a list of SMILES with molecular properties.
Args:
smiles_list: List of SMILES strings.
iupac_dict: Dictionary mapping SMILES to IUPAC names.
nbr_processes: Number of parallel processes to use.
Returns:
DataFrame with SMILES, mol objects, complexity, heavy atoms, and IUPAC names.
"""
# Create rdkit mol objects
mols = [Chem.MolFromSmiles(smi) for smi in smiles_list]
assert None not in mols, "Some SMILES could not be converted to rdkit mol objects."
# Compute Bertz complexity
with Pool(nbr_processes) as pool:
bertz_complexities = list(
tqdm(
pool.imap(_compute_bertz_complexity, mols),
total=len(mols),
desc="Computing Bertz complexities",
)
)
# Compute number of heavy atoms
with Pool(nbr_processes) as pool:
nbr_heavy_atoms = list(
tqdm(
pool.imap(_compute_nbr_heavy_atoms, mols),
total=len(mols),
desc="Computing number of heavy atoms",
)
)
# Create DataFrame
df = pd.DataFrame(
{
"smiles": smiles_list,
"mol": mols,
"complexity": bertz_complexities,
"nbr_heavy_atoms": nbr_heavy_atoms,
}
)
# Include IUPAC names if available
df["iupac_name"] = df["smiles"].map(iupac_dict)
# Create complexity bins: 0-100, 100-200, ..., 900-1000, 1000+
# Bins are preliminarily defined here and can be adjusted as needed
bins = [0, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 1e9]
labels = [f"{i}-{i+100}" for i in range(0, 1000, 100)] + ["1000+"]
df["complexity_bin"] = pd.cut(df["complexity"], bins=bins, labels=labels, right=False)
return df
def create_hard_test_pool_dataframe(cfg: CreateHardTestPoolDataFrameConfig) -> None:
"""Create and save a DataFrame from the hard test pool."""
input_path = Path(cfg.input_path)
output_path = Path(cfg.output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
if input_path.exists():
with input_path.open("rb") as handle:
hard_test_set, _, _, iupac_dict = pickle.load(handle)
else:
raise FileNotFoundError(f"Input file {input_path} does not exist.")
hard_test_df = _create_dataframe_from_pool(hard_test_set, iupac_dict, cfg.nbr_processes)
with output_path.open("wb") as handle:
pickle.dump(hard_test_df, handle)
if __name__ == "__main__":
create_hard_test_pool_dataframe(CreateHardTestPoolDataFrameConfig())