Skip to content

Commit 0777e29

Browse files
xysmlxLingxiao Ma
andauthored
Fix AntaresCpuKernelEmitter and add ir_based_fusion in GENERIC_CPU backend (#351)
* Update AntaresCpuKernelEmitter for Antares v0.2.x * Add IRBasedFusion for GENERIC_CPU backend Co-authored-by: Lingxiao Ma <lingm@microsoft.com>
1 parent 31f688f commit 0777e29

2 files changed

Lines changed: 6 additions & 6 deletions

File tree

src/nnfusion/core/kernels/cpu/cpu_kernel_emitter.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,20 +151,19 @@ LanguageUnit_p cpu::AntaresCpuKernelEmitter::emit_function_body()
151151
auto& lu = *_lu;
152152

153153
// extract kernel code
154-
const char* s_func_pattern = "// [thread_compute]\n";
154+
const char* s_func_pattern = "// [thread_extent] ";
155155
const char* e_func_pattern = "\n}\n";
156156
const char* s_rank_pattern = "__rank__ = ";
157157
const char* e_rank_pattern = "\n";
158158
std::string::size_type s_func_pos = antares_code.find(s_func_pattern);
159159
std::string::size_type e_func_pos = antares_code.rfind(e_func_pattern);
160160

161-
if (s_func_pos != std::string::npos || e_func_pos != std::string::npos)
161+
if (s_func_pos == std::string::npos || e_func_pos == std::string::npos)
162162
return nullptr;
163163

164164
NNFUSION_CHECK(s_func_pos != std::string::npos && e_func_pos != std::string::npos);
165165

166-
std::string func_body = antares_code.substr(s_func_pos + strlen(s_func_pattern),
167-
e_func_pos - s_func_pos - strlen(s_func_pattern));
166+
std::string func_body = antares_code.substr(s_func_pos, e_func_pos - s_func_pos);
168167
std::string::size_type s_rank_pos = func_body.find(s_rank_pattern);
169168
std::string::size_type e_rank_pos = func_body.find(e_rank_pattern);
170169
std::string rank_str = func_body.substr(s_rank_pos + strlen(s_rank_pattern),

src/nnfusion/engine/device/cpu.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "cpu.hpp"
55
#include "reversed_dfs_visitor.hpp"
66

7+
#include "nnfusion/engine/pass/extract_graph_signature.hpp"
78
#include "nnfusion/engine/pass/graph/assign_async_info_pass.hpp"
89
#include "nnfusion/engine/pass/graph/assign_layout_pass.hpp"
910
#include "nnfusion/engine/pass/graph/autodiff_pass.hpp"
@@ -14,6 +15,7 @@
1415
#include "nnfusion/engine/pass/graph/gemm_fusion_pass.hpp"
1516
#include "nnfusion/engine/pass/graph/gnode_device_dispatcher.hpp"
1617
#include "nnfusion/engine/pass/graph/gradient_weight_mapping_pass.hpp"
18+
#include "nnfusion/engine/pass/graph/ir_based_fusion_pass.hpp"
1719
#include "nnfusion/engine/pass/graph/kernel_fusion_pass.hpp"
1820
#include "nnfusion/engine/pass/graph/kernel_profiling_pass.hpp"
1921
#include "nnfusion/engine/pass/graph/kernel_selection.hpp"
@@ -23,8 +25,6 @@
2325
#include "nnfusion/engine/pass/graph/pattern_substitution.hpp"
2426
#include "nnfusion/engine/pass/graph/runtime_const_folding_pass.hpp"
2527
#include "nnfusion/engine/pass/graph/vector_dot_transpose_pass.hpp"
26-
27-
#include "nnfusion/engine/pass/extract_graph_signature.hpp"
2828
#include "nnfusion/engine/pass/tensor/inplace_tensor_analysis.hpp"
2929
#include "nnfusion/engine/pass/tensor/liveness_analysis.hpp"
3030
#include "nnfusion/engine/pass/tensor/tensor_device_dispatcher.hpp"
@@ -51,6 +51,7 @@ CpuEngine::CpuEngine()
5151
g_passes->push_back(make_shared<AssignLayoutPass>());
5252
g_passes->push_back(make_shared<OpInplacePass>());
5353

54+
g_passes->push_back(make_shared<IRBasedFusionPass>());
5455
g_passes->push_back(make_shared<PatternSubstitutionPass>());
5556

5657
// Kernel selection

0 commit comments

Comments
 (0)