Skip to content

Commit 62e70ec

Browse files
committed
Make core function start_dask_run a utils.
1 parent 898927a commit 62e70ec

5 files changed

Lines changed: 103 additions & 54 deletions

File tree

simpeg_drivers/driver.py

Lines changed: 3 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,7 @@
1414
from __future__ import annotations
1515

1616
from abc import abstractmethod, ABC
17-
import cProfile
18-
import pstats
1917

20-
import contextlib
2118
from copy import deepcopy
2219
import sys
2320
from datetime import datetime, timedelta
@@ -26,7 +23,7 @@
2623
from time import time
2724

2825
import numpy as np
29-
from dask.distributed import get_client, Client, LocalCluster, performance_report
26+
from dask.distributed import get_client, Client
3027

3128
from geoapps_utils.base import Driver, Options
3229
from geoapps_utils.run import load_ui_json_as_dict
@@ -74,7 +71,7 @@
7471
from simpeg_drivers.joint.options import BaseJointOptions
7572
from simpeg_drivers.utils.nested import tile_locations
7673
from simpeg_drivers.utils.regularization import cell_neighbors, set_rotated_operators
77-
from simpeg_drivers.utils.utils import validate_out_group
74+
from simpeg_drivers.utils.utils import validate_out_group, start_dask_run
7875

7976
mlogger = logging.getLogger("distributed")
8077
mlogger.setLevel(logging.WARNING)
@@ -500,47 +497,7 @@ def start_dask_run(
500497
:param n_workers: Number of workers to use.
501498
:param n_threads: Number of threads to use.
502499
"""
503-
ui_json = load_ui_json_as_dict(json_path)
504-
505-
n_workers = ui_json.get("n_workers", n_workers)
506-
n_threads = ui_json.get("n_threads", n_threads)
507-
save_report = ui_json.get("performance_report", False)
508-
509-
if (n_workers is not None and n_workers > 1) or n_threads is not None:
510-
cluster = LocalCluster(
511-
processes=True,
512-
n_workers=n_workers,
513-
threads_per_worker=n_threads,
514-
)
515-
else:
516-
cluster = None
517-
518-
profiler = cProfile.Profile()
519-
profiler.enable()
520-
521-
with (
522-
cluster.get_client()
523-
if cluster is not None
524-
else contextlib.nullcontext() as context_client
525-
):
526-
# Full run
527-
with (
528-
performance_report(filename=json_path.parent / "dask_profile.html")
529-
if (save_report and isinstance(context_client, Client))
530-
else contextlib.nullcontext()
531-
):
532-
cls.start(json_path)
533-
sys.stdout.close()
534-
535-
profiler.disable()
536-
537-
if save_report:
538-
with open(
539-
json_path.parent / "runtime_profile.txt", encoding="utf-8", mode="w"
540-
) as s:
541-
ps = pstats.Stats(profiler, stream=s)
542-
ps.sort_stats("cumulative")
543-
ps.print_stats()
500+
start_dask_run(cls, json_path, n_workers=n_workers, n_threads=n_threads)
544501

545502
@property
546503
def workers(self) -> list[tuple[str]]:

simpeg_drivers/plate_simulation/driver.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from simpeg_drivers.plate_simulation.models.parametric import Plate
3434
from simpeg_drivers.plate_simulation.models.series import DikeSwarm, Geology
3535
from simpeg_drivers.plate_simulation.options import PlateSimulationOptions
36-
from simpeg_drivers.utils.utils import validate_out_group
36+
from simpeg_drivers.utils.utils import start_dask_run, validate_out_group
3737

3838

3939
logger = get_logger(__name__, propagate=False)
@@ -278,8 +278,19 @@ def replicate(
278278
plates.append(new)
279279
return plates
280280

281+
@classmethod
282+
def start_dask_run(
283+
cls, json_path: Path, n_workers: int | None = None, n_threads: int | None = None
284+
):
285+
"""
286+
Sets Dask config settings.
287+
288+
:param json_path: Path to input file (.ui.json) for the application.
289+
:param n_workers: Number of workers to use.
290+
:param n_threads: Number of threads to use.
291+
"""
292+
start_dask_run(cls, json_path, n_workers=n_workers, n_threads=n_threads)
281293

282-
PlateSimulationDriver.start_dask_run = InversionDriver.start_dask_run
283294

284295
if __name__ == "__main__":
285296
file = Path(sys.argv[1])

simpeg_drivers/plate_simulation/match/driver.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from simpeg_drivers.plate_simulation.options import ModelOptions, PlateSimulationOptions
3939
from simpeg_drivers.utils.utils import (
4040
get_default_parallelization_params,
41+
start_dask_run,
4142
validate_out_group,
4243
)
4344

@@ -382,8 +383,18 @@ def run_scores(self, spatial_projection, data) -> tuple[np.ndarray, np.ndarray]:
382383

383384
return scores, centers
384385

386+
@classmethod
387+
def start_dask_run(
388+
cls, json_path: Path, n_workers: int | None = None, n_threads: int | None = None
389+
):
390+
"""
391+
Sets Dask config settings.
385392
386-
PlateMatchDriver.start_dask_run = BaseDriver.start_dask_run
393+
:param json_path: Path to input file (.ui.json) for the application.
394+
:param n_workers: Number of workers to use.
395+
:param n_threads: Number of threads to use.
396+
"""
397+
start_dask_run(cls, json_path, n_workers=n_workers, n_threads=n_threads)
387398

388399

389400
def is_up_dip(data: np.ndarray) -> bool:
@@ -514,7 +525,6 @@ def batch_files_score(
514525

515526
if __name__ == "__main__":
516527
file = Path(sys.argv[1]).resolve()
528+
n_w, n_t = get_default_parallelization_params(file)
517529

518-
n_workers, n_threads = get_default_parallelization_params(file)
519-
520-
PlateMatchDriver.start_dask_run(file, n_workers=n_workers, n_threads=n_threads)
530+
PlateMatchDriver.start_dask_run(file, n_workers=n_w, n_threads=n_t)

simpeg_drivers/plate_simulation/sweep/driver.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from simpeg_drivers.plate_simulation.options import PlateSimulationOptions
3838
from simpeg_drivers.plate_simulation.sweep.options import SweepOptions
3939
from simpeg_drivers.plate_simulation.sweep.uijson import PlateSweepUIJson
40-
from simpeg_drivers.utils.utils import validate_out_group
40+
from simpeg_drivers.utils.utils import start_dask_run, validate_out_group
4141

4242

4343
logger = get_logger(name=__name__, level_name=False, propagate=False, add_name=False)
@@ -192,8 +192,18 @@ def run_trial(
192192
del plate_sim
193193
return None
194194

195+
@classmethod
196+
def start_dask_run(
197+
cls, json_path: Path, n_workers: int | None = None, n_threads: int | None = None
198+
):
199+
"""
200+
Sets Dask config settings.
195201
196-
PlateSweepDriver.start_dask_run = BaseDriver.start_dask_run
202+
:param json_path: Path to input file (.ui.json) for the application.
203+
:param n_workers: Number of workers to use.
204+
:param n_threads: Number of threads to use.
205+
"""
206+
start_dask_run(cls, json_path, n_workers=n_workers, n_threads=n_threads)
197207

198208

199209
def forms_to_values(data: dict) -> dict:

simpeg_drivers/utils/utils.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,18 @@
1111

1212
from __future__ import annotations
1313

14+
import contextlib
15+
import cProfile
1416
import multiprocessing
17+
import pstats
18+
import sys
1519
from collections.abc import Sequence
1620
from copy import deepcopy
1721
from pathlib import Path
1822
from typing import TYPE_CHECKING
1923

2024
import numpy as np
25+
from dask.distributed import Client, LocalCluster, performance_report
2126
from discretize import TensorMesh, TreeMesh
2227
from discretize.utils import mesh_utils
2328
from geoapps_utils.base import Options
@@ -724,3 +729,59 @@ def validate_out_group(options: Options) -> SimPEGGroup:
724729
out_group.metadata = None
725730

726731
return out_group
732+
733+
734+
def start_dask_run(
735+
class_type,
736+
json_path: Path,
737+
n_workers: int | None = None,
738+
n_threads: int | None = None,
739+
):
740+
"""
741+
Sets Dask config settings.
742+
743+
:param json_path: Path to input file (.ui.json) for the application.
744+
:param n_workers: Number of workers to use.
745+
:param n_threads: Number of threads to use.
746+
"""
747+
ui_json = load_ui_json_as_dict(json_path)
748+
749+
n_workers = ui_json.get("n_workers", n_workers)
750+
n_threads = ui_json.get("n_threads", n_threads)
751+
save_report = ui_json.get("performance_report", False)
752+
753+
if (n_workers is not None and n_workers > 1) or n_threads is not None:
754+
cluster = LocalCluster(
755+
processes=True,
756+
n_workers=n_workers,
757+
threads_per_worker=n_threads,
758+
)
759+
else:
760+
cluster = None
761+
762+
profiler = cProfile.Profile()
763+
profiler.enable()
764+
765+
with (
766+
cluster.get_client()
767+
if cluster is not None
768+
else contextlib.nullcontext() as context_client
769+
):
770+
# Full run
771+
with (
772+
performance_report(filename=json_path.parent / "dask_profile.html")
773+
if (save_report and isinstance(context_client, Client))
774+
else contextlib.nullcontext()
775+
):
776+
class_type.start(json_path)
777+
sys.stdout.close()
778+
779+
profiler.disable()
780+
781+
if save_report:
782+
with open(
783+
json_path.parent / "runtime_profile.txt", encoding="utf-8", mode="w"
784+
) as s:
785+
ps = pstats.Stats(profiler, stream=s)
786+
ps.sort_stats("cumulative")
787+
ps.print_stats()

0 commit comments

Comments
 (0)