Skip to content

Commit d8f3be7

Browse files
authored
[mlir][dialect-conversion] Fix OOB crash in convertFuncOpTypes for funcs with extra block args (llvm#185060)
Some function ops (e.g., gpu.func with workgroup memory arguments) have more entry block arguments than their FunctionType has inputs. The workgroup memory arguments are not part of the public function signature but are present as additional block arguments. `convertFuncOpTypes` previously created a `SignatureConversion` sized only for `type.getNumInputs()`, then called `applySignatureConversion` on the entry block. When the block had more arguments (e.g., workgroup args), the loop in `applySignatureConversion` would call `getInputMapping(i)` with out-of-bounds indices, causing an assertion failure in `SmallVector::operator[]`. Fix this by: 1. Sizing the `SignatureConversion` for all entry block arguments. 2. Adding identity mappings for extra block args beyond the function type inputs. 3. Using only the converted function-type-input types when updating the FunctionType (so extra block arg types are not included in the signature). Fixes llvm#184744 Assisted-by: Claude Code
1 parent b78ceef commit d8f3be7

2 files changed

Lines changed: 47 additions & 8 deletions

File tree

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3843,19 +3843,39 @@ static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
38433843
if (!type)
38443844
return failure();
38453845

3846-
// Convert the original function types.
3847-
TypeConverter::SignatureConversion result(type.getNumInputs());
3846+
// Convert the function signature (inputs and results).
3847+
TypeConverter::SignatureConversion funcConversion(type.getNumInputs());
38483848
SmallVector<Type, 1> newResults;
3849-
if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
3849+
if (failed(typeConverter.convertSignatureArgs(type.getInputs(),
3850+
funcConversion)) ||
38503851
failed(typeConverter.convertTypes(type.getResults(), newResults)))
38513852
return failure();
3852-
if (!funcOp.getFunctionBody().empty())
3853-
rewriter.applySignatureConversion(&funcOp.getFunctionBody().front(), result,
3853+
3854+
// If the function has a body, apply a separate signature conversion to the
3855+
// entry block. Some function ops (e.g., gpu.func) have extra block arguments
3856+
// beyond the function type inputs (e.g., workgroup memory arguments that are
3857+
// not part of the public signature). Use a distinct conversion sized for all
3858+
// entry block arguments so that applySignatureConversion does not access
3859+
// out-of-bounds mappings.
3860+
if (!funcOp.getFunctionBody().empty()) {
3861+
Block *entryBlock = &funcOp.getFunctionBody().front();
3862+
unsigned numEntryBlockArgs = entryBlock->getNumArguments();
3863+
unsigned numFuncTypeInputs = type.getNumInputs();
3864+
TypeConverter::SignatureConversion blockConversion(numEntryBlockArgs);
3865+
// Convert the function-type inputs the same way as for the function type.
3866+
if (failed(typeConverter.convertSignatureArgs(type.getInputs(),
3867+
blockConversion)))
3868+
return failure();
3869+
// Add identity mappings for extra block args beyond the function type
3870+
// inputs. These arguments are preserved as-is.
3871+
for (unsigned i = numFuncTypeInputs; i < numEntryBlockArgs; ++i)
3872+
blockConversion.addInputs(i, entryBlock->getArgument(i).getType());
3873+
rewriter.applySignatureConversion(entryBlock, blockConversion,
38543874
&typeConverter);
3875+
}
38553876

3856-
// Update the function signature in-place.
3857-
auto newType = FunctionType::get(rewriter.getContext(),
3858-
result.getConvertedTypes(), newResults);
3877+
auto newType = FunctionType::get(
3878+
rewriter.getContext(), funcConversion.getConvertedTypes(), newResults);
38593879

38603880
rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });
38613881

mlir/test/Transforms/test-legalize-type-conversion.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,3 +165,22 @@ func.func @test_unstructured_cf_conversion(%arg0: f32, %c: i1) {
165165
"test.bar"(%arg2) : (f32) -> ()
166166
return
167167
}
168+
169+
// -----
170+
171+
// Test that gpu.func with workgroup arguments (which are extra block arguments
172+
// beyond the function type inputs) does not crash during signature conversion.
173+
// See https://github.com/llvm/llvm-project/issues/184744
174+
175+
// CHECK-LABEL: gpu.func @func_with_workgroup_args
176+
// CHECK-SAME: (%{{.*}}: memref<?xi32>, %{{.*}}: i32) workgroup(%{{.*}} : memref<512xi32>)
177+
gpu.module @cuda_events {
178+
gpu.func @func_with_workgroup_args(%arg0: memref<?xi32>, %arg1: i32)
179+
workgroup(%workgroup_mem: memref<512xi32>)
180+
kernel {
181+
%idx = arith.constant 0 : index
182+
%val = memref.load %arg0[%idx] : memref<?xi32>
183+
memref.store %val, %arg0[%idx] : memref<?xi32>
184+
gpu.return
185+
}
186+
}

0 commit comments

Comments
 (0)