Skip to content

Commit c48e54c

Browse files
committed
Fix TCmp (partially)
1 parent 5daa6e5 commit c48e54c

12 files changed

Lines changed: 989 additions & 1038 deletions

include/PTO/IR/PTOOps.td

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1438,6 +1438,35 @@ def SimdReductionOp : PTO_Op<"simd.reduction", [
14381438
}];
14391439
}
14401440

1441+
def SimdStorePredicateOp : PTO_Op<"simd.store_predicate"> {
1442+
let summary = "A5 OP-Lib marker for packed predicate compare+store.";
1443+
let description = [{
1444+
Backend-only OP-Lib bridge op that compares a SIMD vector against either
1445+
another SIMD vector or a scalar under an active-lane count and writes the
1446+
packed predicate result to the destination tile/memref using the hardware
1447+
packed-mask store path.
1448+
}];
1449+
1450+
let arguments = (ins
1451+
AnyVector:$lhs,
1452+
AnyType:$rhs,
1453+
TileBufOrMemRef:$dst,
1454+
Index:$linear_offset,
1455+
Index:$active_count,
1456+
PTO_CmpModeAttr:$cmpMode
1457+
);
1458+
1459+
let results = (outs);
1460+
1461+
let hasVerifier = 1;
1462+
1463+
let assemblyFormat = [{
1464+
$lhs `,` $rhs `,` $dst `,` $linear_offset `,` $active_count attr-dict
1465+
`:` type($lhs) `,` type($rhs) `,` qualified(type($dst)) `,`
1466+
type($linear_offset) `,` type($active_count)
1467+
}];
1468+
}
1469+
14411470
// High-Level Mov: 值语义
14421471
def MovOp : PTO_Op<"mov", [
14431472
SameOperandsAndResultType, // 输入输出类型相同 (!pto.tile -> !pto.tile)

lib/PTO/IR/PTO.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3829,6 +3829,43 @@ mlir::LogicalResult mlir::pto::SimdReductionOp::verify() {
38293829

38303830
return success();
38313831
}
3832+
3833+
static LogicalResult verifySimdPackedPredicateDst(Operation *op, Type dstTy) {
3834+
Type elemTy;
3835+
if (auto tileTy = dyn_cast<TileBufType>(dstTy)) {
3836+
elemTy = tileTy.getElementType();
3837+
} else if (auto memTy = dyn_cast<MemRefType>(dstTy)) {
3838+
elemTy = memTy.getElementType();
3839+
} else {
3840+
return op->emitOpError("expects dst to be !pto.tile_buf or memref");
3841+
}
3842+
3843+
auto intTy = dyn_cast<IntegerType>(elemTy);
3844+
if (!intTy || intTy.getWidth() != 8)
3845+
return op->emitOpError("expects dst element type to be i8/ui8");
3846+
return success();
3847+
}
3848+
3849+
mlir::LogicalResult mlir::pto::SimdStorePredicateOp::verify() {
3850+
auto lhsTy = dyn_cast<VectorType>(getLhs().getType());
3851+
if (!lhsTy || lhsTy.isScalable())
3852+
return emitOpError("expects lhs to be a fixed-width vector type");
3853+
if (!lhsTy.getElementType().isIntOrFloat())
3854+
return emitOpError("expects lhs element type to be integer or float");
3855+
3856+
Type rhsTy = getRhs().getType();
3857+
if (auto rhsVecTy = dyn_cast<VectorType>(rhsTy)) {
3858+
if (rhsVecTy.isScalable())
3859+
return emitOpError("expects rhs vector to be fixed-width");
3860+
if (rhsVecTy != lhsTy)
3861+
return emitOpError("expects rhs vector type to match lhs");
3862+
} else if (rhsTy != lhsTy.getElementType()) {
3863+
return emitOpError(
3864+
"expects rhs to be either a matching vector type or lhs element type");
3865+
}
3866+
3867+
return verifySimdPackedPredicateDst(*this, getDst().getType());
3868+
}
38323869
//===----------------------------------------------------------------------===//
38333870
// PTO.cpp (add TSYNC DPS/tilebuf implementation)
38343871
//===----------------------------------------------------------------------===//

lib/PTO/Transforms/PTOLowerToOpLibCalls.cpp

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,7 +1000,12 @@ static FailureOr<int64_t> getFixedVectorLanes(Type ty) {
10001000

10011001
static bool isSimdBridgeOp(Operation *op) {
10021002
return isa<pto::SimdPredicateOp, pto::SimdLoadOp, pto::SimdStoreOp,
1003-
pto::SimdLoadPUOp, pto::SimdStorePUOp>(op);
1003+
pto::SimdLoadPUOp, pto::SimdStorePUOp,
1004+
pto::SimdStorePredicateOp>(op);
1005+
}
1006+
1007+
static bool isPackedCmpSimdBridgeOp(Operation *op) {
1008+
return isa<pto::SimdStorePredicateOp>(op);
10041009
}
10051010

10061011
static bool isAllowedTemplateBodyOp(Operation *op) {
@@ -1462,6 +1467,22 @@ struct TemplateRegistry {
14621467
std::string inferredVstDist;
14631468
std::string inferredExecMode;
14641469
bool requiresUnifiedLevel3Simd = isLevel3TemplateKind(entry.kind);
1470+
bool hasPackedCmpBridgeOnly = false;
1471+
if (entry.hasSimdBridgeOps) {
1472+
bool sawPackedCmp = false;
1473+
bool sawOtherBridge = false;
1474+
imported.walk([&](Operation *op) {
1475+
if (!isSimdBridgeOp(op))
1476+
return WalkResult::advance();
1477+
if (isPackedCmpSimdBridgeOp(op)) {
1478+
sawPackedCmp = true;
1479+
return WalkResult::advance();
1480+
}
1481+
sawOtherBridge = true;
1482+
return WalkResult::interrupt();
1483+
});
1484+
hasPackedCmpBridgeOnly = sawPackedCmp && !sawOtherBridge;
1485+
}
14651486

14661487
llvm::DenseMap<Operation *, int64_t> preorder;
14671488
int64_t seq = 0;
@@ -1738,23 +1759,26 @@ struct TemplateRegistry {
17381759
return emitFailureWithCode(imported.getLoc(), kErrCoreSlot,
17391760
"seed template must contain exactly one core slot op");
17401761
}
1741-
if (entry.hasSimdBridgeOps && coreCount != 1) {
1762+
if (entry.hasSimdBridgeOps && !hasPackedCmpBridgeOnly && coreCount != 1) {
17421763
return emitFailureWithCode(imported.getLoc(), kErrCoreSlot,
17431764
"template using pto.simd.* must contain exactly one core slot op");
17441765
}
17451766
if (coreCount > 1) {
17461767
return emitFailureWithCode(imported.getLoc(), kErrCoreSlot,
17471768
"template must not contain multiple core slot ops");
17481769
}
1749-
if (entry.hasSimdBridgeOps && firstLoad == std::numeric_limits<int64_t>::max()) {
1770+
if (entry.hasSimdBridgeOps && !hasPackedCmpBridgeOnly &&
1771+
firstLoad == std::numeric_limits<int64_t>::max()) {
17501772
return emitFailureWithCode(imported.getLoc(), kErrCoreSlot,
17511773
"template must contain simd.load/simd.load_pu");
17521774
}
1753-
if (entry.hasSimdBridgeOps && firstStore == std::numeric_limits<int64_t>::max()) {
1775+
if (entry.hasSimdBridgeOps && !hasPackedCmpBridgeOnly &&
1776+
firstStore == std::numeric_limits<int64_t>::max()) {
17541777
return emitFailureWithCode(imported.getLoc(), kErrCoreSlot,
17551778
"template must contain simd.store/simd.store_pu");
17561779
}
1757-
if (entry.hasSimdBridgeOps && !(firstLoad < coreSeq && coreSeq < firstStore)) {
1780+
if (entry.hasSimdBridgeOps && !hasPackedCmpBridgeOnly &&
1781+
!(firstLoad < coreSeq && coreSeq < firstStore)) {
17581782
return emitFailureWithCode(imported.getLoc(), kErrCoreSlot,
17591783
"template ordering must satisfy load -> core -> store");
17601784
}

lib/PTO/Transforms/PTOToEmitC.cpp

Lines changed: 139 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2441,6 +2441,21 @@ buildA5PredicateFromActiveCount(ConversionPatternRewriter &rewriter, Location lo
24412441
return pred.getResult(0);
24422442
}
24432443

2444+
static FailureOr<Value> buildA5PredicateFromActiveCountAndElemToken(
2445+
ConversionPatternRewriter &rewriter, Location loc,
2446+
llvm::StringRef elemTok, Value activeCount) {
2447+
auto *ctx = rewriter.getContext();
2448+
Value activeCountLValue =
2449+
materializeA5PredicateScalarLValue(rewriter, loc, activeCount);
2450+
auto maskRegTy = emitc::OpaqueType::get(ctx, "MaskReg");
2451+
auto templateArgs =
2452+
rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, elemTok)});
2453+
auto pred = rewriter.create<emitc::CallOpaqueOp>(
2454+
loc, TypeRange{maskRegTy}, "CreatePredicate",
2455+
ValueRange{activeCountLValue}, ArrayAttr{}, templateArgs);
2456+
return pred.getResult(0);
2457+
}
2458+
24442459
template <typename ArithOp> static llvm::StringRef getA5VectorBinaryCallee();
24452460

24462461
template <> llvm::StringRef getA5VectorBinaryCallee<arith::AddFOp>() {
@@ -2590,6 +2605,8 @@ getA5ReductionCallee(vector::CombiningKind kind, Type elemTy) {
25902605
}
25912606
}
25922607

2608+
static std::string cmpModeTok(pto::CmpModeAttr a);
2609+
25932610
static FailureOr<llvm::StringRef> getA5SimdReductionCallee(StringRef kind) {
25942611
if (kind == "add")
25952612
return llvm::StringRef("vcadd");
@@ -2942,6 +2959,81 @@ struct SimdReductionToEmitC : public OpConversionPattern<pto::SimdReductionOp> {
29422959
}
29432960
};
29442961

2962+
struct SimdStorePredicateToEmitC
2963+
: public OpConversionPattern<pto::SimdStorePredicateOp> {
2964+
using OpConversionPattern<pto::SimdStorePredicateOp>::OpConversionPattern;
2965+
2966+
LogicalResult matchAndRewrite(pto::SimdStorePredicateOp op, OpAdaptor adaptor,
2967+
ConversionPatternRewriter &rewriter) const override {
2968+
auto vecTy = dyn_cast<VectorType>(op.getLhs().getType());
2969+
if (!vecTy)
2970+
return failure();
2971+
if (failed(
2972+
validateA5OplibVectorType(op, vecTy, op->getName().getStringRef())))
2973+
return failure();
2974+
2975+
auto elemTokOr =
2976+
getA5VectorElemToken(op, vecTy, op->getName().getStringRef());
2977+
if (failed(elemTokOr))
2978+
return failure();
2979+
2980+
auto predOr = buildA5PredicateFromActiveCountAndElemToken(
2981+
rewriter, op.getLoc(), *elemTokOr, adaptor.getActiveCount());
2982+
if (failed(predOr))
2983+
return failure();
2984+
2985+
auto *ctx = rewriter.getContext();
2986+
auto maskTy = emitc::OpaqueType::get(ctx, "MaskReg");
2987+
auto maskVar = rewriter.create<emitc::VariableOp>(
2988+
op.getLoc(), maskTy, emitc::OpaqueAttr::get(ctx, ""));
2989+
2990+
auto modeTy = emitc::OpaqueType::get(ctx, "CmpMode");
2991+
auto modeVal = rewriter.create<emitc::ConstantOp>(
2992+
op.getLoc(), modeTy,
2993+
emitc::OpaqueAttr::get(ctx, cmpModeTok(op.getCmpModeAttr())));
2994+
2995+
auto cmpArgs = rewriter.getArrayAttr(
2996+
{rewriter.getIndexAttr(0), rewriter.getIndexAttr(1),
2997+
rewriter.getIndexAttr(2), rewriter.getIndexAttr(3),
2998+
rewriter.getIndexAttr(4)});
2999+
Value rhs = adaptor.getRhs();
3000+
if (isa<VectorType>(op.getRhs().getType())) {
3001+
auto rhsTy = cast<VectorType>(op.getRhs().getType());
3002+
if (failed(validateA5OplibVectorType(op, rhsTy,
3003+
op->getName().getStringRef())))
3004+
return failure();
3005+
rewriter.create<emitc::CallOpaqueOp>(
3006+
op.getLoc(), TypeRange{}, "ptoas_vcmp",
3007+
ValueRange{maskVar.getResult(), adaptor.getLhs(), rhs,
3008+
modeVal.getResult(), *predOr},
3009+
cmpArgs, ArrayAttr{});
3010+
} else {
3011+
rewriter.create<emitc::CallOpaqueOp>(
3012+
op.getLoc(), TypeRange{}, "ptoas_vcmps",
3013+
ValueRange{maskVar.getResult(), adaptor.getLhs(), rhs,
3014+
modeVal.getResult(), *predOr},
3015+
cmpArgs, ArrayAttr{});
3016+
}
3017+
3018+
Value dst = peelUnrealized(adaptor.getDst());
3019+
if (!isEmitCTileOpaqueType(dst.getType()))
3020+
return rewriter.notifyMatchFailure(
3021+
op,
3022+
"simd.store_predicate currently requires tile-like dst in EmitC lowering");
3023+
3024+
auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t");
3025+
Value linearIndex = emitCCast(rewriter, op.getLoc(), i32Ty,
3026+
adaptor.getLinearOffset());
3027+
rewriter.create<emitc::CallOpaqueOp>(
3028+
op.getLoc(), TypeRange{}, "ptoas_pstore",
3029+
ValueRange{maskVar.getResult(), dst, linearIndex}, ArrayAttr{},
3030+
ArrayAttr{});
3031+
3032+
rewriter.eraseOp(op);
3033+
return success();
3034+
}
3035+
};
3036+
29453037
struct OplibMemRefLoadToEmitC : public OpConversionPattern<memref::LoadOp> {
29463038
using OpConversionPattern<memref::LoadOp>::OpConversionPattern;
29473039

@@ -9124,6 +9216,8 @@ struct SimdVecScopeToEmitC : public OpConversionPattern<pto::SimdVecScopeOp> {
91249216
matchAndRewrite(pto::SimdVecScopeOp op, OpAdaptor adaptor,
91259217
ConversionPatternRewriter &rewriter) const override {
91269218
Location loc = op.getLoc();
9219+
rewriter.create<emitc::VerbatimOp>(
9220+
loc, "#if defined(__CCE_AICORE__) || defined(__CPU_SIM)");
91279221
rewriter.create<emitc::VerbatimOp>(loc, "__VEC_SCOPE__ {");
91289222

91299223
Block &innerBlock = op.getBody().front();
@@ -9132,6 +9226,8 @@ struct SimdVecScopeToEmitC : public OpConversionPattern<pto::SimdVecScopeOp> {
91329226
}
91339227

91349228
rewriter.create<emitc::VerbatimOp>(loc, "}");
9229+
rewriter.create<emitc::VerbatimOp>(
9230+
loc, "#endif // __CCE_AICORE__ || __CPU_SIM");
91359231
rewriter.eraseOp(op);
91369232
return success();
91379233
}
@@ -9771,6 +9867,7 @@ static void populatePTOToEmitCPatterns(RewritePatternSet &patterns,
97719867
OplibVectorUnaryToEmitC<math::SqrtOp>,
97729868
OplibVectorUnaryToEmitC<math::RsqrtOp>,
97739869
OplibVectorCmpFToEmitC, OplibVectorCmpIToEmitC,
9870+
SimdStorePredicateToEmitC,
97749871
OplibVectorSelectToEmitC<arith::SelectOp>,
97759872
SimdReductionToEmitC,
97769873
OplibVectorReductionToEmitC,
@@ -10465,7 +10562,8 @@ struct EmitPTOManualPass
1046510562
if (callee == "ptoas_vreduce_add" || callee == "ptoas_vreduce_max" ||
1046610563
callee == "ptoas_vreduce_min")
1046710564
needsVectorReductionHelper = true;
10468-
if (callee == "ptoas_vcmp")
10565+
if (callee == "ptoas_vcmp" || callee == "ptoas_vcmps" ||
10566+
callee == "ptoas_pstore")
1046910567
needsVectorCmpHelper = true;
1047010568
if (callee == "ptoas_vrem")
1047110569
needsVectorRemfHelper = true;
@@ -10494,6 +10592,7 @@ struct EmitPTOManualPass
1049410592
if (needsVectorReductionHelper) {
1049510593
helperBuilder.create<emitc::VerbatimOp>(
1049610594
loc, helperBuilder.getStringAttr(R"cpp(
10595+
#if defined(__CCE_AICORE__) || defined(__CPU_SIM)
1049710596
template <typename T>
1049810597
PTO_INTERNAL T ptoas_vreduce_add(RegTensor<T> src, MaskReg pred) {
1049910598
RegTensor<T> dst;
@@ -10546,11 +10645,13 @@ struct EmitPTOManualPass
1054610645
T red = ptoas_vreduce_min(src, pred);
1054710646
return red < acc ? red : acc;
1054810647
}
10648+
#endif
1054910649
)cpp"));
1055010650
}
1055110651
if (needsVectorCmpHelper) {
1055210652
helperBuilder.create<emitc::VerbatimOp>(
1055310653
loc, helperBuilder.getStringAttr(R"cpp(
10654+
#if defined(__CCE_AICORE__) || defined(__CPU_SIM)
1055410655
template <typename T>
1055510656
PTO_INTERNAL void ptoas_vcmp(MaskReg &dst, RegTensor<T> src0, RegTensor<T> src1,
1055610657
CmpMode mode, MaskReg pred) {
@@ -10578,11 +10679,47 @@ struct EmitPTOManualPass
1057810679
break;
1057910680
}
1058010681
}
10682+
template <typename RegT, typename T>
10683+
PTO_INTERNAL void ptoas_vcmps(MaskReg &dst, RegT src0, T src1,
10684+
CmpMode mode, MaskReg pred) {
10685+
switch (mode) {
10686+
case CmpMode::EQ:
10687+
vcmps_eq(dst, src0, src1, pred);
10688+
break;
10689+
case CmpMode::NE:
10690+
vcmps_ne(dst, src0, src1, pred);
10691+
break;
10692+
case CmpMode::LT:
10693+
vcmps_lt(dst, src0, src1, pred);
10694+
break;
10695+
case CmpMode::LE:
10696+
vcmps_le(dst, src0, src1, pred);
10697+
break;
10698+
case CmpMode::GT:
10699+
vcmps_gt(dst, src0, src1, pred);
10700+
break;
10701+
case CmpMode::GE:
10702+
vcmps_ge(dst, src0, src1, pred);
10703+
break;
10704+
default:
10705+
vcmps_eq(dst, src0, src1, pred);
10706+
break;
10707+
}
10708+
}
10709+
template <typename TileT>
10710+
PTO_INTERNAL void ptoas_pstore(MaskReg src, TileT &dst, int32_t linearIndex) {
10711+
__ubuf__ uint32_t *dstWords =
10712+
reinterpret_cast<__ubuf__ uint32_t *>(dst.data());
10713+
int32_t wordOffset = linearIndex / 32;
10714+
psts(src, dstWords + wordOffset, 0, PK);
10715+
}
10716+
#endif
1058110717
)cpp"));
1058210718
}
1058310719
if (needsVectorRemfHelper) {
1058410720
helperBuilder.create<emitc::VerbatimOp>(
1058510721
loc, helperBuilder.getStringAttr(R"cpp(
10722+
#if defined(__CCE_AICORE__) || defined(__CPU_SIM)
1058610723
template <typename T>
1058710724
PTO_INTERNAL void ptoas_vrem(RegTensor<T> &dst, RegTensor<T> src0,
1058810725
RegTensor<T> src1, MaskReg pred) {
@@ -10611,6 +10748,7 @@ struct EmitPTOManualPass
1061110748
vor(dst, dstEven, dstOdd, pred);
1061210749
}
1061310750
}
10751+
#endif
1061410752
)cpp"));
1061510753
}
1061610754
if (needsBitcastHelper) {

0 commit comments

Comments
 (0)