|
14 | 14 | from __future__ import annotations |
15 | 15 |
|
16 | 16 | from abc import abstractmethod, ABC |
17 | | -import cProfile |
18 | | -import pstats |
19 | 17 |
|
20 | | -import contextlib |
21 | 18 | from copy import deepcopy |
22 | 19 | import sys |
23 | 20 | from datetime import datetime, timedelta |
|
26 | 23 | from time import time |
27 | 24 |
|
28 | 25 | import numpy as np |
29 | | -from dask.distributed import get_client, Client, LocalCluster, performance_report |
| 26 | +from dask.distributed import get_client, Client |
30 | 27 |
|
31 | 28 | from geoapps_utils.base import Driver, Options |
32 | 29 | from geoapps_utils.run import load_ui_json_as_dict |
|
74 | 71 | from simpeg_drivers.joint.options import BaseJointOptions |
75 | 72 | from simpeg_drivers.utils.nested import tile_locations |
76 | 73 | from simpeg_drivers.utils.regularization import cell_neighbors, set_rotated_operators |
77 | | -from simpeg_drivers.utils.utils import validate_out_group |
| 74 | +from simpeg_drivers.utils.utils import validate_out_group, start_dask_run |
78 | 75 |
|
79 | 76 | mlogger = logging.getLogger("distributed") |
80 | 77 | mlogger.setLevel(logging.WARNING) |
@@ -500,47 +497,7 @@ def start_dask_run( |
500 | 497 | :param n_workers: Number of workers to use. |
501 | 498 | :param n_threads: Number of threads to use. |
502 | 499 | """ |
503 | | - ui_json = load_ui_json_as_dict(json_path) |
504 | | - |
505 | | - n_workers = ui_json.get("n_workers", n_workers) |
506 | | - n_threads = ui_json.get("n_threads", n_threads) |
507 | | - save_report = ui_json.get("performance_report", False) |
508 | | - |
509 | | - if (n_workers is not None and n_workers > 1) or n_threads is not None: |
510 | | - cluster = LocalCluster( |
511 | | - processes=True, |
512 | | - n_workers=n_workers, |
513 | | - threads_per_worker=n_threads, |
514 | | - ) |
515 | | - else: |
516 | | - cluster = None |
517 | | - |
518 | | - profiler = cProfile.Profile() |
519 | | - profiler.enable() |
520 | | - |
521 | | - with ( |
522 | | - cluster.get_client() |
523 | | - if cluster is not None |
524 | | - else contextlib.nullcontext() as context_client |
525 | | - ): |
526 | | - # Full run |
527 | | - with ( |
528 | | - performance_report(filename=json_path.parent / "dask_profile.html") |
529 | | - if (save_report and isinstance(context_client, Client)) |
530 | | - else contextlib.nullcontext() |
531 | | - ): |
532 | | - cls.start(json_path) |
533 | | - sys.stdout.close() |
534 | | - |
535 | | - profiler.disable() |
536 | | - |
537 | | - if save_report: |
538 | | - with open( |
539 | | - json_path.parent / "runtime_profile.txt", encoding="utf-8", mode="w" |
540 | | - ) as s: |
541 | | - ps = pstats.Stats(profiler, stream=s) |
542 | | - ps.sort_stats("cumulative") |
543 | | - ps.print_stats() |
| 500 | + start_dask_run(cls, json_path, n_workers=n_workers, n_threads=n_threads) |
544 | 501 |
|
545 | 502 | @property |
546 | 503 | def workers(self) -> list[tuple[str]]: |
|
0 commit comments