Skip to content

Commit 7633a1a

Browse files
committed
compiler: fix halo placement for non out dimm exchange
1 parent d9dd186 commit 7633a1a

4 files changed

Lines changed: 65 additions & 33 deletions

File tree

devito/ir/clusters/algorithms.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ def callback(self, clusters, prefix, seen=None):
443443

444444
# Construct a representation of the halo accesses
445445
processed = []
446-
for c in clusters:
446+
for i, c in enumerate(clusters):
447447
if c.properties.is_sequential(d) or \
448448
c in seen:
449449
continue
@@ -453,6 +453,10 @@ def callback(self, clusters, prefix, seen=None):
453453
not d._defines & hs.distributed_aindices:
454454
continue
455455

456+
if any(halo_write(ci, hs) for ci in clusters[:i]):
457+
# If there's a halo write before `c`, then we cannot inject the HaloTouch
458+
continue
459+
456460
points = set()
457461
for f in hs.fmapper:
458462
for a in c.scope.getreads(f):
@@ -781,3 +785,14 @@ def normalize_reductions_sparse(cluster, sregistry):
781785
processed.append(e)
782786

783787
return cluster.rebuild(processed)
788+
789+
790+
def halo_write(c, hs):
791+
loc_vals = hs.loc_values
792+
793+
for f in hs.fmapper:
794+
for a in c.scope.getwrites(f):
795+
if set(a.access.indices) & loc_vals:
796+
return True
797+
798+
return False

devito/mpi/halo_scheme.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from operator import attrgetter
66

77
import sympy
8-
from sympy import Max, Min
8+
from sympy import Max, Min, S
99

1010
from devito import configuration
1111
from devito.data import CENTER, CORE, LEFT, OWNED, RIGHT
@@ -514,6 +514,31 @@ def merge(self, hs):
514514
fmapper[f] = fmapper.get(f, hse).merge(hse)
515515
return HaloScheme.build(fmapper, self.honored)
516516

517+
def _is_iter_carried(self, scope):
518+
"""
519+
True if the HaloScheme is iteration-carried, i.e., it induces
520+
a halo exchange that requires values from the previous iteration(s); False
521+
otherwise.
522+
"""
523+
524+
def rule0(dep):
525+
# E.g., `dep=W<f,[t1, x]> -> R<f,[t0, x-1]>`, `d=t` => OK
526+
return not any(dep.distance_mapper[d] is S.Infinity for d in dep.cause)
527+
528+
def rule1(dep, loc_indices):
529+
# E.g., `dep=W<f,[t1, x+1]> -> R<f,[t1, xl+1]>`, `loc_indices={t: t0}` => OK
530+
return any(dep.distance_mapper[d] == 0 and
531+
dep.source[d] is not v and
532+
dep.sink[d] is not v
533+
for d, v in loc_indices.items())
534+
535+
for f, v in self.fmapper.items():
536+
for dep in scope.d_flow.project(f):
537+
if not rule0(dep) and not rule1(dep, v.loc_indices):
538+
return False
539+
540+
return True
541+
517542

518543
def classify(exprs, ispace):
519544
"""

devito/passes/iet/mpi.py

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def _hoist_invariant(iet):
206206
# Ensure there's another HaloScheme that could cover for
207207
# us should we get hoisted while still satisfying the
208208
# data dependences
209-
if hsf1.issubset(hsf0) and _is_iter_carried(hsf1, scope):
209+
if hsf1.issubset(hsf0) and hsf1._is_iter_carried(scope):
210210
hs, hsf = hs1, hsf1
211211
elif hsf0.issubset(hsf1) and hs0 is halo_spots[0]:
212212
# Special case
@@ -474,32 +474,6 @@ def _check_control_flow(hs0, hs1, cond_mapper):
474474
return cond0 != cond1
475475

476476

477-
def _is_iter_carried(hsf, scope):
478-
"""
479-
True if the provided HaloScheme `hsf` is iteration-carried, i.e., it induces
480-
a halo exchange that requires values from the previous iteration(s); False
481-
otherwise.
482-
"""
483-
484-
def rule0(dep):
485-
# E.g., `dep=W<f,[t1, x]> -> R<f,[t0, x-1]>`, `d=t` => OK
486-
return not any(dep.distance_mapper[d] is S.Infinity for d in dep.cause)
487-
488-
def rule1(dep, loc_indices):
489-
# E.g., `dep=W<f,[t1, x+1]> -> R<f,[t1, xl+1]>`, `loc_indices={t: t0}` => OK
490-
return any(dep.distance_mapper[d] == 0 and
491-
dep.source[d] is not v and
492-
dep.sink[d] is not v
493-
for d, v in loc_indices.items())
494-
495-
for f, v in hsf.fmapper.items():
496-
for dep in scope.d_flow.project(f):
497-
if not rule0(dep) and not rule1(dep, v.loc_indices):
498-
return False
499-
500-
return True
501-
502-
503477
def _is_mergeable(hsf0, hsf1, scope):
504478
"""
505479
True if `hsf1` can be merged into `hsf0`, i.e., if they are compatible
@@ -515,7 +489,7 @@ def _is_mergeable(hsf0, hsf1, scope):
515489
return False
516490

517491
# Finally, check the data dependences would be satisfied
518-
return _is_iter_carried(hsf1, scope)
492+
return hsf1._is_iter_carried(scope)
519493

520494

521495
def _semantical_eq_loc_indices(hsf0, hsf1):

tests/test_mpi.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2203,6 +2203,24 @@ def test_lift_halo_update_outside_distributed(self, mode):
22032203
halo_update = tloop.nodes[0].body[0].body[0].body[0]
22042204
assert isinstance(halo_update, HaloUpdateList)
22052205

2206+
@pytest.mark.parallel(mode=4)
2207+
def test_halo_inner_dim(self, mode):
2208+
grid = Grid((11, 11, 11))
2209+
2210+
np.random.seed(0)
2211+
v = TimeFunction(name="v", grid=grid, space_order=4,
2212+
time_order=1, save=Buffer(1))
2213+
v.data[:] = np.random.randn(*grid.shape)
2214+
e = TimeFunction(name="dummy", grid=grid, space_order=4, time_order=0)
2215+
2216+
eq = [Eq(v.forward, v + 1), Eq(e, v.forward.dydz)]
2217+
2218+
op = Operator(eq, opt=('advanced', {'blocklevels': 0}))
2219+
assert_structure(op, ['txyz', 't', 'txyz', 'txyz'], 'txyzxyzz')
2220+
op(time=100)
2221+
2222+
assert np.isclose(norm(e), 23484.863, rtol=0, atol=1e-1)
2223+
22062224

22072225
class TestOperatorAdvanced:
22082226

@@ -2736,7 +2754,7 @@ def test_haloupdate_same_timestep_v2(self, mode):
27362754
assert titer.dim is grid.time_dim
27372755
assert titer.nodes[0].body[0].body[0].is_List
27382756
assert len(titer.nodes[0].body[0].body[0].body[0].body) == 1
2739-
assert titer.nodes[0].body[0].body[0].body[0].body[0].is_Call
2757+
assert not titer.nodes[0].body[0].body[0].body[0].body[0].is_Call
27402758

27412759
op.apply(time=0)
27422760

@@ -3138,8 +3156,8 @@ def test_fission_due_to_antidep(self, mode):
31383156
# First, check the generated code
31393157
assert_structure(op1, ['t',
31403158
't,x0_blk0,y0_blk0,x,y,z',
3141-
't,x0_blk0,y0_blk0,x,y,z'],
3142-
't,x0_blk0,y0_blk0,x,y,z,z')
3159+
't,x1_blk0,y1_blk0,x,y,z'],
3160+
'tx0_blk0y0_blk0xyzx1_blk0y1_blk0xyz')
31433161

31443162
def init(f, v=1):
31453163
f.data[:] = np.indices(grid.shape).sum(axis=0) % (.004*v) + .01

0 commit comments

Comments
 (0)