@@ -2441,6 +2441,21 @@ buildA5PredicateFromActiveCount(ConversionPatternRewriter &rewriter, Location lo
24412441 return pred.getResult (0 );
24422442}
24432443
2444+ static FailureOr<Value> buildA5PredicateFromActiveCountAndElemToken (
2445+ ConversionPatternRewriter &rewriter, Location loc,
2446+ llvm::StringRef elemTok, Value activeCount) {
2447+ auto *ctx = rewriter.getContext ();
2448+ Value activeCountLValue =
2449+ materializeA5PredicateScalarLValue (rewriter, loc, activeCount);
2450+ auto maskRegTy = emitc::OpaqueType::get (ctx, " MaskReg" );
2451+ auto templateArgs =
2452+ rewriter.getArrayAttr ({emitc::OpaqueAttr::get (ctx, elemTok)});
2453+ auto pred = rewriter.create <emitc::CallOpaqueOp>(
2454+ loc, TypeRange{maskRegTy}, " CreatePredicate" ,
2455+ ValueRange{activeCountLValue}, ArrayAttr{}, templateArgs);
2456+ return pred.getResult (0 );
2457+ }
2458+
24442459template <typename ArithOp> static llvm::StringRef getA5VectorBinaryCallee ();
24452460
24462461template <> llvm::StringRef getA5VectorBinaryCallee<arith::AddFOp>() {
@@ -2590,6 +2605,8 @@ getA5ReductionCallee(vector::CombiningKind kind, Type elemTy) {
25902605 }
25912606}
25922607
2608+ static std::string cmpModeTok (pto::CmpModeAttr a);
2609+
25932610static FailureOr<llvm::StringRef> getA5SimdReductionCallee (StringRef kind) {
25942611 if (kind == " add" )
25952612 return llvm::StringRef (" vcadd" );
@@ -2942,6 +2959,81 @@ struct SimdReductionToEmitC : public OpConversionPattern<pto::SimdReductionOp> {
29422959 }
29432960};
29442961
2962+ struct SimdStorePredicateToEmitC
2963+ : public OpConversionPattern<pto::SimdStorePredicateOp> {
2964+ using OpConversionPattern<pto::SimdStorePredicateOp>::OpConversionPattern;
2965+
2966+ LogicalResult matchAndRewrite (pto::SimdStorePredicateOp op, OpAdaptor adaptor,
2967+ ConversionPatternRewriter &rewriter) const override {
2968+ auto vecTy = dyn_cast<VectorType>(op.getLhs ().getType ());
2969+ if (!vecTy)
2970+ return failure ();
2971+ if (failed (
2972+ validateA5OplibVectorType (op, vecTy, op->getName ().getStringRef ())))
2973+ return failure ();
2974+
2975+ auto elemTokOr =
2976+ getA5VectorElemToken (op, vecTy, op->getName ().getStringRef ());
2977+ if (failed (elemTokOr))
2978+ return failure ();
2979+
2980+ auto predOr = buildA5PredicateFromActiveCountAndElemToken (
2981+ rewriter, op.getLoc (), *elemTokOr, adaptor.getActiveCount ());
2982+ if (failed (predOr))
2983+ return failure ();
2984+
2985+ auto *ctx = rewriter.getContext ();
2986+ auto maskTy = emitc::OpaqueType::get (ctx, " MaskReg" );
2987+ auto maskVar = rewriter.create <emitc::VariableOp>(
2988+ op.getLoc (), maskTy, emitc::OpaqueAttr::get (ctx, " " ));
2989+
2990+ auto modeTy = emitc::OpaqueType::get (ctx, " CmpMode" );
2991+ auto modeVal = rewriter.create <emitc::ConstantOp>(
2992+ op.getLoc (), modeTy,
2993+ emitc::OpaqueAttr::get (ctx, cmpModeTok (op.getCmpModeAttr ())));
2994+
2995+ auto cmpArgs = rewriter.getArrayAttr (
2996+ {rewriter.getIndexAttr (0 ), rewriter.getIndexAttr (1 ),
2997+ rewriter.getIndexAttr (2 ), rewriter.getIndexAttr (3 ),
2998+ rewriter.getIndexAttr (4 )});
2999+ Value rhs = adaptor.getRhs ();
3000+ if (isa<VectorType>(op.getRhs ().getType ())) {
3001+ auto rhsTy = cast<VectorType>(op.getRhs ().getType ());
3002+ if (failed (validateA5OplibVectorType (op, rhsTy,
3003+ op->getName ().getStringRef ())))
3004+ return failure ();
3005+ rewriter.create <emitc::CallOpaqueOp>(
3006+ op.getLoc (), TypeRange{}, " ptoas_vcmp" ,
3007+ ValueRange{maskVar.getResult (), adaptor.getLhs (), rhs,
3008+ modeVal.getResult (), *predOr},
3009+ cmpArgs, ArrayAttr{});
3010+ } else {
3011+ rewriter.create <emitc::CallOpaqueOp>(
3012+ op.getLoc (), TypeRange{}, " ptoas_vcmps" ,
3013+ ValueRange{maskVar.getResult (), adaptor.getLhs (), rhs,
3014+ modeVal.getResult (), *predOr},
3015+ cmpArgs, ArrayAttr{});
3016+ }
3017+
3018+ Value dst = peelUnrealized (adaptor.getDst ());
3019+ if (!isEmitCTileOpaqueType (dst.getType ()))
3020+ return rewriter.notifyMatchFailure (
3021+ op,
3022+ " simd.store_predicate currently requires tile-like dst in EmitC lowering" );
3023+
3024+ auto i32Ty = emitc::OpaqueType::get (ctx, " int32_t" );
3025+ Value linearIndex = emitCCast (rewriter, op.getLoc (), i32Ty,
3026+ adaptor.getLinearOffset ());
3027+ rewriter.create <emitc::CallOpaqueOp>(
3028+ op.getLoc (), TypeRange{}, " ptoas_pstore" ,
3029+ ValueRange{maskVar.getResult (), dst, linearIndex}, ArrayAttr{},
3030+ ArrayAttr{});
3031+
3032+ rewriter.eraseOp (op);
3033+ return success ();
3034+ }
3035+ };
3036+
29453037struct OplibMemRefLoadToEmitC : public OpConversionPattern <memref::LoadOp> {
29463038 using OpConversionPattern<memref::LoadOp>::OpConversionPattern;
29473039
@@ -9124,6 +9216,8 @@ struct SimdVecScopeToEmitC : public OpConversionPattern<pto::SimdVecScopeOp> {
91249216 matchAndRewrite (pto::SimdVecScopeOp op, OpAdaptor adaptor,
91259217 ConversionPatternRewriter &rewriter) const override {
91269218 Location loc = op.getLoc ();
9219+ rewriter.create <emitc::VerbatimOp>(
9220+ loc, " #if defined(__CCE_AICORE__) || defined(__CPU_SIM)" );
91279221 rewriter.create <emitc::VerbatimOp>(loc, " __VEC_SCOPE__ {" );
91289222
91299223 Block &innerBlock = op.getBody ().front ();
@@ -9132,6 +9226,8 @@ struct SimdVecScopeToEmitC : public OpConversionPattern<pto::SimdVecScopeOp> {
91329226 }
91339227
91349228 rewriter.create <emitc::VerbatimOp>(loc, " }" );
9229+ rewriter.create <emitc::VerbatimOp>(
9230+ loc, " #endif // __CCE_AICORE__ || __CPU_SIM" );
91359231 rewriter.eraseOp (op);
91369232 return success ();
91379233 }
@@ -9771,6 +9867,7 @@ static void populatePTOToEmitCPatterns(RewritePatternSet &patterns,
97719867 OplibVectorUnaryToEmitC<math::SqrtOp>,
97729868 OplibVectorUnaryToEmitC<math::RsqrtOp>,
97739869 OplibVectorCmpFToEmitC, OplibVectorCmpIToEmitC,
9870+ SimdStorePredicateToEmitC,
97749871 OplibVectorSelectToEmitC<arith::SelectOp>,
97759872 SimdReductionToEmitC,
97769873 OplibVectorReductionToEmitC,
@@ -10465,7 +10562,8 @@ struct EmitPTOManualPass
1046510562 if (callee == " ptoas_vreduce_add" || callee == " ptoas_vreduce_max" ||
1046610563 callee == " ptoas_vreduce_min" )
1046710564 needsVectorReductionHelper = true ;
10468- if (callee == " ptoas_vcmp" )
10565+ if (callee == " ptoas_vcmp" || callee == " ptoas_vcmps" ||
10566+ callee == " ptoas_pstore" )
1046910567 needsVectorCmpHelper = true ;
1047010568 if (callee == " ptoas_vrem" )
1047110569 needsVectorRemfHelper = true ;
@@ -10494,6 +10592,7 @@ struct EmitPTOManualPass
1049410592 if (needsVectorReductionHelper) {
1049510593 helperBuilder.create <emitc::VerbatimOp>(
1049610594 loc, helperBuilder.getStringAttr (R"cpp(
10595+ #if defined(__CCE_AICORE__) || defined(__CPU_SIM)
1049710596 template <typename T>
1049810597 PTO_INTERNAL T ptoas_vreduce_add(RegTensor<T> src, MaskReg pred) {
1049910598 RegTensor<T> dst;
@@ -10546,11 +10645,13 @@ struct EmitPTOManualPass
1054610645 T red = ptoas_vreduce_min(src, pred);
1054710646 return red < acc ? red : acc;
1054810647 }
10648+ #endif
1054910649 )cpp" ));
1055010650 }
1055110651 if (needsVectorCmpHelper) {
1055210652 helperBuilder.create <emitc::VerbatimOp>(
1055310653 loc, helperBuilder.getStringAttr (R"cpp(
10654+ #if defined(__CCE_AICORE__) || defined(__CPU_SIM)
1055410655 template <typename T>
1055510656 PTO_INTERNAL void ptoas_vcmp(MaskReg &dst, RegTensor<T> src0, RegTensor<T> src1,
1055610657 CmpMode mode, MaskReg pred) {
@@ -10578,11 +10679,47 @@ struct EmitPTOManualPass
1057810679 break;
1057910680 }
1058010681 }
10682+ template <typename RegT, typename T>
10683+ PTO_INTERNAL void ptoas_vcmps(MaskReg &dst, RegT src0, T src1,
10684+ CmpMode mode, MaskReg pred) {
10685+ switch (mode) {
10686+ case CmpMode::EQ:
10687+ vcmps_eq(dst, src0, src1, pred);
10688+ break;
10689+ case CmpMode::NE:
10690+ vcmps_ne(dst, src0, src1, pred);
10691+ break;
10692+ case CmpMode::LT:
10693+ vcmps_lt(dst, src0, src1, pred);
10694+ break;
10695+ case CmpMode::LE:
10696+ vcmps_le(dst, src0, src1, pred);
10697+ break;
10698+ case CmpMode::GT:
10699+ vcmps_gt(dst, src0, src1, pred);
10700+ break;
10701+ case CmpMode::GE:
10702+ vcmps_ge(dst, src0, src1, pred);
10703+ break;
10704+ default:
10705+ vcmps_eq(dst, src0, src1, pred);
10706+ break;
10707+ }
10708+ }
10709+ template <typename TileT>
10710+ PTO_INTERNAL void ptoas_pstore(MaskReg src, TileT &dst, int32_t linearIndex) {
10711+ __ubuf__ uint32_t *dstWords =
10712+ reinterpret_cast<__ubuf__ uint32_t *>(dst.data());
10713+ int32_t wordOffset = linearIndex / 32;
10714+ psts(src, dstWords + wordOffset, 0, PK);
10715+ }
10716+ #endif
1058110717 )cpp" ));
1058210718 }
1058310719 if (needsVectorRemfHelper) {
1058410720 helperBuilder.create <emitc::VerbatimOp>(
1058510721 loc, helperBuilder.getStringAttr (R"cpp(
10722+ #if defined(__CCE_AICORE__) || defined(__CPU_SIM)
1058610723 template <typename T>
1058710724 PTO_INTERNAL void ptoas_vrem(RegTensor<T> &dst, RegTensor<T> src0,
1058810725 RegTensor<T> src1, MaskReg pred) {
@@ -10611,6 +10748,7 @@ struct EmitPTOManualPass
1061110748 vor(dst, dstEven, dstOdd, pred);
1061210749 }
1061310750 }
10751+ #endif
1061410752 )cpp" ));
1061510753 }
1061610754 if (needsBitcastHelper) {
0 commit comments