Skip to content

Commit ecf5356

Browse files
authored
Add files via upload
1 parent fa89a62 commit ecf5356

1 file changed

Lines changed: 90 additions & 48 deletions

File tree

multioptpy/Wrapper/mapper.py

Lines changed: 90 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
import logging
1919
import multiprocessing
2020
import os
21+
import re
2122
import copy
2223
import shutil
2324
import sys
2425
import tempfile
25-
import traceback
2626
from abc import ABC, abstractmethod
2727
from collections import Counter
2828
from dataclasses import dataclass, field
@@ -136,7 +136,7 @@ def _autots_worker_with_queue(
136136
parent process via *result_queue* as a ``(tag, payload)`` tuple:
137137
138138
* ``("ok", profiles)`` on success.
139-
* ``("err", exception)`` on failure.
139+
* ``("err", traceback_str)`` on failure (full traceback as a string).
140140
141141
This avoids the per-call ``ProcessPoolExecutor`` setup/teardown
142142
overhead incurred by the sequential (``n_parallel=1``) code path.
@@ -145,7 +145,9 @@ def _autots_worker_with_queue(
145145
profiles = _autots_worker(config, run_dir, workspace)
146146
result_queue.put(("ok", profiles))
147147
except Exception as exc: # noqa: BLE001
148-
result_queue.put(("err", f"{type(exc).__name__}: {exc}\n{traceback.format_exc()}"))
148+
# Serialise the full traceback as a string so it crosses the process
149+
# boundary safely (some exception types are not picklable).
150+
result_queue.put(("err", traceback.format_exc()))
149151

150152

151153
logger = logging.getLogger(__name__)
@@ -159,45 +161,80 @@ def _autots_worker_with_queue(
159161
# Section 1 : XYZ Utilities
160162
# ===========================================================================
161163

164+
165+
# Pre-compiled XYZ atom-line pattern. Placed at module level so the regex is
166+
# compiled exactly once rather than on every parse_xyz() call.
167+
_XYZ_PATTERN: re.Pattern = re.compile(
168+
r"\s*([A-Za-z]+)\s+"
169+
r"([+-]?(?:\d+(?:\.\d+)?)(?:[eE][+-]?\d+)?)\s+"
170+
r"([+-]?(?:\d+(?:\.\d+)?)(?:[eE][+-]?\d+)?)\s+"
171+
r"([+-]?(?:\d+(?:\.\d+)?)(?:[eE][+-]?\d+)?)\s*"
172+
)
173+
174+
175+
def get_pattern_xyz() -> re.Pattern:
176+
"""Return the pre-compiled XYZ atom-line regex.
177+
178+
Kept for backward compatibility; callers should prefer ``_XYZ_PATTERN``
179+
directly. The regex is no longer recompiled on each call.
180+
"""
181+
return _XYZ_PATTERN
182+
162183
def parse_xyz(filepath: str) -> tuple[list[str], np.ndarray]:
163-
with open(filepath, "r") as fh:
164-
lines = fh.readlines()
184+
with open(filepath, "r", encoding="utf-8") as fh:
165185

166-
n_atoms: int | None = None
167-
data_start: int = 0
186+
n_atoms = None
187+
header_line_idx = 0
188+
for line_idx, line in enumerate(fh, 1):
189+
stripped = line.strip()
190+
if not stripped:
191+
continue
192+
if stripped.isdigit():
193+
n_atoms = int(stripped)
194+
header_line_idx = line_idx
195+
break
196+
else:
197+
raise ValueError(
198+
f"Invalid XYZ format at {filepath}:{line_idx}: "
199+
f"Expected atom count, got '{stripped}'"
200+
)
168201

169-
non_blank = [(i, ln.strip()) for i, ln in enumerate(lines) if ln.strip()]
170-
if non_blank and non_blank[0][1].isdigit():
171-
n_atoms = int(non_blank[0][1])
172-
data_start = non_blank[0][0] + 2
202+
if n_atoms is None:
203+
raise ValueError(f"Empty or invalid XYZ file: {filepath}")
173204

174-
symbols: list[str] = []
175-
coords_raw: list[list[float]] = []
205+
# Skip comment line
206+
next(fh, None)
176207

177-
for ln in lines[data_start:]:
178-
parts = ln.split()
179-
if len(parts) < 4:
180-
continue
181-
try:
182-
symbols.append(parts[0])
183-
coords_raw.append([float(parts[1]), float(parts[2]), float(parts[3])])
184-
except ValueError:
185-
logger.warning(
186-
"parse_xyz: skipping malformed line in %s: %r", filepath, ln.strip()
187-
)
188-
continue
189-
if n_atoms is not None and len(symbols) >= n_atoms:
190-
break
208+
symbols: list[str] = []
209+
coords_raw: list[list[float]] = []
191210

192-
if n_atoms is not None and len(symbols) != n_atoms:
193-
raise ValueError(
194-
f"Expected {n_atoms} atoms in {filepath}, but parsed {len(symbols)}."
195-
)
211+
# line number of the first atom line = header + comment + 1
212+
atom_line_start = header_line_idx + 2
213+
for atom_line_idx, line in enumerate(fh, atom_line_start):
214+
stripped = line.strip()
215+
if not stripped:
216+
continue
217+
218+
match = _XYZ_PATTERN.match(line)
219+
if not match:
220+
raise ValueError(
221+
f"Invalid atom data at {filepath}:{atom_line_idx}: '{stripped}'"
222+
)
196223

197-
if not symbols:
198-
raise ValueError(f"No atomic coordinates found in: {filepath}")
224+
sym, x, y, z = match.groups()
225+
symbols.append(sym)
226+
coords_raw.append([float(x), float(y), float(z)])
227+
228+
if len(symbols) >= n_atoms:
229+
break
230+
231+
if len(symbols) < n_atoms:
232+
raise ValueError(
233+
f"Unexpected EOF in {filepath}: "
234+
f"Expected {n_atoms} atoms, but only found {len(symbols)}."
235+
)
199236

200-
return symbols, np.array(coords_raw, dtype=float)
237+
return symbols, np.array(coords_raw)
201238

202239
def distance_matrix(coords: np.ndarray) -> np.ndarray:
203240
# cdist avoids the (N,N,3) intermediate array produced by manual broadcasting.
@@ -648,13 +685,12 @@ def fingerprint(
648685
thresholds = self.covalent_margin * (radii_arr[ii] + radii_arr[jj])
649686
bonded_idx = np.where(dists <= thresholds)[0]
650687

651-
counts: dict[tuple[str, str], int] = {}
652-
for k in bonded_idx:
653-
si, sj = symbols[ii[k]], symbols[jj[k]]
654-
key = (si, sj) if si <= sj else (sj, si)
655-
counts[key] = counts.get(key, 0) + 1
656-
657-
return counts
688+
symbols_arr = np.array(symbols)
689+
bonded_si = symbols_arr[ii[bonded_idx]]
690+
bonded_sj = symbols_arr[jj[bonded_idx]]
691+
pairs = np.sort(np.column_stack((bonded_si, bonded_sj)), axis=1)
692+
unique_pairs, counts = np.unique(pairs, axis=0, return_counts=True)
693+
return {tuple(p): int(c) for p, c in zip(unique_pairs, counts)}
658694

659695
def has_rearrangement(
660696
self,
@@ -1807,6 +1843,12 @@ def parse(self, profile_dir: str) -> dict | None:
18071843
return None
18081844

18091845
ts_file = ts_matches[0]
1846+
if len(ts_matches) > 1:
1847+
logger.warning(
1848+
"ProfileParser: %d *_ts_final.xyz files found in %s — "
1849+
"using the first one (%s). Check for unexpected duplicates.",
1850+
len(ts_matches), profile_dir, ts_file,
1851+
)
18101852
energies = self._parse_energy_txt(txt_path)
18111853
free_energies = self._parse_free_energy_txt(txt_path)
18121854

@@ -1865,11 +1907,7 @@ def _parse_energy_txt(txt_path: str) -> dict:
18651907
if not os.path.isfile(txt_path):
18661908
return result
18671909

1868-
with open(txt_path, "r") as fh:
1869-
for line in fh:
1870-
stripped = line.strip()
1871-
# Stop at the free-energy section so G_tot is not mis-parsed
1872-
# as an electronic energy (the section starts with "# ===...
1910+
with open(txt_path, "r", encoding="utf-8") as fh:
18731911
# FREE ENERGY SECTION").
18741912
if "FREE ENERGY SECTION" in stripped:
18751913
break
@@ -1918,7 +1956,7 @@ def _parse_free_energy_txt(txt_path: str) -> dict:
19181956
return result
19191957

19201958
in_section = False
1921-
with open(txt_path, "r") as fh:
1959+
with open(txt_path, "r", encoding="utf-8") as fh:
19221960
for line in fh:
19231961
stripped = line.strip()
19241962
# Detect entry into free-energy section
@@ -3041,7 +3079,11 @@ def _run_autots(self, task: ExplorationTask, run_dir: str) -> list[str]:
30413079
)
30423080
tag, payload = result_q.get_nowait()
30433081
if tag == "err":
3044-
raise payload # re-raise the original worker exception
3082+
# payload is a formatted traceback string (see _autots_worker_with_queue).
3083+
# Wrap in RuntimeError so it can be raised and caught normally.
3084+
raise RuntimeError(
3085+
f"AutoTSWorkflow subprocess failed:\n{payload}"
3086+
)
30453087
profiles: list[str] = payload
30463088
return profiles
30473089

0 commit comments

Comments
 (0)