From 1e0309f34898031cd20c14343675d7487b650796 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 8 May 2026 13:11:56 -0500 Subject: [PATCH 1/2] attempt to fix get_insn_access_map to handle reductions --- loopy/kernel/tools.py | 105 ++++++++++++++++++++------------- loopy/transform/loop_fusion.py | 26 +++++--- 2 files changed, 82 insertions(+), 49 deletions(-) diff --git a/loopy/kernel/tools.py b/loopy/kernel/tools.py index 4c416ebd5..49f7b4b86 100644 --- a/loopy/kernel/tools.py +++ b/loopy/kernel/tools.py @@ -28,7 +28,6 @@ import itertools import logging import sys -from collections.abc import Set as AbstractSet from functools import reduce from sys import intern from typing import ( @@ -47,7 +46,7 @@ import pymbolic.primitives as p from islpy import dim_type from pymbolic import Expression -from pytools import fset_union, memoize_on_first_arg, natsorted, set_union +from pytools import fset_union, memoize_on_first_arg, natsorted from loopy.diagnostic import LoopyError, warn_with_kernel from loopy.kernel import LoopKernel @@ -59,7 +58,7 @@ MultiAssignmentBase, _DataObliviousInstruction, ) -from loopy.symbolic import CombineMapper +from loopy.symbolic import CombineMapper, Reduction from loopy.translation_unit import ( CallableId, CallablesTable, @@ -70,7 +69,14 @@ if TYPE_CHECKING: - from collections.abc import Callable, Collection, Iterable, Mapping, Sequence + from collections.abc import ( + Callable, + Collection, + Iterable, + Mapping, + Sequence, + Set as AbstractSet, + ) from pymbolic import ArithmeticExpression from pytools.tag import Tag @@ -2246,68 +2252,87 @@ def get_hw_axis_base_for_codegen(kernel: LoopKernel, iname: str) -> isl.Aff: # {{{ get access map from an instruction +def union_amaps(amaps: Sequence[isl.Map]): + import islpy as isl + return reduce(isl.Map.union, amaps[1:], amaps[0]) + + @dataclasses.dataclass -class _IndexCollector(CombineMapper[AbstractSet[tuple[Expression, ...]], []]): +class _InstructionAccessMapCollector( + CombineMapper[dict[frozenset[str], isl.Map], [isl.Set]]): + knl: LoopKernel var: str def __post_init__(self) -> None: super().__init__() @override - def combine(self, - values: Iterable[AbstractSet[tuple[Expression, ...]]] - ) -> AbstractSet[tuple[Expression, ...]]: - return set_union(values) + def combine( + self, + values: Iterable[dict[frozenset[str], isl.Map]] + ) -> dict[frozenset[str], isl.Map]: + result: dict[frozenset[str], isl.Map] = {} + for value in values: + for inames, amap in value.items(): + try: + old_amap = result[inames] + except KeyError: + result[inames] = amap + else: + result[inames] = union_amaps((old_amap, amap)) + return result + + @override + def map_reduction( + self, expr: Reduction, domain: isl.Set) -> dict[frozenset[str], isl.Map]: + new_domain = self.knl.get_inames_domain( + frozenset(domain.get_var_dict(dim_type.set)) + | frozenset(expr.inames)).to_set() + return super().map_reduction(expr, new_domain) @override - def map_subscript(self, expr: p.Subscript) -> AbstractSet[tuple[Expression, ...]]: + def map_subscript( + self, expr: p.Subscript, domain: isl.Set) -> dict[frozenset[str], isl.Map]: + from loopy.symbolic import get_access_map assert isinstance(expr.aggregate, p.Variable) if expr.aggregate.name == self.var: - return (super().map_subscript(expr) | frozenset([expr.index_tuple])) + inames = frozenset(domain.get_var_dict(dim_type.set).keys()) + amap = get_access_map( + domain, expr.index_tuple, self.knl.assumptions) + return self.combine([ + super().map_subscript(expr, domain), {inames: amap}]) else: - return super().map_subscript(expr) + return super().map_subscript(expr, domain) @override def map_algebraic_leaf( - self, expr: p.AlgebraicLeaf, - ) -> frozenset[tuple[Expression, ...]]: - return frozenset() + self, expr: p.AlgebraicLeaf, domain: isl.Set, + ) -> dict[frozenset[str], isl.Map]: + return {} @override def map_constant( - self, expr: object - ) -> frozenset[tuple[Expression, ...]]: - return frozenset() - - -def _union_amaps(amaps: Sequence[isl.Map]): - import islpy as isl - return reduce(isl.Map.union, amaps[1:], amaps[0]) + self, expr: object, domain: isl.Set) -> dict[frozenset[str], isl.Map]: + return {} -def get_insn_access_map(kernel: LoopKernel, insn_id: str, var: str): +def get_insn_access_maps( + kernel: LoopKernel, insn_id: str, var: str) -> list[isl.Map]: from loopy.match import Id - from loopy.symbolic import get_access_map from loopy.transform.subst import expand_subst - insn = kernel.id_to_insn[insn_id] - kernel = expand_subst(kernel, within=Id(insn_id)) - indices = tuple( - _IndexCollector(var)( - (insn.expression, insn.assignees, tuple(insn.predicates)) - ) - ) - amaps = [ - get_access_map( - kernel.get_inames_domain(insn.within_inames).to_set(), - idx, kernel.assumptions - ) - for idx in indices - ] + insn = kernel.id_to_insn[insn_id] + insn_inames = kernel.insn_inames(insn) + inames_domain = kernel.get_inames_domain(insn_inames) + domain = inames_domain.project_out_except( + insn_inames, [dim_type.set]).to_set() + + inames_to_amap = _InstructionAccessMapCollector(kernel, var)( + (insn.expression, insn.assignees, tuple(insn.predicates)), domain) - return _union_amaps(amaps) + return list(inames_to_amap.values()) # }}} diff --git a/loopy/transform/loop_fusion.py b/loopy/transform/loop_fusion.py index 3af8a3bc8..b709cfd58 100644 --- a/loopy/transform/loop_fusion.py +++ b/loopy/transform/loop_fusion.py @@ -390,23 +390,31 @@ def _compute_isinfusible_via_access_map( import pymbolic.primitives as prim from loopy.diagnostic import UnableToDetermineAccessRangeError - from loopy.kernel.tools import get_insn_access_map + from loopy.kernel.tools import get_insn_access_maps from loopy.symbolic import isl_set_from_expr try: - amap_pred = get_insn_access_map(kernel, insn_pred, var) - amap_succ = get_insn_access_map(kernel, insn_succ, var) + amaps_pred = get_insn_access_maps(kernel, insn_pred, var) + amaps_succ = get_insn_access_maps(kernel, insn_succ, var) except UnableToDetermineAccessRangeError: # either predecessors or successors has a non-affine access i.e. # fallback to the safer option => infusible return True - amap_pred = amap_pred.project_out_except( - outer_inames | {candidate_pred}, [isl.dim_type.param, isl.dim_type.in_] - ) - amap_succ = amap_succ.project_out_except( - outer_inames | {candidate_succ}, [isl.dim_type.param, isl.dim_type.in_] - ) + amaps_pred = [ + amap.project_out_except( + outer_inames | {candidate_pred}, [isl.dim_type.param, isl.dim_type.in_]) + for amap in amaps_pred] + amaps_succ = [ + amap.project_out_except( + outer_inames | {candidate_succ}, [isl.dim_type.param, isl.dim_type.in_]) + for amap in amaps_succ] + + # amaps should have the same space after projecting out the inner loops, so they + # can safely be unioned + from loopy.kernel.tools import union_amaps + amap_pred = union_amaps(amaps_pred) + amap_succ = union_amaps(amaps_succ) # move outer inames to param for outer_iname in sorted(outer_inames): From 425a2c1bca652347c07e59a7dc792c9c6eefa697 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 12 May 2026 16:10:13 -0500 Subject: [PATCH 2/2] add regression test --- test/test_loop_fusion.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/test/test_loop_fusion.py b/test/test_loop_fusion.py index 9ffa29ac5..2f3b136bc 100644 --- a/test/test_loop_fusion.py +++ b/test/test_loop_fusion.py @@ -497,6 +497,39 @@ def test_reduction_loop_fusion_with_multiple_redn_in_same_insn( lp.auto_test_vs_ref(ref_t_unit, ctx, t_unit.with_kernel(knl)) +def test_loop_fusion_with_inner_reduction(ctx_factory: cl.CtxFactory): + ctx = ctx_factory() + + t_unit = lp.make_kernel( + ["{[i0, j0]: 0 <= i0, j0 < 10}", + "{[i1]: 0 <= i1 < 10}", + # Intentionally keeping j1 separate from i1 to test for regression. See + # https://github.com/inducer/loopy/pull/1009 for details. + "{[j1]: 0 <= j1 < 10}"], + """ + a[i0, j0] = j0 * 1.0 {id=insn1} + out[i1] = sum(j1, a[i1, j1]) {id=insn2} + """, + ) + ref_t_unit = t_unit + + knl = t_unit.default_entrypoint + + fused_chunks = lp.get_kennedy_unweighted_fusion_candidates( + knl, frozenset(["i0", "i1"]) + ) + knl = lp.rename_inames_in_batch(knl, fused_chunks) + + assert ( + len( + knl.id_to_insn["insn1"].within_inames + & knl.id_to_insn["insn2"].within_inames + ) == 1 + ) + + lp.auto_test_vs_ref(ref_t_unit, ctx, t_unit.with_kernel(knl)) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])