forked from hw-native-sys/PTOAS
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathSyncCodegen.cpp
More file actions
463 lines (400 loc) · 18 KB
/
SyncCodegen.cpp
File metadata and controls
463 lines (400 loc) · 18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
// Copyright (c) 2026 Huawei Technologies Co., Ltd.
// This program is free software, you can redistribute it and/or modify it under the terms and conditions of
// CANN Open Software License Agreement Version 2.0 (the "License").
// Please refer to the License for details. You may not use this file except in compliance with the License.
// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
// See LICENSE in the root of the software repository for the full text of the License.
// Please refer to the License for details. You may not use this file except in compliance with the License.
// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
// See LICENSE in the root of the software repository for the full text of the License.
#include "PTO/Transforms/InsertSync/SyncCodegen.h"
#include "PTO/IR/PTO.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinOps.h"
#include "llvm/ADT/STLExtras.h"
#define DEBUG_TYPE "pto-inject-sync"
using namespace mlir;
using namespace mlir::pto;
// ==============================================================================
// 1. Helper Functions
// ==============================================================================
static pto::PipeAttr getPipeAttr(Builder &builder, PipelineType pipe) {
auto odsPipeVal = static_cast<pto::PIPE>(pipe);
return pto::PipeAttr::get(builder.getContext(), odsPipeVal);
}
static pto::EventAttr getEventAttr(Builder &builder, int id) {
auto odsEventVal = static_cast<pto::EVENT>(id);
return pto::EventAttr::get(builder.getContext(), odsEventVal);
}
static bool isA5LowLevelSyncPipeLegal(PipelineType pipe) {
switch (pipe) {
case PipelineType::PIPE_S:
case PipelineType::PIPE_V:
case PipelineType::PIPE_MTE2:
case PipelineType::PIPE_MTE3:
return true;
default:
return false;
}
}
static bool shouldUseA5BarrierFallback(func::FuncOp func,
const SyncOperation *sync) {
if (!isTargetArchA5(func.getOperation()))
return false;
return !isA5LowLevelSyncPipeLegal(sync->GetActualSrcPipe()) ||
!isA5LowLevelSyncPipeLegal(sync->GetActualDstPipe());
}
static bool IsSameSyncSignature(const SyncOperation *existing,
const SyncOperation *candidate) {
if (existing->GetType() != candidate->GetType())
return false;
if (existing->GetActualSrcPipe() != candidate->GetActualSrcPipe())
return false;
if (existing->GetActualDstPipe() != candidate->GetActualDstPipe())
return false;
if (candidate->isSyncSetType() || candidate->isSyncWaitType())
return existing->eventIds == candidate->eventIds;
return true;
}
static bool IsSyncExist(const SyncOps &list, SyncOperation *newSync) {
// Tombstone entries are soft-deleted and should never participate in
// deduplication; otherwise they can shadow a later live sync with
// the same signature.
if (newSync->uselessSync)
return true;
for (auto *existing : list) {
if (existing == newSync)
return true;
if (existing->uselessSync)
continue;
if (!IsSameSyncSignature(existing, newSync))
continue;
return true;
}
return false;
}
static void MergeSyncList(SyncOps &dstList, const SyncOps &srcList) {
for (auto *sync : srcList) {
if (sync->uselessSync)
continue;
if (!IsSyncExist(dstList, sync)) {
dstList.push_back(sync);
}
}
}
// ==============================================================================
// 2. SyncCodegen Implementation
// ==============================================================================
void SyncCodegen::Run() {
MLIRContext *ctx = func_->getContext();
IRRewriter rewriter(ctx);
UpdateOpInsertSync(rewriter);
// [Optional Debug] 这里的 Debug 打印可以保留或注释掉
// ...
func_->walk<WalkOrder::PreOrder>([&](Operation *op) {
if (op2InsertSync.count(op)) {
// 处理 PRE Sync
for (auto &syncBefore : op2InsertSync[op].pipeBefore) {
SyncInsert(rewriter, op, syncBefore, true);
}
// 处理 POST Sync (逆序遍历,为了保持插入后的顺序正确)
for (auto &syncAfter : llvm::reverse(op2InsertSync[op].pipeAfter)) {
SyncInsert(rewriter, op, syncAfter, false);
}
}
});
// Ensure the tail clean barrier is emitted at function tail, right before
// return, instead of being interleaved with other trailing sync ops.
AppendAutoSyncTailBarrierIfNeeded(rewriter);
}
void SyncCodegen::UpdateOpInsertSync(IRRewriter &rewriter) {
for (auto &nowElement : syncIR_) {
if (auto *compoundElement = dyn_cast<CompoundInstanceElement>(nowElement.get())) {
UpdateCompoundOpInsertSync(compoundElement);
} else if (auto *placeHolder = dyn_cast<PlaceHolderInstanceElement>(nowElement.get())) {
updatePlaceHolderOpInsertSync(placeHolder);
} else if (auto *loopElement = dyn_cast<LoopInstanceElement>(nowElement.get())) {
UpdateLoopOpInsertSync(loopElement);
} else if (auto *branchElement = dyn_cast<BranchInstanceElement>(nowElement.get())) {
UpdateBranchOpInsertSync(branchElement);
}
}
}
void SyncCodegen::UpdateCompoundOpInsertSync(CompoundInstanceElement *nowCompound) {
auto &pipeBuild = op2InsertSync[nowCompound->elementOp];
MergeSyncList(pipeBuild.pipeBefore, nowCompound->pipeBefore);
MergeSyncList(pipeBuild.pipeAfter, nowCompound->pipeAfter);
}
void SyncCodegen::UpdateLoopOpInsertSync(LoopInstanceElement *nowElement) {
if (nowElement->getLoopKind() == KindOfLoop::LOOP_END) {
auto *loopBegin = dyn_cast<LoopInstanceElement>(syncIR_[nowElement->beginId].get());
auto &pipeBuild = op2InsertSync[nowElement->elementOp];
MergeSyncList(pipeBuild.pipeBefore, loopBegin->pipeBefore);
MergeSyncList(pipeBuild.pipeAfter, nowElement->pipeAfter);
}
}
void SyncCodegen::UpdateBranchOpInsertSync(BranchInstanceElement *nowElement) {
if (nowElement->getBranchKind() == KindOfBranch::IF_END) {
auto *branchBegin = dyn_cast<BranchInstanceElement>(syncIR_[nowElement->beginId].get());
auto &pipeBuild = op2InsertSync[nowElement->elementOp];
MergeSyncList(pipeBuild.pipeBefore, branchBegin->pipeBefore);
MergeSyncList(pipeBuild.pipeAfter, nowElement->pipeAfter);
}
}
void SyncCodegen::updatePlaceHolderOpInsertSync(PlaceHolderInstanceElement *placeHolder) {
// 1. 处理 Virtual Else
if (placeHolder->isVirtualElse) {
auto ifOp = dyn_cast<scf::IfOp>(placeHolder->parentIfOp);
if (!ifOp) return;
// 如果还没有 else block,创建一个
if (!ifOp.elseBlock()) {
OpBuilder builder(ifOp.getContext());
// 只有当确实有 Sync 指令需要插入时才创建
if (!placeHolder->pipeBefore.empty() || !placeHolder->pipeAfter.empty()) {
Region &elseRegion = ifOp.getElseRegion();
Block *elseBlock = new Block();
elseRegion.push_back(elseBlock);
builder.setInsertionPointToEnd(elseBlock);
builder.create<scf::YieldOp>(ifOp.getLoc());
}
}
// 更新映射:将 Virtual Placeholder 映射到新创建的 Yield Op
if (ifOp.elseBlock()) {
placeHolder->elementOp = ifOp.getElseRegion().front().getTerminator();
} else {
// 依然没有 Sync 需要插入,直接返回
return;
}
}
// 2. 处理 Normal PlaceHolder (Then End or Existing Else End)
else if (placeHolder->elementOp == placeHolder->parentIfOp) {
// 之前的 Translator 逻辑把 Normal Placeholder 也映射到了 ifOp
// 我们需要修正它指向 Yield
auto ifOp = dyn_cast<scf::IfOp>(placeHolder->elementOp);
// 判断是 Then 还是 Else
// 简单判断:看 index。或者 Translator 里直接存 Yield Op。
// 这里假设 Translator 存的是 IfOp,我们需要找到对应的 Yield。
// ...
// 建议在 Translator 里直接让 elementOp 指向 Yield Op(如果存在)。
}
// 执行常规的 Sync 插入
if (!placeHolder->elementOp) return;
auto &pipeBuild = op2InsertSync[placeHolder->elementOp];
MergeSyncList(pipeBuild.pipeBefore, placeHolder->pipeBefore);
MergeSyncList(pipeBuild.pipeAfter, placeHolder->pipeAfter);
}
void SyncCodegen::SyncInsert(IRRewriter &rewriter, Operation *op,
SyncOperation *sync, bool beforeInsert) {
if (sync->uselessSync) return;
// [Fix] 处理补偿逻辑的强制插入点
Operation *insertAnchorOp = op;
bool forceBefore = beforeInsert;
if (sync->isCompensation) {
// 策略:补偿指令必须插在控制流块的末尾(Terminator 之前)
// Case 1: Anchor 是 scf.if (Virtual Else 的情况)
if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
// 我们需要确定是插在 Then 还是 Else。
// 通常 Analysis 会根据 context 知道,但这里 op 只是 anchor。
// 我们利用 SyncOperation 的上下文推断,或者更简单地:
// 如果是 Virtual Else,PTOIRTranslator 应该已经处理了 Block 创建。
// 如果这里还是 IfOp,说明我们必须进入 Else Region。
if (!ifOp.elseBlock()) {
// 再次兜底:创建 Else Block
OpBuilder b(ifOp.getContext());
Block *elseBlock = new Block();
ifOp.getElseRegion().push_back(elseBlock);
b.setInsertionPointToEnd(elseBlock);
b.create<scf::YieldOp>(ifOp.getLoc());
}
// 将插入点重定向到 Else Block 的 Yield
insertAnchorOp = ifOp.getElseRegion().front().getTerminator();
}
// Case 2: Anchor 已经是 Terminator (YieldOp)
else if (op->hasTrait<OpTrait::IsTerminator>()) {
insertAnchorOp = op;
}
// Case 3: 其他情况 (Anchor 指向了 Block 内的某条指令)
else {
// 找到该 Block 的 Terminator
insertAnchorOp = op->getBlock()->getTerminator();
}
// 强制在 Terminator 之前插入
forceBefore = true;
}
// 分发创建逻辑,传入修正后的 insertAnchorOp 和 forceBefore
if (sync->GetType() == SyncOperation::TYPE::PIPE_BARRIER) {
CreateBarrierOp(rewriter, insertAnchorOp, sync, forceBefore);
} else if (sync->isSyncSetType() || sync->isSyncWaitType()) {
if (sync->eventIds.size() == 1) {
CreateSetWaitOpForSingleBuffer(rewriter, insertAnchorOp, sync, forceBefore);
} else {
CreateSetWaitOpForMultiBuffer(rewriter, insertAnchorOp, sync, forceBefore);
}
}
}
// [核心修改] 加强版 CreateBarrierOp
void SyncCodegen::CreateBarrierOp(IRRewriter &rewriter, Operation *op,
SyncOperation *sync, bool beforeInsert) {
// A5: PIPE_V intra-pipe ordering is guaranteed by hardware; do not emit
// explicit vector barrier (it is also rejected by backend checks).
if (isTargetArchA5(func_.getOperation()) &&
sync->GetActualSrcPipe() == PipelineType::PIPE_V) {
return;
}
// Compiler-inserted tail clean barrier must be anchored at function tail.
if (sync->GetActualSrcPipe() == PipelineType::PIPE_ALL &&
sync->GetActualDstPipe() == PipelineType::PIPE_ALL) {
pendingAutoSyncTailBarrier_ = true;
return;
}
// [Fix] 判定是否需要前置插入:如果是显式 Before,或者 Op 是 Terminator (如 Yield)
bool insertAtPos = beforeInsert || op->hasTrait<OpTrait::IsTerminator>();
// 1. 设置插入点
if (insertAtPos) {
rewriter.setInsertionPoint(op);
} else {
rewriter.setInsertionPointAfter(op);
}
// 2. 获取上下文
Block *block = rewriter.getInsertionBlock();
Block::iterator ip = rewriter.getInsertionPoint();
auto currentPipeAttr = getPipeAttr(rewriter, sync->GetActualSrcPipe());
// 3. 窥孔优化 (双向检查)
// 注意:如果是 Terminator 导致的强制前置插入,我们也应该检查 Prev,因为它是插在末尾
if (insertAtPos) {
// PRE 插入:检查前一条指令
if (ip != block->begin()) {
if (auto prevBarrier = dyn_cast<pto::BarrierOp>(&*std::prev(ip))) {
if (prevBarrier.getPipe() == currentPipeAttr) return; // Dedup
}
}
} else {
// POST 插入:检查当前/下一条指令
if (ip != block->end()) {
if (auto nextBarrier = dyn_cast<pto::BarrierOp>(&*ip)) {
if (nextBarrier.getPipe() == currentPipeAttr) return; // Dedup
}
}
}
// 4. 创建指令
auto barrier =
rewriter.create<pto::BarrierOp>(op->getLoc(), currentPipeAttr);
(void)barrier;
}
void SyncCodegen::AppendAutoSyncTailBarrierIfNeeded(IRRewriter &rewriter) {
if (!pendingAutoSyncTailBarrier_)
return;
SmallVector<func::ReturnOp, 4> returns;
func_.walk([&](func::ReturnOp ret) { returns.push_back(ret); });
if (returns.empty())
return;
auto pipeAllAttr = getPipeAttr(rewriter, PipelineType::PIPE_ALL);
for (auto ret : returns) {
rewriter.setInsertionPoint(ret);
auto barrier = rewriter.create<pto::BarrierOp>(ret.getLoc(), pipeAllAttr);
barrier->setAttr("pto.auto_sync_tail_barrier", rewriter.getUnitAttr());
if (auto hintAttr =
func_->getAttrOfType<mlir::StringAttr>("pto.auto_sync_tail_hint")) {
barrier->setAttr("pto.auto_sync_tail_hint", hintAttr);
}
}
pendingAutoSyncTailBarrier_ = false;
}
void SyncCodegen::CreateSetWaitOpForSingleBuffer(IRRewriter &rewriter,
Operation *op,
SyncOperation *sync,
bool beforeInsert) {
if (shouldUseA5BarrierFallback(func_, sync)) {
auto pipeAllAttr = getPipeAttr(rewriter, PipelineType::PIPE_ALL);
if (beforeInsert || op->hasTrait<OpTrait::IsTerminator>()) {
rewriter.setInsertionPoint(op);
} else {
rewriter.setInsertionPointAfter(op);
}
rewriter.create<pto::BarrierOp>(op->getLoc(), pipeAllAttr);
return;
}
// [Fix] Terminator 强制前置插入
if (beforeInsert || op->hasTrait<OpTrait::IsTerminator>()) {
rewriter.setInsertionPoint(op);
} else {
rewriter.setInsertionPointAfter(op);
}
auto srcPipe = getPipeAttr(rewriter, sync->GetActualSrcPipe());
auto dstPipe = getPipeAttr(rewriter, sync->GetActualDstPipe());
auto eventId = getEventAttr(rewriter, sync->eventIds[0]);
if (sync->isSyncWaitType()) {
rewriter.create<pto::WaitFlagOp>(op->getLoc(), srcPipe, dstPipe, eventId);
} else {
rewriter.create<pto::SetFlagOp>(op->getLoc(), srcPipe, dstPipe, eventId);
}
}
void SyncCodegen::CreateSetWaitOpForMultiBuffer(IRRewriter &rewriter,
Operation *op,
SyncOperation *sync,
bool beforeInsert) {
if (shouldUseA5BarrierFallback(func_, sync)) {
auto pipeAllAttr = getPipeAttr(rewriter, PipelineType::PIPE_ALL);
if (beforeInsert || op->hasTrait<OpTrait::IsTerminator>()) {
rewriter.setInsertionPoint(op);
} else {
rewriter.setInsertionPointAfter(op);
}
rewriter.create<pto::BarrierOp>(op->getLoc(), pipeAllAttr);
return;
}
// 注意:GetBufferSelected 可能需要在插入 Set/Wait 之前调用,以确保 SSA 顺序
// 但这里只是获取 Value,不影响 InsertionPoint 的设定
Value bufferSelected = GetBufferSelected(rewriter, op, sync);
(void)bufferSelected;
// [Fix] Terminator 强制前置插入
if (beforeInsert || op->hasTrait<OpTrait::IsTerminator>()) {
rewriter.setInsertionPoint(op);
} else {
rewriter.setInsertionPointAfter(op);
}
auto srcPipe = getPipeAttr(rewriter, sync->GetActualSrcPipe());
auto dstPipe = getPipeAttr(rewriter, sync->GetActualDstPipe());
auto eventId = getEventAttr(rewriter, sync->eventIds[0]); // 注意:MultiBuffer可能需要特殊处理Attr
// 这里假设 SetFlagOp/WaitFlagOp 支持动态 Value 作为 EventID,或者您有特殊的 Op
// 如果 PTO 定义只支持 Attribute,那么上面的 GetBufferSelected 逻辑需要配合修改 Op 定义
// 假设目前的 Op 定义如下:
if (sync->isSyncWaitType()) {
// 假设 WaitFlagOp 有支持 Value eventId 的重载或变体
// 如果没有,这行代码可能需要调整。但在您之前的 Double Buffer 测试中,看起来它是工作的?
// 或者您是否使用了 UpdateFlagOp (带 Value)?
// 这里保持原样,只修改 InsertionPoint
rewriter.create<pto::WaitFlagOp>(op->getLoc(), srcPipe, dstPipe, eventId);
} else {
rewriter.create<pto::SetFlagOp>(op->getLoc(), srcPipe, dstPipe, eventId);
}
}
Value SyncCodegen::GetBufferSelected(IRRewriter &rewriter, Operation *op,
SyncOperation *sync) {
if (SyncIndex2SelectBuffer.count(sync->GetSyncIndex())) {
return SyncIndex2SelectBuffer[sync->GetSyncIndex()];
}
auto parentLoop = op->getParentOfType<scf::ForOp>();
if (!parentLoop) return nullptr;
Value counter;
if (loop2BufferCounter.count(parentLoop)) {
counter = loop2BufferCounter[parentLoop];
} else {
rewriter.setInsertionPointToStart(parentLoop.getBody());
Value iv = parentLoop.getInductionVar();
Value c2 = rewriter.create<arith::ConstantIndexOp>(op->getLoc(), 2);
counter = rewriter.create<arith::RemUIOp>(op->getLoc(), iv, c2);
loop2BufferCounter[parentLoop] = counter;
}
rewriter.setInsertionPointAfter(counter.getDefiningOp());
Value id0 = rewriter.create<arith::ConstantIndexOp>(op->getLoc(), sync->eventIds[0]);
Value id1 = rewriter.create<arith::ConstantIndexOp>(op->getLoc(), sync->eventIds[1]);
Value isZero = rewriter.create<arith::CmpIOp>(op->getLoc(), arith::CmpIPredicate::eq, counter,
rewriter.create<arith::ConstantIndexOp>(op->getLoc(), 0));
Value selected = rewriter.create<arith::SelectOp>(op->getLoc(), isZero, id0, id1);
SyncIndex2SelectBuffer[sync->GetSyncIndex()] = selected;
return selected;
}