@@ -442,8 +442,8 @@ def callback(self, clusters, prefix, seen=None):
442442 d = prefix [- 1 ].dim
443443
444444 # Construct a representation of the halo accesses
445- processed = []
446- for i , c in enumerate (clusters ):
445+ processed = list ( clusters )
446+ for n , c in enumerate (clusters ):
447447 if c .properties .is_sequential (d ) or \
448448 c in seen :
449449 continue
@@ -453,10 +453,6 @@ def callback(self, clusters, prefix, seen=None):
453453 not d ._defines & hs .distributed_aindices :
454454 continue
455455
456- if any (_halo_write (ci , hs ) for ci in clusters [:i ]):
457- # If there's a halo write before `c`, then we cannot inject the HaloTouch
458- continue
459-
460456 points = set ()
461457 for f in hs .fmapper :
462458 for a in c .scope .getreads (f ):
@@ -484,10 +480,16 @@ def callback(self, clusters, prefix, seen=None):
484480
485481 halo_touch = c .rebuild (exprs = expr , ispace = ispace , properties = properties )
486482
487- processed .append (halo_touch )
488- seen .update ({halo_touch , c })
483+ # Insert `halo_touch` at the top of the IterationSpace within which
484+ # `c` is scheduled
485+ index = 0
486+ for i in reversed (range (n )):
487+ if not processed [i ].ispace .is_subset (c .ispace ):
488+ index = i + 1
489+ break
490+ processed .insert (index , halo_touch )
489491
490- processed . extend ( clusters )
492+ seen . update ({ halo_touch , c } )
491493
492494 return processed
493495
@@ -785,21 +787,3 @@ def normalize_reductions_sparse(cluster, sregistry):
785787 processed .append (e )
786788
787789 return cluster .rebuild (processed )
788-
789-
790- def _halo_write (c , hs ):
791- """
792- Check if the cluster `c` writes into any of the local values read by `hs`.
793- """
794- for f in hs .fmapper :
795- if not any (f .grid .distributor .topology .get (d , 1 ) > 1
796- for d in hs .dimensions ):
797- # Not distributed halo dimension, write does not impact the halo exchange
798- continue
799-
800- if any (set (a .access .indices ) & hs .loc_values for a in c .scope .getwrites (f )):
801- # Writing into a local value, which is read by the halo exchange,
802- # creates a write dependency
803- return True
804-
805- return False
0 commit comments