Skip to content

Commit 5be06e1

Browse files
authored
Merge pull request #2883 from devitocodes/op-arg-fix
compiler: fix operator arg processing and subsampling size
2 parents 2413d19 + ea5f763 commit 5be06e1

6 files changed

Lines changed: 84 additions & 14 deletions

File tree

devito/operations/interpolators.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,34 @@ def wrapper(interp, *args, **kwargs):
3434
return wrapper
3535

3636

37+
def check_coords(func):
38+
@wraps(func)
39+
def wrapper(interp, *args, **kwargs):
40+
inputs = args + as_tuple(kwargs.get('expr', ()))
41+
42+
# SubFunction of the SparseFunction use to create the interpolator
43+
sfunc = interp.sfunction
44+
45+
# SubFunctions found in the arguments of the interpolation/injection operation
46+
a_sfuncs = {f for f in retrieve_functions(inputs)
47+
if f.is_SparseFunction} - {sfunc}
48+
if not a_sfuncs:
49+
# Only uses the the interpolator's SparseFunction, so no need to check
50+
return func(interp, *args, **kwargs)
51+
52+
# Check that it uses the same coordinates as the interpolator's SparseFunction
53+
subfuncs = {getattr(sfunc, s, None) for s in sfunc._sub_functions}
54+
for f in a_sfuncs:
55+
for s in f._sub_functions:
56+
if getattr(f, s, None) not in subfuncs:
57+
raise ValueError(f"Interpolation/injection with {sfunc}"
58+
f"requires {f} "
59+
f"to use the same {s} as {sfunc}")
60+
61+
return func(interp, *args, **kwargs)
62+
return wrapper
63+
64+
3765
def _extract_subdomain(variables):
3866
"""
3967
Check if any of the variables provided are defined on a SubDomain
@@ -322,6 +350,7 @@ def _interp_idx(self, variables, implicit_dims=None, pos_only=(), subdomain=None
322350
return idx_subs, temps
323351

324352
@check_radius
353+
@check_coords
325354
def interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None):
326355
"""
327356
Generate equations interpolating an arbitrary expression into ``self``.
@@ -342,6 +371,7 @@ def interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None)
342371
return Interpolation(expr, increment, implicit_dims, self_subs, self)
343372

344373
@check_radius
374+
@check_coords
345375
def inject(self, field, expr, implicit_dims=None):
346376
"""
347377
Generate equations injecting an arbitrary expression into a field.

devito/operator/operator.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -603,11 +603,6 @@ def _prepare_arguments(self, autotune=None, estimate_memory=False, **kwargs):
603603
if i.is_Derived and i.parent in nodes]
604604
toposort = DAG(nodes, edges).topological_sort()
605605

606-
futures = {}
607-
for d in reversed(toposort):
608-
if set(d._arg_names).intersection(kwargs):
609-
futures.update(d._arg_values(self._dspace[d], args={}, **kwargs))
610-
611606
# Prepare to process data-carriers
612607
args = kwargs['args'] = ReducerMap()
613608

@@ -637,9 +632,6 @@ def _prepare_arguments(self, autotune=None, estimate_memory=False, **kwargs):
637632
for k, v in p._arg_values(estimate_memory=estimate_memory, **kwargs).items():
638633
if k not in args:
639634
args[k] = v
640-
elif k in futures:
641-
# An explicit override is later going to set `args[k]`
642-
pass
643635
elif k in kwargs:
644636
# User is in control
645637
# E.g., given a ConditionalDimension `t_sub` with factor `fact`
@@ -652,8 +644,11 @@ def _prepare_arguments(self, autotune=None, estimate_memory=False, **kwargs):
652644
f"`{k}={v}`, while `{k}={args[k]}` is expected. Perhaps "
653645
f"you forgot to override `{p}`?"
654646
)
647+
else:
648+
args[k] = args.unique(k, candidate=v)
655649

656-
args = kwargs['args'] = args.reduce_all()
650+
args.reduce_inplace()
651+
kwargs['args'] = args
657652

658653
for i in discretizations:
659654
args.update(i._arg_values(**kwargs))

devito/tools/data_structures.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def update(self, values):
139139
else:
140140
self.extend(values)
141141

142-
def unique(self, key):
142+
def unique(self, key, candidate=None):
143143
"""
144144
Returns a unique value for a given key, if such a value
145145
exists, and raises a ``ValueError`` if it does not.
@@ -150,7 +150,7 @@ def unique(self, key):
150150
Key for which to retrieve a unique value.
151151
"""
152152
candidates = self.getall(key)
153-
candidates = [c for c in candidates if c is not None]
153+
candidates = [c for c in candidates + [candidate] if c is not None]
154154
if not candidates:
155155
return None
156156

devito/types/dimension.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -318,12 +318,12 @@ def _arg_values(self, interval, grid=None, args=None, **kwargs):
318318
# may represent sets of legal values. If that's the case, here we just
319319
# pick one. Note that we sort for determinism
320320
try:
321-
loc_minv = loc_minv.stop
321+
loc_minv = loc_minv.start
322322
except AttributeError:
323323
with suppress(TypeError):
324324
loc_minv = sorted(loc_minv).pop(0)
325325
try:
326-
loc_maxv = loc_maxv.stop
326+
loc_maxv = loc_maxv.stop - 1
327327
except AttributeError:
328328
with suppress(TypeError):
329329
loc_maxv = sorted(loc_maxv).pop(0)
@@ -1041,7 +1041,9 @@ def _arg_defaults(self, _min=None, size=None, alias=None):
10411041
raise ValueError(f"Incompatible size for ConditionalDimension "
10421042
f"{self.name}: {size} < {size0}")
10431043
else:
1044-
defaults[dim.parent.max_name] = range(d0, d0 + factor*size - 1)
1044+
# Given a factor the last time index is factor*(size - 1)
1045+
# The maximum allowed value is then factor*size - 1
1046+
defaults[dim.parent.max_name] = range(d0, d0 + factor*size)
10451047

10461048
return defaults
10471049

tests/test_interpolation.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1255,3 +1255,18 @@ def test_inject_subdomain_mpi(self, mode):
12551255
assert data1 == None # noqa
12561256
assert data2 == None # noqa
12571257
assert data3 == None # noqa
1258+
1259+
1260+
def test_wrong_coords():
1261+
grid = Grid(shape=(11, 11))
1262+
s = SparseFunction(name='src', npoint=1, grid=grid)
1263+
s2 = SparseFunction(name='src2', npoint=1, grid=grid)
1264+
u = Function(name='u', grid=grid)
1265+
1266+
with pytest.raises(ValueError) as vinfo:
1267+
s.inject(u, expr=s2)
1268+
assert "Interpolation/injection with" in str(vinfo.value)
1269+
1270+
with pytest.raises(ValueError) as vinfo:
1271+
s.interpolate(u + s2)
1272+
assert "Interpolation/injection with" in str(vinfo.value)

tests/test_operator.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1315,6 +1315,34 @@ def test_loose_kwargs(self):
13151315
# But the following should work perfectly fine
13161316
op.arguments(x_size=2, y_size=2)
13171317

1318+
@pytest.mark.parametrize('vfact', [1, 3, 4])
1319+
def test_apply_args_consitency(self, vfact):
1320+
nt = 201
1321+
grid = Grid(shape=(11, 11, 11))
1322+
time = grid.time_dim
1323+
1324+
u = TimeFunction(name='u', grid=grid, time_order=2, space_order=4)
1325+
rec = SparseTimeFunction(name='rec', grid=grid, npoint=1, nt=nt)
1326+
1327+
factor = Constant(name='factor', value=vfact, dtype=np.int32)
1328+
time_sub = ConditionalDimension(name='t_sub', parent=time, factor=factor)
1329+
usave = TimeFunction(name='usave', grid=grid, space_order=4, time_order=0,
1330+
save=nt, time_dim=time_sub)
1331+
1332+
eqns = [
1333+
Eq(u.forward, u + 1),
1334+
Eq(usave, u),
1335+
] + rec.interpolate(expr=u)
1336+
1337+
op = Operator(eqns, opt='noop')
1338+
args0 = op.arguments(time_m=0, time_M=nt-2)
1339+
args1 = op.arguments(time_m=0, time_M=nt-2, rec=rec, usave=usave)
1340+
1341+
for k, v in args0.items():
1342+
assert k in args1
1343+
if isinstance(v, int):
1344+
assert args1[k] == v
1345+
13181346

13191347
@skipif('device')
13201348
class TestDeclarator:

0 commit comments

Comments
 (0)