Skip to content

Commit 7b6354f

Browse files
authored
Merge pull request #1552 from Libensemble/refactor/310_plus_typehints
Refactor/310 plus typehints
2 parents 5892690 + c1040ff commit 7b6354f

26 files changed

Lines changed: 327 additions & 297 deletions

docs/data_structures/libE_specs.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -205,14 +205,14 @@ libEnsemble is primarily customized by setting options within a ``LibeSpecs`` cl
205205
**save_H_and_persis_on_abort** [bool] = ``True``:
206206
Save states of ``H`` and ``persis_info`` to file on aborting after an exception.
207207

208-
**save_H_on_completion** Optional[bool] = ``False``
208+
**save_H_on_completion** bool | None = ``False``
209209
Save state of ``H`` to file upon completing a workflow. Also enabled when either ``save_every_k_sims``
210210
or ``save_every_k_gens`` is set.
211211

212-
**save_H_with_date** Optional[bool] = ``False``
212+
**save_H_with_date** bool | None = ``False``
213213
Save ``H`` filename contains date and timestamp.
214214

215-
**H_file_prefix** Optional[str] = ``"libE_history"``
215+
**H_file_prefix** str | None = ``"libE_history"``
216216
Prefix for ``H`` filename.
217217

218218
**use_persis_return_gen** [bool] = ``False``:

libensemble/alloc_funcs/give_sim_work_first.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import time
2-
from typing import Tuple
32

43
import numpy as np
54
import numpy.typing as npt
@@ -15,7 +14,7 @@ def give_sim_work_first(
1514
alloc_specs: dict,
1615
persis_info: dict,
1716
libE_info: dict,
18-
) -> Tuple[dict]:
17+
) -> tuple[dict]:
1918
"""
2019
Decide what should be given to workers. This allocation function gives any
2120
available simulation work first, and only when all simulations are

libensemble/ensemble.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import importlib
22
import json
33
import logging
4-
from typing import Optional
54

65
import numpy.typing as npt
76
import tomli
@@ -270,15 +269,15 @@ class Ensemble:
270269

271270
def __init__(
272271
self,
273-
sim_specs: Optional[SimSpecs] = SimSpecs(),
274-
gen_specs: Optional[GenSpecs] = GenSpecs(),
275-
exit_criteria: Optional[ExitCriteria] = {},
276-
libE_specs: Optional[LibeSpecs] = LibeSpecs(),
277-
alloc_specs: Optional[AllocSpecs] = AllocSpecs(),
278-
persis_info: Optional[dict] = {},
279-
executor: Optional[Executor] = None,
280-
H0: Optional[npt.NDArray] = None,
281-
parse_args: Optional[bool] = False,
272+
sim_specs: SimSpecs | None = SimSpecs(),
273+
gen_specs: GenSpecs | None = GenSpecs(),
274+
exit_criteria: ExitCriteria | None = {},
275+
libE_specs: LibeSpecs | None = LibeSpecs(),
276+
alloc_specs: AllocSpecs | None = AllocSpecs(),
277+
persis_info: dict | None = {},
278+
executor: Executor | None = None,
279+
H0: npt.NDArray | None = None,
280+
parse_args: bool | None = False,
282281
):
283282
self.sim_specs = sim_specs
284283
self.gen_specs = gen_specs

libensemble/executors/balsam_executor.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,12 @@ class HelloApp(ApplicationDefinition):
7474
.. _Globus: https://www.globus.org/
7575
"""
7676

77+
from __future__ import annotations
78+
7779
import datetime
7880
import logging
7981
import os
8082
import time
81-
from typing import Any, Dict, List, Optional, Union
8283

8384
from balsam import util
8485

@@ -106,9 +107,9 @@ class BalsamTask(Task):
106107

107108
def __init__(
108109
self,
109-
app: Optional[Application] = None,
110+
app: Application | None = None,
110111
app_args: dict = None,
111-
workdir: Optional[str] = None,
112+
workdir: str | None = None,
112113
stdout: str = None,
113114
stderr: str = None,
114115
workerid: int = None,
@@ -122,7 +123,7 @@ def __init__(
122123
# May want to override workdir with Balsam value when it exists
123124
Task.__init__(self, app, app_args, workdir, stdout, stderr, workerid)
124125

125-
def _get_time_since_balsam_submit(self) -> Union[int, float]:
126+
def _get_time_since_balsam_submit(self) -> int:
126127
"""Return time since balsam task entered ``RUNNING`` state"""
127128
event_query = EventLog.objects.filter(job_id=self.process.id, to_state="RUNNING")
128129
if not len(event_query):
@@ -203,7 +204,7 @@ def poll(self) -> None:
203204
self.state = "FAILED"
204205
self._set_complete()
205206

206-
def wait(self, timeout: Optional[int] = None) -> None:
207+
def wait(self, timeout: int | None = None) -> None:
207208
"""Waits on completion of the task or raises ``TimeoutExpired``.
208209
209210
Status attributes of task are updated on completion.
@@ -280,10 +281,10 @@ def add_app(self, *args) -> None:
280281
def register_app(
281282
self,
282283
BalsamApp: ApplicationDefinition,
283-
app_name: Optional[str] = None,
284-
calc_type: Optional[str] = None,
284+
app_name: str | None = None,
285+
calc_type: str | None = None,
285286
desc: str = None,
286-
precedent: Optional[str] = None,
287+
precedent: str | None = None,
287288
) -> None:
288289
"""Registers a Balsam ``ApplicationDefinition`` to libEnsemble. This class
289290
instance *must* have a ``site`` and ``command_template`` specified. See
@@ -331,9 +332,9 @@ def submit_allocation(
331332
job_mode: str = "mpi",
332333
queue: str = "local",
333334
project: str = "local",
334-
optional_params: Dict[Any, Any] = {},
335-
filter_tags: Dict[Any, Any] = {},
336-
partitions: List[Any] = [],
335+
optional_params: dict = {},
336+
filter_tags: dict = {},
337+
partitions: list = [],
337338
) -> BatchJob:
338339
"""
339340
Submits a Balsam ``BatchJob`` machine allocation request to Balsam.
@@ -435,14 +436,14 @@ def set_resources(self, resources: str) -> None:
435436

436437
def submit(
437438
self,
438-
calc_type: Optional[str] = None,
439-
app_name: Optional[str] = None,
439+
calc_type: str | None = None,
440+
app_name: str | None = None,
440441
app_args: dict = None,
441442
num_procs: int = None,
442443
num_nodes: int = None,
443444
procs_per_node: int = None,
444445
max_tasks_per_node: int = None,
445-
machinefile: Optional[str] = None,
446+
machinefile: str | None = None,
446447
gpus_per_rank: int = 0,
447448
transfers: dict = {},
448449
workdir: str = "",

libensemble/executors/executor.py

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import sys
1313
import time
1414
from pathlib import Path
15-
from typing import Any, Optional, Union
15+
from typing import Any
1616

1717
import libensemble.utils.launcher as launcher
1818
from libensemble.message_numbers import (
@@ -24,7 +24,6 @@
2424
WORKER_DONE,
2525
WORKER_KILL_ON_TIMEOUT,
2626
)
27-
from libensemble.resources.resources import Resources
2827
from libensemble.utils.timer import TaskTimer
2928

3029
logger = logging.getLogger(__name__)
@@ -78,10 +77,10 @@ class Application:
7877
def __init__(
7978
self,
8079
full_path: str,
81-
name: Optional[str] = None,
82-
calc_type: Optional[str] = "sim",
83-
desc: Optional[str] = None,
84-
pyobj: Optional[Any] = None, # used by balsam_executor to store ApplicationDefinition
80+
name: str | None = None,
81+
calc_type: str | None = "sim",
82+
desc: str | None = None,
83+
pyobj: Any | None = None, # used by balsam_executor to store ApplicationDefinition
8584
precedent: str = "",
8685
) -> None:
8786
"""Instantiates a new Application instance."""
@@ -101,7 +100,7 @@ def __init__(
101100
self.app_cmd = " ".join(filter(None, [self.precedent, self.full_path]))
102101

103102

104-
def jassert(test: Optional[Union[Application, bool]], *args) -> None:
103+
def jassert(test: Application | bool | None, *args) -> None:
105104
"Version of assert that raises a ExecutorException"
106105
if not test:
107106
raise ExecutorException(*args)
@@ -170,7 +169,7 @@ def _add_to_env(self, key, value):
170169
"""Add to task environment - overwrites if already set"""
171170
self.env[key] = value
172171

173-
def workdir_exists(self) -> Optional[bool]:
172+
def workdir_exists(self) -> bool | None:
174173
"""Returns true if the task's workdir exists"""
175174
return self.workdir and os.path.exists(self.workdir)
176175

@@ -260,7 +259,7 @@ def poll(self) -> None:
260259

261260
self._set_complete()
262261

263-
def wait(self, timeout: Optional[float] = None) -> None:
262+
def wait(self, timeout: float | None = None) -> None:
264263
"""Waits on completion of the task or raises TimeoutExpired exception
265264
266265
Status attributes of task are updated on completion.
@@ -288,7 +287,7 @@ def wait(self, timeout: Optional[float] = None) -> None:
288287

289288
self._set_complete()
290289

291-
def result(self, timeout: Optional[Union[int, float]] = None) -> str:
290+
def result(self, timeout: int | float | None = None) -> str:
292291
"""Wrapper for task.wait() that also returns the task's status on completion.
293292
294293
Parameters
@@ -303,7 +302,7 @@ def result(self, timeout: Optional[Union[int, float]] = None) -> str:
303302
self.wait(timeout=timeout)
304303
return self.state
305304

306-
def exception(self, timeout: Optional[Union[int, float]] = None):
305+
def exception(self, timeout: int | float | None = None):
307306
"""Wrapper for task.wait() that instead returns the task's error code on completion.
308307
309308
Parameters
@@ -386,7 +385,7 @@ class Executor:
386385

387386
executor = None
388387

389-
def _wait_on_start(self, task: Task, fail_time: Optional[int] = None) -> None:
388+
def _wait_on_start(self, task: Task, fail_time: int | None = None) -> None:
390389
"""Called by submit when wait_on_start is True.
391390
392391
Blocks until task polls as having started.
@@ -472,7 +471,7 @@ def default_app(self, calc_type: str) -> Application:
472471
jassert(app, f"Default {calc_type} app is not set")
473472
return app
474473

475-
def set_resources(self, resources: Resources):
474+
def set_resources(self, resources):
476475
# Does not use resources
477476
pass
478477

@@ -493,9 +492,9 @@ def set_gen_procs_gpus(self, libE_info):
493492
def register_app(
494493
self,
495494
full_path: str,
496-
app_name: Optional[str] = None,
497-
calc_type: Optional[str] = None,
498-
desc: Optional[str] = None,
495+
app_name: str | None = None,
496+
calc_type: str | None = None,
497+
desc: str | None = None,
499498
precedent: str = "",
500499
) -> None:
501500
"""Registers a user application to libEnsemble.
@@ -571,7 +570,7 @@ def manager_kill_received(self) -> bool:
571570
return False
572571

573572
def polling_loop(
574-
self, task: Task, timeout: Optional[int] = None, delay: float = 0.1, poll_manager: bool = False
573+
self, task: Task, timeout: int | None = None, delay: float = 0.1, poll_manager: bool = False
575574
) -> int:
576575
"""Optional, blocking, generic task status polling loop. Operates until the task
577576
finishes, times out, or is optionally killed via a manager signal. On completion, returns a
@@ -637,7 +636,7 @@ def polling_loop(
637636

638637
return calc_status
639638

640-
def get_task(self, taskid: Union[str, int]) -> Optional[Task]:
639+
def get_task(self, taskid: str | int) -> Task | None:
641640
"""Returns the task object for the supplied task ID"""
642641
task = next((j for j in self.list_of_tasks if j.id == taskid), None)
643642
if task is None:
@@ -681,14 +680,14 @@ def _check_app_exists(self, full_path: str) -> None:
681680

682681
def submit(
683682
self,
684-
calc_type: Optional[str] = None,
685-
app_name: Optional[str] = None,
686-
app_args: Optional[str] = None,
687-
stdout: Optional[str] = None,
688-
stderr: Optional[str] = None,
689-
dry_run: Optional[bool] = False,
690-
wait_on_start: Optional[bool] = False,
691-
env_script: Optional[str] = None,
683+
calc_type: str | None = None,
684+
app_name: str | None = None,
685+
app_args: str | None = None,
686+
stdout: str | None = None,
687+
stderr: str | None = None,
688+
dry_run: bool | None = False,
689+
wait_on_start: bool | None = False,
690+
env_script: str | None = None,
692691
) -> Task:
693692
"""Create a new task and run as a local serial subprocess.
694693

libensemble/executors/mpi_executor.py

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,11 @@
1515
import logging
1616
import os
1717
import time
18-
from typing import List, Optional, Union
1918

2019
import libensemble.utils.launcher as launcher
2120
from libensemble.executors.executor import Executor, ExecutorException, Task
2221
from libensemble.executors.mpi_runner import MPIRunner
2322
from libensemble.resources.mpi_resources import get_MPI_variant
24-
from libensemble.resources.resources import Resources
2523

2624
logger = logging.getLogger(__name__)
2725
# To change logging level for just this module
@@ -137,11 +135,11 @@ def set_gen_procs_gpus(self, libE_info):
137135
self.gen_nprocs = libE_info.get("num_procs")
138136
self.gen_ngpus = libE_info.get("num_gpus")
139137

140-
def set_resources(self, resources: Resources) -> None:
138+
def set_resources(self, resources) -> None:
141139
self.resources = resources
142140

143141
def _launch_with_retries(
144-
self, task: Task, subgroup_launch: bool, wait_on_start: bool, run_cmd: List[str], use_shell: bool
142+
self, task: Task, subgroup_launch: bool, wait_on_start: bool, run_cmd: list[str], use_shell: bool
145143
) -> None:
146144
"""Launch task with retry mechanism"""
147145
retry_count = 0
@@ -189,25 +187,25 @@ def _launch_with_retries(
189187

190188
def submit(
191189
self,
192-
calc_type: Optional[str] = None,
193-
app_name: Optional[str] = None,
194-
num_procs: Optional[int] = None,
195-
num_nodes: Optional[int] = None,
196-
procs_per_node: Optional[int] = None,
197-
num_gpus: Optional[int] = None,
198-
machinefile: Optional[str] = None,
199-
app_args: Optional[str] = None,
200-
stdout: Optional[str] = None,
201-
stderr: Optional[str] = None,
202-
stage_inout: Optional[str] = None,
203-
hyperthreads: Optional[bool] = False,
204-
dry_run: Optional[bool] = False,
205-
wait_on_start: Optional[bool] = False,
206-
extra_args: Optional[str] = None,
207-
auto_assign_gpus: Optional[bool] = False,
208-
match_procs_to_gpus: Optional[bool] = False,
209-
env_script: Optional[str] = None,
210-
mpi_runner_type: Optional[Union[str, dict]] = None,
190+
calc_type: str | None = None,
191+
app_name: str | None = None,
192+
num_procs: int | None = None,
193+
num_nodes: int | None = None,
194+
procs_per_node: int | None = None,
195+
num_gpus: int | None = None,
196+
machinefile: str | None = None,
197+
app_args: str | None = None,
198+
stdout: str | None = None,
199+
stderr: str | None = None,
200+
stage_inout: str | None = None,
201+
hyperthreads: bool | None = False,
202+
dry_run: bool | None = False,
203+
wait_on_start: bool | None = False,
204+
extra_args: str | None = None,
205+
auto_assign_gpus: bool | None = False,
206+
match_procs_to_gpus: bool | None = False,
207+
env_script: str | None = None,
208+
mpi_runner_type: str | dict | None = None,
211209
) -> Task:
212210
"""Creates a new task, and either executes or schedules execution.
213211

0 commit comments

Comments
 (0)