Skip to content

Commit 004c95b

Browse files
committed
add type hints to mpi_distribute
1 parent 35e4ce9 commit 004c95b

1 file changed

Lines changed: 10 additions & 4 deletions

File tree

meshmode/distributed.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@
4040
from dataclasses import dataclass
4141
from contextlib import contextmanager
4242
import numpy as np
43-
from typing import List, Set, Union, Mapping, cast, Sequence, TYPE_CHECKING
43+
from typing import (
44+
Any, Optional, List, Set, Union, Mapping, cast, Sequence, TYPE_CHECKING
45+
)
4446

4547
from arraycontext import ArrayContext
4648
from meshmode.discretization.connection import (
@@ -80,14 +82,18 @@ def _duplicate_mpi_comm(mpi_comm):
8082
dup_comm.Free()
8183

8284

83-
def mpi_distribute(mpi_comm, source_data=None, source_rank=0):
85+
def mpi_distribute(
86+
mpi_comm: "mpi4py.MPI.Intracomm",
87+
source_data: Optional[Mapping[int, Any]] = None,
88+
source_rank: int = 0) -> Optional[Any]:
8489
"""
8590
Distribute data to a set of processes.
8691
8792
:arg mpi_comm: An ``MPI.Intracomm``
8893
:arg source_data: A :class:`dict` mapping destination ranks to data to be sent.
8994
Only present on the source rank.
9095
:arg source_rank: The rank from which the data is being sent.
96+
9197
:returns: The data local to the current process if there is any, otherwise
9298
*None*.
9399
"""
@@ -101,7 +107,7 @@ def mpi_distribute(mpi_comm, source_data=None, source_rank=0):
101107
if source_data is None:
102108
raise TypeError("source rank has no data.")
103109

104-
sending_to = np.full(num_proc, False)
110+
sending_to = [False] * num_proc
105111
for dest_rank in source_data.keys():
106112
sending_to[dest_rank] = True
107113

@@ -121,7 +127,7 @@ def mpi_distribute(mpi_comm, source_data=None, source_rank=0):
121127
MPI.Request.waitall(reqs)
122128

123129
else:
124-
receiving = mpi_comm.scatter(None, root=source_rank)
130+
receiving = mpi_comm.scatter([], root=source_rank)
125131

126132
if receiving:
127133
local_data = mpi_comm.recv(source=source_rank)

0 commit comments

Comments
 (0)