Skip to content

Commit 013bb51

Browse files
committed
compiler: switch to better halo placement
1 parent 86fa36f commit 013bb51

2 files changed

Lines changed: 15 additions & 30 deletions

File tree

devito/ir/clusters/algorithms.py

Lines changed: 11 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -442,8 +442,8 @@ def callback(self, clusters, prefix, seen=None):
442442
d = prefix[-1].dim
443443

444444
# Construct a representation of the halo accesses
445-
processed = []
446-
for i, c in enumerate(clusters):
445+
processed = list(clusters)
446+
for n, c in enumerate(clusters):
447447
if c.properties.is_sequential(d) or \
448448
c in seen:
449449
continue
@@ -453,10 +453,6 @@ def callback(self, clusters, prefix, seen=None):
453453
not d._defines & hs.distributed_aindices:
454454
continue
455455

456-
if any(_halo_write(ci, hs) for ci in clusters[:i]):
457-
# If there's a halo write before `c`, then we cannot inject the HaloTouch
458-
continue
459-
460456
points = set()
461457
for f in hs.fmapper:
462458
for a in c.scope.getreads(f):
@@ -484,10 +480,16 @@ def callback(self, clusters, prefix, seen=None):
484480

485481
halo_touch = c.rebuild(exprs=expr, ispace=ispace, properties=properties)
486482

487-
processed.append(halo_touch)
488-
seen.update({halo_touch, c})
483+
# Insert `halo_touch` at the top of the IterationSpace within which
484+
# `c` is scheduled
485+
index = 0
486+
for i in reversed(range(n)):
487+
if not processed[i].ispace.is_subset(c.ispace):
488+
index = i + 1
489+
break
490+
processed.insert(index, halo_touch)
489491

490-
processed.extend(clusters)
492+
seen.update({halo_touch, c})
491493

492494
return processed
493495

@@ -785,21 +787,3 @@ def normalize_reductions_sparse(cluster, sregistry):
785787
processed.append(e)
786788

787789
return cluster.rebuild(processed)
788-
789-
790-
def _halo_write(c, hs):
791-
"""
792-
Check if the cluster `c` writes into any of the local values read by `hs`.
793-
"""
794-
for f in hs.fmapper:
795-
if not any(f.grid.distributor.topology.get(d, 1) > 1
796-
for d in hs.dimensions):
797-
# Not distributed halo dimension, write does not impact the halo exchange
798-
continue
799-
800-
if any(set(a.access.indices) & hs.loc_values for a in c.scope.getwrites(f)):
801-
# Writing into a local value, which is read by the halo exchange,
802-
# creates a write dependency
803-
return True
804-
805-
return False

tests/test_mpi.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2752,9 +2752,10 @@ def test_haloupdate_same_timestep_v2(self, mode):
27522752

27532753
titer = op.body.body[-1].body[0]
27542754
assert titer.dim is grid.time_dim
2755-
assert titer.nodes[0].body[0].body[0].is_List
2756-
assert len(titer.nodes[0].body[0].body[0].body[0].body) == 1
2757-
assert titer.nodes[0].body[0].body[0].body[0].body[0].is_Call
2755+
block = titer.nodes[0].body[0].body[1]
2756+
assert block.is_List
2757+
assert len(block.body) == 3
2758+
assert block.body[0].body[0].is_Call
27582759

27592760
op.apply(time=0)
27602761

0 commit comments

Comments
 (0)