Skip to content

Commit 86c4d65

Browse files
authored
Add files via upload
1 parent c3594f9 commit 86c4d65

1 file changed

Lines changed: 118 additions & 78 deletions

File tree

multioptpy/Wrapper/mapper.py

Lines changed: 118 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,21 @@ class ExplorationTask:
570570
priority: float = 0.0
571571
metadata: dict = field(default_factory=dict)
572572

573+
574+
@dataclass
575+
class NodeUpdateData:
576+
"""Typed container for pending node energy backfill data.
577+
578+
Both fields are nullable so that partial updates (e.g. only ``energy``
579+
is known at registration time) are correctly represented. ``frozen``
580+
is intentionally left ``False`` so that future fields can be added
581+
without breaking existing serialisation/deserialisation code.
582+
"""
583+
584+
energy: float | None = None
585+
free_energy: float | None = None
586+
587+
573588
class ExplorationQueue(ABC):
574589
"""Abstract base class for priority-ordered exploration queues.
575590
@@ -596,48 +611,86 @@ def compute_priority(self, task):
596611
"""
597612

598613
def __init__(self, rng_seed: int = 42) -> None:
599-
# _tasks: canonical list of ExplorationTask objects.
600-
# Kept as a real list so that subclasses (e.g. RCMCQueue) can call
601-
# .sort() and .pop(0) on it directly without breaking.
602-
self._tasks: list[ExplorationTask] = []
603-
# _heap: parallel min-heap of (-priority, counter, task) used by the
604-
# base-class push()/pop() for O(log n) insertion and extraction.
605-
# RCMCQueue overrides pop() entirely and never touches _heap.
606-
self._heap: list[tuple[float, int, ExplorationTask]] = []
614+
# _tasks: id(task) → ExplorationTask mapping.
615+
# Using a dict instead of a list gives O(1) deletion in pop() without
616+
# the O(N) list.remove() cost that previously negated the heap benefit.
617+
# Subclasses that need an ordered view should call
618+
# sorted(self._tasks.values(), ...) rather than operating on the dict
619+
# directly.
620+
self._tasks: dict[int, ExplorationTask] = {}
621+
# _heap: min-heap of (raw_energy, counter, task_id) for O(log N)
622+
# extraction. Entries keyed by id(task); stale entries (whose task_id
623+
# is absent from _tasks) are silently skipped in pop().
624+
# Storing raw_energy (absolute, not delta) means the heap order is
625+
# invariant to reference-energy shifts, so refresh_priorities() no
626+
# longer needs to rebuild the heap.
627+
self._heap: list[tuple[float, int, int]] = []
607628
self._push_counter: int = 0
608629
self._submitted: set[tuple] = set()
609630
self._rng = np.random.default_rng(rng_seed)
631+
# Latest reference energy supplied by refresh_priorities(); used for
632+
# lazy delta_E recalculation inside pop().
633+
self._current_ref_e: float | None = None
610634

611635
def push(self, task: ExplorationTask) -> bool:
612636
key = (task.node_id, tuple(task.afir_params))
613637
if key in self._submitted:
614638
return False
615639

616640
task.priority = self.compute_priority(task)
617-
# Update both _tasks (for subclass access) and _heap (for base-class pop).
618-
self._tasks.append(task)
619-
heapq.heappush(self._heap, (-task.priority, self._push_counter, task))
641+
642+
# Extract the absolute (raw) energy of the source node so the heap
643+
# can be ordered by raw_energy. Lower raw_energy == higher Boltzmann
644+
# priority, so a min-heap on raw_energy gives correct pop() order
645+
# without ever rebuilding the heap when ref_e changes.
646+
# Fall back to delta_E_hartree + current ref_e for legacy callers that
647+
# do not populate source_node_energy.
648+
node_g = task.metadata.get("source_node_free_energy")
649+
node_e = task.metadata.get("source_node_energy")
650+
raw_e: float = (
651+
node_g if node_g is not None
652+
else node_e if node_e is not None
653+
else (
654+
task.metadata.get("delta_E_hartree", 0.0)
655+
+ (self._current_ref_e or 0.0)
656+
)
657+
)
658+
659+
task_id = id(task)
660+
self._tasks[task_id] = task
661+
heapq.heappush(self._heap, (raw_e, self._push_counter, task_id))
620662
self._push_counter += 1
621663
self._submitted.add(key)
622664
return True
623665

624666
def pop(self) -> ExplorationTask | None:
625-
"""Pop the highest-priority task using the heap (O(log n)).
667+
"""Pop the highest-priority task using the heap (O(log N) amortised).
626668
627-
Also removes the task from ``_tasks`` so subclasses that iterate
628-
``_tasks`` see a consistent state.
669+
``_tasks`` deletion is O(1) because it is now a ``dict`` keyed by
670+
``id(task)``. Stale heap entries (whose ``task_id`` has already been
671+
removed from ``_tasks``) are silently skipped.
629672
630-
Stale heap entries (tasks that exist in the heap but have already
631-
been removed from ``_tasks`` by a prior rebuild) are silently
632-
skipped so that only valid tasks are ever returned to the caller.
673+
When :meth:`refresh_priorities` has been called since the last push,
674+
``delta_E_hartree`` is recomputed lazily for the popped task using the
675+
stored ``_current_ref_e`` before returning it, keeping the returned
676+
task's metadata current without O(N) heap reconstruction.
633677
"""
634678
while self._heap:
635-
_, _, task = heapq.heappop(self._heap)
636-
try:
637-
self._tasks.remove(task)
638-
return task
639-
except ValueError:
640-
continue # stale entry — skip and try the next one
679+
raw_e, _, task_id = heapq.heappop(self._heap)
680+
task = self._tasks.pop(task_id, None)
681+
if task is None:
682+
continue # stale entry — task was already removed
683+
684+
# Lazy delta_E / priority refresh for the single task being returned.
685+
if self._current_ref_e is not None:
686+
node_g = task.metadata.get("source_node_free_energy")
687+
node_e = task.metadata.get("source_node_energy")
688+
eff_e = node_g if node_g is not None else node_e
689+
if eff_e is not None:
690+
task.metadata["delta_E_hartree"] = eff_e - self._current_ref_e
691+
task.priority = self.compute_priority(task)
692+
693+
return task
641694
return None
642695

643696
def is_submitted(self, key: tuple) -> bool:
@@ -690,50 +743,36 @@ def should_add(self, node: "EQNode", reference_energy: float, **kwargs) -> bool:
690743
return bool(self._rng.random() < p)
691744

692745
def refresh_priorities(self, ref_e: float | None) -> None:
693-
"""Recompute priorities for all queued tasks using the current reference energy.
746+
"""Record the current reference energy for lazy evaluation at pop() time.
694747
695-
Should be called at the start of each iteration (before ``pop()``) so
696-
that tasks enqueued when the reference energy was higher are
697-
re-weighted against the latest minimum-energy node in the graph.
748+
The previous implementation rebuilt the entire heap with
749+
``heapq.heapify()`` on every iteration — O(N) cost. This was
750+
unnecessary because the heap is ordered by *raw* (absolute) node
751+
energy, and the relative order of tasks is invariant to a uniform
752+
shift of the reference energy.
698753
699-
When ``task.metadata["source_node_free_energy"]`` is present the
700-
free energy is used as the effective energy for the ΔE calculation
701-
(so that the Boltzmann weight tracks the free-energy landscape).
702-
Otherwise ``task.metadata["source_node_energy"]`` (electronic
703-
energy) is used, preserving the original behaviour for nodes without
704-
thermochemistry results.
754+
The new design simply stores ``ref_e`` so that :meth:`pop` can
755+
recompute ``delta_E_hartree`` and ``priority`` for the single task
756+
it is about to return, eliminating the per-iteration O(N) rebuild
757+
entirely.
705758
706-
The ``ref_e`` argument should already be the unified reference
707-
returned by :meth:`NetworkGraph.reference_energy` (G-based when any
708-
node has thermochemistry, electronic otherwise), so the mixed-mode
709-
ΔE values remain directly comparable.
759+
For subclasses whose ``compute_priority`` is *not* monotone in
760+
``delta_E`` (i.e. a lower raw energy does not always imply a higher
761+
priority), the heap order may be approximate. In those cases
762+
subclasses should override ``pop()`` with a full-sort strategy
763+
(as ``RCMCQueue`` does) rather than relying on the heap.
710764
711-
If neither value is available the stored ``delta_E_hartree`` is left
712-
unchanged so the priority degrades gracefully to the value set at
713-
enqueue time.
714-
715-
The queue is re-sorted in descending priority order after all tasks
716-
have been updated.
765+
Parameters
766+
----------
767+
ref_e:
768+
Unified reference energy from
769+
:meth:`NetworkGraph.reference_energy` (G-based when any node
770+
carries thermochemistry, electronic otherwise). ``None`` is
771+
accepted and leaves the stored value unchanged, so the method
772+
is safe to call unconditionally at the start of each iteration.
717773
"""
718-
if not self._tasks or ref_e is None:
719-
return
720-
721-
for task in self._tasks:
722-
node_g = task.metadata.get("source_node_free_energy")
723-
node_e = task.metadata.get("source_node_energy")
724-
eff_e = node_g if node_g is not None else node_e
725-
if eff_e is not None:
726-
task.metadata["delta_E_hartree"] = eff_e - ref_e
727-
task.priority = self.compute_priority(task)
728-
729-
# Rebuild the heap from the updated _tasks list.
730-
# Use _push_counter-based indices to guarantee unique tie-breakers,
731-
# then advance _push_counter so that subsequent push() calls never
732-
# reuse a counter value (monotonic increase requirement).
733-
base = self._push_counter
734-
self._heap = [(-t.priority, base + i, t) for i, t in enumerate(self._tasks)]
735-
heapq.heapify(self._heap)
736-
self._push_counter = base + len(self._tasks)
774+
if ref_e is not None:
775+
self._current_ref_e = ref_e
737776

738777
def export_queue_status(self) -> list[dict]:
739778
return [
@@ -742,7 +781,7 @@ def export_queue_status(self) -> list[dict]:
742781
"priority": t.priority,
743782
"afir_params": t.afir_params,
744783
}
745-
for t in self._tasks
784+
for t in self._tasks.values()
746785
]
747786

748787
def __len__(self) -> int:
@@ -1769,9 +1808,9 @@ def __init__(
17691808
self._iteration: int = 0
17701809

17711810
# Accumulates energy backfill requests discovered during each iteration.
1772-
# Maps node_id -> {"energy": float|None, "free_energy": float|None}.
1811+
# Maps node_id -> NodeUpdateData (typed container for energy/free_energy).
17731812
# Applied in bulk by _flush_node_energy_updates() before graph.save().
1774-
self._pending_node_updates: dict[int, dict] = {}
1813+
self._pending_node_updates: dict[int, NodeUpdateData] = {}
17751814

17761815
os.makedirs(self.output_dir, exist_ok=True)
17771816
os.makedirs(self.work_dir, exist_ok=True)
@@ -2500,12 +2539,12 @@ def _flush_node_energy_updates(self) -> None:
25002539

25012540
changed: list[str] = []
25022541

2503-
new_e = updates.get("energy")
2542+
new_e = updates.energy
25042543
if new_e is not None and node.energy is None:
25052544
node.energy = new_e
25062545
changed.append(f"energy={new_e:.10f} Ha")
25072546

2508-
new_g = updates.get("free_energy")
2547+
new_g = updates.free_energy
25092548
if new_g is not None and node.free_energy is None:
25102549
node.free_energy = new_g
25112550
changed.append(f"free_energy={new_g:.10f} Ha")
@@ -2623,18 +2662,19 @@ def _find_or_register_node(
26232662
# the new incoming structure carries those values, queue an
26242663
# update so that _flush_node_energy_updates() can populate the
26252664
# null fields before the next graph.save().
2626-
needs_update: dict = {}
2627-
if existing.energy is None and energy is not None:
2628-
needs_update["energy"] = energy
2629-
if existing.free_energy is None and free_energy is not None:
2630-
needs_update["free_energy"] = free_energy
2631-
if needs_update:
2632-
prev = self._pending_node_updates.get(existing.node_id, {})
2665+
needs_update = NodeUpdateData(
2666+
energy=energy if existing.energy is None else None,
2667+
free_energy=free_energy if existing.free_energy is None else None,
2668+
)
2669+
if needs_update.energy is not None or needs_update.free_energy is not None:
2670+
prev = self._pending_node_updates.get(
2671+
existing.node_id, NodeUpdateData()
2672+
)
26332673
# Only overwrite fields that are still absent in prior queued updates.
2634-
if "energy" not in prev and "energy" in needs_update:
2635-
prev["energy"] = needs_update["energy"]
2636-
if "free_energy" not in prev and "free_energy" in needs_update:
2637-
prev["free_energy"] = needs_update["free_energy"]
2674+
if prev.energy is None and needs_update.energy is not None:
2675+
prev.energy = needs_update.energy
2676+
if prev.free_energy is None and needs_update.free_energy is not None:
2677+
prev.free_energy = needs_update.free_energy
26382678
self._pending_node_updates[existing.node_id] = prev
26392679
logger.debug(
26402680
"_find_or_register_node: queued backfill for EQ%d: %s",

0 commit comments

Comments
 (0)