Skip to content

Commit 8949c6d

Browse files
authored
[MLIR][OpenMP] Add Taskloop Collapse Support (llvm#175924)
Following work completed in llvm#174386 and llvm#174623, this patch adds support for collapse to Taskloop. Collapse allows for the user to compress multiple loop nests into a single loop, and for this to work with Taskloop, there needs to be some changes to how we process the loops, and the tasks that run them. This patch brings Taskloop equivalent to OpenMP 4.5 support for MLIR and Flang.
1 parent 39d60bb commit 8949c6d

5 files changed

Lines changed: 468 additions & 54 deletions

File tree

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1482,6 +1482,9 @@ class OpenMPIRBuilder {
14821482
/// \param Mergeable If the given task is `mergeable`
14831483
/// \param Priority `priority-value' specifies the execution order of the
14841484
/// tasks that is generated by the construct
1485+
/// \param NumOfCollapseLoops Defines the number of loops that are being
1486+
/// collapsed. The default value is 1, as thats the value when collapse is not
1487+
/// used.
14851488
/// \param DupCB The callback to generate the duplication code. See
14861489
/// documentation for \ref TaskDupCallbackTy. This can be nullptr.
14871490
/// \param TaskContextStructPtrVal If non-null, a pointer to to be placed
@@ -1494,7 +1497,8 @@ class OpenMPIRBuilder {
14941497
Value *LBVal, Value *UBVal, Value *StepVal, bool Untied = false,
14951498
Value *IfCond = nullptr, Value *GrainSize = nullptr, bool NoGroup = false,
14961499
int Sched = 0, Value *Final = nullptr, bool Mergeable = false,
1497-
Value *Priority = nullptr, TaskDupCallbackTy DupCB = nullptr,
1500+
Value *Priority = nullptr, uint64_t NumOfCollapseLoops = 1,
1501+
TaskDupCallbackTy DupCB = nullptr,
14981502
Value *TaskContextStructPtrVal = nullptr);
14991503

15001504
/// Generator for `#omp task`

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 51 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2098,7 +2098,8 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTaskloop(
20982098
llvm::function_ref<llvm::Expected<llvm::CanonicalLoopInfo *>()> LoopInfo,
20992099
Value *LBVal, Value *UBVal, Value *StepVal, bool Untied, Value *IfCond,
21002100
Value *GrainSize, bool NoGroup, int Sched, Value *Final, bool Mergeable,
2101-
Value *Priority, TaskDupCallbackTy DupCB, Value *TaskContextStructPtrVal) {
2101+
Value *Priority, uint64_t NumOfCollapseLoops, TaskDupCallbackTy DupCB,
2102+
Value *TaskContextStructPtrVal) {
21022103

21032104
if (!updateToLocation(Loc))
21042105
return InsertPointTy();
@@ -2175,8 +2176,8 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTaskloop(
21752176
OI.PostOutlineCB = [this, Ident, LBVal, UBVal, StepVal, Untied,
21762177
TaskloopAllocaBB, CLI, Loc, TaskDupFn, ToBeDeleted,
21772178
IfCond, GrainSize, NoGroup, Sched, FakeLB, FakeUB,
2178-
FakeStep, Final, Mergeable,
2179-
Priority](Function &OutlinedFn) mutable {
2179+
FakeStep, Final, Mergeable, Priority,
2180+
NumOfCollapseLoops](Function &OutlinedFn) mutable {
21802181
// Replace the Stale CI by appropriate RTL function call.
21812182
assert(OutlinedFn.hasOneUse() &&
21822183
"there must be a single user for the outlined function");
@@ -2359,29 +2360,53 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTaskloop(
23592360
Builder.SetInsertPoint(CLI->getBody(),
23602361
CLI->getBody()->getFirstInsertionPt());
23612362

2362-
// The canonical loop is generated with a fixed lower bound. We need to
2363-
// update the index calculation code to use the task's lower bound. The
2364-
// generated code looks like this:
2365-
// %omp_loop.iv = phi ...
2366-
// ...
2367-
// %tmp = mul [type] %omp_loop.iv, step
2368-
// %user_index = add [type] tmp, lb
2369-
// OpenMPIRBuilder constructs canonical loops to have exactly three uses of
2370-
// the normalised induction variable:
2371-
// 1. This one: converting the normalised IV to the user IV
2372-
// 2. The increment (add)
2373-
// 3. The comparison against the trip count (icmp)
2374-
// (1) is the only use that is a mul followed by an add so this cannot match
2375-
// other IR.
2376-
assert(CLI->getIndVar()->getNumUses() == 3 &&
2377-
"Canonical loop should have exactly three uses of the ind var");
2378-
for (User *IVUser : CLI->getIndVar()->users()) {
2379-
if (auto *Mul = dyn_cast<BinaryOperator>(IVUser)) {
2380-
if (Mul->getOpcode() == Instruction::Mul) {
2381-
for (User *MulUser : Mul->users()) {
2382-
if (auto *Add = dyn_cast<BinaryOperator>(MulUser)) {
2383-
if (Add->getOpcode() == Instruction::Add) {
2384-
Add->setOperand(1, CastedTaskLB);
2363+
if (NumOfCollapseLoops > 1) {
2364+
llvm::SmallVector<User *> UsersToReplace;
2365+
// When using the collapse clause, the bounds of the loop have to be
2366+
// adjusted to properly represent the iterator of the outer loop.
2367+
Value *IVPlusTaskLB = Builder.CreateAdd(
2368+
CLI->getIndVar(),
2369+
Builder.CreateSub(CastedTaskLB, ConstantInt::get(IVTy, 1)));
2370+
// To ensure every Use is correctly captured, we first want to record
2371+
// which users to replace the value in, and then replace the value.
2372+
for (auto IVUse = CLI->getIndVar()->uses().begin();
2373+
IVUse != CLI->getIndVar()->uses().end(); IVUse++) {
2374+
User *IVUser = IVUse->getUser();
2375+
if (auto *Op = dyn_cast<BinaryOperator>(IVUser)) {
2376+
if (Op->getOpcode() == Instruction::URem ||
2377+
Op->getOpcode() == Instruction::UDiv) {
2378+
UsersToReplace.push_back(IVUser);
2379+
}
2380+
}
2381+
}
2382+
for (User *User : UsersToReplace) {
2383+
User->replaceUsesOfWith(CLI->getIndVar(), IVPlusTaskLB);
2384+
}
2385+
} else {
2386+
// The canonical loop is generated with a fixed lower bound. We need to
2387+
// update the index calculation code to use the task's lower bound. The
2388+
// generated code looks like this:
2389+
// %omp_loop.iv = phi ...
2390+
// ...
2391+
// %tmp = mul [type] %omp_loop.iv, step
2392+
// %user_index = add [type] tmp, lb
2393+
// OpenMPIRBuilder constructs canonical loops to have exactly three uses
2394+
// of the normalised induction variable:
2395+
// 1. This one: converting the normalised IV to the user IV
2396+
// 2. The increment (add)
2397+
// 3. The comparison against the trip count (icmp)
2398+
// (1) is the only use that is a mul followed by an add so this cannot
2399+
// match other IR.
2400+
assert(CLI->getIndVar()->getNumUses() == 3 &&
2401+
"Canonical loop should have exactly three uses of the ind var");
2402+
for (User *IVUser : CLI->getIndVar()->users()) {
2403+
if (auto *Mul = dyn_cast<BinaryOperator>(IVUser)) {
2404+
if (Mul->getOpcode() == Instruction::Mul) {
2405+
for (User *MulUser : Mul->users()) {
2406+
if (auto *Add = dyn_cast<BinaryOperator>(MulUser)) {
2407+
if (Add->getOpcode() == Instruction::Add) {
2408+
Add->setOperand(1, CastedTaskLB);
2409+
}
23852410
}
23862411
}
23872412
}

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 81 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -333,10 +333,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
333333
if (op.getBare())
334334
result = todo("ompx_bare");
335335
};
336-
auto checkCollapse = [&todo](auto op, LogicalResult &result) {
337-
if (op.getCollapseNumLoops() > 1)
338-
result = todo("collapse");
339-
};
340336
auto checkDepend = [&todo](auto op, LogicalResult &result) {
341337
if (!op.getDependVars().empty() || op.getDependKinds())
342338
result = todo("depend");
@@ -400,10 +396,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
400396
checkAllocate(op, result);
401397
checkOrder(op, result);
402398
})
403-
.Case([&](omp::LoopNestOp op) {
404-
if (mlir::isa<omp::TaskloopOp>(op.getOperation()->getParentOp()))
405-
checkCollapse(op, result);
406-
})
407399
.Case([&](omp::OrderedRegionOp op) { checkParLevelSimd(op, result); })
408400
.Case([&](omp::SectionsOp op) {
409401
checkAllocate(op, result);
@@ -2805,6 +2797,84 @@ convertOmpTaskloopOp(Operation &opInst, llvm::IRBuilderBase &builder,
28052797
return loopInfo;
28062798
};
28072799

2800+
Operation::operand_range lowerBounds = loopOp.getLoopLowerBounds();
2801+
Operation::operand_range upperBounds = loopOp.getLoopUpperBounds();
2802+
Operation::operand_range steps = loopOp.getLoopSteps();
2803+
llvm::Type *boundType =
2804+
moduleTranslation.lookupValue(lowerBounds[0])->getType();
2805+
llvm::Value *lbVal = nullptr;
2806+
llvm::Value *ubVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
2807+
llvm::Value *stepVal = nullptr;
2808+
if (loopOp.getCollapseNumLoops() > 1) {
2809+
// In cases where Collapse is used with Taskloop, the upper bound of the
2810+
// iteration space needs to be recalculated to cater for the collapsed loop.
2811+
// The Collapsed Loop UpperBound is the product of all collapsed
2812+
// loop's tripcount.
2813+
// The LowerBound for collapsed loops is always 1. When the loops are
2814+
// collapsed, it will reset the bounds and introduce processing to ensure
2815+
// the index's are presented as expected. As this happens after creating
2816+
// Taskloop, these bounds need predicting. Example:
2817+
// !$omp taskloop collapse(2)
2818+
// do i = 1, 10
2819+
// do j = 1, 5
2820+
// ..
2821+
// end do
2822+
// end do
2823+
// This loop above has a total of 50 iterations, so the lb will be 1, and
2824+
// the ub will be 50. collapseLoops in OMPIRBuilder then handles ensuring
2825+
// that i and j are properly presented when used in the loop.
2826+
for (uint64_t i = 0; i < loopOp.getCollapseNumLoops(); i++) {
2827+
llvm::Value *loopLb = moduleTranslation.lookupValue(lowerBounds[i]);
2828+
llvm::Value *loopUb = moduleTranslation.lookupValue(upperBounds[i]);
2829+
llvm::Value *loopStep = moduleTranslation.lookupValue(steps[i]);
2830+
// In some cases, such as where the ub is less than the lb so the loop
2831+
// steps down, the calculation for the loopTripCount is swapped. To ensure
2832+
// the correct value is found, calculate both UB - LB and LB - UB then
2833+
// select which value to use depending on how the loop has been
2834+
// configured.
2835+
llvm::Value *loopLbMinusOne = builder.CreateSub(
2836+
loopLb, builder.getIntN(boundType->getIntegerBitWidth(), 1));
2837+
llvm::Value *loopUbMinusOne = builder.CreateSub(
2838+
loopUb, builder.getIntN(boundType->getIntegerBitWidth(), 1));
2839+
llvm::Value *boundsCmp = builder.CreateICmpSLT(loopLb, loopUb);
2840+
llvm::Value *ubMinusLb = builder.CreateSub(loopUb, loopLbMinusOne);
2841+
llvm::Value *lbMinusUb = builder.CreateSub(loopLb, loopUbMinusOne);
2842+
llvm::Value *loopTripCount =
2843+
builder.CreateSelect(boundsCmp, ubMinusLb, lbMinusUb);
2844+
loopTripCount = builder.CreateBinaryIntrinsic(
2845+
llvm::Intrinsic::abs, loopTripCount, builder.getFalse());
2846+
// For loops that have a step value not equal to 1, we need to adjust the
2847+
// trip count to ensure the correct number of iterations for the loop is
2848+
// captured.
2849+
llvm::Value *loopTripCountDivStep =
2850+
builder.CreateSDiv(loopTripCount, loopStep);
2851+
loopTripCountDivStep = builder.CreateBinaryIntrinsic(
2852+
llvm::Intrinsic::abs, loopTripCountDivStep, builder.getFalse());
2853+
llvm::Value *loopTripCountRem =
2854+
builder.CreateSRem(loopTripCount, loopStep);
2855+
loopTripCountRem = builder.CreateBinaryIntrinsic(
2856+
llvm::Intrinsic::abs, loopTripCountRem, builder.getFalse());
2857+
llvm::Value *needsRoundUp = builder.CreateICmpNE(
2858+
loopTripCountRem,
2859+
builder.getIntN(loopTripCountRem->getType()->getIntegerBitWidth(),
2860+
0));
2861+
loopTripCount =
2862+
builder.CreateAdd(loopTripCountDivStep,
2863+
builder.CreateZExtOrTrunc(
2864+
needsRoundUp, loopTripCountDivStep->getType()));
2865+
ubVal = builder.CreateMul(ubVal, loopTripCount);
2866+
}
2867+
lbVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
2868+
stepVal = builder.getIntN(boundType->getIntegerBitWidth(), 1);
2869+
} else {
2870+
lbVal = moduleTranslation.lookupValue(lowerBounds[0]);
2871+
ubVal = moduleTranslation.lookupValue(upperBounds[0]);
2872+
stepVal = moduleTranslation.lookupValue(steps[0]);
2873+
}
2874+
assert(lbVal != nullptr && "Expected value for lbVal");
2875+
assert(ubVal != nullptr && "Expected value for ubVal");
2876+
assert(stepVal != nullptr && "Expected value for stepVal");
2877+
28082878
llvm::Value *ifCond = nullptr;
28092879
llvm::Value *grainsize = nullptr;
28102880
int sched = 0; // default
@@ -2837,15 +2907,13 @@ convertOmpTaskloopOp(Operation &opInst, llvm::IRBuilderBase &builder,
28372907
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
28382908
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
28392909
moduleTranslation.getOpenMPBuilder()->createTaskloop(
2840-
ompLoc, allocaIP, bodyCB, loopInfo,
2841-
moduleTranslation.lookupValue(loopOp.getLoopLowerBounds()[0]),
2842-
moduleTranslation.lookupValue(loopOp.getLoopUpperBounds()[0]),
2843-
moduleTranslation.lookupValue(loopOp.getLoopSteps()[0]),
2910+
ompLoc, allocaIP, bodyCB, loopInfo, lbVal, ubVal, stepVal,
28442911
taskloopOp.getUntied(), ifCond, grainsize, taskloopOp.getNogroup(),
28452912
sched, moduleTranslation.lookupValue(taskloopOp.getFinal()),
28462913
taskloopOp.getMergeable(),
28472914
moduleTranslation.lookupValue(taskloopOp.getPriority()),
2848-
taskDupOrNull, taskStructMgr.getStructPtr());
2915+
loopOp.getCollapseNumLoops(), taskDupOrNull,
2916+
taskStructMgr.getStructPtr());
28492917

28502918
if (failed(handleError(afterIP, opInst)))
28512919
return failure();

0 commit comments

Comments
 (0)