@@ -46,8 +46,12 @@ def _is_materialized(expr: Array) -> bool:
4646
4747
4848def _can_index_lambda_propagate_indirections_without_changing_axes (
49- expr : IndexLambda ) -> bool :
50-
49+ expr : IndexLambda , iel_axis : Optional [int ], idof_axis : Optional [int ]
50+ ) -> bool :
51+ """
52+ Returns *True* only if the axes being reindexed appear at the same
53+ positions in the bindings' indexing locations.
54+ """
5155 from pytato .utils import are_shapes_equal
5256 from pytato .raising import (index_lambda_to_high_level_op ,
5357 BinaryOp )
@@ -219,8 +223,8 @@ def _fuse_from_element_indices(from_element_indices: Tuple[Array, ...]):
219223 return result
220224
221225
222- def _fuse_dof_pick_lists (dof_pick_lists : Tuple [Array , ...], from_element_indices :
223- Tuple [Array , ...]):
226+ def _fuse_dof_pick_lists (dof_pick_lists : Tuple [Array , ...],
227+ from_element_indices : Tuple [Array , ...]):
224228 assert all (from_el_idx .ndim == 2 for from_el_idx in from_element_indices )
225229 assert all (dof_pick_list .ndim == 2 for dof_pick_list in dof_pick_lists )
226230 assert all (from_el_idx .shape [1 ] == 1 for from_el_idx in from_element_indices )
@@ -239,7 +243,10 @@ def _pick_list_fusers_map_materialized_node(rec_expr: Array,
239243 from_element_indices : Tuple [Array , ...],
240244 dof_pick_lists : Tuple [Array , ...]
241245 ) -> Array :
242-
246+ raise NotImplementedError ("We still need to port this from"
247+ " the previous version, where only"
248+ " indirections only along the element"
249+ " axes." )
243250 if iel_axis is not None :
244251 assert idof_axis is not None
245252 assert len (from_element_indices ) != 0
@@ -263,6 +270,56 @@ def _pick_list_fusers_map_materialized_node(rec_expr: Array,
263270 return rec_expr
264271
265272
273+ def _is_iel_idof_picking (expr : AdvancedIndexInContiguousAxes ,
274+ iel_axis : Optional [int ],
275+ idof_axis : Optional [int ],
276+ ) -> bool :
277+ if expr .ndim != 2 :
278+ return False
279+
280+ if expr .array .ndim != 2 :
281+ return False
282+
283+ if not ((iel_axis is None and idof_axis is None )
284+ or (iel_axis == 0 and idof_axis == 1 )):
285+ return False
286+
287+ if (isinstance (expr .indices [0 ], Array )
288+ and isinstance (expr .indices [1 ], Array )):
289+ from pytato .utils import are_shape_components_equal
290+ from_el_indices , dof_pick_lists = expr .indices
291+ assert isinstance (from_el_indices , Array )
292+ assert isinstance (dof_pick_lists , Array )
293+
294+ if dof_pick_lists .ndim != 1 :
295+ return False
296+ if from_el_indices .ndim != 2 :
297+ return False
298+ if are_shape_components_equal (from_el_indices .shape [1 ], 1 ):
299+ return False
300+
301+ return True
302+ else :
303+ return False
304+
305+
306+ def _is_iel_only_picking (expr : AdvancedIndexInContiguousAxes ,
307+ iel_axis : Optional [int ]) -> bool :
308+ if expr .ndim != 1 :
309+ return False
310+
311+ if expr .array .ndim != 1 :
312+ return False
313+
314+ if not isinstance (expr .indices [0 ], Array ):
315+ return False
316+
317+ if iel_axis not in [0 , None ]:
318+ return False
319+
320+ return True
321+
322+
266323class PickListFusers (Mapper ):
267324 def __init__ (self ) -> None :
268325 self .can_pick_indirections_be_propagated = _CanPickIndirectionsBePropagated ()
@@ -283,18 +340,22 @@ def rec(self, # type: ignore[override]
283340 " is illegal for PickListFusers. Pass arrays"
284341 " instead." )
285342
286- if iel_axis is not None :
287- assert idof_axis is not None
343+ if idof_axis is not None :
344+ assert iel_axis is not None
288345 assert 0 <= iel_axis < expr .ndim
289346 assert 0 <= idof_axis < expr .ndim
290347 # the condition below ensures that we are only dealing with indirections
291348 # appearing at contiguous locations.
292349 assert abs (iel_axis - idof_axis ) == 1
293- else :
350+ assert len (dof_pick_lists ) == len (from_element_indices )
351+ elif iel_axis is not None :
294352 assert idof_axis is None
353+ assert len (dof_pick_lists ) == 0
354+ assert len (from_element_indices ) > 0
355+ else :
356+ assert iel_axis is None
295357 assert len (from_element_indices ) == 0
296-
297- assert len (dof_pick_lists ) == len (from_element_indices )
358+ assert len (dof_pick_lists ) == 0
298359
299360 key = (expr , iel_axis , idof_axis , from_element_indices , dof_pick_lists )
300361 try :
@@ -318,8 +379,8 @@ def __call__(self, # type: ignore[override]
318379
319380 def _map_input_base (self ,
320381 expr : InputArgumentBase ,
321- iel_axis : int ,
322- idof_axis : int ,
382+ iel_axis : Optional [ int ] ,
383+ idof_axis : Optional [ int ] ,
323384 from_element_indices : Tuple [Array , ...],
324385 dof_pick_lists : Tuple [Array , ...]) -> Array :
325386 return _pick_list_fusers_map_materialized_node (
@@ -351,30 +412,36 @@ def map_index_lambda(self,
351412 rec_expr , iel_axis , idof_axis , from_element_indices , dof_pick_lists )
352413
353414 if iel_axis is not None :
354- assert idof_axis is not None
355415 assert _can_index_lambda_propagate_indirections_without_changing_axes (
356- expr )
357- from pytato .utils import are_shapes_equal
358- new_el_dim , new_dofs_dim = dof_pick_lists [0 ].shape
359- assert are_shapes_equal (from_element_indices [0 ].shape , (new_el_dim , 1 ))
360-
361- new_shape = tuple (
362- new_el_dim if idim == iel_axis else (
363- new_dofs_dim if idim == idof_axis else dim )
364- for idim , dim in enumerate (expr .shape ))
365-
366- return IndexLambda (
367- expr .expr ,
368- new_shape ,
369- expr .dtype ,
370- Map ({name : self .rec (bnd , iel_axis , idof_axis ,
371- from_element_indices ,
372- dof_pick_lists )
373- for name , bnd in expr .bindings .items ()}),
374- var_to_reduction_descr = expr .var_to_reduction_descr ,
375- tags = expr .tags ,
376- axes = expr .axes
377- )
416+ expr , iel_axis , idof_axis )
417+ if idof_axis is None :
418+ # TODO: Not encountered any practical DAGs that take this code path.
419+ # Implement this branch only if seen in any practical applications.
420+ raise NotImplementedError
421+ else :
422+ assert idof_axis is not None
423+ from pytato .utils import are_shapes_equal
424+ new_el_dim , new_dofs_dim = dof_pick_lists [0 ].shape
425+ assert are_shapes_equal (from_element_indices [0 ].shape ,
426+ (new_el_dim , 1 ))
427+
428+ new_shape = tuple (
429+ new_el_dim if idim == iel_axis else (
430+ new_dofs_dim if idim == idof_axis else dim )
431+ for idim , dim in enumerate (expr .shape ))
432+
433+ return IndexLambda (
434+ expr .expr ,
435+ new_shape ,
436+ expr .dtype ,
437+ Map ({name : self .rec (bnd , iel_axis , idof_axis ,
438+ from_element_indices ,
439+ dof_pick_lists )
440+ for name , bnd in expr .bindings .items ()}),
441+ var_to_reduction_descr = expr .var_to_reduction_descr ,
442+ tags = expr .tags ,
443+ axes = expr .axes
444+ )
378445 else :
379446 return IndexLambda (
380447 expr .expr ,
@@ -405,14 +472,17 @@ def map_contiguous_advanced_index(self,
405472 return _pick_list_fusers_map_materialized_node (
406473 rec_expr , iel_axis , idof_axis , from_element_indices , dof_pick_lists )
407474
408- if self .can_pick_indirections_be_propagated (expr ,
409- iel_axis or 0 ,
410- idof_axis or 1 ):
411- idx1 , idx2 = expr .indices
412- assert isinstance (idx1 , Array ) and isinstance (idx2 , Array )
413- return self .rec (expr .array , 0 , 1 ,
414- from_element_indices + (idx1 ,),
415- dof_pick_lists + (idx2 ,))
475+ if (_is_iel_idof_picking (expr , iel_axis , idof_axis )
476+ and self .can_pick_indirections_be_propagated (expr ,
477+ iel_axis or 0 ,
478+ idof_axis or 1 )):
479+ raise NotImplementedError
480+ elif (_is_iel_only_picking (expr , iel_axis )
481+ and self .can_pick_indirections_be_propagated (expr ,
482+ iel_axis or 0 ,
483+ None )):
484+ assert idof_axis is None
485+ raise NotImplementedError
416486 else :
417487 assert iel_axis is None and idof_axis is None
418488 return AdvancedIndexInContiguousAxes (
0 commit comments