Skip to content

Commit 735eb77

Browse files
committed
Refactor DiscretizationCollection constructor
1 parent e38eb1d commit 735eb77

1 file changed

Lines changed: 171 additions & 52 deletions

File tree

grudge/discretization.py

Lines changed: 171 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,10 @@ class DiscretizationCollection:
8080

8181
def __init__(self, array_context: ArrayContext, mesh: Mesh,
8282
order=None,
83-
discr_tag_to_group_factory=None, mpi_communicator=None,
83+
discr_tag_to_group_factory=None,
84+
volume_discr=None,
85+
dist_boundary_connections=None,
86+
mpi_communicator=None,
8487
# FIXME: `quad_tag_to_group_factory` is deprecated
8588
quad_tag_to_group_factory=None):
8689
"""
@@ -93,6 +96,13 @@ def __init__(self, array_context: ArrayContext, mesh: Mesh,
9396
to be carried out, or *None* to indicate that operations with this
9497
discretization tag should be carried out with the standard volume
9598
discretization.
99+
:arg volume_discr: A :class:`meshmode.discretization.Discretization`
100+
object for the base (:class:`grudge.dof_desc.DISCR_TAG_BASE`)
101+
volume discretization.
102+
:arg dist_boundary_connections: A dictionary whose keys denote the
103+
partition group index and map to the appropriate face connections
104+
for distributed boundaries, if any.
105+
:arg mpi_communicator: An (optional) MPI communicator.
96106
"""
97107

98108
if (quad_tag_to_group_factory is not None
@@ -122,8 +132,11 @@ def __init__(self, array_context: ArrayContext, mesh: Mesh,
122132
"one of 'order' and 'discr_tag_to_group_factory' must be given"
123133
)
124134

125-
discr_tag_to_group_factory = {
126-
DISCR_TAG_BASE: PolynomialWarpAndBlendGroupFactory(order=order)}
135+
elment_grp = PolynomialWarpAndBlendGroupFactory(order=order)
136+
self.discr_tag_to_group_factory = {
137+
DISCR_TAG_BASE: elment_grp,
138+
DISCR_TAG_MODAL: _generate_modal_group_factory(elment_grp)
139+
}
127140
else:
128141
if order is not None:
129142
discr_tag_to_group_factory = discr_tag_to_group_factory.copy()
@@ -136,20 +149,23 @@ def __init__(self, array_context: ArrayContext, mesh: Mesh,
136149
discr_tag_to_group_factory[DISCR_TAG_BASE] = \
137150
PolynomialWarpAndBlendGroupFactory(order=order)
138151

139-
# Modal discr should always comes from the base discretization
140-
discr_tag_to_group_factory[DISCR_TAG_MODAL] = \
141-
_generate_modal_group_factory(
142-
discr_tag_to_group_factory[DISCR_TAG_BASE]
143-
)
144-
145-
self.discr_tag_to_group_factory = discr_tag_to_group_factory
152+
if DISCR_TAG_MODAL not in discr_tag_to_group_factory:
153+
discr_tag_to_group_factory[DISCR_TAG_MODAL] = \
154+
_generate_modal_group_factory(
155+
discr_tag_to_group_factory[DISCR_TAG_BASE]
156+
)
146157

147-
from meshmode.discretization import Discretization
158+
self.discr_tag_to_group_factory = discr_tag_to_group_factory
148159

149-
self._volume_discr = Discretization(
150-
array_context, mesh,
151-
self.group_factory_for_discretization_tag(DISCR_TAG_BASE)
152-
)
160+
# FIXME
161+
if volume_discr is None:
162+
from meshmode.discretization import Discretization
163+
self._volume_discr = Discretization(
164+
array_context, mesh,
165+
self.group_factory_for_discretization_tag(DISCR_TAG_BASE)
166+
)
167+
else:
168+
self._volume_discr = volume_discr
153169

154170
# NOTE: Can be removed when symbolics are completely removed
155171
# {{{ management of discretization-scoped common subexpressions
@@ -162,9 +178,17 @@ def __init__(self, array_context: ArrayContext, mesh: Mesh,
162178

163179
# }}}
164180

165-
self._dist_boundary_connections = \
166-
self._set_up_distributed_communication(
167-
mpi_communicator, array_context)
181+
# FIXME
182+
if dist_boundary_connections is None:
183+
self._dist_boundary_connections = \
184+
set_up_distributed_communication(
185+
self._setup_actx, mesh,
186+
self._volume_discr,
187+
self.discr_tag_to_group_factory,
188+
comm=mpi_communicator
189+
)
190+
else:
191+
self._dist_boundary_connections = dist_boundary_connections
168192

169193
self.mpi_communicator = mpi_communicator
170194

@@ -188,40 +212,6 @@ def is_management_rank(self):
188212
return self.mpi_communicator.Get_rank() \
189213
== self.get_management_rank_index()
190214

191-
def _set_up_distributed_communication(self, mpi_communicator, array_context):
192-
from_dd = DOFDesc("vol", DISCR_TAG_BASE)
193-
194-
boundary_connections = {}
195-
196-
from meshmode.distributed import get_connected_partitions
197-
connected_parts = get_connected_partitions(self._volume_discr.mesh)
198-
199-
if connected_parts:
200-
if mpi_communicator is None:
201-
raise RuntimeError("must supply an MPI communicator when using a "
202-
"distributed mesh")
203-
204-
grp_factory = \
205-
self.group_factory_for_discretization_tag(DISCR_TAG_BASE)
206-
207-
local_boundary_connections = {}
208-
for i_remote_part in connected_parts:
209-
local_boundary_connections[i_remote_part] = self.connection_from_dds(
210-
from_dd, DOFDesc(BTAG_PARTITION(i_remote_part),
211-
DISCR_TAG_BASE))
212-
213-
from meshmode.distributed import MPIBoundaryCommSetupHelper
214-
with MPIBoundaryCommSetupHelper(mpi_communicator, array_context,
215-
local_boundary_connections, grp_factory) as bdry_setup_helper:
216-
while True:
217-
conns = bdry_setup_helper.complete_some()
218-
if not conns:
219-
break
220-
for i_remote_part, conn in conns.items():
221-
boundary_connections[i_remote_part] = conn
222-
223-
return boundary_connections
224-
225215
def get_distributed_boundary_swap_connection(self, dd):
226216
warn("`DiscretizationCollection.get_distributed_boundary_swap_connection` "
227217
"is deprecated and will go away in 2022. Use "
@@ -636,6 +626,135 @@ def normal(self, dd):
636626
# }}}
637627

638628

629+
def make_discretization_collection(
630+
array_context: ArrayContext, mesh: Mesh,
631+
order=None,
632+
discr_tag_to_group_factory=None,
633+
mpi_communicator=None) -> DiscretizationCollection:
634+
"""
635+
:arg discr_tag_to_group_factory: A mapping from discretization tags
636+
(typically one of: :class:`grudge.dof_desc.DISCR_TAG_BASE`,
637+
:class:`grudge.dof_desc.DISCR_TAG_MODAL`, or
638+
:class:`grudge.dof_desc.DISCR_TAG_QUAD`) to a
639+
:class:`~meshmode.discretization.poly_element.ElementGroupFactory`
640+
indicating with which type of discretization the operations are
641+
to be carried out, or *None* to indicate that operations with this
642+
discretization tag should be carried out with the standard volume
643+
discretization.
644+
"""
645+
from meshmode.discretization.poly_element import \
646+
PolynomialWarpAndBlendGroupFactory
647+
648+
if discr_tag_to_group_factory is None:
649+
if order is None:
650+
raise TypeError(
651+
"one of 'order' and 'discr_tag_to_group_factory' must be given"
652+
)
653+
654+
# Default choice: warp and blend simplex element group
655+
discr_tag_to_group_factory = {
656+
DISCR_TAG_BASE: PolynomialWarpAndBlendGroupFactory(order=order)
657+
}
658+
else:
659+
if order is not None:
660+
discr_tag_to_group_factory = discr_tag_to_group_factory.copy()
661+
if DISCR_TAG_BASE in discr_tag_to_group_factory:
662+
raise ValueError(
663+
"if 'order' is given, 'discr_tag_to_group_factory' must "
664+
"not have a key of DISCR_TAG_BASE"
665+
)
666+
667+
discr_tag_to_group_factory[DISCR_TAG_BASE] = \
668+
PolynomialWarpAndBlendGroupFactory(order=order)
669+
670+
# Modal discr should always comes from the base discretization
671+
discr_tag_to_group_factory[DISCR_TAG_MODAL] = \
672+
_generate_modal_group_factory(
673+
discr_tag_to_group_factory[DISCR_TAG_BASE]
674+
)
675+
676+
from meshmode.discretization import Discretization
677+
678+
# Define the base discretization
679+
volume_discr = Discretization(
680+
array_context, mesh,
681+
discr_tag_to_group_factory[DISCR_TAG_BASE]
682+
)
683+
684+
# Define boundary connections
685+
dist_boundary_connections = set_up_distributed_communication(
686+
array_context, mesh,
687+
volume_discr,
688+
discr_tag_to_group_factory, comm=mpi_communicator
689+
)
690+
691+
return DiscretizationCollection(
692+
array_context, mesh, order=order,
693+
discr_tag_to_group_factory=discr_tag_to_group_factory,
694+
volume_discr=volume_discr,
695+
dist_boundary_connections=dist_boundary_connections,
696+
mpi_communicator=mpi_communicator
697+
)
698+
699+
700+
def set_up_distributed_communication(
701+
array_context: ArrayContext, mesh: Mesh,
702+
volume_discr,
703+
discr_tag_to_group_factory, comm=None) -> dict:
704+
"""
705+
:arg volume_discr: A :class:`meshmode.discretization.Discretization`
706+
object for the base (:class:`grudge.dof_desc.DISCR_TAG_BASE`)
707+
volume discretization.
708+
:arg discr_tag_to_group_factory: A mapping from discretization tags
709+
(typically one of: :class:`grudge.dof_desc.DISCR_TAG_BASE`,
710+
:class:`grudge.dof_desc.DISCR_TAG_MODAL`, or
711+
:class:`grudge.dof_desc.DISCR_TAG_QUAD`) to a
712+
:class:`~meshmode.discretization.poly_element.ElementGroupFactory`
713+
indicating with which type of discretization the operations are
714+
to be carried out.
715+
:arg comm: An MPI communicator.
716+
"""
717+
from_dd = DOFDesc("vol", DISCR_TAG_BASE)
718+
719+
boundary_connections = {}
720+
721+
from meshmode.distributed import get_connected_partitions
722+
723+
connected_parts = get_connected_partitions(mesh)
724+
725+
if connected_parts:
726+
if comm is None:
727+
raise RuntimeError(
728+
"Must supply an MPI communicator when using a "
729+
"distributed mesh"
730+
)
731+
732+
grp_factory = discr_tag_to_group_factory[DISCR_TAG_BASE]
733+
734+
local_boundary_connections = {}
735+
for i_remote_part in connected_parts:
736+
to_dd = DOFDesc(BTAG_PARTITION(i_remote_part), DISCR_TAG_BASE)
737+
local_boundary_connections[i_remote_part] = \
738+
make_face_restriction(array_context,
739+
volume_discr,
740+
grp_factory,
741+
boundary_tag=to_dd.domain_tag.tag)
742+
743+
from meshmode.distributed import MPIBoundaryCommSetupHelper
744+
745+
with MPIBoundaryCommSetupHelper(comm, array_context,
746+
local_boundary_connections,
747+
grp_factory) as bdry_setup_helper:
748+
while True:
749+
conns = bdry_setup_helper.complete_some()
750+
if not conns:
751+
break
752+
for i_remote_part, conn in conns.items():
753+
boundary_connections[i_remote_part] = conn
754+
755+
return boundary_connections
756+
757+
639758
class DGDiscretizationWithBoundaries(DiscretizationCollection):
640759
def __init__(self, *args, **kwargs):
641760
warn("DGDiscretizationWithBoundaries is deprecated and will go away "

0 commit comments

Comments
 (0)