Skip to content

Commit 6ded0a0

Browse files
committed
compiler: fix operator arg processing and subsampling size
1 parent 2413d19 commit 6ded0a0

4 files changed

Lines changed: 33 additions & 12 deletions

File tree

devito/operator/operator.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -603,11 +603,6 @@ def _prepare_arguments(self, autotune=None, estimate_memory=False, **kwargs):
603603
if i.is_Derived and i.parent in nodes]
604604
toposort = DAG(nodes, edges).topological_sort()
605605

606-
futures = {}
607-
for d in reversed(toposort):
608-
if set(d._arg_names).intersection(kwargs):
609-
futures.update(d._arg_values(self._dspace[d], args={}, **kwargs))
610-
611606
# Prepare to process data-carriers
612607
args = kwargs['args'] = ReducerMap()
613608

@@ -637,15 +632,12 @@ def _prepare_arguments(self, autotune=None, estimate_memory=False, **kwargs):
637632
for k, v in p._arg_values(estimate_memory=estimate_memory, **kwargs).items():
638633
if k not in args:
639634
args[k] = v
640-
elif k in futures:
641-
# An explicit override is later going to set `args[k]`
642-
pass
643635
elif k in kwargs:
644636
# User is in control
645637
# E.g., given a ConditionalDimension `t_sub` with factor `fact`
646638
# and a TimeFunction `usave(t_sub, x, y)`, an override for
647639
# `fact` is supplied w/o overriding `usave`; that's legal
648-
pass
640+
args[k] = args.unique(k, candidate=v)
649641
elif is_integer(args[k]) and not contains_val(args[k], v):
650642
raise InvalidArgument(
651643
f"Default `{p}` is incompatible with other args as "

devito/tools/data_structures.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def update(self, values):
139139
else:
140140
self.extend(values)
141141

142-
def unique(self, key):
142+
def unique(self, key, candidate=None):
143143
"""
144144
Returns a unique value for a given key, if such a value
145145
exists, and raises a ``ValueError`` if it does not.
@@ -150,7 +150,7 @@ def unique(self, key):
150150
Key for which to retrieve a unique value.
151151
"""
152152
candidates = self.getall(key)
153-
candidates = [c for c in candidates if c is not None]
153+
candidates = [c for c in candidates + [candidate] if c is not None]
154154
if not candidates:
155155
return None
156156

devito/types/dimension.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1041,7 +1041,8 @@ def _arg_defaults(self, _min=None, size=None, alias=None):
10411041
raise ValueError(f"Incompatible size for ConditionalDimension "
10421042
f"{self.name}: {size} < {size0}")
10431043
else:
1044-
defaults[dim.parent.max_name] = range(d0, d0 + factor*size - 1)
1044+
# Given a factor the last time index is factor*(size - 1)
1045+
defaults[dim.parent.max_name] = range(d0, d0 + factor*(size - 1) + 1)
10451046

10461047
return defaults
10471048

tests/test_operator.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1315,6 +1315,34 @@ def test_loose_kwargs(self):
13151315
# But the following should work perfectly fine
13161316
op.arguments(x_size=2, y_size=2)
13171317

1318+
@pytest.mark.parametrize('vfact', [1, 3, 4])
1319+
def test_apply_args_consitency(self, vfact):
1320+
nt = 201
1321+
grid = Grid(shape=(11, 11, 11))
1322+
time = grid.time_dim
1323+
1324+
u = TimeFunction(name='u', grid=grid, time_order=2, space_order=4)
1325+
rec = SparseTimeFunction(name='rec', grid=grid, npoint=1, nt=nt)
1326+
1327+
factor = Constant(name='factor', value=vfact, dtype=np.int32)
1328+
time_sub = ConditionalDimension(name='t_sub', parent=time, factor=factor)
1329+
usave = TimeFunction(name='usave', grid=grid, space_order=4, time_order=0,
1330+
save=nt, time_dim=time_sub)
1331+
1332+
eqns = [
1333+
Eq(u.forward, u + 1),
1334+
Eq(usave, u),
1335+
] + rec.interpolate(expr=u)
1336+
1337+
op = Operator(eqns, opt='noop')
1338+
args0 = op.arguments(time_m=0, time_M=nt-2)
1339+
args1 = op.arguments(time_m=0, time_M=nt-2, rec=rec, usave=usave)
1340+
1341+
for k, v in args0.items():
1342+
assert k in args1
1343+
if isinstance(v, int):
1344+
assert args1[k] == v
1345+
13181346

13191347
@skipif('device')
13201348
class TestDeclarator:

0 commit comments

Comments
 (0)