Skip to content

Commit 1d883c6

Browse files
authored
[mlir][linalg] Fix linalg.index handeling in partial reduction tiling (llvm#188261)
PartialReduction tiling wasn't handeling linalg.index offsets properly. This patch fixes it to do the same thing as TilingInterface.
1 parent d0f5df1 commit 1d883c6

3 files changed

Lines changed: 78 additions & 0 deletions

File tree

mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,11 +634,13 @@ struct LinalgOpPartialReductionInterface
634634
IRMapping mapping;
635635
op->getRegion(0).cloneInto(&genericOp.getRegion(),
636636
genericOp.getRegion().begin(), mapping);
637+
offsetIndices(b, genericOp, offsets);
637638
partialReductionOp = genericOp.getOperation();
638639
} else {
639640
SmallVector<Value> operands = std::move(tiledInputs);
640641
llvm::append_range(operands, tiledInits);
641642
partialReductionOp = mlir::clone(b, op, resultTypes, operands);
643+
offsetIndices(b, cast<LinalgOp>(partialReductionOp), offsets);
642644
}
643645
return TilingResult{
644646
{partialReductionOp},

mlir/test/Dialect/Linalg/transform-tile-reduction.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,3 +692,39 @@ module {
692692
// CHECK: %[[R:.*]] = linalg.reduce ins(%[[L]]
693693
// CHECK-SAME: outs(%[[ARG2]] :
694694
// CHECK: return %[[R]]
695+
696+
// -----
697+
698+
// Check that linalg.index is correctly offset after partial reduction tiling.
699+
700+
func.func @reduction_tile_with_linalg_index(%arg0: tensor<8x128xf32>, %out: tensor<8xi32>) -> tensor<8xi32> {
701+
%red = linalg.generic {
702+
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
703+
affine_map<(d0, d1) -> (d0)>],
704+
iterator_types = ["parallel", "reduction"]}
705+
ins(%arg0 : tensor<8x128xf32>)
706+
outs(%out : tensor<8xi32>) {
707+
^bb0(%in: f32, %acc: i32):
708+
%idx = linalg.index 1 : index
709+
%idx_i32 = arith.index_cast %idx : index to i32
710+
%sum = arith.addi %idx_i32, %acc : i32
711+
linalg.yield %sum : i32
712+
} -> tensor<8xi32>
713+
return %red : tensor<8xi32>
714+
}
715+
716+
module attributes {transform.with_named_sequence} {
717+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
718+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
719+
%1, %2, %3, %loop = transform.structured.tile_reduction_using_for %0
720+
by tile_sizes = [0, 32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
721+
transform.yield
722+
}
723+
}
724+
725+
// CHECK-DAG: #[[$INDEX_MAP:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
726+
// CHECK-LABEL: func @reduction_tile_with_linalg_index(
727+
// CHECK: scf.for %[[IV:[a-zA-Z0-9]+]] =
728+
// CHECK: linalg.generic
729+
// CHECK: %[[LOCAL_IDX:.+]] = linalg.index 1 : index
730+
// CHECK: affine.apply #[[$INDEX_MAP]](%[[IV]])[%[[LOCAL_IDX]]]

mlir/test/Interfaces/TilingInterface/tile-and-fuse-with-reduction-tiling.mlir

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,43 @@ module attributes {transform.with_named_sequence} {
5959
// CHECK: %[[REDUCE:.+]] = linalg.reduce
6060
// CHECK-SAME: ins(%[[FORALL]] :
6161
// CHECK: return %[[REDUCE]]
62+
63+
// -----
64+
65+
// Check that linalg.index is correctly offset after partial reduction tiling.
66+
67+
module {
68+
func.func @partial_reduction_with_linalg_index(
69+
%arg0 : tensor<8x128xf32>) -> tensor<8xi32> {
70+
%c0_i32 = arith.constant 0 : i32
71+
%empty = tensor.empty() : tensor<8xi32>
72+
%fill = linalg.fill ins(%c0_i32 : i32) outs(%empty : tensor<8xi32>) -> tensor<8xi32>
73+
%generic = linalg.generic {
74+
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
75+
affine_map<(d0, d1) -> (d0)>],
76+
iterator_types = ["parallel", "reduction"]}
77+
ins(%arg0 : tensor<8x128xf32>) outs(%fill : tensor<8xi32>) {
78+
^bb0(%b0 : f32, %b1 : i32):
79+
%idx = linalg.index 1 : index
80+
%idx_i32 = arith.index_cast %idx : index to i32
81+
%0 = arith.addi %idx_i32, %b1 : i32
82+
linalg.yield %0 : i32
83+
} -> tensor<8xi32>
84+
return %generic : tensor<8xi32>
85+
}
86+
}
87+
module attributes {transform.with_named_sequence} {
88+
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
89+
%generic = transform.structured.match ops{["linalg.generic"]} in %arg1
90+
: (!transform.any_op) -> !transform.any_op
91+
%a, %loop = transform.test.tile_and_fuse_outer_parallel_partial_reduction
92+
%generic tile_sizes = [32]
93+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
94+
transform.yield
95+
}
96+
}
97+
// CHECK-LABEL: func @partial_reduction_with_linalg_index(
98+
// CHECK: scf.forall (%[[IV0:[a-zA-Z0-9]+]]) =
99+
// CHECK: %[[GENERIC:.+]] = linalg.generic
100+
// CHECK: %[[LOCAL_IDX:.+]] = linalg.index 1 : index
101+
// CHECK: affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%[[IV0]])[%[[LOCAL_IDX]]]

0 commit comments

Comments
 (0)