Skip to content

Commit 69f8fed

Browse files
xgqdut2016qinyiqun
andauthored
issue/1090: QY机器添加flash attention (#1099)
* issue/1090: qy flash-attention * issue/1090: success link flash-attention.so * issue/1090: qy flash guard * issue/1090: success qy flash * issue/1090: remove unnessesary coda and .contiguous() function --------- Co-authored-by: qinyiqun <qinyiqun@outlook.com>
1 parent 9ead056 commit 69f8fed

7 files changed

Lines changed: 110 additions & 11 deletions

File tree

include/infinicore/adaptor/aten_adaptor.hpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55

66
#include <ATen/ATen.h>
77

8-
#ifdef ENABLE_NVIDIA_API
9-
#include <ATen/cuda/CUDAContext.h>
8+
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API)
9+
#include <c10/cuda/CUDAStream.h>
1010
#include <c10/cuda/CUDAGuard.h>
11+
#include <ATen/cuda/CUDAContext.h>
1112
#endif
1213

1314
namespace infinicore::adaptor {
@@ -33,14 +34,16 @@ inline at::Device to_at_device(const Device &device) {
3334
return at::Device(at::kCUDA, device.getIndex());
3435
} else if (device.getType() == Device::Type::CPU) {
3536
return at::Device(at::kCPU);
37+
} else if (device.getType() == Device::Type::QY) {
38+
return at::Device(at::kCUDA, device.getIndex());
3639
} else {
3740
throw std::runtime_error("Unsupported device type for ATen");
3841
}
3942
}
4043

4144
at::Tensor to_aten_tensor(const infinicore::Tensor &t);
4245

43-
#ifdef ENABLE_NVIDIA_API
46+
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API)
4447
c10::cuda::CUDAStream get_cuda_stream();
4548
#endif
4649
} // namespace infinicore::adaptor

src/infinicore/adaptor/aten_adaptor.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ at::Tensor to_aten_tensor(const infinicore::Tensor &t) {
3232
options);
3333
}
3434

35-
#ifdef ENABLE_NVIDIA_API
35+
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API)
3636
c10::cuda::CUDAStream get_cuda_stream() {
3737
return c10::cuda::getStreamFromExternal(
3838
cudaStream_t(infinicore::context::getStream()), infinicore::context::getDevice().getIndex());

src/infinicore/nn/embedding.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ Embedding::Embedding(size_t num_embeddings,
4545
Tensor Embedding::forward(const Tensor &indices) const {
4646
// TODO: Implement on-device embedding for all devices, then remove the condition and the classic approach
4747
auto device_type = device_.getType();
48-
if (device_type == Device::Type::NVIDIA || device_type == Device::Type::ILUVATAR || device_type == Device::Type::METAX || device_type == Device::Type::MOORE || device_type == Device::Type::ALI) {
48+
if (device_type == Device::Type::NVIDIA || device_type == Device::Type::ILUVATAR || device_type == Device::Type::METAX || device_type == Device::Type::MOORE || device_type == Device::Type::ALI || device_type == Device::Type::QY) {
4949
// Use op::embedding which supports device-side input and batch dimension
5050
return op::embedding(indices->contiguous()->to(device_), weight_);
5151
}

src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,13 @@ void run(void *planned_meta) {
3838

3939
auto out_tensor = infinicore::adaptor::to_aten_tensor(p->out);
4040
auto q = infinicore::adaptor::to_aten_tensor(p->q);
41+
#if defined(ENABLE_NVIDIA_API)
4142
auto k_cache = infinicore::adaptor::to_aten_tensor(p->k_cache);
4243
auto v_cache = infinicore::adaptor::to_aten_tensor(p->v_cache);
44+
#elif defined(ENABLE_QY_API)
45+
auto k_cache = infinicore::adaptor::to_aten_tensor(p->k_cache).contiguous();
46+
auto v_cache = infinicore::adaptor::to_aten_tensor(p->v_cache).contiguous();
47+
#endif
4348
auto seqlens_k = std::optional<const at::Tensor>(infinicore::adaptor::to_aten_tensor(p->seqlens_k));
4449
auto block_table = std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(p->block_table));
4550
auto alibi_slopes = p->alibi_slopes

src/infiniop/ops/paged_caching/operator.cc

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#include "../../handle.h"
33
#include "infiniop/ops/paged_caching.h"
44

5-
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API)
5+
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
66
#include "nvidia/paged_caching_nvidia.cuh"
77
#endif
88
#ifdef ENABLE_METAX_API
@@ -43,6 +43,9 @@ __INFINI_C infiniStatus_t infiniopCreatePagedCachingDescriptor(
4343
#endif
4444
#ifdef ENABLE_MOORE_API
4545
CREATE(INFINI_DEVICE_MOORE, moore)
46+
#endif
47+
#ifdef ENABLE_QY_API
48+
CREATE(INFINI_DEVICE_QY, nvidia)
4649
#endif
4750
default:
4851
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
@@ -73,6 +76,9 @@ __INFINI_C infiniStatus_t infiniopGetPagedCachingWorkspaceSize(
7376
#endif
7477
#ifdef ENABLE_MOORE_API
7578
GET(INFINI_DEVICE_MOORE, moore)
79+
#endif
80+
#ifdef ENABLE_QY_API
81+
GET(INFINI_DEVICE_QY, nvidia)
7682
#endif
7783
default:
7884
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
@@ -107,6 +113,9 @@ __INFINI_C infiniStatus_t infiniopPagedCaching(
107113
#endif
108114
#ifdef ENABLE_MOORE_API
109115
CALCULATE(INFINI_DEVICE_MOORE, moore)
116+
#endif
117+
#ifdef ENABLE_QY_API
118+
CALCULATE(INFINI_DEVICE_QY, nvidia)
110119
#endif
111120
default:
112121
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
@@ -136,6 +145,9 @@ __INFINI_C infiniStatus_t infiniopDestroyPagedCachingDescriptor(
136145
#endif
137146
#ifdef ENABLE_MOORE_API
138147
DESTROY(INFINI_DEVICE_MOORE, moore)
148+
#endif
149+
#ifdef ENABLE_QY_API
150+
DESTROY(INFINI_DEVICE_QY, nvidia)
139151
#endif
140152
default:
141153
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;

xmake.lua

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,6 @@ if has_config("aten") then
247247
end
248248
end
249249

250-
251250
-- cuda graph
252251
option("graph")
253252
set_default(false)
@@ -259,7 +258,6 @@ if has_config("graph") then
259258
add_defines("USE_INFINIRT_GRAPH")
260259
end
261260

262-
263261
-- InfiniCCL
264262
option("ccl")
265263
set_default(false)
@@ -467,6 +465,22 @@ target("infinicore_cpp_api")
467465
if has_config("nv-gpu") then
468466
add_deps("flash-attn-nvidia")
469467
end
468+
if has_config("qy-gpu") then
469+
add_deps("flash-attn-qy")
470+
end
471+
end
472+
473+
if get_config("flash-attn") and get_config("flash-attn") ~= "" and has_config("qy-gpu") then
474+
local flash_so_qy = _qy_flash_attn_cuda_so_path()
475+
local flash_dir_qy = path.directory(flash_so_qy)
476+
local flash_name_qy = path.filename(flash_so_qy)
477+
before_link(function (target)
478+
target:add(
479+
"shflags",
480+
"-Wl,--no-as-needed -L" .. flash_dir_qy .. " -l:" .. flash_name_qy .. " -Wl,-rpath," .. flash_dir_qy,
481+
{force = true}
482+
)
483+
end)
470484
end
471485

472486
before_build(function (target)

xmake/qy.lua

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,38 @@ if CUDNN_ROOT ~= nil then
33
add_includedirs(CUDNN_ROOT .. "/include")
44
end
55

6+
local FLASH_ATTN_ROOT = get_config("flash-attn")
7+
8+
local INFINI_ROOT = os.getenv("INFINI_ROOT") or (os.getenv(is_host("windows") and "HOMEPATH" or "HOME") .. "/.infini")
9+
10+
function _qy_flash_attn_cuda_so_path()
11+
-- Highest priority: override the exact `.so` file to link.
12+
local env_path = os.getenv("FLASH_ATTN_2_CUDA_SO")
13+
if env_path and env_path ~= "" then
14+
env_path = env_path:trim()
15+
if os.isfile(env_path) then
16+
return env_path
17+
end
18+
print(string.format("warning: qy+flash-attn: FLASH_ATTN_2_CUDA_SO is not a file: %s, fallback to container/default path", env_path))
19+
end
20+
21+
-- Second priority: allow overriding the "expected" container path via env.
22+
local container_path = os.getenv("FLASH_ATTN_QY_CUDA_SO_CONTAINER")
23+
if not container_path or container_path == "" then
24+
raise("Error: Flash Attention SO path not specified!\n")
25+
end
26+
27+
if not os.isfile(container_path) then
28+
print(
29+
string.format(
30+
"warning: qy+flash-attn: expected %s; install flash-attn in conda env, or export FLASH_ATTN_2_CUDA_SO.",
31+
container_path
32+
)
33+
)
34+
end
35+
return container_path
36+
end
37+
638
add_includedirs("/usr/local/denglin/sdk/include", "../include")
739
add_linkdirs("/usr/local/denglin/sdk/lib")
840
add_links("curt", "cublas", "cudnn")
@@ -44,10 +76,20 @@ rule("qy.cuda")
4476
local sdk_path = "/usr/local/denglin/sdk"
4577
local arch = "dlgput64"
4678

47-
local relpath = path.relative(sourcefile, project.directory())
48-
local objfile = path.join(config.buildir(), ".objs", target:name(), "rules", "qy.cuda", relpath .. ".o")
79+
80+
local relpath = path.relative(sourcefile, os.projectdir())
81+
82+
relpath = relpath:gsub("%.%.", "__")
83+
84+
local objfile = path.join(
85+
config.buildir(),
86+
".objs",
87+
target:name(),
88+
"rules",
89+
"qy.cuda",
90+
relpath .. ".o"
91+
)
4992

50-
-- 🟢 强制注册 .o 文件给 target
5193
target:add("objectfiles", objfile)
5294
target:set("buildadd", true)
5395
local argv = {
@@ -153,3 +195,26 @@ target("infiniccl-qy")
153195
set_languages("cxx17")
154196

155197
target_end()
198+
199+
target("flash-attn-qy")
200+
set_kind("phony")
201+
set_default(false)
202+
203+
204+
if FLASH_ATTN_ROOT and FLASH_ATTN_ROOT ~= "" then
205+
before_build(function (target)
206+
target:add("includedirs", "/usr/local/denglin/sdk/include", {public = true})
207+
local TORCH_DIR = os.iorunv("python", {"-c", "import torch, os; print(os.path.dirname(torch.__file__))"}):trim()
208+
local PYTHON_INCLUDE = os.iorunv("python", {"-c", "import sysconfig; print(sysconfig.get_paths()['include'])"}):trim()
209+
local PYTHON_LIB_DIR = os.iorunv("python", {"-c", "import sysconfig; print(sysconfig.get_config_var('LIBDIR'))"}):trim()
210+
211+
-- Validate build/runtime env in container and keep these paths available for downstream linking.
212+
target:add("includedirs", TORCH_DIR .. "/include", TORCH_DIR .. "/include/torch/csrc/api/include", PYTHON_INCLUDE, {public = false})
213+
target:add("linkdirs", TORCH_DIR .. "/lib", PYTHON_LIB_DIR, {public = false})
214+
end)
215+
else
216+
before_build(function (target)
217+
print("Flash Attention not available, skipping flash-attn-qy integration")
218+
end)
219+
end
220+
target_end()

0 commit comments

Comments
 (0)