-
Notifications
You must be signed in to change notification settings - Fork 47
Expand file tree
/
Copy pathPTOIRTranslator.h
More file actions
98 lines (76 loc) · 3.21 KB
/
PTOIRTranslator.h
File metadata and controls
98 lines (76 loc) · 3.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#ifndef MLIR_DIALECT_PTO_TRANSFORMS_INJECTSYNC_PTOIRTRANSLATOR_H
#define MLIR_DIALECT_PTO_TRANSFORMS_INJECTSYNC_PTOIRTRANSLATOR_H
#include "PTO/IR/PTO.h"
#include "PTO/Transforms/InsertSync/SyncCommon.h"
#include "PTO/Transforms/InsertSync/MemoryDependentAnalyzer.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinOps.h"
#include "llvm/Support/raw_ostream.h"
#include <vector>
namespace mlir {
namespace pto {
class PTOIRTranslator {
public:
PTOIRTranslator(SyncIRs &syncIR,
MemoryDependentAnalyzer &memDepAnalyzer,
Buffer2MemInfoMap &buffer2MemInfoMap,
func::FuncOp func,
SyncAnalysisMode syncAnalysisMode)
: func_(func),
index(0),
syncIR_(syncIR),
buffer2MemInfoMap_(buffer2MemInfoMap),
memAnalyzer_(memDepAnalyzer),
mode_(syncAnalysisMode) { };
// 核心入口:执行 IR 分析和转换
void Build();
// 获取生成的 SyncIR (指令序列)
SyncIRs &getSyncIR() { return syncIR_; }
// 获取 Buffer 分析结果 (别名映射)
Buffer2MemInfoMap &getBuffer2MemInfoMap() { return buffer2MemInfoMap_; }
// 打印调试信息 (Buffer Map 和 SyncIR)
void print();
private:
func::FuncOp func_;
unsigned index; // 当前 SyncIR 节点的索引计数器
// 核心数据结构 (定义在 SyncCommon.h 中)
SyncIRs &syncIR_;
Buffer2MemInfoMap &buffer2MemInfoMap_;
MemoryDependentAnalyzer &memAnalyzer_;
SyncAnalysisMode mode_;
// --- 递归遍历逻辑 ---
void RecursionIR(Region *region);
// --- 内存/Alias 分析 ---
void UpdateKernelArgMemInfo();
LogicalResult UpdateAllocTileOpMemInfo(pto::AllocTileOp op);
LogicalResult UpdatePointerCastOpMemInfo(pto::PointerCastOp op);
LogicalResult UpdateMemrefAllocOpMemInfo(memref::AllocOp op);
// 处理 View/Alias (MakeTensorView, Subview, Mov)
void UpdateAliasBufferInfo(Value result, Value source);
// --- 控制流处理 (SCF) ---
void UpdateForOpInfo(scf::ForOp forOp);
void UpdateWhileOpInfo(scf::WhileOp whileOp);
void UpdateIfOpInfo(scf::IfOp ifOp);
void UpdateYieldOpInfo(scf::YieldOp yieldOp);
// --- 核心:处理计算/搬运指令 (生成 Compound 节点) ---
void UpdatePTOOpInfo(Operation *op);
// --- 辅助函数 ---
// 获取 PTO Op 对应的硬件流水线类型
PipelineType getOpPipeline(Operation *op);
// 根据 Values 填充 Def/Use 列表
void UpdateDefUseVec(ValueRange values, SmallVector<const BaseMemInfo *> &vec);
// scalar 访问切片建模:按 ptr+offset 构建访问级依赖信息。
void UpdateScalarDefUseVec(Value ptr, Value offset, Type scalarType,
SmallVector<const BaseMemInfo *> &vec);
// 调试辅助
std::string getPipelineName(PipelineType pipe);
void printMemInfoList(llvm::raw_ostream &os,
const SmallVector<const BaseMemInfo *> &list,
AsmState &state);
// 持久化 scalar 访问切片,保证 def/use 指针在分析期间有效。
std::vector<std::unique_ptr<BaseMemInfo>> scalarAccessMemInfoPool_;
};
} // namespace pto
} // namespace mlir
#endif // MLIR_DIALECT_PTO_TRANSFORMS_INJECTSYNC_PTOIRTRANSLATOR_H