|
24 | 24 | import tempfile |
25 | 25 | from abc import ABC, abstractmethod |
26 | 26 | from collections import Counter |
| 27 | +from concurrent.futures import ProcessPoolExecutor, as_completed |
27 | 28 | from dataclasses import dataclass, field |
28 | 29 | from typing import Any |
29 | 30 |
|
@@ -2511,203 +2512,150 @@ def _run_batch_parallel( |
2511 | 2512 | history_log: str, |
2512 | 2513 | priority_log: str, |
2513 | 2514 | ) -> None: |
2514 | | - """Execute all tasks in the batch as parallel spawned subprocesses. |
2515 | | -
|
2516 | | - Bug fixes (ported from mapper (2).py): |
2517 | | - 1. Moved ``del active[pid]`` inside a finally block in the polling loop |
2518 | | - so it always executes even when _process_single_result() raises. |
2519 | | - 2. Wrapped _process_single_result() in try/except inside the timeout |
2520 | | - block so that cleanup completes reliably after a timeout. |
2521 | | - 3. Removed the ``is_submitted`` guard in the finally block. |
2522 | | - Entries remaining in active are guaranteed to be unprocessed, so |
2523 | | - the guard is unnecessary and was causing _close_queue() to be skipped. |
2524 | | - 4. Per-process timeout is measured independently using proc_start stored |
2525 | | - per entry. The previous implementation shared a single batch-start |
2526 | | - time, causing later-started processes to inherit a shorter effective |
2527 | | - timeout. |
| 2515 | + """Execute all tasks in the batch via :class:`~concurrent.futures.ProcessPoolExecutor`. |
| 2516 | +
|
| 2517 | + Replaces the previous manual ``multiprocessing.Process`` + ``Queue`` |
| 2518 | + polling loop. Key design decisions: |
| 2519 | +
|
| 2520 | + * ``max_tasks_per_child=1`` — each task runs in a fresh child process, |
| 2521 | + isolating the ``os.chdir()`` call inside :func:`_autots_worker` from |
| 2522 | + other concurrently running children and from the parent process. |
| 2523 | + * ``_autots_worker`` is called directly; no result queue is needed |
| 2524 | + because :class:`~concurrent.futures.Future` already carries either |
| 2525 | + the return value or the raised exception. |
| 2526 | + * The executor is managed manually (not via ``with executor:``) so that |
| 2527 | + on timeout we can force-kill all worker processes *before* calling |
| 2528 | + ``shutdown()``. Using the context manager would call |
| 2529 | + ``shutdown(wait=True)`` unconditionally, causing the main process to |
| 2530 | + block forever when a hung external binary does not exit on its own. |
| 2531 | +
|
| 2532 | + Timeout handling |
| 2533 | + ---------------- |
| 2534 | + ``future.cancel()`` only removes *pending* (not yet started) futures |
| 2535 | + from the internal work queue; it cannot stop an already-running worker |
| 2536 | + process. To avoid the freeze described above, on :exc:`TimeoutError` |
| 2537 | + we: |
| 2538 | +
|
| 2539 | + 1. Cancel all not-yet-started futures. |
| 2540 | + 2. Force-kill every live worker process via ``executor._processes`` |
| 2541 | + (a ``{pid: Process}`` dict maintained by the standard library). |
| 2542 | + This is an intentional use of a private attribute; no public API |
| 2543 | + exists for this operation. |
| 2544 | + 3. Call ``executor.shutdown(wait=False, cancel_futures=True)`` so the |
| 2545 | + executor's bookkeeping threads are released without waiting. |
| 2546 | +
|
| 2547 | + When ``worker_timeout_s`` is ``None`` (recommended default for long- |
| 2548 | + running chemistry calculations), ``as_completed`` blocks indefinitely |
| 2549 | + and the timeout path is never entered. |
2528 | 2550 | """ |
2529 | | - import queue as _queue_mod |
2530 | | - |
2531 | | - # ── Build per-task config and write config_used.json ───────────── |
2532 | | - active: dict[int, tuple] = {} |
| 2551 | + futures_map: dict = {} |
| 2552 | + |
| 2553 | + # ``max_tasks_per_child=1`` ensures a fresh process is spawned for |
| 2554 | + # every task, safely isolating os.chdir() calls within the worker. |
| 2555 | + executor = ProcessPoolExecutor( |
| 2556 | + max_workers=self.n_parallel, |
| 2557 | + mp_context=self._mp_ctx, |
| 2558 | + max_tasks_per_child=1, |
| 2559 | + ) |
2533 | 2560 |
|
2534 | | - for task, run_dir, gamma_sign, atom_i, atom_j, iteration in batch: |
2535 | | - workspace = os.path.join(run_dir, "autots_workspace") |
2536 | | - config = self._make_autots_config(task, workspace) |
2537 | | - try: |
2538 | | - with open(os.path.join(run_dir, "config_used.json"), "w", encoding="utf-8") as fh: |
2539 | | - json.dump(config, fh, indent=2, default=str) |
2540 | | - except Exception as exc: |
2541 | | - logger.warning("Could not write config_used.json: %s", exc) |
| 2561 | + timed_out = False |
| 2562 | + try: |
| 2563 | + # ── Submit all tasks ────────────────────────────────────────── |
| 2564 | + for task, run_dir, gamma_sign, atom_i, atom_j, iteration in batch: |
| 2565 | + workspace = os.path.join(run_dir, "autots_workspace") |
| 2566 | + config = self._make_autots_config(task, workspace) |
| 2567 | + try: |
| 2568 | + with open( |
| 2569 | + os.path.join(run_dir, "config_used.json"), "w", encoding="utf-8" |
| 2570 | + ) as fh: |
| 2571 | + json.dump(config, fh, indent=2, default=str) |
| 2572 | + except Exception as exc: |
| 2573 | + logger.warning("Could not write config_used.json: %s", exc) |
2542 | 2574 |
|
2543 | | - q = self._mp_ctx.Queue() |
2544 | | - proc = self._mp_ctx.Process( |
2545 | | - target=_autots_worker_with_queue, |
2546 | | - args=(config, run_dir, workspace, q), |
2547 | | - ) |
2548 | | - try: |
2549 | | - proc.start() |
2550 | | - except OSError: |
2551 | | - self._close_queue(q) |
2552 | | - logger.error( |
2553 | | - "_run_batch_parallel: proc.start() failed for run %s — treating as FAILED.", |
| 2575 | + future = executor.submit( |
| 2576 | + _autots_worker, |
| 2577 | + config, |
2554 | 2578 | run_dir, |
| 2579 | + workspace, |
2555 | 2580 | ) |
2556 | | - self._process_single_result( |
2557 | | - task, run_dir, [], "FAILED", iteration, |
2558 | | - history_log, gamma_sign, atom_i, atom_j, |
| 2581 | + futures_map[future] = ( |
| 2582 | + task, run_dir, iteration, gamma_sign, atom_i, atom_j |
2559 | 2583 | ) |
2560 | | - continue |
2561 | 2584 |
|
2562 | | - if proc.pid is None: |
| 2585 | + # ── Collect results as each future completes ────────────────── |
| 2586 | + # When worker_timeout_s is None, as_completed blocks indefinitely. |
| 2587 | + # When set, TimeoutError is raised once the deadline is exceeded. |
| 2588 | + try: |
| 2589 | + for future in as_completed(futures_map, timeout=self.worker_timeout_s): |
| 2590 | + task, run_dir, iteration, gamma_sign, atom_i, atom_j = ( |
| 2591 | + futures_map[future] |
| 2592 | + ) |
| 2593 | + try: |
| 2594 | + profile_dirs = future.result() |
| 2595 | + self._process_single_result( |
| 2596 | + task, run_dir, profile_dirs, "DONE", iteration, |
| 2597 | + history_log, gamma_sign, atom_i, atom_j, |
| 2598 | + ) |
| 2599 | + except Exception as exc: |
| 2600 | + logger.error( |
| 2601 | + "_run_batch_parallel: worker failed for %s: %s", |
| 2602 | + run_dir, exc, |
| 2603 | + ) |
| 2604 | + self._process_single_result( |
| 2605 | + task, run_dir, [], "FAILED", iteration, |
| 2606 | + history_log, gamma_sign, atom_i, atom_j, |
| 2607 | + ) |
| 2608 | + |
| 2609 | + except TimeoutError: |
| 2610 | + timed_out = True |
2563 | 2611 | logger.error( |
2564 | | - "_run_batch_parallel: subprocess failed to start (pid=None) " |
2565 | | - "for run %s — treating as FAILED.", run_dir, |
2566 | | - ) |
2567 | | - self._process_single_result( |
2568 | | - task, run_dir, [], "FAILED", iteration, |
2569 | | - history_log, gamma_sign, atom_i, atom_j, |
| 2612 | + "_run_batch_parallel: batch-level timeout (%ds) exceeded — " |
| 2613 | + "force-killing all remaining worker processes.", |
| 2614 | + self.worker_timeout_s, |
2570 | 2615 | ) |
2571 | | - self._close_queue(q) |
2572 | | - proc.join(timeout=5) |
2573 | | - continue |
2574 | 2616 |
|
2575 | | - # Store proc_start per process for independent timeout measurement |
2576 | | - active[proc.pid] = (proc, q, task, run_dir, iteration, |
2577 | | - gamma_sign, atom_i, atom_j, time.time()) |
| 2617 | + # ── Step 1: cancel not-yet-started futures ──────────────── |
| 2618 | + for future in futures_map: |
| 2619 | + future.cancel() |
| 2620 | + |
| 2621 | + # ── Step 2: force-kill running worker processes ─────────── |
| 2622 | + # executor._processes is a {pid: multiprocessing.Process} dict |
| 2623 | + # maintained by ProcessPoolExecutor. No public API exposes |
| 2624 | + # individual worker handles, so this private attribute is the |
| 2625 | + # only reliable way to send SIGKILL to hung external binaries. |
| 2626 | + worker_procs = getattr(executor, "_processes", {}) |
| 2627 | + for pid, proc in list(worker_procs.items()): |
| 2628 | + if proc.is_alive(): |
| 2629 | + logger.warning( |
| 2630 | + "_run_batch_parallel: force-killing worker pid=%d", pid |
| 2631 | + ) |
| 2632 | + proc.kill() |
2578 | 2633 |
|
2579 | | - poll_interval = 60.0 |
2580 | | - try: |
2581 | | - while active: |
2582 | | - # ── Per-process timeout check ───────────────────────────── |
2583 | | - # Each process is timed independently using proc_start (index 8 |
2584 | | - # in the active tuple). Fixes the bug where a shared batch-start |
2585 | | - # time gave later-started processes a shorter effective timeout. |
2586 | | - if self.worker_timeout_s is not None: |
2587 | | - now = time.time() |
2588 | | - timed_out_pids = [ |
2589 | | - p for p, e in active.items() |
2590 | | - if now - e[8] >= self.worker_timeout_s |
2591 | | - ] |
2592 | | - for pid in timed_out_pids: |
2593 | | - proc, q, task, run_dir, it, gs, ai, aj, proc_start = active[pid] |
2594 | | - self._kill_proc(proc) |
2595 | | - self._close_queue(q) |
| 2634 | + # ── Step 3: mark all incomplete futures as TIMEOUT ──────── |
| 2635 | + for future, meta in futures_map.items(): |
| 2636 | + task, run_dir, iteration, gamma_sign, atom_i, atom_j = meta |
| 2637 | + if not future.done(): |
2596 | 2638 | logger.error( |
2597 | | - "Worker timed out after %.0fs (limit=%ds): %s", |
2598 | | - time.time() - proc_start, self.worker_timeout_s, run_dir, |
| 2639 | + "_run_batch_parallel: worker timed out (limit=%ds): %s", |
| 2640 | + self.worker_timeout_s, run_dir, |
2599 | 2641 | ) |
2600 | | - # Catch exceptions from _process_single_result after timeout |
2601 | 2642 | try: |
2602 | 2643 | self._process_single_result( |
2603 | | - task, run_dir, [], "TIMEOUT", it, |
2604 | | - history_log, gs, ai, aj, |
| 2644 | + task, run_dir, [], "TIMEOUT", iteration, |
| 2645 | + history_log, gamma_sign, atom_i, atom_j, |
2605 | 2646 | ) |
2606 | 2647 | except Exception as exc: |
2607 | 2648 | logger.error( |
2608 | 2649 | "_process_single_result failed after TIMEOUT (%s): %s", |
2609 | 2650 | run_dir, exc, |
2610 | 2651 | ) |
2611 | | - del active[pid] |
2612 | | - if not active: |
2613 | | - break |
2614 | | - |
2615 | | - # ── Poll each active process ────────────────────────────── |
2616 | | - for pid, (proc, q, task, run_dir, it, |
2617 | | - gs, ai, aj, proc_start) in list(active.items()): |
2618 | | - try: |
2619 | | - tag, payload = q.get_nowait() |
2620 | | - |
2621 | | - except _queue_mod.Empty: |
2622 | | - if not proc.is_alive(): |
2623 | | - # Worker exited without placing a result (crash) |
2624 | | - try: |
2625 | | - tag, payload = q.get(timeout=5.0) |
2626 | | - except _queue_mod.Empty: |
2627 | | - logger.error( |
2628 | | - "AutoTS worker terminated unexpectedly (run=%s)", |
2629 | | - run_dir, |
2630 | | - ) |
2631 | | - tag, payload = "err", None |
2632 | | - finally: |
2633 | | - proc.join(timeout=30) |
2634 | | - self._close_queue(q) |
2635 | | - else: |
2636 | | - continue |
2637 | | - |
2638 | | - else: |
2639 | | - proc.join(timeout=120) |
2640 | | - if proc.is_alive(): |
2641 | | - self._kill_proc(proc) |
2642 | | - self._close_queue(q) |
2643 | | - |
2644 | | - # ── Result processing: del active[pid] guaranteed by finally ─ |
2645 | | - # Bug fix: before the fix, if _process_single_result raised, |
2646 | | - # del active[pid] was never reached, leaving a stale entry |
2647 | | - # that could be processed twice. |
2648 | | - try: |
2649 | | - if tag == "err": |
2650 | | - logger.error("AutoTS failed for %s:\n%s", run_dir, payload) |
2651 | | - self._process_single_result( |
2652 | | - task, run_dir, [], "FAILED", it, |
2653 | | - history_log, gs, ai, aj, |
2654 | | - ) |
2655 | | - else: |
2656 | | - self._process_single_result( |
2657 | | - task, run_dir, payload, "DONE", it, |
2658 | | - history_log, gs, ai, aj, |
2659 | | - ) |
2660 | | - finally: |
2661 | | - # Always executes (core of the fix) |
2662 | | - del active[pid] |
2663 | | - |
2664 | | - if active: |
2665 | | - # Adjust sleep time to the nearest upcoming timeout deadline |
2666 | | - sleep_t = poll_interval |
2667 | | - if self.worker_timeout_s is not None: |
2668 | | - now = time.time() |
2669 | | - sleep_t = min(poll_interval, max(0.0, min( |
2670 | | - self.worker_timeout_s - (now - e[8]) |
2671 | | - for e in active.values() |
2672 | | - ))) |
2673 | | - if sleep_t > 0: |
2674 | | - time.sleep(sleep_t) |
2675 | 2652 |
|
2676 | 2653 | finally: |
2677 | | - # ── Clean up remaining active processes ─────────────────────── |
2678 | | - # Entries still in active are guaranteed to be unprocessed. |
2679 | | - # Bug fix: the previous is_submitted() guard was removed so that |
2680 | | - # _close_queue() is always executed per-entry. |
2681 | | - for pid, (proc, q, task, run_dir, it, |
2682 | | - gs, ai, aj, _) in list(active.items()): |
2683 | | - if proc.is_alive(): |
2684 | | - self._kill_proc(proc) |
2685 | | - |
2686 | | - status, payload = "FAILED", [] |
2687 | | - try: |
2688 | | - tag, raw = q.get_nowait() |
2689 | | - if tag == "ok": |
2690 | | - status, payload = "DONE", raw |
2691 | | - else: |
2692 | | - logger.error( |
2693 | | - "AutoTS failed (finally) %s:\n%s", run_dir, raw |
2694 | | - ) |
2695 | | - except _queue_mod.Empty: |
2696 | | - pass |
2697 | | - |
2698 | | - try: |
2699 | | - self._process_single_result( |
2700 | | - task, run_dir, payload, status, it, |
2701 | | - history_log, gs, ai, aj, |
2702 | | - ) |
2703 | | - except Exception as exc: |
2704 | | - logger.error( |
2705 | | - "_process_single_result failed (finally) %s: %s", |
2706 | | - run_dir, exc, |
2707 | | - ) |
2708 | | - finally: |
2709 | | - # Executes reliably because the is_submitted guard is gone |
2710 | | - self._close_queue(q) |
| 2654 | + # Shut down the executor. After force-killing all workers in the |
| 2655 | + # timeout path, wait=False avoids a redundant join on dead processes. |
| 2656 | + # In the normal path (no timeout), wait=True ensures all workers are |
| 2657 | + # cleanly joined before proceeding. |
| 2658 | + executor.shutdown(wait=not timed_out, cancel_futures=timed_out) |
2711 | 2659 |
|
2712 | 2660 | try: |
2713 | 2661 | self.graph.save(self.graph_json_path) |
|
0 commit comments