1111import bisect
1212import concurrent .futures
1313import traceback
14+ import atexit
1415import glob
1516import heapq
1617import json
2728from dataclasses import dataclass , field
2829
2930import numpy as np
30- from scipy .spatial .distance import cdist
31+ from scipy .spatial .distance import cdist , pdist
3132from scipy .optimize import linear_sum_assignment
3233
3334
@@ -363,7 +364,8 @@ def _pca_align(coords: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
363364 if len (coords ) < 2 :
364365 return coords , np .ones (3 )
365366
366- eigvals , eigvecs = np .linalg .eigh (np .cov (coords .T ))
367+ cov_matrix = (coords .T @ coords ) / max (1 , len (coords ) - 1 )
368+ eigvals , eigvecs = np .linalg .eigh (cov_matrix )
367369
368370 # Descending eigenvalue order → canonical axis labelling.
369371 order = eigvals .argsort ()[::- 1 ]
@@ -571,12 +573,11 @@ def _kabsch_rmsd(pa: np.ndarray, pb: np.ndarray) -> float:
571573 SVD from finding a reflection that would artificially lower the
572574 RMSD for enantiomeric pairs.
573575 """
574- U , _ , Vt = np .linalg .svd (pb .T @ pa )
575- D = np .diag ([1.0 , 1.0 , np .linalg .det (Vt .T @ U .T )])
576- R = Vt .T @ D @ U .T
577- diff = pa - pb @ R .T
578- return float (np .sqrt ((diff ** 2 ).sum () / len (pa )))
579-
576+ U , S , Vt = np .linalg .svd (pb .T @ pa )
577+ d = np .sign (np .linalg .det (U ) * np .linalg .det (Vt ))
578+ E0 = np .sum (pa ** 2 ) + np .sum (pb ** 2 )
579+ rmsd_sq = max (0.0 , E0 - 2.0 * (S [0 ] + S [1 ] + d * S [2 ])) / len (pa )
580+ return float (np .sqrt (rmsd_sq ))
580581
581582# ===========================================================================
582583# Section 2b : BondTopologyChecker
@@ -639,10 +640,9 @@ def fingerprint(
639640
640641 radii_arr = np .array ([elem_radius [s ] for s in symbols ], dtype = np .float64 )
641642
642- # Pairwise distances — vectorised via cdist .
643- dmat = cdist ( coords , coords )
643+ # Pairwise distances — vectorised via pdist (O(N^2/2)) .
644+ dists = pdist ( coords )
644645 ii , jj = np .triu_indices (n , k = 1 )
645- dists = dmat [ii , jj ]
646646
647647 # Per-pair bonding threshold.
648648 thresholds = self .covalent_margin * (radii_arr [ii ] + radii_arr [jj ])
@@ -764,27 +764,13 @@ def push(self, task: ExplorationTask) -> bool:
764764
765765 task .priority = self .compute_priority (task )
766766
767- # Extract the absolute (raw) energy of the source node so the heap
768- # can be ordered by raw_energy. Lower raw_energy == higher Boltzmann
769- # priority, so a min-heap on raw_energy gives correct pop() order
770- # without ever rebuilding the heap when ref_e changes.
771- # Fall back to delta_E_hartree + current ref_e for legacy callers that
772- # do not populate source_node_energy.
773- node_g = task .metadata .get ("source_node_free_energy" )
774- node_e = task .metadata .get ("source_node_energy" )
775- raw_e : float = (
776- node_g if node_g is not None
777- else node_e if node_e is not None
778- else (
779- task .metadata .get ("delta_E_hartree" , 0.0 )
780- + (self ._current_ref_e or 0.0 )
781- )
782- )
783-
784767 self ._task_counter += 1
785768 task_id = self ._task_counter
786769 self ._tasks [task_id ] = task
787- heapq .heappush (self ._heap , (raw_e , self ._push_counter , task_id ))
770+
771+
772+ heapq .heappush (self ._heap , (- task .priority , self ._push_counter , task_id ))
773+
788774 self ._push_counter += 1
789775 self ._submitted .add (key )
790776 return True
@@ -826,7 +812,7 @@ def pop(self) -> ExplorationTask | None:
826812 )
827813
828814 while self ._heap :
829- raw_e , _ , task_id = heapq .heappop (self ._heap )
815+ _ , _ , task_id = heapq .heappop (self ._heap )
830816 task = self ._tasks .pop (task_id , None )
831817 if task is None :
832818 continue # stale entry — task was already removed
@@ -998,12 +984,12 @@ class ExploredPairsLog:
998984 filepath : str
999985 Absolute path to the text file used for persistence.
1000986 flush_interval : int
1001- Number of records to batch before flushing to disk. Defaults to 50 .
987+ Number of records to batch before flushing to disk. Defaults to 100 .
1002988 Higher values reduce I/O bottleneck but risk losing the latest
1003989 records if the process crashes unexpectedly.
1004990 """
1005991
1006- def __init__ (self , filepath : str , flush_interval : int = 50 ) -> None :
992+ def __init__ (self , filepath : str , flush_interval : int = 100 ) -> None :
1007993 self ._filepath = filepath
1008994 self ._flush_interval = flush_interval
1009995 self ._write_count = 0
@@ -1018,6 +1004,7 @@ def __init__(self, filepath: str, flush_interval: int = 50) -> None:
10181004 # write() rather than open/write/close.
10191005 # The handle is flushed periodically based on flush_interval.
10201006 self ._fh = open (self ._filepath , "a" , encoding = "utf-8" ) # noqa: SIM115
1007+ atexit .register (self .close )
10211008
10221009 # ------------------------------------------------------------------ #
10231010 # Private helpers #
0 commit comments