Skip to content

Commit 3a767c1

Browse files
committed
wip: save work
1 parent ce9b9fa commit 3a767c1

1 file changed

Lines changed: 113 additions & 43 deletions

File tree

grudge/pytato_transforms/pytato_indirection_transforms.py

Lines changed: 113 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,12 @@ def _is_materialized(expr: Array) -> bool:
4646

4747

4848
def _can_index_lambda_propagate_indirections_without_changing_axes(
49-
expr: IndexLambda) -> bool:
50-
49+
expr: IndexLambda, iel_axis: Optional[int], idof_axis: Optional[int]
50+
) -> bool:
51+
"""
52+
Returns *True* only if the axes being reindexed appear at the same
53+
positions in the bindings' indexing locations.
54+
"""
5155
from pytato.utils import are_shapes_equal
5256
from pytato.raising import (index_lambda_to_high_level_op,
5357
BinaryOp)
@@ -219,8 +223,8 @@ def _fuse_from_element_indices(from_element_indices: Tuple[Array, ...]):
219223
return result
220224

221225

222-
def _fuse_dof_pick_lists(dof_pick_lists: Tuple[Array, ...], from_element_indices:
223-
Tuple[Array, ...]):
226+
def _fuse_dof_pick_lists(dof_pick_lists: Tuple[Array, ...],
227+
from_element_indices: Tuple[Array, ...]):
224228
assert all(from_el_idx.ndim == 2 for from_el_idx in from_element_indices)
225229
assert all(dof_pick_list.ndim == 2 for dof_pick_list in dof_pick_lists)
226230
assert all(from_el_idx.shape[1] == 1 for from_el_idx in from_element_indices)
@@ -239,7 +243,10 @@ def _pick_list_fusers_map_materialized_node(rec_expr: Array,
239243
from_element_indices: Tuple[Array, ...],
240244
dof_pick_lists: Tuple[Array, ...]
241245
) -> Array:
242-
246+
raise NotImplementedError("We still need to port this from"
247+
" the previous version, where only"
248+
" indirections only along the element"
249+
" axes.")
243250
if iel_axis is not None:
244251
assert idof_axis is not None
245252
assert len(from_element_indices) != 0
@@ -263,6 +270,56 @@ def _pick_list_fusers_map_materialized_node(rec_expr: Array,
263270
return rec_expr
264271

265272

273+
def _is_iel_idof_picking(expr: AdvancedIndexInContiguousAxes,
274+
iel_axis: Optional[int],
275+
idof_axis: Optional[int],
276+
) -> bool:
277+
if expr.ndim != 2:
278+
return False
279+
280+
if expr.array.ndim != 2:
281+
return False
282+
283+
if not ((iel_axis is None and idof_axis is None)
284+
or (iel_axis == 0 and idof_axis == 1)):
285+
return False
286+
287+
if (isinstance(expr.indices[0], Array)
288+
and isinstance(expr.indices[1], Array)):
289+
from pytato.utils import are_shape_components_equal
290+
from_el_indices, dof_pick_lists = expr.indices
291+
assert isinstance(from_el_indices, Array)
292+
assert isinstance(dof_pick_lists, Array)
293+
294+
if dof_pick_lists.ndim != 1:
295+
return False
296+
if from_el_indices.ndim != 2:
297+
return False
298+
if are_shape_components_equal(from_el_indices.shape[1], 1):
299+
return False
300+
301+
return True
302+
else:
303+
return False
304+
305+
306+
def _is_iel_only_picking(expr: AdvancedIndexInContiguousAxes,
307+
iel_axis: Optional[int]) -> bool:
308+
if expr.ndim != 1:
309+
return False
310+
311+
if expr.array.ndim != 1:
312+
return False
313+
314+
if not isinstance(expr.indices[0], Array):
315+
return False
316+
317+
if iel_axis not in [0, None]:
318+
return False
319+
320+
return True
321+
322+
266323
class PickListFusers(Mapper):
267324
def __init__(self) -> None:
268325
self.can_pick_indirections_be_propagated = _CanPickIndirectionsBePropagated()
@@ -283,18 +340,22 @@ def rec(self, # type: ignore[override]
283340
" is illegal for PickListFusers. Pass arrays"
284341
" instead.")
285342

286-
if iel_axis is not None:
287-
assert idof_axis is not None
343+
if idof_axis is not None:
344+
assert iel_axis is not None
288345
assert 0 <= iel_axis < expr.ndim
289346
assert 0 <= idof_axis < expr.ndim
290347
# the condition below ensures that we are only dealing with indirections
291348
# appearing at contiguous locations.
292349
assert abs(iel_axis-idof_axis) == 1
293-
else:
350+
assert len(dof_pick_lists) == len(from_element_indices)
351+
elif iel_axis is not None:
294352
assert idof_axis is None
353+
assert len(dof_pick_lists) == 0
354+
assert len(from_element_indices) > 0
355+
else:
356+
assert iel_axis is None
295357
assert len(from_element_indices) == 0
296-
297-
assert len(dof_pick_lists) == len(from_element_indices)
358+
assert len(dof_pick_lists) == 0
298359

299360
key = (expr, iel_axis, idof_axis, from_element_indices, dof_pick_lists)
300361
try:
@@ -318,8 +379,8 @@ def __call__(self, # type: ignore[override]
318379

319380
def _map_input_base(self,
320381
expr: InputArgumentBase,
321-
iel_axis: int,
322-
idof_axis: int,
382+
iel_axis: Optional[int],
383+
idof_axis: Optional[int],
323384
from_element_indices: Tuple[Array, ...],
324385
dof_pick_lists: Tuple[Array, ...]) -> Array:
325386
return _pick_list_fusers_map_materialized_node(
@@ -351,30 +412,36 @@ def map_index_lambda(self,
351412
rec_expr, iel_axis, idof_axis, from_element_indices, dof_pick_lists)
352413

353414
if iel_axis is not None:
354-
assert idof_axis is not None
355415
assert _can_index_lambda_propagate_indirections_without_changing_axes(
356-
expr)
357-
from pytato.utils import are_shapes_equal
358-
new_el_dim, new_dofs_dim = dof_pick_lists[0].shape
359-
assert are_shapes_equal(from_element_indices[0].shape, (new_el_dim, 1))
360-
361-
new_shape = tuple(
362-
new_el_dim if idim == iel_axis else (
363-
new_dofs_dim if idim == idof_axis else dim)
364-
for idim, dim in enumerate(expr.shape))
365-
366-
return IndexLambda(
367-
expr.expr,
368-
new_shape,
369-
expr.dtype,
370-
Map({name: self.rec(bnd, iel_axis, idof_axis,
371-
from_element_indices,
372-
dof_pick_lists)
373-
for name, bnd in expr.bindings.items()}),
374-
var_to_reduction_descr=expr.var_to_reduction_descr,
375-
tags=expr.tags,
376-
axes=expr.axes
377-
)
416+
expr, iel_axis, idof_axis)
417+
if idof_axis is None:
418+
# TODO: Not encountered any practical DAGs that take this code path.
419+
# Implement this branch only if seen in any practical applications.
420+
raise NotImplementedError
421+
else:
422+
assert idof_axis is not None
423+
from pytato.utils import are_shapes_equal
424+
new_el_dim, new_dofs_dim = dof_pick_lists[0].shape
425+
assert are_shapes_equal(from_element_indices[0].shape,
426+
(new_el_dim, 1))
427+
428+
new_shape = tuple(
429+
new_el_dim if idim == iel_axis else (
430+
new_dofs_dim if idim == idof_axis else dim)
431+
for idim, dim in enumerate(expr.shape))
432+
433+
return IndexLambda(
434+
expr.expr,
435+
new_shape,
436+
expr.dtype,
437+
Map({name: self.rec(bnd, iel_axis, idof_axis,
438+
from_element_indices,
439+
dof_pick_lists)
440+
for name, bnd in expr.bindings.items()}),
441+
var_to_reduction_descr=expr.var_to_reduction_descr,
442+
tags=expr.tags,
443+
axes=expr.axes
444+
)
378445
else:
379446
return IndexLambda(
380447
expr.expr,
@@ -405,14 +472,17 @@ def map_contiguous_advanced_index(self,
405472
return _pick_list_fusers_map_materialized_node(
406473
rec_expr, iel_axis, idof_axis, from_element_indices, dof_pick_lists)
407474

408-
if self.can_pick_indirections_be_propagated(expr,
409-
iel_axis or 0,
410-
idof_axis or 1):
411-
idx1, idx2 = expr.indices
412-
assert isinstance(idx1, Array) and isinstance(idx2, Array)
413-
return self.rec(expr.array, 0, 1,
414-
from_element_indices + (idx1,),
415-
dof_pick_lists + (idx2,))
475+
if (_is_iel_idof_picking(expr, iel_axis, idof_axis)
476+
and self.can_pick_indirections_be_propagated(expr,
477+
iel_axis or 0,
478+
idof_axis or 1)):
479+
raise NotImplementedError
480+
elif (_is_iel_only_picking(expr, iel_axis)
481+
and self.can_pick_indirections_be_propagated(expr,
482+
iel_axis or 0,
483+
None)):
484+
assert idof_axis is None
485+
raise NotImplementedError
416486
else:
417487
assert iel_axis is None and idof_axis is None
418488
return AdvancedIndexInContiguousAxes(

0 commit comments

Comments
 (0)