@@ -3,6 +3,38 @@ if CUDNN_ROOT ~= nil then
33 add_includedirs (CUDNN_ROOT .. " /include" )
44end
55
6+ local FLASH_ATTN_ROOT = get_config (" flash-attn" )
7+
8+ local INFINI_ROOT = os.getenv (" INFINI_ROOT" ) or (os.getenv (is_host (" windows" ) and " HOMEPATH" or " HOME" ) .. " /.infini" )
9+
10+ function _qy_flash_attn_cuda_so_path ()
11+ -- Highest priority: override the exact `.so` file to link.
12+ local env_path = os.getenv (" FLASH_ATTN_2_CUDA_SO" )
13+ if env_path and env_path ~= " " then
14+ env_path = env_path :trim ()
15+ if os .isfile (env_path ) then
16+ return env_path
17+ end
18+ print (string.format (" warning: qy+flash-attn: FLASH_ATTN_2_CUDA_SO is not a file: %s, fallback to container/default path" , env_path ))
19+ end
20+
21+ -- Second priority: allow overriding the "expected" container path via env.
22+ local container_path = os.getenv (" FLASH_ATTN_QY_CUDA_SO_CONTAINER" )
23+ if not container_path or container_path == " " then
24+ raise (" Error: Flash Attention SO path not specified!\n " )
25+ end
26+
27+ if not os .isfile (container_path ) then
28+ print (
29+ string.format (
30+ " warning: qy+flash-attn: expected %s; install flash-attn in conda env, or export FLASH_ATTN_2_CUDA_SO." ,
31+ container_path
32+ )
33+ )
34+ end
35+ return container_path
36+ end
37+
638add_includedirs (" /usr/local/denglin/sdk/include" , " ../include" )
739add_linkdirs (" /usr/local/denglin/sdk/lib" )
840add_links (" curt" , " cublas" , " cudnn" )
@@ -44,10 +76,20 @@ rule("qy.cuda")
4476 local sdk_path = " /usr/local/denglin/sdk"
4577 local arch = " dlgput64"
4678
47- local relpath = path .relative (sourcefile , project .directory ())
48- local objfile = path .join (config .buildir (), " .objs" , target :name (), " rules" , " qy.cuda" , relpath .. " .o" )
79+
80+ local relpath = path .relative (sourcefile , os .projectdir ())
81+
82+ relpath = relpath :gsub (" %.%." , " __" )
83+
84+ local objfile = path .join (
85+ config .buildir (),
86+ " .objs" ,
87+ target :name (),
88+ " rules" ,
89+ " qy.cuda" ,
90+ relpath .. " .o"
91+ )
4992
50- -- 🟢 强制注册 .o 文件给 target
5193 target :add (" objectfiles" , objfile )
5294 target :set (" buildadd" , true )
5395 local argv = {
@@ -153,3 +195,26 @@ target("infiniccl-qy")
153195 set_languages (" cxx17" )
154196
155197target_end ()
198+
199+ target (" flash-attn-qy" )
200+ set_kind (" phony" )
201+ set_default (false )
202+
203+
204+ if FLASH_ATTN_ROOT and FLASH_ATTN_ROOT ~= " " then
205+ before_build (function (target )
206+ target :add (" includedirs" , " /usr/local/denglin/sdk/include" , {public = true })
207+ local TORCH_DIR = os .iorunv (" python" , {" -c" , " import torch, os; print(os.path.dirname(torch.__file__))" }):trim ()
208+ local PYTHON_INCLUDE = os .iorunv (" python" , {" -c" , " import sysconfig; print(sysconfig.get_paths()['include'])" }):trim ()
209+ local PYTHON_LIB_DIR = os .iorunv (" python" , {" -c" , " import sysconfig; print(sysconfig.get_config_var('LIBDIR'))" }):trim ()
210+
211+ -- Validate build/runtime env in container and keep these paths available for downstream linking.
212+ target :add (" includedirs" , TORCH_DIR .. " /include" , TORCH_DIR .. " /include/torch/csrc/api/include" , PYTHON_INCLUDE , {public = false })
213+ target :add (" linkdirs" , TORCH_DIR .. " /lib" , PYTHON_LIB_DIR , {public = false })
214+ end )
215+ else
216+ before_build (function (target )
217+ print (" Flash Attention not available, skipping flash-attn-qy integration" )
218+ end )
219+ end
220+ target_end ()
0 commit comments