Skip to content

Commit 86fa36f

Browse files
committed
compiler: refine distributed dimension check for smarter halotouch
1 parent 7633a1a commit 86fa36f

4 files changed

Lines changed: 21 additions & 15 deletions

File tree

devito/ir/clusters/algorithms.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ def callback(self, clusters, prefix, seen=None):
453453
not d._defines & hs.distributed_aindices:
454454
continue
455455

456-
if any(halo_write(ci, hs) for ci in clusters[:i]):
456+
if any(_halo_write(ci, hs) for ci in clusters[:i]):
457457
# If there's a halo write before `c`, then we cannot inject the HaloTouch
458458
continue
459459

@@ -787,12 +787,19 @@ def normalize_reductions_sparse(cluster, sregistry):
787787
return cluster.rebuild(processed)
788788

789789

790-
def halo_write(c, hs):
791-
loc_vals = hs.loc_values
792-
790+
def _halo_write(c, hs):
791+
"""
792+
Check if the cluster `c` writes into any of the local values read by `hs`.
793+
"""
793794
for f in hs.fmapper:
794-
for a in c.scope.getwrites(f):
795-
if set(a.access.indices) & loc_vals:
796-
return True
795+
if not any(f.grid.distributor.topology.get(d, 1) > 1
796+
for d in hs.dimensions):
797+
# Not distributed halo dimension, write does not impact the halo exchange
798+
continue
799+
800+
if any(set(a.access.indices) & hs.loc_values for a in c.scope.getwrites(f)):
801+
# Writing into a local value, which is read by the halo exchange,
802+
# creates a write dependency
803+
return True
797804

798805
return False

devito/mpi/distributed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def nprocs_local(self):
261261

262262
@property
263263
def topology(self):
264-
return self._topology
264+
return DimensionTuple(*self._topology, getters=self.dimensions)
265265

266266
@property
267267
def topology_logical(self):

devito/mpi/halo_scheme.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -642,12 +642,11 @@ def classify(exprs, ispace):
642642
f"scheme for `{f}` along Dimension `{d}`")
643643
elif hl.pop() is STENCIL:
644644
halos.append(Halo(d, s))
645-
else:
645+
elif d._defines & set(ispace.itdims):
646646
raw_loc_indices[d].append(s)
647647

648648
loc_indices, loc_dirs = process_loc_indices(raw_loc_indices,
649649
ispace.directions)
650-
651650
mapper[f] = HaloSchemeEntry(loc_indices, loc_dirs, halos, dims)
652651

653652
return mapper

tests/test_mpi.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import numpy as np
44
import pytest
5-
from test_dse import TestTTI
65

76
from conftest import _R, assert_blocking, assert_structure, body0
87
from devito import (
@@ -2216,6 +2215,7 @@ def test_halo_inner_dim(self, mode):
22162215
eq = [Eq(v.forward, v + 1), Eq(e, v.forward.dydz)]
22172216

22182217
op = Operator(eq, opt=('advanced', {'blocklevels': 0}))
2218+
22192219
assert_structure(op, ['txyz', 't', 'txyz', 'txyz'], 'txyzxyzz')
22202220
op(time=100)
22212221

@@ -2754,7 +2754,7 @@ def test_haloupdate_same_timestep_v2(self, mode):
27542754
assert titer.dim is grid.time_dim
27552755
assert titer.nodes[0].body[0].body[0].is_List
27562756
assert len(titer.nodes[0].body[0].body[0].body[0].body) == 1
2757-
assert not titer.nodes[0].body[0].body[0].body[0].body[0].is_Call
2757+
assert titer.nodes[0].body[0].body[0].body[0].body[0].is_Call
27582758

27592759
op.apply(time=0)
27602760

@@ -3156,8 +3156,8 @@ def test_fission_due_to_antidep(self, mode):
31563156
# First, check the generated code
31573157
assert_structure(op1, ['t',
31583158
't,x0_blk0,y0_blk0,x,y,z',
3159-
't,x1_blk0,y1_blk0,x,y,z'],
3160-
'tx0_blk0y0_blk0xyzx1_blk0y1_blk0xyz')
3159+
't,x0_blk0,y0_blk0,x,y,z'],
3160+
'tx0_blk0y0_blk0xyzz')
31613161

31623162
def init(f, v=1):
31633163
f.data[:] = np.indices(grid.shape).sum(axis=0) % (.004*v) + .01
@@ -3531,9 +3531,9 @@ def test_issue_2448_backward(self, mode):
35313531

35323532
class TestTTIOp:
35333533

3534-
@pytest.mark.skipif(TestTTI is None, reason="Requires installing the tests")
35353534
@pytest.mark.parallel(mode=1)
35363535
def test_halo_structure(self, mode):
3536+
from test_dse import TestTTI
35373537
solver = TestTTI().tti_operator(opt='advanced', space_order=8)
35383538
op = solver.op_fwd(save=False)
35393539

0 commit comments

Comments
 (0)