@@ -333,10 +333,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
333333 if (op.getBare ())
334334 result = todo (" ompx_bare" );
335335 };
336- auto checkCollapse = [&todo](auto op, LogicalResult &result) {
337- if (op.getCollapseNumLoops () > 1 )
338- result = todo (" collapse" );
339- };
340336 auto checkDepend = [&todo](auto op, LogicalResult &result) {
341337 if (!op.getDependVars ().empty () || op.getDependKinds ())
342338 result = todo (" depend" );
@@ -400,10 +396,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
400396 checkAllocate (op, result);
401397 checkOrder (op, result);
402398 })
403- .Case ([&](omp::LoopNestOp op) {
404- if (mlir::isa<omp::TaskloopOp>(op.getOperation ()->getParentOp ()))
405- checkCollapse (op, result);
406- })
407399 .Case ([&](omp::OrderedRegionOp op) { checkParLevelSimd (op, result); })
408400 .Case ([&](omp::SectionsOp op) {
409401 checkAllocate (op, result);
@@ -2805,6 +2797,84 @@ convertOmpTaskloopOp(Operation &opInst, llvm::IRBuilderBase &builder,
28052797 return loopInfo;
28062798 };
28072799
2800+ Operation::operand_range lowerBounds = loopOp.getLoopLowerBounds ();
2801+ Operation::operand_range upperBounds = loopOp.getLoopUpperBounds ();
2802+ Operation::operand_range steps = loopOp.getLoopSteps ();
2803+ llvm::Type *boundType =
2804+ moduleTranslation.lookupValue (lowerBounds[0 ])->getType ();
2805+ llvm::Value *lbVal = nullptr ;
2806+ llvm::Value *ubVal = builder.getIntN (boundType->getIntegerBitWidth (), 1 );
2807+ llvm::Value *stepVal = nullptr ;
2808+ if (loopOp.getCollapseNumLoops () > 1 ) {
2809+ // In cases where Collapse is used with Taskloop, the upper bound of the
2810+ // iteration space needs to be recalculated to cater for the collapsed loop.
2811+ // The Collapsed Loop UpperBound is the product of all collapsed
2812+ // loop's tripcount.
2813+ // The LowerBound for collapsed loops is always 1. When the loops are
2814+ // collapsed, it will reset the bounds and introduce processing to ensure
2815+ // the index's are presented as expected. As this happens after creating
2816+ // Taskloop, these bounds need predicting. Example:
2817+ // !$omp taskloop collapse(2)
2818+ // do i = 1, 10
2819+ // do j = 1, 5
2820+ // ..
2821+ // end do
2822+ // end do
2823+ // This loop above has a total of 50 iterations, so the lb will be 1, and
2824+ // the ub will be 50. collapseLoops in OMPIRBuilder then handles ensuring
2825+ // that i and j are properly presented when used in the loop.
2826+ for (uint64_t i = 0 ; i < loopOp.getCollapseNumLoops (); i++) {
2827+ llvm::Value *loopLb = moduleTranslation.lookupValue (lowerBounds[i]);
2828+ llvm::Value *loopUb = moduleTranslation.lookupValue (upperBounds[i]);
2829+ llvm::Value *loopStep = moduleTranslation.lookupValue (steps[i]);
2830+ // In some cases, such as where the ub is less than the lb so the loop
2831+ // steps down, the calculation for the loopTripCount is swapped. To ensure
2832+ // the correct value is found, calculate both UB - LB and LB - UB then
2833+ // select which value to use depending on how the loop has been
2834+ // configured.
2835+ llvm::Value *loopLbMinusOne = builder.CreateSub (
2836+ loopLb, builder.getIntN (boundType->getIntegerBitWidth (), 1 ));
2837+ llvm::Value *loopUbMinusOne = builder.CreateSub (
2838+ loopUb, builder.getIntN (boundType->getIntegerBitWidth (), 1 ));
2839+ llvm::Value *boundsCmp = builder.CreateICmpSLT (loopLb, loopUb);
2840+ llvm::Value *ubMinusLb = builder.CreateSub (loopUb, loopLbMinusOne);
2841+ llvm::Value *lbMinusUb = builder.CreateSub (loopLb, loopUbMinusOne);
2842+ llvm::Value *loopTripCount =
2843+ builder.CreateSelect (boundsCmp, ubMinusLb, lbMinusUb);
2844+ loopTripCount = builder.CreateBinaryIntrinsic (
2845+ llvm::Intrinsic::abs, loopTripCount, builder.getFalse ());
2846+ // For loops that have a step value not equal to 1, we need to adjust the
2847+ // trip count to ensure the correct number of iterations for the loop is
2848+ // captured.
2849+ llvm::Value *loopTripCountDivStep =
2850+ builder.CreateSDiv (loopTripCount, loopStep);
2851+ loopTripCountDivStep = builder.CreateBinaryIntrinsic (
2852+ llvm::Intrinsic::abs, loopTripCountDivStep, builder.getFalse ());
2853+ llvm::Value *loopTripCountRem =
2854+ builder.CreateSRem (loopTripCount, loopStep);
2855+ loopTripCountRem = builder.CreateBinaryIntrinsic (
2856+ llvm::Intrinsic::abs, loopTripCountRem, builder.getFalse ());
2857+ llvm::Value *needsRoundUp = builder.CreateICmpNE (
2858+ loopTripCountRem,
2859+ builder.getIntN (loopTripCountRem->getType ()->getIntegerBitWidth (),
2860+ 0 ));
2861+ loopTripCount =
2862+ builder.CreateAdd (loopTripCountDivStep,
2863+ builder.CreateZExtOrTrunc (
2864+ needsRoundUp, loopTripCountDivStep->getType ()));
2865+ ubVal = builder.CreateMul (ubVal, loopTripCount);
2866+ }
2867+ lbVal = builder.getIntN (boundType->getIntegerBitWidth (), 1 );
2868+ stepVal = builder.getIntN (boundType->getIntegerBitWidth (), 1 );
2869+ } else {
2870+ lbVal = moduleTranslation.lookupValue (lowerBounds[0 ]);
2871+ ubVal = moduleTranslation.lookupValue (upperBounds[0 ]);
2872+ stepVal = moduleTranslation.lookupValue (steps[0 ]);
2873+ }
2874+ assert (lbVal != nullptr && " Expected value for lbVal" );
2875+ assert (ubVal != nullptr && " Expected value for ubVal" );
2876+ assert (stepVal != nullptr && " Expected value for stepVal" );
2877+
28082878 llvm::Value *ifCond = nullptr ;
28092879 llvm::Value *grainsize = nullptr ;
28102880 int sched = 0 ; // default
@@ -2837,15 +2907,13 @@ convertOmpTaskloopOp(Operation &opInst, llvm::IRBuilderBase &builder,
28372907 llvm::OpenMPIRBuilder::LocationDescription ompLoc (builder);
28382908 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
28392909 moduleTranslation.getOpenMPBuilder ()->createTaskloop (
2840- ompLoc, allocaIP, bodyCB, loopInfo,
2841- moduleTranslation.lookupValue (loopOp.getLoopLowerBounds ()[0 ]),
2842- moduleTranslation.lookupValue (loopOp.getLoopUpperBounds ()[0 ]),
2843- moduleTranslation.lookupValue (loopOp.getLoopSteps ()[0 ]),
2910+ ompLoc, allocaIP, bodyCB, loopInfo, lbVal, ubVal, stepVal,
28442911 taskloopOp.getUntied (), ifCond, grainsize, taskloopOp.getNogroup (),
28452912 sched, moduleTranslation.lookupValue (taskloopOp.getFinal ()),
28462913 taskloopOp.getMergeable (),
28472914 moduleTranslation.lookupValue (taskloopOp.getPriority ()),
2848- taskDupOrNull, taskStructMgr.getStructPtr ());
2915+ loopOp.getCollapseNumLoops (), taskDupOrNull,
2916+ taskStructMgr.getStructPtr ());
28492917
28502918 if (failed (handleError (afterIP, opInst)))
28512919 return failure ();
0 commit comments