4040from dataclasses import dataclass
4141from contextlib import contextmanager
4242import numpy as np
43- from typing import List , Set , Union , Mapping , cast , Sequence , TYPE_CHECKING
43+ from typing import (
44+ Any , Optional , List , Set , Union , Mapping , cast , Sequence , TYPE_CHECKING
45+ )
4446
4547from arraycontext import ArrayContext
4648from meshmode .discretization .connection import (
@@ -80,14 +82,18 @@ def _duplicate_mpi_comm(mpi_comm):
8082 dup_comm .Free ()
8183
8284
83- def mpi_distribute (mpi_comm , source_data = None , source_rank = 0 ):
85+ def mpi_distribute (
86+ mpi_comm : "mpi4py.MPI.Intracomm" ,
87+ source_data : Optional [Mapping [int , Any ]] = None ,
88+ source_rank : int = 0 ) -> Optional [Any ]:
8489 """
8590 Distribute data to a set of processes.
8691
8792 :arg mpi_comm: An ``MPI.Intracomm``
8893 :arg source_data: A :class:`dict` mapping destination ranks to data to be sent.
8994 Only present on the source rank.
9095 :arg source_rank: The rank from which the data is being sent.
96+
9197 :returns: The data local to the current process if there is any, otherwise
9298 *None*.
9399 """
@@ -101,7 +107,7 @@ def mpi_distribute(mpi_comm, source_data=None, source_rank=0):
101107 if source_data is None :
102108 raise TypeError ("source rank has no data." )
103109
104- sending_to = np . full ( num_proc , False )
110+ sending_to = [ False ] * num_proc
105111 for dest_rank in source_data .keys ():
106112 sending_to [dest_rank ] = True
107113
@@ -121,7 +127,7 @@ def mpi_distribute(mpi_comm, source_data=None, source_rank=0):
121127 MPI .Request .waitall (reqs )
122128
123129 else :
124- receiving = mpi_comm .scatter (None , root = source_rank )
130+ receiving = mpi_comm .scatter ([] , root = source_rank )
125131
126132 if receiving :
127133 local_data = mpi_comm .recv (source = source_rank )
0 commit comments